""" SDXL Image Generator Handles image generation using Stable Diffusion XL with Jetson optimizations. Supports IP-Adapter FaceID for consistent face generation. """ import torch import logging import cv2 import numpy as np from typing import Optional, Dict, Any, List, Union from pathlib import Path from diffusers import StableDiffusionXLPipeline, DDIMScheduler from PIL import Image import asyncio import config logger = logging.getLogger(__name__) class SDXLGenerator: """SDXL image generator with optimizations for Jetson AGX Orin.""" def __init__(self): self.pipeline = None self.device = None self.model_loaded = False self._load_lock = asyncio.Lock() # IP-Adapter FaceID components self.ip_model = None self.face_app = None self.face_embeds_cache = {} # Cache for precomputed face embeddings self.ip_adapter_loaded = False async def initialize(self): """Load SDXL model with Jetson optimizations.""" async with self._load_lock: if self.model_loaded: logger.info("Model already loaded, skipping initialization") return logger.info("Initializing SDXL model...") logger.info(f"Model: {config.SDXL_MODEL_ID}") logger.info(f"FP16: {config.USE_FP16}") logger.info(f"Attention slicing: {config.ENABLE_ATTENTION_SLICING}") # Determine device if torch.cuda.is_available(): self.device = "cuda" logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}") logger.info(f"VRAM available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") else: self.device = "cpu" logger.warning("CUDA not available, using CPU (will be very slow)") # Load pipeline try: dtype = torch.float16 if config.USE_FP16 else torch.float32 # Only use DDIM scheduler when IP-Adapter is enabled (it requires DDIM) # Otherwise let the model use its default scheduler for best quality if config.ENABLE_IP_ADAPTER: noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) logger.info("Using DDIM scheduler for IP-Adapter compatibility") self.pipeline = StableDiffusionXLPipeline.from_pretrained( config.SDXL_MODEL_ID, torch_dtype=dtype, scheduler=noise_scheduler, use_safetensors=True, cache_dir=str(config.MODELS_DIR), add_watermarker=False, ) else: # Use model's default scheduler for best quality self.pipeline = StableDiffusionXLPipeline.from_pretrained( config.SDXL_MODEL_ID, torch_dtype=dtype, use_safetensors=True, cache_dir=str(config.MODELS_DIR), add_watermarker=False, ) # Move to device self.pipeline = self.pipeline.to(self.device) # Apply optimizations if config.ENABLE_ATTENTION_SLICING: self.pipeline.enable_attention_slicing() logger.info("Attention slicing enabled") if config.ENABLE_VAE_SLICING: self.pipeline.enable_vae_slicing() logger.info("VAE slicing enabled") if config.ENABLE_CPU_OFFLOAD and self.device == "cuda": self.pipeline.enable_model_cpu_offload() logger.info("CPU offload enabled") self.model_loaded = True logger.info("SDXL model loaded successfully!") except Exception as e: logger.error(f"Failed to load SDXL model: {e}") raise async def initialize_ip_adapter(self): """Load IP-Adapter FaceID components (lazy loading).""" if self.ip_adapter_loaded: return logger.info("Initializing IP-Adapter FaceID...") try: # Import IP-Adapter components from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL from insightface.app import FaceAnalysis # Initialize InsightFace for face detection/embedding self.face_app = FaceAnalysis( name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] ) self.face_app.prepare(ctx_id=0, det_size=(640, 640)) logger.info("InsightFace initialized") # Load IP-Adapter FaceID model ip_ckpt = str(config.IP_ADAPTER_PATH) self.ip_model = IPAdapterFaceIDXL( self.pipeline, ip_ckpt, self.device ) self.ip_adapter_loaded = True logger.info("IP-Adapter FaceID loaded successfully!") except ImportError as e: logger.warning(f"IP-Adapter dependencies not available: {e}") logger.warning("Face-locked generation will not be available") except Exception as e: logger.error(f"Failed to load IP-Adapter FaceID: {e}") raise def extract_face_embedding(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor: """ Extract face embedding from an image. Args: image: Path to image, PIL Image, or numpy array Returns: Face embedding tensor """ if self.face_app is None: raise RuntimeError("InsightFace not initialized. Call initialize_ip_adapter() first.") # Convert to numpy array if needed if isinstance(image, (str, Path)): img_cv = cv2.imread(str(image)) elif isinstance(image, Image.Image): img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) else: img_cv = image # Detect faces and extract embedding faces = self.face_app.get(img_cv) if len(faces) == 0: raise ValueError("No face detected in image") # Use first detected face face_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) logger.info(f"Face embedding extracted: shape {face_embed.shape}") return face_embed def precompute_face_embeddings(self, face_images: List[Union[str, Path]]) -> torch.Tensor: """ Precompute and average face embeddings from multiple reference images. Args: face_images: List of paths to face reference images Returns: Averaged face embedding tensor """ embeddings = [] for img_path in face_images: try: embed = self.extract_face_embedding(img_path) embeddings.append(embed) logger.info(f"Extracted embedding from {img_path}") except Exception as e: logger.warning(f"Failed to extract face from {img_path}: {e}") if len(embeddings) == 0: raise ValueError("No faces could be extracted from any reference images") # Average the embeddings for better consistency avg_embedding = torch.mean(torch.stack(embeddings), dim=0) logger.info(f"Averaged {len(embeddings)} face embeddings") return avg_embedding async def generate_image( self, prompt: str, negative_prompt: Optional[str] = None, width: int = config.DEFAULT_WIDTH, height: int = config.DEFAULT_HEIGHT, num_inference_steps: int = config.DEFAULT_STEPS, guidance_scale: float = config.DEFAULT_GUIDANCE_SCALE, seed: Optional[int] = None, progress_callback = None, face_image: Optional[Union[str, Path, List[str]]] = None, face_strength: float = 0.6, ) -> Image.Image: """ Generate an image from a text prompt. Args: prompt: Text prompt negative_prompt: Negative prompt width: Image width height: Image height num_inference_steps: Number of diffusion steps guidance_scale: Guidance scale seed: Random seed for reproducibility progress_callback: Optional async callback(step, total) for progress updates face_image: Optional path(s) to face reference image(s) for face locking face_strength: Strength of face conditioning (0.0-1.0, default 0.6) Returns: PIL Image Raises: RuntimeError: If model not loaded """ if not self.model_loaded: raise RuntimeError("Model not loaded. Call initialize() first.") logger.info(f"Generating image: '{prompt[:50]}...'") logger.info(f"Parameters: {width}x{height}, steps={num_inference_steps}, guidance={guidance_scale}") # Check if face-locked generation requested use_face_id = face_image is not None and config.ENABLE_IP_ADAPTER if face_image is not None and not config.ENABLE_IP_ADAPTER: logger.info("IP-Adapter disabled in config, generating without face lock") if use_face_id: # Initialize IP-Adapter if needed await self.initialize_ip_adapter() if not self.ip_adapter_loaded: logger.warning("IP-Adapter not available, generating without face lock") use_face_id = False else: logger.info(f"Face-locked generation enabled, strength={face_strength}") # Set random seed if provided generator = None if seed is not None: generator = torch.Generator(device=self.device).manual_seed(seed) logger.info(f"Using seed: {seed}") try: loop = asyncio.get_event_loop() if use_face_id: # Extract face embedding(s) if isinstance(face_image, list): face_embed = self.precompute_face_embeddings(face_image) else: face_embed = self.extract_face_embedding(face_image) # Generate with IP-Adapter FaceID image = await loop.run_in_executor( None, lambda: self.ip_model.generate( prompt=prompt, negative_prompt=negative_prompt, faceid_embeds=face_embed, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_samples=1, seed=seed, s_scale=face_strength, )[0] ) else: # Check if IP-Adapter is loaded - if so, we must use it with s_scale=0 # to avoid corrupted output from dangling adapter layers if self.ip_adapter_loaded: logger.info("IP-Adapter loaded but no face requested, using s_scale=0") # Create zero embedding (512-dim for FaceID) zero_embed = torch.zeros((1, 512), device=self.device, dtype=torch.float16) image = await loop.run_in_executor( None, lambda: self.ip_model.generate( prompt=prompt, negative_prompt=negative_prompt, faceid_embeds=zero_embed, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_samples=1, seed=seed, s_scale=0.0, # Disable face conditioning )[0] ) else: # Standard generation - IP-Adapter not loaded def callback_wrapper(step: int, timestep: int, latents: torch.FloatTensor): if progress_callback: progress = int((step / num_inference_steps) * 100) try: asyncio.create_task(progress_callback(progress)) except: pass image = await loop.run_in_executor( None, lambda: self.pipeline( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, callback=callback_wrapper, callback_steps=1 ).images[0] ) logger.info("Image generated successfully") return image except Exception as e: logger.error(f"Error generating image: {e}") raise def get_model_info(self) -> Dict[str, Any]: """Get information about the loaded model.""" return { "model_id": config.SDXL_MODEL_ID, "device": self.device, "fp16": config.USE_FP16, "attention_slicing": config.ENABLE_ATTENTION_SLICING, "vae_slicing": config.ENABLE_VAE_SLICING, "cpu_offload": config.ENABLE_CPU_OFFLOAD, "loaded": self.model_loaded, "ip_adapter_loaded": self.ip_adapter_loaded, } # Global generator instance generator = SDXLGenerator()