| from typing import Dict, Any
|
| import torch
|
| import base64
|
| import io
|
| from PIL import Image
|
| from diffusers import AutoPipelineForImage2Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class EndpointHandler:
|
| def __init__(self, path=""):
|
| """Initialize the model from the given path."""
|
| self.pipeline = AutoPipelineForImage2Image.from_pretrained(
|
| "cjwalch/kandinsky-endpoint",
|
| torch_dtype=torch.float16,
|
| use_safetensors=True
|
| )
|
| self.pipeline.enable_model_cpu_offload()
|
| if torch.cuda.is_available():
|
| self.pipeline.to("cuda")
|
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| """Run inference on the input image and return a base64-encoded result."""
|
| try:
|
|
|
| prompt = data.get("inputs", "")
|
| strength = float(data.get("strength", 0.6))
|
| guidance_scale = float(data.get("guidance_scale", 7.0))
|
| negative_prompt = data.get("negative_prompt", "blurry, ugly, deformed")
|
|
|
|
|
| init_image_b64 = data.get("init_image", None)
|
| if not init_image_b64:
|
| return {"error": "Missing 'init_image' in input data"}
|
|
|
| image_bytes = base64.b64decode(init_image_b64)
|
| init_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
|
|
| output_image = self.pipeline(
|
| prompt=prompt,
|
| image=init_image,
|
| strength=strength,
|
| guidance_scale=guidance_scale,
|
| negative_prompt=negative_prompt
|
| ).images[0]
|
|
|
|
|
| buffered = io.BytesIO()
|
| output_image.save(buffered, format="PNG")
|
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
|
|
|
| torch.cuda.empty_cache()
|
| del output_image
|
| del init_image
|
|
|
| return {"generated_image": img_str}
|
|
|
| except Exception as e:
|
| return {"error": str(e)}
|
|
|