DreamTail v1.0.0 with IP-Adapter FaceID support
- SDXL image generation using RealVisXL_V4.0 - IP-Adapter FaceID integration for consistent face generation - Simplified API (removed client_id requirement) - New params: face_image, face_strength - 'vixy' shortcut for face-locked generation - Queue-based async job processing - FastAPI with proper error handling Co-authored-by: Alex <alex@k4zka.online>
This commit is contained in:
1
worker/__init__.py
Executable file
1
worker/__init__.py
Executable file
@@ -0,0 +1 @@
|
||||
"""Worker modules for DreamTail."""
|
||||
339
worker/generator.py
Executable file
339
worker/generator.py
Executable file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
SDXL Image Generator
|
||||
|
||||
Handles image generation using Stable Diffusion XL with Jetson optimizations.
|
||||
Supports IP-Adapter FaceID for consistent face generation.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from pathlib import Path
|
||||
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
||||
from PIL import Image
|
||||
import asyncio
|
||||
|
||||
import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDXLGenerator:
|
||||
"""SDXL image generator with optimizations for Jetson AGX Orin."""
|
||||
|
||||
def __init__(self):
|
||||
self.pipeline = None
|
||||
self.device = None
|
||||
self.model_loaded = False
|
||||
self._load_lock = asyncio.Lock()
|
||||
|
||||
# IP-Adapter FaceID components
|
||||
self.ip_model = None
|
||||
self.face_app = None
|
||||
self.face_embeds_cache = {} # Cache for precomputed face embeddings
|
||||
self.ip_adapter_loaded = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Load SDXL model with Jetson optimizations."""
|
||||
async with self._load_lock:
|
||||
if self.model_loaded:
|
||||
logger.info("Model already loaded, skipping initialization")
|
||||
return
|
||||
|
||||
logger.info("Initializing SDXL model...")
|
||||
logger.info(f"Model: {config.SDXL_MODEL_ID}")
|
||||
logger.info(f"FP16: {config.USE_FP16}")
|
||||
logger.info(f"Attention slicing: {config.ENABLE_ATTENTION_SLICING}")
|
||||
|
||||
# Determine device
|
||||
if torch.cuda.is_available():
|
||||
self.device = "cuda"
|
||||
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
logger.info(f"VRAM available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
||||
else:
|
||||
self.device = "cpu"
|
||||
logger.warning("CUDA not available, using CPU (will be very slow)")
|
||||
|
||||
# Load pipeline
|
||||
try:
|
||||
dtype = torch.float16 if config.USE_FP16 else torch.float32
|
||||
|
||||
# Use DDIM scheduler for IP-Adapter compatibility
|
||||
noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
|
||||
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
config.SDXL_MODEL_ID,
|
||||
torch_dtype=dtype,
|
||||
scheduler=noise_scheduler,
|
||||
use_safetensors=True,
|
||||
cache_dir=str(config.MODELS_DIR),
|
||||
add_watermarker=False,
|
||||
)
|
||||
|
||||
# Move to device
|
||||
self.pipeline = self.pipeline.to(self.device)
|
||||
|
||||
# Apply optimizations
|
||||
if config.ENABLE_ATTENTION_SLICING:
|
||||
self.pipeline.enable_attention_slicing()
|
||||
logger.info("Attention slicing enabled")
|
||||
|
||||
if config.ENABLE_VAE_SLICING:
|
||||
self.pipeline.enable_vae_slicing()
|
||||
logger.info("VAE slicing enabled")
|
||||
|
||||
if config.ENABLE_CPU_OFFLOAD and self.device == "cuda":
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
logger.info("CPU offload enabled")
|
||||
|
||||
self.model_loaded = True
|
||||
logger.info("SDXL model loaded successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SDXL model: {e}")
|
||||
raise
|
||||
|
||||
async def initialize_ip_adapter(self):
|
||||
"""Load IP-Adapter FaceID components (lazy loading)."""
|
||||
if self.ip_adapter_loaded:
|
||||
return
|
||||
|
||||
logger.info("Initializing IP-Adapter FaceID...")
|
||||
|
||||
try:
|
||||
# Import IP-Adapter components
|
||||
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
# Initialize InsightFace for face detection/embedding
|
||||
self.face_app = FaceAnalysis(
|
||||
name="buffalo_l",
|
||||
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||
)
|
||||
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
logger.info("InsightFace initialized")
|
||||
|
||||
# Load IP-Adapter FaceID model
|
||||
ip_ckpt = str(config.IP_ADAPTER_PATH)
|
||||
self.ip_model = IPAdapterFaceIDXL(
|
||||
self.pipeline,
|
||||
ip_ckpt,
|
||||
self.device
|
||||
)
|
||||
|
||||
self.ip_adapter_loaded = True
|
||||
logger.info("IP-Adapter FaceID loaded successfully!")
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"IP-Adapter dependencies not available: {e}")
|
||||
logger.warning("Face-locked generation will not be available")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load IP-Adapter FaceID: {e}")
|
||||
raise
|
||||
|
||||
def extract_face_embedding(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor:
|
||||
"""
|
||||
Extract face embedding from an image.
|
||||
|
||||
Args:
|
||||
image: Path to image, PIL Image, or numpy array
|
||||
|
||||
Returns:
|
||||
Face embedding tensor
|
||||
"""
|
||||
if self.face_app is None:
|
||||
raise RuntimeError("InsightFace not initialized. Call initialize_ip_adapter() first.")
|
||||
|
||||
# Convert to numpy array if needed
|
||||
if isinstance(image, (str, Path)):
|
||||
img_cv = cv2.imread(str(image))
|
||||
elif isinstance(image, Image.Image):
|
||||
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
img_cv = image
|
||||
|
||||
# Detect faces and extract embedding
|
||||
faces = self.face_app.get(img_cv)
|
||||
|
||||
if len(faces) == 0:
|
||||
raise ValueError("No face detected in image")
|
||||
|
||||
# Use first detected face
|
||||
face_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
||||
logger.info(f"Face embedding extracted: shape {face_embed.shape}")
|
||||
|
||||
return face_embed
|
||||
|
||||
def precompute_face_embeddings(self, face_images: List[Union[str, Path]]) -> torch.Tensor:
|
||||
"""
|
||||
Precompute and average face embeddings from multiple reference images.
|
||||
|
||||
Args:
|
||||
face_images: List of paths to face reference images
|
||||
|
||||
Returns:
|
||||
Averaged face embedding tensor
|
||||
"""
|
||||
embeddings = []
|
||||
|
||||
for img_path in face_images:
|
||||
try:
|
||||
embed = self.extract_face_embedding(img_path)
|
||||
embeddings.append(embed)
|
||||
logger.info(f"Extracted embedding from {img_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract face from {img_path}: {e}")
|
||||
|
||||
if len(embeddings) == 0:
|
||||
raise ValueError("No faces could be extracted from any reference images")
|
||||
|
||||
# Average the embeddings for better consistency
|
||||
avg_embedding = torch.mean(torch.stack(embeddings), dim=0)
|
||||
logger.info(f"Averaged {len(embeddings)} face embeddings")
|
||||
|
||||
return avg_embedding
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
width: int = config.DEFAULT_WIDTH,
|
||||
height: int = config.DEFAULT_HEIGHT,
|
||||
num_inference_steps: int = config.DEFAULT_STEPS,
|
||||
guidance_scale: float = config.DEFAULT_GUIDANCE_SCALE,
|
||||
seed: Optional[int] = None,
|
||||
progress_callback = None,
|
||||
face_image: Optional[Union[str, Path, List[str]]] = None,
|
||||
face_strength: float = 0.6,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Generate an image from a text prompt.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
negative_prompt: Negative prompt
|
||||
width: Image width
|
||||
height: Image height
|
||||
num_inference_steps: Number of diffusion steps
|
||||
guidance_scale: Guidance scale
|
||||
seed: Random seed for reproducibility
|
||||
progress_callback: Optional async callback(step, total) for progress updates
|
||||
face_image: Optional path(s) to face reference image(s) for face locking
|
||||
face_strength: Strength of face conditioning (0.0-1.0, default 0.6)
|
||||
|
||||
Returns:
|
||||
PIL Image
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model not loaded
|
||||
"""
|
||||
if not self.model_loaded:
|
||||
raise RuntimeError("Model not loaded. Call initialize() first.")
|
||||
|
||||
logger.info(f"Generating image: '{prompt[:50]}...'")
|
||||
logger.info(f"Parameters: {width}x{height}, steps={num_inference_steps}, guidance={guidance_scale}")
|
||||
|
||||
# Check if face-locked generation requested
|
||||
use_face_id = face_image is not None
|
||||
|
||||
if use_face_id:
|
||||
# Initialize IP-Adapter if needed
|
||||
await self.initialize_ip_adapter()
|
||||
|
||||
if not self.ip_adapter_loaded:
|
||||
logger.warning("IP-Adapter not available, generating without face lock")
|
||||
use_face_id = False
|
||||
else:
|
||||
logger.info(f"Face-locked generation enabled, strength={face_strength}")
|
||||
|
||||
# Set random seed if provided
|
||||
generator = None
|
||||
if seed is not None:
|
||||
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||||
logger.info(f"Using seed: {seed}")
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if use_face_id:
|
||||
# Extract face embedding(s)
|
||||
if isinstance(face_image, list):
|
||||
face_embed = self.precompute_face_embeddings(face_image)
|
||||
else:
|
||||
face_embed = self.extract_face_embedding(face_image)
|
||||
|
||||
# Generate with IP-Adapter FaceID
|
||||
image = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.ip_model.generate(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
faceid_embeds=face_embed,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
num_samples=1,
|
||||
seed=seed,
|
||||
s_scale=face_strength,
|
||||
)[0]
|
||||
)
|
||||
else:
|
||||
# Progress callback wrapper (only for standard pipeline)
|
||||
def callback_wrapper(step: int, timestep: int, latents: torch.FloatTensor):
|
||||
if progress_callback:
|
||||
progress = int((step / num_inference_steps) * 100)
|
||||
try:
|
||||
asyncio.create_task(progress_callback(progress))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Standard generation without face lock
|
||||
image = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=generator,
|
||||
callback=callback_wrapper,
|
||||
callback_steps=1
|
||||
).images[0]
|
||||
)
|
||||
|
||||
logger.info("Image generated successfully")
|
||||
return image
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating image: {e}")
|
||||
raise
|
||||
|
||||
def get_model_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the loaded model."""
|
||||
return {
|
||||
"model_id": config.SDXL_MODEL_ID,
|
||||
"device": self.device,
|
||||
"fp16": config.USE_FP16,
|
||||
"attention_slicing": config.ENABLE_ATTENTION_SLICING,
|
||||
"vae_slicing": config.ENABLE_VAE_SLICING,
|
||||
"cpu_offload": config.ENABLE_CPU_OFFLOAD,
|
||||
"loaded": self.model_loaded,
|
||||
"ip_adapter_loaded": self.ip_adapter_loaded,
|
||||
}
|
||||
|
||||
|
||||
# Global generator instance
|
||||
generator = SDXLGenerator()
|
||||
163
worker/queue_manager.py
Executable file
163
worker/queue_manager.py
Executable file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Job Queue Manager
|
||||
|
||||
In-memory job queue for managing image generation requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Literal
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
|
||||
import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
"""Represents a single generation job."""
|
||||
job_id: str
|
||||
prompt: str
|
||||
negative_prompt: Optional[str]
|
||||
params: Dict
|
||||
status: Literal["queued", "processing", "completed", "failed"]
|
||||
progress: int = 0
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
error: Optional[str] = None
|
||||
result_path: Optional[str] = None
|
||||
|
||||
|
||||
class QueueManager:
|
||||
"""Manages the job queue and job lifecycle."""
|
||||
|
||||
def __init__(self):
|
||||
self.jobs: Dict[str, Job] = {}
|
||||
self.queue: asyncio.Queue = asyncio.Queue(maxsize=config.MAX_QUEUE_SIZE)
|
||||
self.active_jobs: int = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def submit_job(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str],
|
||||
params: Dict
|
||||
) -> str:
|
||||
"""
|
||||
Submit a new generation job.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
negative_prompt: Negative prompt
|
||||
params: Generation parameters
|
||||
|
||||
Returns:
|
||||
job_id: Unique job identifier
|
||||
|
||||
Raises:
|
||||
asyncio.QueueFull: If queue is at capacity
|
||||
"""
|
||||
job_id = str(uuid.uuid4())
|
||||
|
||||
job = Job(
|
||||
job_id=job_id,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
params=params,
|
||||
status="queued"
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self.jobs[job_id] = job
|
||||
|
||||
# Add to queue (raises QueueFull if at capacity)
|
||||
await self.queue.put(job_id)
|
||||
|
||||
logger.info(f"Job {job_id} submitted: '{prompt[:50]}...'")
|
||||
return job_id
|
||||
|
||||
async def get_next_job(self) -> Optional[str]:
|
||||
"""
|
||||
Get the next job from the queue (blocks until available).
|
||||
|
||||
Returns:
|
||||
job_id or None if queue is empty
|
||||
"""
|
||||
try:
|
||||
job_id = await self.queue.get()
|
||||
return job_id
|
||||
except asyncio.CancelledError:
|
||||
return None
|
||||
|
||||
async def start_job(self, job_id: str):
|
||||
"""Mark a job as started."""
|
||||
async with self._lock:
|
||||
if job_id in self.jobs:
|
||||
self.jobs[job_id].status = "processing"
|
||||
self.jobs[job_id].started_at = datetime.utcnow()
|
||||
self.active_jobs += 1
|
||||
logger.info(f"Job {job_id} started processing")
|
||||
|
||||
async def update_progress(self, job_id: str, progress: int):
|
||||
"""Update job progress (0-100)."""
|
||||
async with self._lock:
|
||||
if job_id in self.jobs:
|
||||
self.jobs[job_id].progress = min(100, max(0, progress))
|
||||
|
||||
async def complete_job(self, job_id: str, result_path: str):
|
||||
"""Mark a job as completed successfully."""
|
||||
async with self._lock:
|
||||
if job_id in self.jobs:
|
||||
self.jobs[job_id].status = "completed"
|
||||
self.jobs[job_id].completed_at = datetime.utcnow()
|
||||
self.jobs[job_id].progress = 100
|
||||
self.jobs[job_id].result_path = result_path
|
||||
self.active_jobs = max(0, self.active_jobs - 1)
|
||||
logger.info(f"Job {job_id} completed successfully")
|
||||
|
||||
async def fail_job(self, job_id: str, error: str):
|
||||
"""Mark a job as failed."""
|
||||
async with self._lock:
|
||||
if job_id in self.jobs:
|
||||
self.jobs[job_id].status = "failed"
|
||||
self.jobs[job_id].completed_at = datetime.utcnow()
|
||||
self.jobs[job_id].error = error
|
||||
self.active_jobs = max(0, self.active_jobs - 1)
|
||||
logger.error(f"Job {job_id} failed: {error}")
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Job]:
|
||||
"""Get job by ID."""
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
def get_queue_size(self) -> int:
|
||||
"""Get current queue size."""
|
||||
return self.queue.qsize()
|
||||
|
||||
def get_active_jobs(self) -> int:
|
||||
"""Get number of currently processing jobs."""
|
||||
return self.active_jobs
|
||||
|
||||
async def cleanup_old_jobs(self, max_age_hours: int = 24):
|
||||
"""Remove old completed/failed jobs from memory."""
|
||||
cutoff = datetime.utcnow().timestamp() - (max_age_hours * 3600)
|
||||
|
||||
async with self._lock:
|
||||
to_remove = []
|
||||
for job_id, job in self.jobs.items():
|
||||
if job.status in ["completed", "failed"] and job.completed_at:
|
||||
if job.completed_at.timestamp() < cutoff:
|
||||
to_remove.append(job_id)
|
||||
|
||||
for job_id in to_remove:
|
||||
del self.jobs[job_id]
|
||||
|
||||
if to_remove:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old jobs from memory")
|
||||
|
||||
|
||||
# Global queue manager instance
|
||||
queue_manager = QueueManager()
|
||||
Reference in New Issue
Block a user