DreamTail v1.0.0 with IP-Adapter FaceID support

- SDXL image generation using RealVisXL_V4.0
- IP-Adapter FaceID integration for consistent face generation
- Simplified API (removed client_id requirement)
- New params: face_image, face_strength
- 'vixy' shortcut for face-locked generation
- Queue-based async job processing
- FastAPI with proper error handling

Co-authored-by: Alex <alex@k4zka.online>
This commit is contained in:
2026-01-01 19:54:59 -06:00
commit e4294b57e6
18 changed files with 1895 additions and 0 deletions

1
worker/__init__.py Executable file
View File

@@ -0,0 +1 @@
"""Worker modules for DreamTail."""

339
worker/generator.py Executable file
View File

@@ -0,0 +1,339 @@
"""
SDXL Image Generator
Handles image generation using Stable Diffusion XL with Jetson optimizations.
Supports IP-Adapter FaceID for consistent face generation.
"""
import torch
import logging
import cv2
import numpy as np
from typing import Optional, Dict, Any, List, Union
from pathlib import Path
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
from PIL import Image
import asyncio
import config
logger = logging.getLogger(__name__)
class SDXLGenerator:
"""SDXL image generator with optimizations for Jetson AGX Orin."""
def __init__(self):
self.pipeline = None
self.device = None
self.model_loaded = False
self._load_lock = asyncio.Lock()
# IP-Adapter FaceID components
self.ip_model = None
self.face_app = None
self.face_embeds_cache = {} # Cache for precomputed face embeddings
self.ip_adapter_loaded = False
async def initialize(self):
"""Load SDXL model with Jetson optimizations."""
async with self._load_lock:
if self.model_loaded:
logger.info("Model already loaded, skipping initialization")
return
logger.info("Initializing SDXL model...")
logger.info(f"Model: {config.SDXL_MODEL_ID}")
logger.info(f"FP16: {config.USE_FP16}")
logger.info(f"Attention slicing: {config.ENABLE_ATTENTION_SLICING}")
# Determine device
if torch.cuda.is_available():
self.device = "cuda"
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"VRAM available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
self.device = "cpu"
logger.warning("CUDA not available, using CPU (will be very slow)")
# Load pipeline
try:
dtype = torch.float16 if config.USE_FP16 else torch.float32
# Use DDIM scheduler for IP-Adapter compatibility
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
config.SDXL_MODEL_ID,
torch_dtype=dtype,
scheduler=noise_scheduler,
use_safetensors=True,
cache_dir=str(config.MODELS_DIR),
add_watermarker=False,
)
# Move to device
self.pipeline = self.pipeline.to(self.device)
# Apply optimizations
if config.ENABLE_ATTENTION_SLICING:
self.pipeline.enable_attention_slicing()
logger.info("Attention slicing enabled")
if config.ENABLE_VAE_SLICING:
self.pipeline.enable_vae_slicing()
logger.info("VAE slicing enabled")
if config.ENABLE_CPU_OFFLOAD and self.device == "cuda":
self.pipeline.enable_model_cpu_offload()
logger.info("CPU offload enabled")
self.model_loaded = True
logger.info("SDXL model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load SDXL model: {e}")
raise
async def initialize_ip_adapter(self):
"""Load IP-Adapter FaceID components (lazy loading)."""
if self.ip_adapter_loaded:
return
logger.info("Initializing IP-Adapter FaceID...")
try:
# Import IP-Adapter components
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
from insightface.app import FaceAnalysis
# Initialize InsightFace for face detection/embedding
self.face_app = FaceAnalysis(
name="buffalo_l",
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
logger.info("InsightFace initialized")
# Load IP-Adapter FaceID model
ip_ckpt = str(config.IP_ADAPTER_PATH)
self.ip_model = IPAdapterFaceIDXL(
self.pipeline,
ip_ckpt,
self.device
)
self.ip_adapter_loaded = True
logger.info("IP-Adapter FaceID loaded successfully!")
except ImportError as e:
logger.warning(f"IP-Adapter dependencies not available: {e}")
logger.warning("Face-locked generation will not be available")
except Exception as e:
logger.error(f"Failed to load IP-Adapter FaceID: {e}")
raise
def extract_face_embedding(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor:
"""
Extract face embedding from an image.
Args:
image: Path to image, PIL Image, or numpy array
Returns:
Face embedding tensor
"""
if self.face_app is None:
raise RuntimeError("InsightFace not initialized. Call initialize_ip_adapter() first.")
# Convert to numpy array if needed
if isinstance(image, (str, Path)):
img_cv = cv2.imread(str(image))
elif isinstance(image, Image.Image):
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
else:
img_cv = image
# Detect faces and extract embedding
faces = self.face_app.get(img_cv)
if len(faces) == 0:
raise ValueError("No face detected in image")
# Use first detected face
face_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
logger.info(f"Face embedding extracted: shape {face_embed.shape}")
return face_embed
def precompute_face_embeddings(self, face_images: List[Union[str, Path]]) -> torch.Tensor:
"""
Precompute and average face embeddings from multiple reference images.
Args:
face_images: List of paths to face reference images
Returns:
Averaged face embedding tensor
"""
embeddings = []
for img_path in face_images:
try:
embed = self.extract_face_embedding(img_path)
embeddings.append(embed)
logger.info(f"Extracted embedding from {img_path}")
except Exception as e:
logger.warning(f"Failed to extract face from {img_path}: {e}")
if len(embeddings) == 0:
raise ValueError("No faces could be extracted from any reference images")
# Average the embeddings for better consistency
avg_embedding = torch.mean(torch.stack(embeddings), dim=0)
logger.info(f"Averaged {len(embeddings)} face embeddings")
return avg_embedding
async def generate_image(
self,
prompt: str,
negative_prompt: Optional[str] = None,
width: int = config.DEFAULT_WIDTH,
height: int = config.DEFAULT_HEIGHT,
num_inference_steps: int = config.DEFAULT_STEPS,
guidance_scale: float = config.DEFAULT_GUIDANCE_SCALE,
seed: Optional[int] = None,
progress_callback = None,
face_image: Optional[Union[str, Path, List[str]]] = None,
face_strength: float = 0.6,
) -> Image.Image:
"""
Generate an image from a text prompt.
Args:
prompt: Text prompt
negative_prompt: Negative prompt
width: Image width
height: Image height
num_inference_steps: Number of diffusion steps
guidance_scale: Guidance scale
seed: Random seed for reproducibility
progress_callback: Optional async callback(step, total) for progress updates
face_image: Optional path(s) to face reference image(s) for face locking
face_strength: Strength of face conditioning (0.0-1.0, default 0.6)
Returns:
PIL Image
Raises:
RuntimeError: If model not loaded
"""
if not self.model_loaded:
raise RuntimeError("Model not loaded. Call initialize() first.")
logger.info(f"Generating image: '{prompt[:50]}...'")
logger.info(f"Parameters: {width}x{height}, steps={num_inference_steps}, guidance={guidance_scale}")
# Check if face-locked generation requested
use_face_id = face_image is not None
if use_face_id:
# Initialize IP-Adapter if needed
await self.initialize_ip_adapter()
if not self.ip_adapter_loaded:
logger.warning("IP-Adapter not available, generating without face lock")
use_face_id = False
else:
logger.info(f"Face-locked generation enabled, strength={face_strength}")
# Set random seed if provided
generator = None
if seed is not None:
generator = torch.Generator(device=self.device).manual_seed(seed)
logger.info(f"Using seed: {seed}")
try:
loop = asyncio.get_event_loop()
if use_face_id:
# Extract face embedding(s)
if isinstance(face_image, list):
face_embed = self.precompute_face_embeddings(face_image)
else:
face_embed = self.extract_face_embedding(face_image)
# Generate with IP-Adapter FaceID
image = await loop.run_in_executor(
None,
lambda: self.ip_model.generate(
prompt=prompt,
negative_prompt=negative_prompt,
faceid_embeds=face_embed,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_samples=1,
seed=seed,
s_scale=face_strength,
)[0]
)
else:
# Progress callback wrapper (only for standard pipeline)
def callback_wrapper(step: int, timestep: int, latents: torch.FloatTensor):
if progress_callback:
progress = int((step / num_inference_steps) * 100)
try:
asyncio.create_task(progress_callback(progress))
except:
pass
# Standard generation without face lock
image = await loop.run_in_executor(
None,
lambda: self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=generator,
callback=callback_wrapper,
callback_steps=1
).images[0]
)
logger.info("Image generated successfully")
return image
except Exception as e:
logger.error(f"Error generating image: {e}")
raise
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the loaded model."""
return {
"model_id": config.SDXL_MODEL_ID,
"device": self.device,
"fp16": config.USE_FP16,
"attention_slicing": config.ENABLE_ATTENTION_SLICING,
"vae_slicing": config.ENABLE_VAE_SLICING,
"cpu_offload": config.ENABLE_CPU_OFFLOAD,
"loaded": self.model_loaded,
"ip_adapter_loaded": self.ip_adapter_loaded,
}
# Global generator instance
generator = SDXLGenerator()

163
worker/queue_manager.py Executable file
View File

@@ -0,0 +1,163 @@
"""
Job Queue Manager
In-memory job queue for managing image generation requests.
"""
import asyncio
import uuid
from datetime import datetime
from typing import Dict, Optional, Literal
from dataclasses import dataclass, field
import logging
import config
logger = logging.getLogger(__name__)
@dataclass
class Job:
"""Represents a single generation job."""
job_id: str
prompt: str
negative_prompt: Optional[str]
params: Dict
status: Literal["queued", "processing", "completed", "failed"]
progress: int = 0
created_at: datetime = field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
error: Optional[str] = None
result_path: Optional[str] = None
class QueueManager:
"""Manages the job queue and job lifecycle."""
def __init__(self):
self.jobs: Dict[str, Job] = {}
self.queue: asyncio.Queue = asyncio.Queue(maxsize=config.MAX_QUEUE_SIZE)
self.active_jobs: int = 0
self._lock = asyncio.Lock()
async def submit_job(
self,
prompt: str,
negative_prompt: Optional[str],
params: Dict
) -> str:
"""
Submit a new generation job.
Args:
prompt: Text prompt
negative_prompt: Negative prompt
params: Generation parameters
Returns:
job_id: Unique job identifier
Raises:
asyncio.QueueFull: If queue is at capacity
"""
job_id = str(uuid.uuid4())
job = Job(
job_id=job_id,
prompt=prompt,
negative_prompt=negative_prompt,
params=params,
status="queued"
)
async with self._lock:
self.jobs[job_id] = job
# Add to queue (raises QueueFull if at capacity)
await self.queue.put(job_id)
logger.info(f"Job {job_id} submitted: '{prompt[:50]}...'")
return job_id
async def get_next_job(self) -> Optional[str]:
"""
Get the next job from the queue (blocks until available).
Returns:
job_id or None if queue is empty
"""
try:
job_id = await self.queue.get()
return job_id
except asyncio.CancelledError:
return None
async def start_job(self, job_id: str):
"""Mark a job as started."""
async with self._lock:
if job_id in self.jobs:
self.jobs[job_id].status = "processing"
self.jobs[job_id].started_at = datetime.utcnow()
self.active_jobs += 1
logger.info(f"Job {job_id} started processing")
async def update_progress(self, job_id: str, progress: int):
"""Update job progress (0-100)."""
async with self._lock:
if job_id in self.jobs:
self.jobs[job_id].progress = min(100, max(0, progress))
async def complete_job(self, job_id: str, result_path: str):
"""Mark a job as completed successfully."""
async with self._lock:
if job_id in self.jobs:
self.jobs[job_id].status = "completed"
self.jobs[job_id].completed_at = datetime.utcnow()
self.jobs[job_id].progress = 100
self.jobs[job_id].result_path = result_path
self.active_jobs = max(0, self.active_jobs - 1)
logger.info(f"Job {job_id} completed successfully")
async def fail_job(self, job_id: str, error: str):
"""Mark a job as failed."""
async with self._lock:
if job_id in self.jobs:
self.jobs[job_id].status = "failed"
self.jobs[job_id].completed_at = datetime.utcnow()
self.jobs[job_id].error = error
self.active_jobs = max(0, self.active_jobs - 1)
logger.error(f"Job {job_id} failed: {error}")
def get_job(self, job_id: str) -> Optional[Job]:
"""Get job by ID."""
return self.jobs.get(job_id)
def get_queue_size(self) -> int:
"""Get current queue size."""
return self.queue.qsize()
def get_active_jobs(self) -> int:
"""Get number of currently processing jobs."""
return self.active_jobs
async def cleanup_old_jobs(self, max_age_hours: int = 24):
"""Remove old completed/failed jobs from memory."""
cutoff = datetime.utcnow().timestamp() - (max_age_hours * 3600)
async with self._lock:
to_remove = []
for job_id, job in self.jobs.items():
if job.status in ["completed", "failed"] and job.completed_at:
if job.completed_at.timestamp() < cutoff:
to_remove.append(job_id)
for job_id in to_remove:
del self.jobs[job_id]
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old jobs from memory")
# Global queue manager instance
queue_manager = QueueManager()