DDIM scheduler was always active, causing softer output. Now only uses DDIM when ENABLE_IP_ADAPTER=True, otherwise uses model's default scheduler for best quality.
377 lines
14 KiB
Python
Executable File
377 lines
14 KiB
Python
Executable File
"""
|
|
SDXL Image Generator
|
|
|
|
Handles image generation using Stable Diffusion XL with Jetson optimizations.
|
|
Supports IP-Adapter FaceID for consistent face generation.
|
|
"""
|
|
|
|
import torch
|
|
import logging
|
|
import cv2
|
|
import numpy as np
|
|
from typing import Optional, Dict, Any, List, Union
|
|
from pathlib import Path
|
|
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
|
from PIL import Image
|
|
import asyncio
|
|
|
|
import config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SDXLGenerator:
|
|
"""SDXL image generator with optimizations for Jetson AGX Orin."""
|
|
|
|
def __init__(self):
|
|
self.pipeline = None
|
|
self.device = None
|
|
self.model_loaded = False
|
|
self._load_lock = asyncio.Lock()
|
|
|
|
# IP-Adapter FaceID components
|
|
self.ip_model = None
|
|
self.face_app = None
|
|
self.face_embeds_cache = {} # Cache for precomputed face embeddings
|
|
self.ip_adapter_loaded = False
|
|
|
|
async def initialize(self):
|
|
"""Load SDXL model with Jetson optimizations."""
|
|
async with self._load_lock:
|
|
if self.model_loaded:
|
|
logger.info("Model already loaded, skipping initialization")
|
|
return
|
|
|
|
logger.info("Initializing SDXL model...")
|
|
logger.info(f"Model: {config.SDXL_MODEL_ID}")
|
|
logger.info(f"FP16: {config.USE_FP16}")
|
|
logger.info(f"Attention slicing: {config.ENABLE_ATTENTION_SLICING}")
|
|
|
|
# Determine device
|
|
if torch.cuda.is_available():
|
|
self.device = "cuda"
|
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
logger.info(f"VRAM available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
|
else:
|
|
self.device = "cpu"
|
|
logger.warning("CUDA not available, using CPU (will be very slow)")
|
|
|
|
# Load pipeline
|
|
try:
|
|
dtype = torch.float16 if config.USE_FP16 else torch.float32
|
|
|
|
# Only use DDIM scheduler when IP-Adapter is enabled (it requires DDIM)
|
|
# Otherwise let the model use its default scheduler for best quality
|
|
if config.ENABLE_IP_ADAPTER:
|
|
noise_scheduler = DDIMScheduler(
|
|
num_train_timesteps=1000,
|
|
beta_start=0.00085,
|
|
beta_end=0.012,
|
|
beta_schedule="scaled_linear",
|
|
clip_sample=False,
|
|
set_alpha_to_one=False,
|
|
steps_offset=1,
|
|
)
|
|
logger.info("Using DDIM scheduler for IP-Adapter compatibility")
|
|
|
|
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
|
|
config.SDXL_MODEL_ID,
|
|
torch_dtype=dtype,
|
|
scheduler=noise_scheduler,
|
|
use_safetensors=True,
|
|
cache_dir=str(config.MODELS_DIR),
|
|
add_watermarker=False,
|
|
)
|
|
else:
|
|
# Use model's default scheduler for best quality
|
|
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
|
|
config.SDXL_MODEL_ID,
|
|
torch_dtype=dtype,
|
|
use_safetensors=True,
|
|
cache_dir=str(config.MODELS_DIR),
|
|
add_watermarker=False,
|
|
)
|
|
|
|
# Move to device
|
|
self.pipeline = self.pipeline.to(self.device)
|
|
|
|
# Apply optimizations
|
|
if config.ENABLE_ATTENTION_SLICING:
|
|
self.pipeline.enable_attention_slicing()
|
|
logger.info("Attention slicing enabled")
|
|
|
|
if config.ENABLE_VAE_SLICING:
|
|
self.pipeline.enable_vae_slicing()
|
|
logger.info("VAE slicing enabled")
|
|
|
|
if config.ENABLE_CPU_OFFLOAD and self.device == "cuda":
|
|
self.pipeline.enable_model_cpu_offload()
|
|
logger.info("CPU offload enabled")
|
|
|
|
self.model_loaded = True
|
|
logger.info("SDXL model loaded successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load SDXL model: {e}")
|
|
raise
|
|
|
|
async def initialize_ip_adapter(self):
|
|
"""Load IP-Adapter FaceID components (lazy loading)."""
|
|
if self.ip_adapter_loaded:
|
|
return
|
|
|
|
logger.info("Initializing IP-Adapter FaceID...")
|
|
|
|
try:
|
|
# Import IP-Adapter components
|
|
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
|
|
from insightface.app import FaceAnalysis
|
|
|
|
# Initialize InsightFace for face detection/embedding
|
|
self.face_app = FaceAnalysis(
|
|
name="buffalo_l",
|
|
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
)
|
|
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
|
|
logger.info("InsightFace initialized")
|
|
|
|
# Load IP-Adapter FaceID model
|
|
ip_ckpt = str(config.IP_ADAPTER_PATH)
|
|
self.ip_model = IPAdapterFaceIDXL(
|
|
self.pipeline,
|
|
ip_ckpt,
|
|
self.device
|
|
)
|
|
|
|
self.ip_adapter_loaded = True
|
|
logger.info("IP-Adapter FaceID loaded successfully!")
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"IP-Adapter dependencies not available: {e}")
|
|
logger.warning("Face-locked generation will not be available")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load IP-Adapter FaceID: {e}")
|
|
raise
|
|
|
|
def extract_face_embedding(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""
|
|
Extract face embedding from an image.
|
|
|
|
Args:
|
|
image: Path to image, PIL Image, or numpy array
|
|
|
|
Returns:
|
|
Face embedding tensor
|
|
"""
|
|
if self.face_app is None:
|
|
raise RuntimeError("InsightFace not initialized. Call initialize_ip_adapter() first.")
|
|
|
|
# Convert to numpy array if needed
|
|
if isinstance(image, (str, Path)):
|
|
img_cv = cv2.imread(str(image))
|
|
elif isinstance(image, Image.Image):
|
|
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
|
else:
|
|
img_cv = image
|
|
|
|
# Detect faces and extract embedding
|
|
faces = self.face_app.get(img_cv)
|
|
|
|
if len(faces) == 0:
|
|
raise ValueError("No face detected in image")
|
|
|
|
# Use first detected face
|
|
face_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
|
logger.info(f"Face embedding extracted: shape {face_embed.shape}")
|
|
|
|
return face_embed
|
|
|
|
def precompute_face_embeddings(self, face_images: List[Union[str, Path]]) -> torch.Tensor:
|
|
"""
|
|
Precompute and average face embeddings from multiple reference images.
|
|
|
|
Args:
|
|
face_images: List of paths to face reference images
|
|
|
|
Returns:
|
|
Averaged face embedding tensor
|
|
"""
|
|
embeddings = []
|
|
|
|
for img_path in face_images:
|
|
try:
|
|
embed = self.extract_face_embedding(img_path)
|
|
embeddings.append(embed)
|
|
logger.info(f"Extracted embedding from {img_path}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract face from {img_path}: {e}")
|
|
|
|
if len(embeddings) == 0:
|
|
raise ValueError("No faces could be extracted from any reference images")
|
|
|
|
# Average the embeddings for better consistency
|
|
avg_embedding = torch.mean(torch.stack(embeddings), dim=0)
|
|
logger.info(f"Averaged {len(embeddings)} face embeddings")
|
|
|
|
return avg_embedding
|
|
|
|
async def generate_image(
|
|
self,
|
|
prompt: str,
|
|
negative_prompt: Optional[str] = None,
|
|
width: int = config.DEFAULT_WIDTH,
|
|
height: int = config.DEFAULT_HEIGHT,
|
|
num_inference_steps: int = config.DEFAULT_STEPS,
|
|
guidance_scale: float = config.DEFAULT_GUIDANCE_SCALE,
|
|
seed: Optional[int] = None,
|
|
progress_callback = None,
|
|
face_image: Optional[Union[str, Path, List[str]]] = None,
|
|
face_strength: float = 0.6,
|
|
) -> Image.Image:
|
|
"""
|
|
Generate an image from a text prompt.
|
|
|
|
Args:
|
|
prompt: Text prompt
|
|
negative_prompt: Negative prompt
|
|
width: Image width
|
|
height: Image height
|
|
num_inference_steps: Number of diffusion steps
|
|
guidance_scale: Guidance scale
|
|
seed: Random seed for reproducibility
|
|
progress_callback: Optional async callback(step, total) for progress updates
|
|
face_image: Optional path(s) to face reference image(s) for face locking
|
|
face_strength: Strength of face conditioning (0.0-1.0, default 0.6)
|
|
|
|
Returns:
|
|
PIL Image
|
|
|
|
Raises:
|
|
RuntimeError: If model not loaded
|
|
"""
|
|
if not self.model_loaded:
|
|
raise RuntimeError("Model not loaded. Call initialize() first.")
|
|
|
|
logger.info(f"Generating image: '{prompt[:50]}...'")
|
|
logger.info(f"Parameters: {width}x{height}, steps={num_inference_steps}, guidance={guidance_scale}")
|
|
|
|
# Check if face-locked generation requested
|
|
use_face_id = face_image is not None and config.ENABLE_IP_ADAPTER
|
|
|
|
if face_image is not None and not config.ENABLE_IP_ADAPTER:
|
|
logger.info("IP-Adapter disabled in config, generating without face lock")
|
|
|
|
if use_face_id:
|
|
# Initialize IP-Adapter if needed
|
|
await self.initialize_ip_adapter()
|
|
|
|
if not self.ip_adapter_loaded:
|
|
logger.warning("IP-Adapter not available, generating without face lock")
|
|
use_face_id = False
|
|
else:
|
|
logger.info(f"Face-locked generation enabled, strength={face_strength}")
|
|
|
|
# Set random seed if provided
|
|
generator = None
|
|
if seed is not None:
|
|
generator = torch.Generator(device=self.device).manual_seed(seed)
|
|
logger.info(f"Using seed: {seed}")
|
|
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
|
|
if use_face_id:
|
|
# Extract face embedding(s)
|
|
if isinstance(face_image, list):
|
|
face_embed = self.precompute_face_embeddings(face_image)
|
|
else:
|
|
face_embed = self.extract_face_embedding(face_image)
|
|
|
|
# Generate with IP-Adapter FaceID
|
|
image = await loop.run_in_executor(
|
|
None,
|
|
lambda: self.ip_model.generate(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
faceid_embeds=face_embed,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
num_samples=1,
|
|
seed=seed,
|
|
s_scale=face_strength,
|
|
)[0]
|
|
)
|
|
else:
|
|
# Check if IP-Adapter is loaded - if so, we must use it with s_scale=0
|
|
# to avoid corrupted output from dangling adapter layers
|
|
if self.ip_adapter_loaded:
|
|
logger.info("IP-Adapter loaded but no face requested, using s_scale=0")
|
|
# Create zero embedding (512-dim for FaceID)
|
|
zero_embed = torch.zeros((1, 512), device=self.device, dtype=torch.float16)
|
|
|
|
image = await loop.run_in_executor(
|
|
None,
|
|
lambda: self.ip_model.generate(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
faceid_embeds=zero_embed,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
num_samples=1,
|
|
seed=seed,
|
|
s_scale=0.0, # Disable face conditioning
|
|
)[0]
|
|
)
|
|
else:
|
|
# Standard generation - IP-Adapter not loaded
|
|
def callback_wrapper(step: int, timestep: int, latents: torch.FloatTensor):
|
|
if progress_callback:
|
|
progress = int((step / num_inference_steps) * 100)
|
|
try:
|
|
asyncio.create_task(progress_callback(progress))
|
|
except:
|
|
pass
|
|
|
|
image = await loop.run_in_executor(
|
|
None,
|
|
lambda: self.pipeline(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
generator=generator,
|
|
callback=callback_wrapper,
|
|
callback_steps=1
|
|
).images[0]
|
|
)
|
|
|
|
logger.info("Image generated successfully")
|
|
return image
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating image: {e}")
|
|
raise
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
"""Get information about the loaded model."""
|
|
return {
|
|
"model_id": config.SDXL_MODEL_ID,
|
|
"device": self.device,
|
|
"fp16": config.USE_FP16,
|
|
"attention_slicing": config.ENABLE_ATTENTION_SLICING,
|
|
"vae_slicing": config.ENABLE_VAE_SLICING,
|
|
"cpu_offload": config.ENABLE_CPU_OFFLOAD,
|
|
"loaded": self.model_loaded,
|
|
"ip_adapter_loaded": self.ip_adapter_loaded,
|
|
}
|
|
|
|
|
|
# Global generator instance
|
|
generator = SDXLGenerator()
|