DPMSolverMultistepScheduler produces sharper, higher quality output. This matches the original dreamtail configuration that produced excellent results.
382 lines
15 KiB
Python
Executable File
382 lines
15 KiB
Python
Executable File
"""
|
|
SDXL Image Generator
|
|
|
|
Handles image generation using Stable Diffusion XL with Jetson optimizations.
|
|
Supports IP-Adapter FaceID for consistent face generation.
|
|
"""
|
|
|
|
import torch
|
|
import logging
|
|
import cv2
|
|
import numpy as np
|
|
from typing import Optional, Dict, Any, List, Union
|
|
from pathlib import Path
|
|
from diffusers import StableDiffusionXLPipeline, DDIMScheduler, DPMSolverMultistepScheduler
|
|
from PIL import Image
|
|
import asyncio
|
|
|
|
import config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SDXLGenerator:
|
|
"""SDXL image generator with optimizations for Jetson AGX Orin."""
|
|
|
|
def __init__(self):
|
|
self.pipeline = None
|
|
self.device = None
|
|
self.model_loaded = False
|
|
self._load_lock = asyncio.Lock()
|
|
|
|
# IP-Adapter FaceID components
|
|
self.ip_model = None
|
|
self.face_app = None
|
|
self.face_embeds_cache = {} # Cache for precomputed face embeddings
|
|
self.ip_adapter_loaded = False
|
|
|
|
async def initialize(self):
|
|
"""Load SDXL model with Jetson optimizations."""
|
|
async with self._load_lock:
|
|
if self.model_loaded:
|
|
logger.info("Model already loaded, skipping initialization")
|
|
return
|
|
|
|
logger.info("Initializing SDXL model...")
|
|
logger.info(f"Model: {config.SDXL_MODEL_ID}")
|
|
logger.info(f"FP16: {config.USE_FP16}")
|
|
logger.info(f"Attention slicing: {config.ENABLE_ATTENTION_SLICING}")
|
|
|
|
# Determine device
|
|
if torch.cuda.is_available():
|
|
self.device = "cuda"
|
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
logger.info(f"VRAM available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
|
else:
|
|
self.device = "cpu"
|
|
logger.warning("CUDA not available, using CPU (will be very slow)")
|
|
|
|
# Load pipeline
|
|
try:
|
|
dtype = torch.float16 if config.USE_FP16 else torch.float32
|
|
|
|
# Only use DDIM scheduler when IP-Adapter is enabled (it requires DDIM)
|
|
# Otherwise let the model use its default scheduler for best quality
|
|
if config.ENABLE_IP_ADAPTER:
|
|
noise_scheduler = DDIMScheduler(
|
|
num_train_timesteps=1000,
|
|
beta_start=0.00085,
|
|
beta_end=0.012,
|
|
beta_schedule="scaled_linear",
|
|
clip_sample=False,
|
|
set_alpha_to_one=False,
|
|
steps_offset=1,
|
|
)
|
|
logger.info("Using DDIM scheduler for IP-Adapter compatibility")
|
|
|
|
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
|
|
config.SDXL_MODEL_ID,
|
|
torch_dtype=dtype,
|
|
scheduler=noise_scheduler,
|
|
use_safetensors=True,
|
|
cache_dir=str(config.MODELS_DIR),
|
|
add_watermarker=False,
|
|
)
|
|
else:
|
|
# Use model's default scheduler, then switch to DPM++ for best quality
|
|
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
|
|
config.SDXL_MODEL_ID,
|
|
torch_dtype=dtype,
|
|
use_safetensors=True,
|
|
cache_dir=str(config.MODELS_DIR),
|
|
add_watermarker=False,
|
|
)
|
|
# Use DPM++ scheduler for better quality/speed
|
|
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
|
self.pipeline.scheduler.config
|
|
)
|
|
logger.info("Using DPM++ scheduler for best quality")
|
|
|
|
# Move to device
|
|
self.pipeline = self.pipeline.to(self.device)
|
|
|
|
# Apply optimizations
|
|
if config.ENABLE_ATTENTION_SLICING:
|
|
self.pipeline.enable_attention_slicing()
|
|
logger.info("Attention slicing enabled")
|
|
|
|
if config.ENABLE_VAE_SLICING:
|
|
self.pipeline.enable_vae_slicing()
|
|
logger.info("VAE slicing enabled")
|
|
|
|
if config.ENABLE_CPU_OFFLOAD and self.device == "cuda":
|
|
self.pipeline.enable_model_cpu_offload()
|
|
logger.info("CPU offload enabled")
|
|
|
|
self.model_loaded = True
|
|
logger.info("SDXL model loaded successfully!")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load SDXL model: {e}")
|
|
raise
|
|
|
|
async def initialize_ip_adapter(self):
|
|
"""Load IP-Adapter FaceID components (lazy loading)."""
|
|
if self.ip_adapter_loaded:
|
|
return
|
|
|
|
logger.info("Initializing IP-Adapter FaceID...")
|
|
|
|
try:
|
|
# Import IP-Adapter components
|
|
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
|
|
from insightface.app import FaceAnalysis
|
|
|
|
# Initialize InsightFace for face detection/embedding
|
|
self.face_app = FaceAnalysis(
|
|
name="buffalo_l",
|
|
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
)
|
|
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
|
|
logger.info("InsightFace initialized")
|
|
|
|
# Load IP-Adapter FaceID model
|
|
ip_ckpt = str(config.IP_ADAPTER_PATH)
|
|
self.ip_model = IPAdapterFaceIDXL(
|
|
self.pipeline,
|
|
ip_ckpt,
|
|
self.device
|
|
)
|
|
|
|
self.ip_adapter_loaded = True
|
|
logger.info("IP-Adapter FaceID loaded successfully!")
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"IP-Adapter dependencies not available: {e}")
|
|
logger.warning("Face-locked generation will not be available")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load IP-Adapter FaceID: {e}")
|
|
raise
|
|
|
|
def extract_face_embedding(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor:
|
|
"""
|
|
Extract face embedding from an image.
|
|
|
|
Args:
|
|
image: Path to image, PIL Image, or numpy array
|
|
|
|
Returns:
|
|
Face embedding tensor
|
|
"""
|
|
if self.face_app is None:
|
|
raise RuntimeError("InsightFace not initialized. Call initialize_ip_adapter() first.")
|
|
|
|
# Convert to numpy array if needed
|
|
if isinstance(image, (str, Path)):
|
|
img_cv = cv2.imread(str(image))
|
|
elif isinstance(image, Image.Image):
|
|
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
|
else:
|
|
img_cv = image
|
|
|
|
# Detect faces and extract embedding
|
|
faces = self.face_app.get(img_cv)
|
|
|
|
if len(faces) == 0:
|
|
raise ValueError("No face detected in image")
|
|
|
|
# Use first detected face
|
|
face_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
|
logger.info(f"Face embedding extracted: shape {face_embed.shape}")
|
|
|
|
return face_embed
|
|
|
|
def precompute_face_embeddings(self, face_images: List[Union[str, Path]]) -> torch.Tensor:
|
|
"""
|
|
Precompute and average face embeddings from multiple reference images.
|
|
|
|
Args:
|
|
face_images: List of paths to face reference images
|
|
|
|
Returns:
|
|
Averaged face embedding tensor
|
|
"""
|
|
embeddings = []
|
|
|
|
for img_path in face_images:
|
|
try:
|
|
embed = self.extract_face_embedding(img_path)
|
|
embeddings.append(embed)
|
|
logger.info(f"Extracted embedding from {img_path}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract face from {img_path}: {e}")
|
|
|
|
if len(embeddings) == 0:
|
|
raise ValueError("No faces could be extracted from any reference images")
|
|
|
|
# Average the embeddings for better consistency
|
|
avg_embedding = torch.mean(torch.stack(embeddings), dim=0)
|
|
logger.info(f"Averaged {len(embeddings)} face embeddings")
|
|
|
|
return avg_embedding
|
|
|
|
async def generate_image(
|
|
self,
|
|
prompt: str,
|
|
negative_prompt: Optional[str] = None,
|
|
width: int = config.DEFAULT_WIDTH,
|
|
height: int = config.DEFAULT_HEIGHT,
|
|
num_inference_steps: int = config.DEFAULT_STEPS,
|
|
guidance_scale: float = config.DEFAULT_GUIDANCE_SCALE,
|
|
seed: Optional[int] = None,
|
|
progress_callback = None,
|
|
face_image: Optional[Union[str, Path, List[str]]] = None,
|
|
face_strength: float = 0.6,
|
|
) -> Image.Image:
|
|
"""
|
|
Generate an image from a text prompt.
|
|
|
|
Args:
|
|
prompt: Text prompt
|
|
negative_prompt: Negative prompt
|
|
width: Image width
|
|
height: Image height
|
|
num_inference_steps: Number of diffusion steps
|
|
guidance_scale: Guidance scale
|
|
seed: Random seed for reproducibility
|
|
progress_callback: Optional async callback(step, total) for progress updates
|
|
face_image: Optional path(s) to face reference image(s) for face locking
|
|
face_strength: Strength of face conditioning (0.0-1.0, default 0.6)
|
|
|
|
Returns:
|
|
PIL Image
|
|
|
|
Raises:
|
|
RuntimeError: If model not loaded
|
|
"""
|
|
if not self.model_loaded:
|
|
raise RuntimeError("Model not loaded. Call initialize() first.")
|
|
|
|
logger.info(f"Generating image: '{prompt[:50]}...'")
|
|
logger.info(f"Parameters: {width}x{height}, steps={num_inference_steps}, guidance={guidance_scale}")
|
|
|
|
# Check if face-locked generation requested
|
|
use_face_id = face_image is not None and config.ENABLE_IP_ADAPTER
|
|
|
|
if face_image is not None and not config.ENABLE_IP_ADAPTER:
|
|
logger.info("IP-Adapter disabled in config, generating without face lock")
|
|
|
|
if use_face_id:
|
|
# Initialize IP-Adapter if needed
|
|
await self.initialize_ip_adapter()
|
|
|
|
if not self.ip_adapter_loaded:
|
|
logger.warning("IP-Adapter not available, generating without face lock")
|
|
use_face_id = False
|
|
else:
|
|
logger.info(f"Face-locked generation enabled, strength={face_strength}")
|
|
|
|
# Set random seed if provided
|
|
generator = None
|
|
if seed is not None:
|
|
generator = torch.Generator(device=self.device).manual_seed(seed)
|
|
logger.info(f"Using seed: {seed}")
|
|
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
|
|
if use_face_id:
|
|
# Extract face embedding(s)
|
|
if isinstance(face_image, list):
|
|
face_embed = self.precompute_face_embeddings(face_image)
|
|
else:
|
|
face_embed = self.extract_face_embedding(face_image)
|
|
|
|
# Generate with IP-Adapter FaceID
|
|
image = await loop.run_in_executor(
|
|
None,
|
|
lambda: self.ip_model.generate(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
faceid_embeds=face_embed,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
num_samples=1,
|
|
seed=seed,
|
|
s_scale=face_strength,
|
|
)[0]
|
|
)
|
|
else:
|
|
# Check if IP-Adapter is loaded - if so, we must use it with s_scale=0
|
|
# to avoid corrupted output from dangling adapter layers
|
|
if self.ip_adapter_loaded:
|
|
logger.info("IP-Adapter loaded but no face requested, using s_scale=0")
|
|
# Create zero embedding (512-dim for FaceID)
|
|
zero_embed = torch.zeros((1, 512), device=self.device, dtype=torch.float16)
|
|
|
|
image = await loop.run_in_executor(
|
|
None,
|
|
lambda: self.ip_model.generate(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
faceid_embeds=zero_embed,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
num_samples=1,
|
|
seed=seed,
|
|
s_scale=0.0, # Disable face conditioning
|
|
)[0]
|
|
)
|
|
else:
|
|
# Standard generation - IP-Adapter not loaded
|
|
def callback_wrapper(step: int, timestep: int, latents: torch.FloatTensor):
|
|
if progress_callback:
|
|
progress = int((step / num_inference_steps) * 100)
|
|
try:
|
|
asyncio.create_task(progress_callback(progress))
|
|
except:
|
|
pass
|
|
|
|
image = await loop.run_in_executor(
|
|
None,
|
|
lambda: self.pipeline(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
generator=generator,
|
|
callback=callback_wrapper,
|
|
callback_steps=1
|
|
).images[0]
|
|
)
|
|
|
|
logger.info("Image generated successfully")
|
|
return image
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error generating image: {e}")
|
|
raise
|
|
|
|
def get_model_info(self) -> Dict[str, Any]:
|
|
"""Get information about the loaded model."""
|
|
return {
|
|
"model_id": config.SDXL_MODEL_ID,
|
|
"device": self.device,
|
|
"fp16": config.USE_FP16,
|
|
"attention_slicing": config.ENABLE_ATTENTION_SLICING,
|
|
"vae_slicing": config.ENABLE_VAE_SLICING,
|
|
"cpu_offload": config.ENABLE_CPU_OFFLOAD,
|
|
"loaded": self.model_loaded,
|
|
"ip_adapter_loaded": self.ip_adapter_loaded,
|
|
}
|
|
|
|
|
|
# Global generator instance
|
|
generator = SDXLGenerator()
|