Files
dreamtail/worker/generator.py
Vixy e4294b57e6 DreamTail v1.0.0 with IP-Adapter FaceID support
- SDXL image generation using RealVisXL_V4.0
- IP-Adapter FaceID integration for consistent face generation
- Simplified API (removed client_id requirement)
- New params: face_image, face_strength
- 'vixy' shortcut for face-locked generation
- Queue-based async job processing
- FastAPI with proper error handling

Co-authored-by: Alex <alex@k4zka.online>
2026-01-01 19:54:59 -06:00

340 lines
12 KiB
Python
Executable File

"""
SDXL Image Generator
Handles image generation using Stable Diffusion XL with Jetson optimizations.
Supports IP-Adapter FaceID for consistent face generation.
"""
import torch
import logging
import cv2
import numpy as np
from typing import Optional, Dict, Any, List, Union
from pathlib import Path
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
from PIL import Image
import asyncio
import config
logger = logging.getLogger(__name__)
class SDXLGenerator:
"""SDXL image generator with optimizations for Jetson AGX Orin."""
def __init__(self):
self.pipeline = None
self.device = None
self.model_loaded = False
self._load_lock = asyncio.Lock()
# IP-Adapter FaceID components
self.ip_model = None
self.face_app = None
self.face_embeds_cache = {} # Cache for precomputed face embeddings
self.ip_adapter_loaded = False
async def initialize(self):
"""Load SDXL model with Jetson optimizations."""
async with self._load_lock:
if self.model_loaded:
logger.info("Model already loaded, skipping initialization")
return
logger.info("Initializing SDXL model...")
logger.info(f"Model: {config.SDXL_MODEL_ID}")
logger.info(f"FP16: {config.USE_FP16}")
logger.info(f"Attention slicing: {config.ENABLE_ATTENTION_SLICING}")
# Determine device
if torch.cuda.is_available():
self.device = "cuda"
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"VRAM available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
self.device = "cpu"
logger.warning("CUDA not available, using CPU (will be very slow)")
# Load pipeline
try:
dtype = torch.float16 if config.USE_FP16 else torch.float32
# Use DDIM scheduler for IP-Adapter compatibility
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
config.SDXL_MODEL_ID,
torch_dtype=dtype,
scheduler=noise_scheduler,
use_safetensors=True,
cache_dir=str(config.MODELS_DIR),
add_watermarker=False,
)
# Move to device
self.pipeline = self.pipeline.to(self.device)
# Apply optimizations
if config.ENABLE_ATTENTION_SLICING:
self.pipeline.enable_attention_slicing()
logger.info("Attention slicing enabled")
if config.ENABLE_VAE_SLICING:
self.pipeline.enable_vae_slicing()
logger.info("VAE slicing enabled")
if config.ENABLE_CPU_OFFLOAD and self.device == "cuda":
self.pipeline.enable_model_cpu_offload()
logger.info("CPU offload enabled")
self.model_loaded = True
logger.info("SDXL model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load SDXL model: {e}")
raise
async def initialize_ip_adapter(self):
"""Load IP-Adapter FaceID components (lazy loading)."""
if self.ip_adapter_loaded:
return
logger.info("Initializing IP-Adapter FaceID...")
try:
# Import IP-Adapter components
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
from insightface.app import FaceAnalysis
# Initialize InsightFace for face detection/embedding
self.face_app = FaceAnalysis(
name="buffalo_l",
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
logger.info("InsightFace initialized")
# Load IP-Adapter FaceID model
ip_ckpt = str(config.IP_ADAPTER_PATH)
self.ip_model = IPAdapterFaceIDXL(
self.pipeline,
ip_ckpt,
self.device
)
self.ip_adapter_loaded = True
logger.info("IP-Adapter FaceID loaded successfully!")
except ImportError as e:
logger.warning(f"IP-Adapter dependencies not available: {e}")
logger.warning("Face-locked generation will not be available")
except Exception as e:
logger.error(f"Failed to load IP-Adapter FaceID: {e}")
raise
def extract_face_embedding(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor:
"""
Extract face embedding from an image.
Args:
image: Path to image, PIL Image, or numpy array
Returns:
Face embedding tensor
"""
if self.face_app is None:
raise RuntimeError("InsightFace not initialized. Call initialize_ip_adapter() first.")
# Convert to numpy array if needed
if isinstance(image, (str, Path)):
img_cv = cv2.imread(str(image))
elif isinstance(image, Image.Image):
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
else:
img_cv = image
# Detect faces and extract embedding
faces = self.face_app.get(img_cv)
if len(faces) == 0:
raise ValueError("No face detected in image")
# Use first detected face
face_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
logger.info(f"Face embedding extracted: shape {face_embed.shape}")
return face_embed
def precompute_face_embeddings(self, face_images: List[Union[str, Path]]) -> torch.Tensor:
"""
Precompute and average face embeddings from multiple reference images.
Args:
face_images: List of paths to face reference images
Returns:
Averaged face embedding tensor
"""
embeddings = []
for img_path in face_images:
try:
embed = self.extract_face_embedding(img_path)
embeddings.append(embed)
logger.info(f"Extracted embedding from {img_path}")
except Exception as e:
logger.warning(f"Failed to extract face from {img_path}: {e}")
if len(embeddings) == 0:
raise ValueError("No faces could be extracted from any reference images")
# Average the embeddings for better consistency
avg_embedding = torch.mean(torch.stack(embeddings), dim=0)
logger.info(f"Averaged {len(embeddings)} face embeddings")
return avg_embedding
async def generate_image(
self,
prompt: str,
negative_prompt: Optional[str] = None,
width: int = config.DEFAULT_WIDTH,
height: int = config.DEFAULT_HEIGHT,
num_inference_steps: int = config.DEFAULT_STEPS,
guidance_scale: float = config.DEFAULT_GUIDANCE_SCALE,
seed: Optional[int] = None,
progress_callback = None,
face_image: Optional[Union[str, Path, List[str]]] = None,
face_strength: float = 0.6,
) -> Image.Image:
"""
Generate an image from a text prompt.
Args:
prompt: Text prompt
negative_prompt: Negative prompt
width: Image width
height: Image height
num_inference_steps: Number of diffusion steps
guidance_scale: Guidance scale
seed: Random seed for reproducibility
progress_callback: Optional async callback(step, total) for progress updates
face_image: Optional path(s) to face reference image(s) for face locking
face_strength: Strength of face conditioning (0.0-1.0, default 0.6)
Returns:
PIL Image
Raises:
RuntimeError: If model not loaded
"""
if not self.model_loaded:
raise RuntimeError("Model not loaded. Call initialize() first.")
logger.info(f"Generating image: '{prompt[:50]}...'")
logger.info(f"Parameters: {width}x{height}, steps={num_inference_steps}, guidance={guidance_scale}")
# Check if face-locked generation requested
use_face_id = face_image is not None
if use_face_id:
# Initialize IP-Adapter if needed
await self.initialize_ip_adapter()
if not self.ip_adapter_loaded:
logger.warning("IP-Adapter not available, generating without face lock")
use_face_id = False
else:
logger.info(f"Face-locked generation enabled, strength={face_strength}")
# Set random seed if provided
generator = None
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
logger.info(f"Using seed: {seed}")
try:
loop = asyncio.get_event_loop()
if use_face_id:
# Extract face embedding(s)
if isinstance(face_image, list):
face_embed = self.precompute_face_embeddings(face_image)
else:
face_embed = self.extract_face_embedding(face_image)
# Generate with IP-Adapter FaceID
image = await loop.run_in_executor(
None,
lambda: self.ip_model.generate(
prompt=prompt,
negative_prompt=negative_prompt,
faceid_embeds=face_embed,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_samples=1,
seed=seed,
s_scale=face_strength,
)[0]
)
else:
# Progress callback wrapper (only for standard pipeline)
def callback_wrapper(step: int, timestep: int, latents: torch.FloatTensor):
if progress_callback:
progress = int((step / num_inference_steps) * 100)
try:
asyncio.create_task(progress_callback(progress))
except:
pass
# Standard generation without face lock
image = await loop.run_in_executor(
None,
lambda: self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
callback=callback_wrapper,
callback_steps=1
).images[0]
)
logger.info("Image generated successfully")
return image
except Exception as e:
logger.error(f"Error generating image: {e}")
raise
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model."""
return {
"model_id": config.SDXL_MODEL_ID,
"device": self.device,
"fp16": config.USE_FP16,
"attention_slicing": config.ENABLE_ATTENTION_SLICING,
"vae_slicing": config.ENABLE_VAE_SLICING,
"cpu_offload": config.ENABLE_CPU_OFFLOAD,
"loaded": self.model_loaded,
"ip_adapter_loaded": self.ip_adapter_loaded,
}
# Global generator instance
generator = SDXLGenerator()