DreamTail v1.0.0 with IP-Adapter FaceID support
- SDXL image generation using RealVisXL_V4.0 - IP-Adapter FaceID integration for consistent face generation - Simplified API (removed client_id requirement) - New params: face_image, face_strength - 'vixy' shortcut for face-locked generation - Queue-based async job processing - FastAPI with proper error handling Co-authored-by: Alex <alex@k4zka.online>
This commit is contained in:
34
.gitignore
vendored
Normal file
34
.gitignore
vendored
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# DreamTail specific
|
||||||
|
dreamtail_storage/images/
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
|
||||||
|
# Models (too large for git)
|
||||||
|
models/
|
||||||
|
*.bin
|
||||||
|
*.safetensors
|
||||||
|
*.ckpt
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
48
Dockerfile
Executable file
48
Dockerfile
Executable file
@@ -0,0 +1,48 @@
|
|||||||
|
# DreamTail - SDXL Image Generation Service for NVIDIA Jetson AGX Orin
|
||||||
|
# Based on NVIDIA L4T PyTorch container optimized for Jetson
|
||||||
|
|
||||||
|
# Try the jetson-containers format (alternative: nvcr.io/nvidia/l4t-pytorch:r35.2.1-pth2.0-py3)
|
||||||
|
FROM dustynv/pytorch:2.1-r36.2.0
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
git \
|
||||||
|
wget \
|
||||||
|
libgl1-mesa-glx \
|
||||||
|
libglib2.0-0 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy requirements first for better caching
|
||||||
|
COPY requirements.txt /app/
|
||||||
|
|
||||||
|
# Install Python dependencies
|
||||||
|
# Note: torch and torchvision are already in the base image
|
||||||
|
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY config.py /app/
|
||||||
|
COPY main.py /app/
|
||||||
|
COPY api/ /app/api/
|
||||||
|
COPY worker/ /app/worker/
|
||||||
|
COPY dreamtail_storage/ /app/dreamtail_storage/
|
||||||
|
|
||||||
|
# Create storage directories
|
||||||
|
RUN mkdir -p /app/storage/images /app/models
|
||||||
|
|
||||||
|
# Expose API port
|
||||||
|
EXPOSE 8765
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV DREAMTAIL_STORAGE=/app/storage
|
||||||
|
ENV DREAMTAIL_MODELS=/app/models
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:8765/health || exit 1
|
||||||
|
|
||||||
|
# Run the FastAPI application
|
||||||
|
CMD ["python3", "main.py"]
|
||||||
386
README.md
Executable file
386
README.md
Executable file
@@ -0,0 +1,386 @@
|
|||||||
|
# 🎨 DreamTail
|
||||||
|
|
||||||
|
**SDXL Image Generation Service for NVIDIA Jetson AGX Orin**
|
||||||
|
|
||||||
|
DreamTail is a standalone FastAPI service that provides high-quality image generation using Stable Diffusion XL (SDXL), optimized for NVIDIA Jetson AGX Orin. It's designed to be used by multiple clients (Lyra, Vixy, etc.) through a simple REST API with job queue management.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- ✨ **SDXL (Stable Diffusion XL)** for high-quality 1024x1024 image generation
|
||||||
|
- 🚀 **Jetson-optimized** with FP16, attention slicing, and VAE slicing
|
||||||
|
- 📋 **Job queue system** with async processing
|
||||||
|
- 🔄 **Multi-client support** (Lyra, Vixy, and more)
|
||||||
|
- 💾 **Automatic cleanup** (images deleted after 10 days)
|
||||||
|
- 🔍 **Progress tracking** via REST API
|
||||||
|
- 🏥 **Health monitoring** and statistics
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────┐
|
||||||
|
│ Clients │ (Lyra, Vixy, etc.)
|
||||||
|
└──────┬──────┘
|
||||||
|
│ HTTP/REST
|
||||||
|
▼
|
||||||
|
┌──────────────────────┐
|
||||||
|
│ FastAPI Server │
|
||||||
|
│ (Port 8765) │
|
||||||
|
└──────┬───────────────┘
|
||||||
|
│
|
||||||
|
┌──────▼─────┬─────────┬──────────┐
|
||||||
|
│ Job Queue │ SDXL │ Storage │
|
||||||
|
│ Manager │ Worker │ Manager │
|
||||||
|
└────────────┴─────────┴──────────┘
|
||||||
|
│ │ │
|
||||||
|
│ ┌────▼────┐ │
|
||||||
|
│ │ GPU │ │
|
||||||
|
│ │ (Orin) │ │
|
||||||
|
│ └─────────┘ │
|
||||||
|
▼ ▼
|
||||||
|
/app/storage /data/models
|
||||||
|
```
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
### Hardware
|
||||||
|
- **NVIDIA Jetson AGX Orin** (32GB or 64GB recommended)
|
||||||
|
- ~8-12GB VRAM for SDXL
|
||||||
|
- ~50GB storage for models and generated images
|
||||||
|
|
||||||
|
### Software
|
||||||
|
- Docker with NVIDIA Container Runtime
|
||||||
|
- JetPack 6.0+ (L4T R36.2.0+)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
### 1. Download SDXL Models (First Time Only)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download models to shared cache (takes ~30 minutes, 13GB download)
|
||||||
|
export DREAMTAIL_MODELS=/data/models
|
||||||
|
./scripts/download-models.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Build Docker Image
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build on bigorin (AGX Orin)
|
||||||
|
./scripts/build.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Run DreamTail
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start the service
|
||||||
|
./scripts/run.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The service will be available at `http://bigorin:8765`
|
||||||
|
|
||||||
|
## API Documentation
|
||||||
|
|
||||||
|
### POST /generate
|
||||||
|
|
||||||
|
Submit an image generation job.
|
||||||
|
|
||||||
|
**Request:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "a serene landscape with mountains at sunset",
|
||||||
|
"client_id": "lyra",
|
||||||
|
"negative_prompt": "blurry, low quality, distorted",
|
||||||
|
"params": {
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"num_inference_steps": 30,
|
||||||
|
"guidance_scale": 7.5,
|
||||||
|
"seed": 42
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response (202 Accepted):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"job_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"status": "queued",
|
||||||
|
"created_at": "2025-11-06T12:00:00Z",
|
||||||
|
"message": "Job queued. Queue position: 0"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### GET /status/{job_id}
|
||||||
|
|
||||||
|
Check job status and progress.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"job_id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"status": "processing",
|
||||||
|
"progress": 67,
|
||||||
|
"created_at": "2025-11-06T12:00:00Z",
|
||||||
|
"started_at": "2025-11-06T12:00:05Z",
|
||||||
|
"completed_at": null,
|
||||||
|
"error": null,
|
||||||
|
"client_id": "lyra",
|
||||||
|
"prompt": "a serene landscape..."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Status values:** `queued`, `processing`, `completed`, `failed`
|
||||||
|
|
||||||
|
### GET /result/{job_id}
|
||||||
|
|
||||||
|
Download generated image (only when status is `completed`).
|
||||||
|
|
||||||
|
**Response:** PNG image file
|
||||||
|
|
||||||
|
### GET /health
|
||||||
|
|
||||||
|
Service health check.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "healthy",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"model_loaded": true,
|
||||||
|
"queue_size": 2,
|
||||||
|
"active_jobs": 1,
|
||||||
|
"uptime_seconds": 3600.5
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### GET /models
|
||||||
|
|
||||||
|
Model configuration information.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"base_model": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
|
"refiner_model": null,
|
||||||
|
"refiner_enabled": false,
|
||||||
|
"device": "cuda",
|
||||||
|
"fp16_enabled": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Python Client
|
||||||
|
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 1. Submit generation job
|
||||||
|
response = requests.post("http://bigorin:8765/generate", json={
|
||||||
|
"prompt": "a futuristic city at night with neon lights",
|
||||||
|
"client_id": "lyra",
|
||||||
|
"params": {
|
||||||
|
"width": 1024,
|
||||||
|
"height": 1024,
|
||||||
|
"num_inference_steps": 30
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
job = response.json()
|
||||||
|
job_id = job["job_id"]
|
||||||
|
print(f"Job submitted: {job_id}")
|
||||||
|
|
||||||
|
# 2. Poll for completion
|
||||||
|
while True:
|
||||||
|
status = requests.get(f"http://bigorin:8765/status/{job_id}").json()
|
||||||
|
print(f"Status: {status['status']} - Progress: {status['progress']}%")
|
||||||
|
|
||||||
|
if status["status"] == "completed":
|
||||||
|
break
|
||||||
|
elif status["status"] == "failed":
|
||||||
|
print(f"Error: {status['error']}")
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# 3. Download result
|
||||||
|
image = requests.get(f"http://bigorin:8765/result/{job_id}")
|
||||||
|
with open(f"{job_id}.png", "wb") as f:
|
||||||
|
f.write(image.content)
|
||||||
|
print(f"Image saved: {job_id}.png")
|
||||||
|
```
|
||||||
|
|
||||||
|
### cURL Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Generate image
|
||||||
|
curl -X POST http://bigorin:8765/generate \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"prompt": "a cat wearing a wizard hat",
|
||||||
|
"client_id": "test"
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Check status
|
||||||
|
curl http://bigorin:8765/status/YOUR_JOB_ID
|
||||||
|
|
||||||
|
# Download result
|
||||||
|
curl http://bigorin:8765/result/YOUR_JOB_ID -o image.png
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
curl http://bigorin:8765/health
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
- `DREAMTAIL_STORAGE` - Storage directory (default: `/app/storage`)
|
||||||
|
- `DREAMTAIL_MODELS` - Models cache directory (default: `/app/models`)
|
||||||
|
- `LOG_LEVEL` - Logging level (default: `INFO`)
|
||||||
|
|
||||||
|
### config.py Settings
|
||||||
|
|
||||||
|
Key configuration parameters in `config.py`:
|
||||||
|
|
||||||
|
- `DEFAULT_STEPS`: 30 (20-50 recommended for SDXL)
|
||||||
|
- `MAX_CONCURRENT_JOBS`: 1 (Orin handles 1 SDXL job at a time)
|
||||||
|
- `IMAGE_RETENTION_DAYS`: 10 (auto-cleanup after 10 days)
|
||||||
|
- `USE_FP16`: True (reduces VRAM to ~8GB)
|
||||||
|
- `ENABLE_ATTENTION_SLICING`: True (memory optimization)
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
**Typical generation time on AGX Orin:**
|
||||||
|
- 1024x1024, 30 steps: **~45-60 seconds**
|
||||||
|
- 1024x1024, 20 steps: **~30-40 seconds** (faster, slightly lower quality)
|
||||||
|
|
||||||
|
**Memory usage:**
|
||||||
|
- SDXL with FP16: ~8GB VRAM
|
||||||
|
- Peak with attention slicing: ~10GB VRAM
|
||||||
|
|
||||||
|
## Maintenance
|
||||||
|
|
||||||
|
### View Logs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker logs -f dreamtail
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check Storage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://bigorin:8765/storage
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"total_images": 42,
|
||||||
|
"total_size_mb": 156.3,
|
||||||
|
"storage_path": "/app/storage/images",
|
||||||
|
"retention_days": 10
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Manual Cleanup
|
||||||
|
|
||||||
|
Images are automatically deleted after 10 days. To manually clean up:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker exec dreamtail rm -rf /app/storage/images/*
|
||||||
|
```
|
||||||
|
|
||||||
|
### Restart Service
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker restart dreamtail
|
||||||
|
```
|
||||||
|
|
||||||
|
### Stop Service
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker stop dreamtail
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Model not loading
|
||||||
|
|
||||||
|
**Symptom:** `"model_loaded": false` in `/health`
|
||||||
|
|
||||||
|
**Solutions:**
|
||||||
|
1. Check VRAM: `nvidia-smi` (need ~10GB free)
|
||||||
|
2. Check logs: `docker logs dreamtail`
|
||||||
|
3. Re-download models: `./scripts/download-models.sh`
|
||||||
|
|
||||||
|
### Out of memory errors
|
||||||
|
|
||||||
|
**Solutions:**
|
||||||
|
1. Reduce concurrent jobs to 1 (default)
|
||||||
|
2. Enable CPU offload: Set `ENABLE_CPU_OFFLOAD=True` in `config.py`
|
||||||
|
3. Reduce image size: Use 768x768 or 512x512
|
||||||
|
|
||||||
|
### Slow generation
|
||||||
|
|
||||||
|
**Expected:** 45-60 seconds for 1024x1024 @ 30 steps
|
||||||
|
|
||||||
|
**To speed up:**
|
||||||
|
- Reduce steps to 20-25 (minor quality loss)
|
||||||
|
- Use smaller resolution (768x768)
|
||||||
|
- Ensure GPU isn't thermal throttling
|
||||||
|
|
||||||
|
## Integration with Lyra
|
||||||
|
|
||||||
|
DreamTail is designed to be used by Lyra but runs independently (no NATS integration). Lyra can call DreamTail via HTTP:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In Lyra's code
|
||||||
|
async def generate_image_for_user(prompt: str):
|
||||||
|
response = await http_client.post(
|
||||||
|
"http://bigorin:8765/generate",
|
||||||
|
json={"prompt": prompt, "client_id": "lyra"}
|
||||||
|
)
|
||||||
|
job_id = response.json()["job_id"]
|
||||||
|
|
||||||
|
# Poll until complete...
|
||||||
|
# Return image to user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
dreamtail/
|
||||||
|
├── Dockerfile # Jetson-optimized container
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
├── config.py # Configuration
|
||||||
|
├── main.py # FastAPI app + worker
|
||||||
|
├── api/
|
||||||
|
│ ├── models.py # Pydantic schemas
|
||||||
|
│ └── routes.py # API endpoints
|
||||||
|
├── worker/
|
||||||
|
│ ├── generator.py # SDXL pipeline
|
||||||
|
│ └── queue_manager.py # Job queue
|
||||||
|
├── storage/
|
||||||
|
│ ├── file_manager.py # Image storage
|
||||||
|
│ └── cleanup_task.py # Periodic cleanup
|
||||||
|
└── scripts/
|
||||||
|
├── build.sh # Build Docker image
|
||||||
|
├── run.sh # Run container
|
||||||
|
└── download-models.sh # Download SDXL
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is part of the Lyra ecosystem. For internal use.
|
||||||
|
|
||||||
|
## Support
|
||||||
|
|
||||||
|
For issues or questions:
|
||||||
|
- Check logs: `docker logs -f dreamtail`
|
||||||
|
- Check health: `curl http://bigorin:8765/health`
|
||||||
|
- Review configuration in `config.py`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Built with ❤️ for the Lyra project**
|
||||||
1
api/__init__.py
Executable file
1
api/__init__.py
Executable file
@@ -0,0 +1 @@
|
|||||||
|
"""API modules for DreamTail."""
|
||||||
73
api/models.py
Executable file
73
api/models.py
Executable file
@@ -0,0 +1,73 @@
|
|||||||
|
"""
|
||||||
|
Pydantic models for API requests and responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any, Literal
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from datetime import datetime
|
||||||
|
import config
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationParams(BaseModel):
|
||||||
|
"""Optional generation parameters."""
|
||||||
|
width: int = Field(default=config.DEFAULT_WIDTH, ge=512, le=2048)
|
||||||
|
height: int = Field(default=config.DEFAULT_HEIGHT, ge=512, le=2048)
|
||||||
|
num_inference_steps: int = Field(default=config.DEFAULT_STEPS, ge=config.MIN_STEPS, le=config.MAX_STEPS)
|
||||||
|
guidance_scale: float = Field(default=config.DEFAULT_GUIDANCE_SCALE, ge=config.MIN_GUIDANCE, le=config.MAX_GUIDANCE)
|
||||||
|
seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
|
||||||
|
face_image: Optional[str] = Field(default=None, description="Face reference image name (from faces directory) or 'vixy' for default")
|
||||||
|
face_strength: float = Field(default=config.DEFAULT_FACE_STRENGTH, ge=0.0, le=1.0, description="Face conditioning strength (0.0-1.0)")
|
||||||
|
|
||||||
|
@validator('width', 'height')
|
||||||
|
def must_be_multiple_of_8(cls, v):
|
||||||
|
if v % 8 != 0:
|
||||||
|
raise ValueError('Width and height must be multiples of 8')
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
"""Request to generate an image."""
|
||||||
|
prompt: str = Field(..., min_length=1, max_length=2000, description="Text prompt for image generation")
|
||||||
|
negative_prompt: Optional[str] = Field(default=None, max_length=2000, description="Negative prompt to avoid certain features")
|
||||||
|
params: Optional[GenerationParams] = Field(default_factory=GenerationParams)
|
||||||
|
|
||||||
|
|
||||||
|
class JobResponse(BaseModel):
|
||||||
|
"""Response when submitting a generation job."""
|
||||||
|
job_id: str = Field(..., description="Unique job identifier")
|
||||||
|
status: Literal["queued", "processing", "completed", "failed"] = Field(..., description="Current job status")
|
||||||
|
created_at: datetime = Field(..., description="Job creation timestamp")
|
||||||
|
message: Optional[str] = Field(default=None, description="Optional message")
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus(BaseModel):
|
||||||
|
"""Detailed job status information."""
|
||||||
|
job_id: str
|
||||||
|
status: Literal["queued", "processing", "completed", "failed"]
|
||||||
|
progress: int = Field(..., ge=0, le=100, description="Progress percentage (0-100)")
|
||||||
|
created_at: datetime
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class HealthResponse(BaseModel):
|
||||||
|
"""Health check response."""
|
||||||
|
model_config = {"protected_namespaces": ()} # Allow "model_" prefix
|
||||||
|
|
||||||
|
status: Literal["healthy", "unhealthy"]
|
||||||
|
version: str
|
||||||
|
model_loaded: bool
|
||||||
|
queue_size: int
|
||||||
|
active_jobs: int
|
||||||
|
uptime_seconds: float
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsResponse(BaseModel):
|
||||||
|
"""Available models information."""
|
||||||
|
base_model: str
|
||||||
|
refiner_model: Optional[str] = None
|
||||||
|
refiner_enabled: bool
|
||||||
|
device: str
|
||||||
|
fp16_enabled: bool
|
||||||
166
api/routes.py
Executable file
166
api/routes.py
Executable file
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
API Routes
|
||||||
|
|
||||||
|
FastAPI routes for image generation service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, HTTPException, Response
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from api.models import (
|
||||||
|
GenerateRequest, JobResponse, JobStatus,
|
||||||
|
HealthResponse, ModelsResponse
|
||||||
|
)
|
||||||
|
from worker.queue_manager import queue_manager
|
||||||
|
from worker.generator import generator
|
||||||
|
from dreamtail_storage.file_manager import file_manager
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate", response_model=JobResponse, status_code=202)
|
||||||
|
async def generate_image(request: GenerateRequest):
|
||||||
|
"""
|
||||||
|
Submit an image generation job.
|
||||||
|
|
||||||
|
Returns immediately with a job_id. Use /status/{job_id} to check progress.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Submit job to queue
|
||||||
|
job_id = await queue_manager.submit_job(
|
||||||
|
prompt=request.prompt,
|
||||||
|
negative_prompt=request.negative_prompt,
|
||||||
|
params=request.params.dict() if request.params else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
job = queue_manager.get_job(job_id)
|
||||||
|
|
||||||
|
return JobResponse(
|
||||||
|
job_id=job_id,
|
||||||
|
status=job.status,
|
||||||
|
created_at=job.created_at,
|
||||||
|
message=f"Job queued. Queue position: {queue_manager.get_queue_size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail=f"Queue is full (max: {config.MAX_QUEUE_SIZE}). Please try again later."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error submitting job: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/status/{job_id}", response_model=JobStatus)
|
||||||
|
async def get_job_status(job_id: str):
|
||||||
|
"""
|
||||||
|
Get the status of a generation job.
|
||||||
|
|
||||||
|
Returns job progress, status, and timestamps.
|
||||||
|
"""
|
||||||
|
job = queue_manager.get_job(job_id)
|
||||||
|
|
||||||
|
if not job:
|
||||||
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
|
return JobStatus(
|
||||||
|
job_id=job.job_id,
|
||||||
|
status=job.status,
|
||||||
|
progress=job.progress,
|
||||||
|
created_at=job.created_at,
|
||||||
|
started_at=job.started_at,
|
||||||
|
completed_at=job.completed_at,
|
||||||
|
error=job.error,
|
||||||
|
prompt=job.prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/result/{job_id}")
|
||||||
|
async def get_result(job_id: str):
|
||||||
|
"""
|
||||||
|
Download the generated image for a completed job.
|
||||||
|
|
||||||
|
Returns the image file as PNG.
|
||||||
|
"""
|
||||||
|
job = queue_manager.get_job(job_id)
|
||||||
|
|
||||||
|
if not job:
|
||||||
|
raise HTTPException(status_code=404, detail="Job not found")
|
||||||
|
|
||||||
|
if job.status != "completed":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Job is {job.status}, not completed. Check /status/{job_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if image file exists
|
||||||
|
image_path = file_manager.get_image_path(job_id)
|
||||||
|
|
||||||
|
if not image_path:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="Image file not found (may have been cleaned up)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return image file
|
||||||
|
return FileResponse(
|
||||||
|
path=image_path,
|
||||||
|
media_type="image/png",
|
||||||
|
filename=f"{job_id}.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health", response_model=HealthResponse)
|
||||||
|
async def health_check():
|
||||||
|
"""
|
||||||
|
Health check endpoint.
|
||||||
|
|
||||||
|
Returns service status and basic statistics.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from main import start_time
|
||||||
|
|
||||||
|
return HealthResponse(
|
||||||
|
status="healthy" if generator.model_loaded else "unhealthy",
|
||||||
|
version=config.APP_VERSION,
|
||||||
|
model_loaded=generator.model_loaded,
|
||||||
|
queue_size=queue_manager.get_queue_size(),
|
||||||
|
active_jobs=queue_manager.get_active_jobs(),
|
||||||
|
uptime_seconds=time.time() - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models", response_model=ModelsResponse)
|
||||||
|
async def get_models_info():
|
||||||
|
"""
|
||||||
|
Get information about loaded models and configuration.
|
||||||
|
"""
|
||||||
|
model_info = generator.get_model_info()
|
||||||
|
|
||||||
|
return ModelsResponse(
|
||||||
|
base_model=config.SDXL_MODEL_ID,
|
||||||
|
refiner_model=config.SDXL_REFINER_ID if config.USE_REFINER else None,
|
||||||
|
refiner_enabled=config.USE_REFINER,
|
||||||
|
device=model_info.get("device", "unknown"),
|
||||||
|
fp16_enabled=config.USE_FP16
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/storage")
|
||||||
|
async def get_storage_info():
|
||||||
|
"""
|
||||||
|
Get storage statistics (admin endpoint).
|
||||||
|
"""
|
||||||
|
stats = file_manager.get_storage_stats()
|
||||||
|
return {
|
||||||
|
"total_images": stats["total_images"],
|
||||||
|
"total_size_mb": round(stats["total_size_mb"], 2),
|
||||||
|
"storage_path": stats["storage_path"],
|
||||||
|
"retention_days": config.IMAGE_RETENTION_DAYS
|
||||||
|
}
|
||||||
72
config.py
Executable file
72
config.py
Executable file
@@ -0,0 +1,72 @@
|
|||||||
|
"""
|
||||||
|
DreamTail Configuration
|
||||||
|
|
||||||
|
Configuration for SDXL image generation service running on AGX Orin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Application settings
|
||||||
|
APP_NAME = "DreamTail"
|
||||||
|
APP_VERSION = "1.0.0"
|
||||||
|
API_HOST = "0.0.0.0"
|
||||||
|
API_PORT = 8765
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
BASE_DIR = Path(__file__).parent
|
||||||
|
STORAGE_DIR = Path(os.getenv("DREAMTAIL_STORAGE", "/app/storage"))
|
||||||
|
MODELS_DIR = Path(os.getenv("DREAMTAIL_MODELS", "/app/models"))
|
||||||
|
IMAGES_DIR = STORAGE_DIR / "images"
|
||||||
|
|
||||||
|
# Ensure directories exist
|
||||||
|
STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Model settings
|
||||||
|
SDXL_MODEL_ID = "SG161222/RealVisXL_V4.0"
|
||||||
|
SDXL_REFINER_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"
|
||||||
|
USE_REFINER = False # Set to True to enable refiner (requires more VRAM)
|
||||||
|
|
||||||
|
# IP-Adapter FaceID settings
|
||||||
|
IP_ADAPTER_DIR = Path(os.getenv("DREAMTAIL_IP_ADAPTER", MODELS_DIR / "ip-adapter"))
|
||||||
|
IP_ADAPTER_PATH = IP_ADAPTER_DIR / "ip-adapter-faceid_sdxl.bin"
|
||||||
|
FACE_REFERENCE_DIR = STORAGE_DIR / "faces" # Directory for face reference images
|
||||||
|
DEFAULT_FACE_STRENGTH = 0.6 # How strongly to apply face conditioning (0.0-1.0)
|
||||||
|
|
||||||
|
# Ensure IP-Adapter directories exist
|
||||||
|
IP_ADAPTER_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
FACE_REFERENCE_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generation defaults
|
||||||
|
DEFAULT_WIDTH = 1024
|
||||||
|
DEFAULT_HEIGHT = 1024
|
||||||
|
DEFAULT_STEPS = 30
|
||||||
|
DEFAULT_GUIDANCE_SCALE = 7.5
|
||||||
|
MIN_STEPS = 10
|
||||||
|
MAX_STEPS = 100
|
||||||
|
MIN_GUIDANCE = 1.0
|
||||||
|
MAX_GUIDANCE = 20.0
|
||||||
|
|
||||||
|
# Performance settings
|
||||||
|
MAX_CONCURRENT_JOBS = 1 # AGX Orin can handle 1 SDXL generation at a time
|
||||||
|
ENABLE_ATTENTION_SLICING = True
|
||||||
|
ENABLE_VAE_SLICING = True
|
||||||
|
ENABLE_CPU_OFFLOAD = False # Only if VRAM is insufficient
|
||||||
|
USE_FP16 = True # Half precision for reduced VRAM usage
|
||||||
|
|
||||||
|
# Queue settings
|
||||||
|
MAX_QUEUE_SIZE = 50 # Maximum queued jobs
|
||||||
|
JOB_TIMEOUT_SECONDS = 300 # 5 minutes max per job
|
||||||
|
|
||||||
|
# Storage settings
|
||||||
|
IMAGE_RETENTION_DAYS = 10
|
||||||
|
CLEANUP_INTERVAL_HOURS = 24
|
||||||
|
IMAGE_FORMAT = "PNG"
|
||||||
|
IMAGE_QUALITY = 95 # For JPEG (not used for PNG)
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||||
|
LOG_FORMAT = "[%(asctime)s] %(levelname)s [%(name)s] %(message)s"
|
||||||
|
LOG_DATE_FORMAT = "%H:%M:%S"
|
||||||
|
|
||||||
1
dreamtail_storage/__init__.py
Executable file
1
dreamtail_storage/__init__.py
Executable file
@@ -0,0 +1 @@
|
|||||||
|
"""Storage management for DreamTail."""
|
||||||
111
dreamtail_storage/cleanup_task.py
Executable file
111
dreamtail_storage/cleanup_task.py
Executable file
@@ -0,0 +1,111 @@
|
|||||||
|
"""
|
||||||
|
Cleanup Task
|
||||||
|
|
||||||
|
Periodically deletes images older than the retention period.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CleanupTask:
|
||||||
|
"""Background task to clean up old images."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.images_dir = config.IMAGES_DIR
|
||||||
|
self.retention_days = config.IMAGE_RETENTION_DAYS
|
||||||
|
self.interval_hours = config.CLEANUP_INTERVAL_HOURS
|
||||||
|
self.running = False
|
||||||
|
self._task = None
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the cleanup task."""
|
||||||
|
if self.running:
|
||||||
|
logger.warning("Cleanup task already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self._task = asyncio.create_task(self._run())
|
||||||
|
logger.info(f"Cleanup task started (retention: {self.retention_days} days, interval: {self.interval_hours}h)")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the cleanup task."""
|
||||||
|
if not self.running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = False
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("Cleanup task stopped")
|
||||||
|
|
||||||
|
async def _run(self):
|
||||||
|
"""Main cleanup loop."""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
await self._cleanup_old_images()
|
||||||
|
# Sleep for the configured interval
|
||||||
|
await asyncio.sleep(self.interval_hours * 3600)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in cleanup task: {e}")
|
||||||
|
await asyncio.sleep(300) # Wait 5 minutes before retry
|
||||||
|
|
||||||
|
async def _cleanup_old_images(self):
|
||||||
|
"""Delete images older than retention period."""
|
||||||
|
try:
|
||||||
|
cutoff_time = datetime.now() - timedelta(days=self.retention_days)
|
||||||
|
cutoff_timestamp = cutoff_time.timestamp()
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
deleted_size = 0
|
||||||
|
|
||||||
|
# Find all image files
|
||||||
|
image_files = list(self.images_dir.glob(f"*.{config.IMAGE_FORMAT.lower()}"))
|
||||||
|
|
||||||
|
for file_path in image_files:
|
||||||
|
try:
|
||||||
|
# Check file modification time
|
||||||
|
file_mtime = file_path.stat().st_mtime
|
||||||
|
|
||||||
|
if file_mtime < cutoff_timestamp:
|
||||||
|
file_size = file_path.stat().st_size
|
||||||
|
file_path.unlink()
|
||||||
|
deleted_count += 1
|
||||||
|
deleted_size += file_size
|
||||||
|
|
||||||
|
logger.debug(f"Deleted old image: {file_path.name}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting {file_path.name}: {e}")
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Cleanup completed: deleted {deleted_count} images "
|
||||||
|
f"({deleted_size / 1024 / 1024:.1f} MB) older than {self.retention_days} days"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Cleanup completed: no images older than {self.retention_days} days")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during cleanup: {e}")
|
||||||
|
|
||||||
|
async def cleanup_now(self):
|
||||||
|
"""Trigger immediate cleanup (for testing or manual trigger)."""
|
||||||
|
logger.info("Manual cleanup triggered")
|
||||||
|
await self._cleanup_old_images()
|
||||||
|
|
||||||
|
|
||||||
|
# Global cleanup task instance
|
||||||
|
cleanup_task = CleanupTask()
|
||||||
132
dreamtail_storage/file_manager.py
Executable file
132
dreamtail_storage/file_manager.py
Executable file
@@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
File Storage Manager
|
||||||
|
|
||||||
|
Handles saving and retrieving generated images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
from PIL import Image
|
||||||
|
import aiofiles
|
||||||
|
import os
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FileManager:
|
||||||
|
"""Manages image file storage and retrieval."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.images_dir = config.IMAGES_DIR
|
||||||
|
self.image_format = config.IMAGE_FORMAT
|
||||||
|
|
||||||
|
# Ensure storage directory exists
|
||||||
|
self.images_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.info(f"Image storage directory: {self.images_dir}")
|
||||||
|
|
||||||
|
async def save_image(self, job_id: str, image: Image.Image) -> str:
|
||||||
|
"""
|
||||||
|
Save generated image to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Job identifier (used as filename)
|
||||||
|
image: PIL Image to save
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
file_path: Absolute path to saved image
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IOError: If save fails
|
||||||
|
"""
|
||||||
|
filename = f"{job_id}.{self.image_format.lower()}"
|
||||||
|
file_path = self.images_dir / filename
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Save in thread pool to avoid blocking
|
||||||
|
import asyncio
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: image.save(
|
||||||
|
file_path,
|
||||||
|
format=self.image_format,
|
||||||
|
quality=config.IMAGE_QUALITY if self.image_format == "JPEG" else None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Image saved: {file_path} ({os.path.getsize(file_path) / 1024:.1f} KB)")
|
||||||
|
return str(file_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save image {job_id}: {e}")
|
||||||
|
raise IOError(f"Failed to save image: {e}")
|
||||||
|
|
||||||
|
def get_image_path(self, job_id: str) -> Optional[Path]:
|
||||||
|
"""
|
||||||
|
Get path to image file if it exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Job identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to image file or None if not found
|
||||||
|
"""
|
||||||
|
filename = f"{job_id}.{self.image_format.lower()}"
|
||||||
|
file_path = self.images_dir / filename
|
||||||
|
|
||||||
|
if file_path.exists():
|
||||||
|
return file_path
|
||||||
|
return None
|
||||||
|
|
||||||
|
def image_exists(self, job_id: str) -> bool:
|
||||||
|
"""Check if image file exists."""
|
||||||
|
return self.get_image_path(job_id) is not None
|
||||||
|
|
||||||
|
async def delete_image(self, job_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete an image file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Job identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if not found
|
||||||
|
"""
|
||||||
|
file_path = self.get_image_path(job_id)
|
||||||
|
|
||||||
|
if file_path:
|
||||||
|
try:
|
||||||
|
file_path.unlink()
|
||||||
|
logger.info(f"Deleted image: {file_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete image {job_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_storage_stats(self) -> dict:
|
||||||
|
"""Get storage statistics."""
|
||||||
|
try:
|
||||||
|
files = list(self.images_dir.glob(f"*.{self.image_format.lower()}"))
|
||||||
|
total_size = sum(f.stat().st_size for f in files)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_images": len(files),
|
||||||
|
"total_size_mb": total_size / (1024 * 1024),
|
||||||
|
"storage_path": str(self.images_dir)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get storage stats: {e}")
|
||||||
|
return {
|
||||||
|
"total_images": 0,
|
||||||
|
"total_size_mb": 0,
|
||||||
|
"storage_path": str(self.images_dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global file manager instance
|
||||||
|
file_manager = FileManager()
|
||||||
216
main.py
Executable file
216
main.py
Executable file
@@ -0,0 +1,216 @@
|
|||||||
|
"""
|
||||||
|
DreamTail - SDXL Image Generation Service
|
||||||
|
|
||||||
|
Main application entry point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add app directory to Python path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from api.routes import router
|
||||||
|
from worker.queue_manager import queue_manager
|
||||||
|
from worker.generator import generator
|
||||||
|
from dreamtail_storage.file_manager import file_manager
|
||||||
|
from dreamtail_storage.cleanup_task import cleanup_task
|
||||||
|
import config
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, config.LOG_LEVEL),
|
||||||
|
format=config.LOG_FORMAT,
|
||||||
|
datefmt=config.LOG_DATE_FORMAT
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Track application start time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
async def process_jobs():
|
||||||
|
"""
|
||||||
|
Background worker that processes jobs from the queue.
|
||||||
|
"""
|
||||||
|
logger.info("Job processor started")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Get next job from queue (blocks until available)
|
||||||
|
job_id = await queue_manager.get_next_job()
|
||||||
|
|
||||||
|
if not job_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
job = queue_manager.get_job(job_id)
|
||||||
|
if not job:
|
||||||
|
logger.warning(f"Job {job_id} not found in queue")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Mark job as started
|
||||||
|
await queue_manager.start_job(job_id)
|
||||||
|
|
||||||
|
logger.info(f"Processing job {job_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Progress callback
|
||||||
|
async def update_progress(progress: int):
|
||||||
|
await queue_manager.update_progress(job_id, progress)
|
||||||
|
|
||||||
|
# Resolve face image path if specified
|
||||||
|
face_image = job.params.get("face_image")
|
||||||
|
if face_image:
|
||||||
|
# Handle "vixy" shortcut for default face
|
||||||
|
if face_image.lower() == "vixy":
|
||||||
|
# Use all faces in the vixy directory
|
||||||
|
vixy_faces = list(config.FACE_REFERENCE_DIR.glob("*.jpg")) + \
|
||||||
|
list(config.FACE_REFERENCE_DIR.glob("*.png"))
|
||||||
|
if vixy_faces:
|
||||||
|
face_image = [str(f) for f in vixy_faces]
|
||||||
|
logger.info(f"Using {len(face_image)} Vixy reference faces")
|
||||||
|
else:
|
||||||
|
logger.warning("No Vixy faces found, generating without face lock")
|
||||||
|
face_image = None
|
||||||
|
else:
|
||||||
|
# Look for specific face file
|
||||||
|
face_path = config.FACE_REFERENCE_DIR / face_image
|
||||||
|
if face_path.exists():
|
||||||
|
face_image = str(face_path)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Face image {face_image} not found, generating without face lock")
|
||||||
|
face_image = None
|
||||||
|
|
||||||
|
# Generate image
|
||||||
|
image = await generator.generate_image(
|
||||||
|
prompt=job.prompt,
|
||||||
|
negative_prompt=job.negative_prompt,
|
||||||
|
width=job.params.get("width", config.DEFAULT_WIDTH),
|
||||||
|
height=job.params.get("height", config.DEFAULT_HEIGHT),
|
||||||
|
num_inference_steps=job.params.get("num_inference_steps", config.DEFAULT_STEPS),
|
||||||
|
guidance_scale=job.params.get("guidance_scale", config.DEFAULT_GUIDANCE_SCALE),
|
||||||
|
seed=job.params.get("seed"),
|
||||||
|
progress_callback=update_progress,
|
||||||
|
face_image=face_image,
|
||||||
|
face_strength=job.params.get("face_strength", config.DEFAULT_FACE_STRENGTH),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save image to disk
|
||||||
|
result_path = await file_manager.save_image(job_id, image)
|
||||||
|
|
||||||
|
# Mark job as completed
|
||||||
|
await queue_manager.complete_job(job_id, result_path)
|
||||||
|
|
||||||
|
logger.info(f"Job {job_id} completed successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Job {job_id} failed: {e}")
|
||||||
|
await queue_manager.fail_job(job_id, str(e))
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Job processor cancelled")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in job processor: {e}")
|
||||||
|
await asyncio.sleep(5) # Wait before retrying
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
Lifespan context manager for startup and shutdown.
|
||||||
|
"""
|
||||||
|
# Startup
|
||||||
|
logger.info(f"🎨 Starting {config.APP_NAME} v{config.APP_VERSION}")
|
||||||
|
logger.info(f"Storage: {config.STORAGE_DIR}")
|
||||||
|
logger.info(f"Models: {config.MODELS_DIR}")
|
||||||
|
|
||||||
|
# Initialize SDXL model
|
||||||
|
try:
|
||||||
|
await generator.initialize()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize SDXL model: {e}")
|
||||||
|
logger.warning("Service will start but generation will fail until model is loaded")
|
||||||
|
|
||||||
|
# Start background tasks
|
||||||
|
worker_task = asyncio.create_task(process_jobs())
|
||||||
|
await cleanup_task.start()
|
||||||
|
|
||||||
|
logger.info(f"✅ {config.APP_NAME} ready on http://{config.API_HOST}:{config.API_PORT}")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown
|
||||||
|
logger.info(f"🛑 Shutting down {config.APP_NAME}...")
|
||||||
|
|
||||||
|
# Cancel worker task
|
||||||
|
worker_task.cancel()
|
||||||
|
try:
|
||||||
|
await worker_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Stop cleanup task
|
||||||
|
await cleanup_task.stop()
|
||||||
|
|
||||||
|
logger.info("✅ Shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
# Create FastAPI application
|
||||||
|
app = FastAPI(
|
||||||
|
title=config.APP_NAME,
|
||||||
|
version=config.APP_VERSION,
|
||||||
|
description="SDXL Image Generation Service for Jetson AGX Orin",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware (allow all origins for multi-client support)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Include API routes
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""Root endpoint with service information."""
|
||||||
|
return {
|
||||||
|
"service": config.APP_NAME,
|
||||||
|
"version": config.APP_VERSION,
|
||||||
|
"status": "running",
|
||||||
|
"model": config.SDXL_MODEL_ID,
|
||||||
|
"endpoints": {
|
||||||
|
"generate": "POST /generate",
|
||||||
|
"status": "GET /status/{job_id}",
|
||||||
|
"result": "GET /result/{job_id}",
|
||||||
|
"health": "GET /health",
|
||||||
|
"models": "GET /models"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run with uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"main:app",
|
||||||
|
host=config.API_HOST,
|
||||||
|
port=config.API_PORT,
|
||||||
|
log_level=config.LOG_LEVEL.lower()
|
||||||
|
)
|
||||||
35
requirements.txt
Executable file
35
requirements.txt
Executable file
@@ -0,0 +1,35 @@
|
|||||||
|
# DreamTail Dependencies
|
||||||
|
|
||||||
|
# Web framework
|
||||||
|
fastapi==0.109.0
|
||||||
|
uvicorn[standard]==0.27.0
|
||||||
|
python-multipart==0.0.6
|
||||||
|
|
||||||
|
# SDXL / Stable Diffusion (upgraded for compatibility)
|
||||||
|
diffusers==0.27.0
|
||||||
|
transformers==4.38.0
|
||||||
|
accelerate==0.27.0
|
||||||
|
safetensors==0.4.2
|
||||||
|
huggingface_hub==0.21.0
|
||||||
|
omegaconf==2.3.0
|
||||||
|
|
||||||
|
# PyTorch (installed in Jetson container, but listed for reference)
|
||||||
|
# torch==2.1.0+cu121 (from Jetson L4T)
|
||||||
|
# torchvision==0.16.0+cu121
|
||||||
|
|
||||||
|
# Image processing
|
||||||
|
Pillow==10.2.0
|
||||||
|
opencv-python==4.9.0.80
|
||||||
|
|
||||||
|
# IP-Adapter FaceID (for consistent face generation)
|
||||||
|
insightface==0.7.3
|
||||||
|
onnxruntime-gpu==1.17.0
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
pydantic==2.6.0
|
||||||
|
pydantic-settings==2.1.0
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
aiofiles==23.2.1
|
||||||
|
|
||||||
|
# Monitoring (optional)
|
||||||
|
prometheus-client==0.19.0
|
||||||
19
scripts/build.sh
Executable file
19
scripts/build.sh
Executable file
@@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Build DreamTail Docker image for Jetson AGX Orin
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "🎨 Building DreamTail Docker image..."
|
||||||
|
|
||||||
|
cd "$(dirname "$0")/.."
|
||||||
|
|
||||||
|
# Build for ARM64 (Jetson architecture)
|
||||||
|
docker build \
|
||||||
|
--platform linux/arm64 \
|
||||||
|
-t dreamtail:latest \
|
||||||
|
-f Dockerfile \
|
||||||
|
.
|
||||||
|
|
||||||
|
echo "✅ Build complete!"
|
||||||
|
echo ""
|
||||||
|
echo "To run DreamTail:"
|
||||||
|
echo " ./scripts/run.sh"
|
||||||
53
scripts/download-models.sh
Executable file
53
scripts/download-models.sh
Executable file
@@ -0,0 +1,53 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Download SDXL models for DreamTail using Docker
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "📥 Downloading SDXL models..."
|
||||||
|
echo "This will download ~13GB of model weights"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Model cache directory
|
||||||
|
MODELS_DIR="${DREAMTAIL_MODELS:-/mnt/nvme/models}"
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
mkdir -p "$MODELS_DIR"
|
||||||
|
|
||||||
|
echo "Models will be cached in: $MODELS_DIR"
|
||||||
|
echo ""
|
||||||
|
echo "Using Docker container to download models..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Use L4T PyTorch container to download models
|
||||||
|
docker run --rm -it \
|
||||||
|
-v "${MODELS_DIR}:/models" \
|
||||||
|
dustynv/l4t-pytorch:r36.2.0-pth2.1-py3 \
|
||||||
|
bash -c "
|
||||||
|
pip3 install -q diffusers transformers accelerate safetensors &&
|
||||||
|
python3 << 'PYEOF'
|
||||||
|
from diffusers import StableDiffusionXLPipeline
|
||||||
|
|
||||||
|
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
|
||||||
|
cache_dir = '/models'
|
||||||
|
|
||||||
|
print(f'Downloading {model_id}...')
|
||||||
|
print(f'Cache directory: {cache_dir}')
|
||||||
|
print('')
|
||||||
|
|
||||||
|
try:
|
||||||
|
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
use_safetensors=True,
|
||||||
|
cache_dir=cache_dir
|
||||||
|
)
|
||||||
|
print('✅ SDXL model downloaded successfully!')
|
||||||
|
except Exception as e:
|
||||||
|
print(f'❌ Error downloading model: {e}')
|
||||||
|
exit(1)
|
||||||
|
PYEOF
|
||||||
|
"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "✅ Model download complete!"
|
||||||
|
echo ""
|
||||||
|
echo "Models are cached in: $MODELS_DIR"
|
||||||
|
echo "You can now build and run DreamTail"
|
||||||
45
scripts/run.sh
Executable file
45
scripts/run.sh
Executable file
@@ -0,0 +1,45 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Run DreamTail on Jetson AGX Orin
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "🎨 Starting DreamTail..."
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
CONTAINER_NAME="dreamtail"
|
||||||
|
PORT=8765
|
||||||
|
MODELS_DIR="/mnt/nvme/models" # Models on NVMe
|
||||||
|
STORAGE_DIR="/mnt/nvme/dreamtail" # DreamTail storage on NVMe
|
||||||
|
|
||||||
|
# Create storage directory if it doesn't exist
|
||||||
|
mkdir -p "$STORAGE_DIR"
|
||||||
|
|
||||||
|
# Stop existing container if running
|
||||||
|
if docker ps -a --format '{{.Names}}' | grep -q "^${CONTAINER_NAME}$"; then
|
||||||
|
echo "Stopping existing DreamTail container..."
|
||||||
|
docker stop "$CONTAINER_NAME" 2>/dev/null || true
|
||||||
|
docker rm "$CONTAINER_NAME" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run container
|
||||||
|
echo "Starting DreamTail container..."
|
||||||
|
docker run -d \
|
||||||
|
--name "$CONTAINER_NAME" \
|
||||||
|
--runtime=nvidia \
|
||||||
|
--restart unless-stopped \
|
||||||
|
-p ${PORT}:8765 \
|
||||||
|
-v "${MODELS_DIR}:/app/models" \
|
||||||
|
-v "${STORAGE_DIR}:/app/storage" \
|
||||||
|
-e DREAMTAIL_STORAGE=/app/storage \
|
||||||
|
-e DREAMTAIL_MODELS=/app/models \
|
||||||
|
-e LOG_LEVEL=INFO \
|
||||||
|
dreamtail:latest
|
||||||
|
|
||||||
|
echo "✅ DreamTail started!"
|
||||||
|
echo ""
|
||||||
|
echo "API available at: http://bigorin:${PORT}"
|
||||||
|
echo ""
|
||||||
|
echo "To check logs:"
|
||||||
|
echo " docker logs -f ${CONTAINER_NAME}"
|
||||||
|
echo ""
|
||||||
|
echo "To stop:"
|
||||||
|
echo " docker stop ${CONTAINER_NAME}"
|
||||||
1
worker/__init__.py
Executable file
1
worker/__init__.py
Executable file
@@ -0,0 +1 @@
|
|||||||
|
"""Worker modules for DreamTail."""
|
||||||
339
worker/generator.py
Executable file
339
worker/generator.py
Executable file
@@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
SDXL Image Generator
|
||||||
|
|
||||||
|
Handles image generation using Stable Diffusion XL with Jetson optimizations.
|
||||||
|
Supports IP-Adapter FaceID for consistent face generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional, Dict, Any, List, Union
|
||||||
|
from pathlib import Path
|
||||||
|
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
|
||||||
|
from PIL import Image
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLGenerator:
|
||||||
|
"""SDXL image generator with optimizations for Jetson AGX Orin."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.pipeline = None
|
||||||
|
self.device = None
|
||||||
|
self.model_loaded = False
|
||||||
|
self._load_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# IP-Adapter FaceID components
|
||||||
|
self.ip_model = None
|
||||||
|
self.face_app = None
|
||||||
|
self.face_embeds_cache = {} # Cache for precomputed face embeddings
|
||||||
|
self.ip_adapter_loaded = False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""Load SDXL model with Jetson optimizations."""
|
||||||
|
async with self._load_lock:
|
||||||
|
if self.model_loaded:
|
||||||
|
logger.info("Model already loaded, skipping initialization")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Initializing SDXL model...")
|
||||||
|
logger.info(f"Model: {config.SDXL_MODEL_ID}")
|
||||||
|
logger.info(f"FP16: {config.USE_FP16}")
|
||||||
|
logger.info(f"Attention slicing: {config.ENABLE_ATTENTION_SLICING}")
|
||||||
|
|
||||||
|
# Determine device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device = "cuda"
|
||||||
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
logger.info(f"VRAM available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
||||||
|
else:
|
||||||
|
self.device = "cpu"
|
||||||
|
logger.warning("CUDA not available, using CPU (will be very slow)")
|
||||||
|
|
||||||
|
# Load pipeline
|
||||||
|
try:
|
||||||
|
dtype = torch.float16 if config.USE_FP16 else torch.float32
|
||||||
|
|
||||||
|
# Use DDIM scheduler for IP-Adapter compatibility
|
||||||
|
noise_scheduler = DDIMScheduler(
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
beta_start=0.00085,
|
||||||
|
beta_end=0.012,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
clip_sample=False,
|
||||||
|
set_alpha_to_one=False,
|
||||||
|
steps_offset=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||||
|
config.SDXL_MODEL_ID,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
scheduler=noise_scheduler,
|
||||||
|
use_safetensors=True,
|
||||||
|
cache_dir=str(config.MODELS_DIR),
|
||||||
|
add_watermarker=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
self.pipeline = self.pipeline.to(self.device)
|
||||||
|
|
||||||
|
# Apply optimizations
|
||||||
|
if config.ENABLE_ATTENTION_SLICING:
|
||||||
|
self.pipeline.enable_attention_slicing()
|
||||||
|
logger.info("Attention slicing enabled")
|
||||||
|
|
||||||
|
if config.ENABLE_VAE_SLICING:
|
||||||
|
self.pipeline.enable_vae_slicing()
|
||||||
|
logger.info("VAE slicing enabled")
|
||||||
|
|
||||||
|
if config.ENABLE_CPU_OFFLOAD and self.device == "cuda":
|
||||||
|
self.pipeline.enable_model_cpu_offload()
|
||||||
|
logger.info("CPU offload enabled")
|
||||||
|
|
||||||
|
self.model_loaded = True
|
||||||
|
logger.info("SDXL model loaded successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load SDXL model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def initialize_ip_adapter(self):
|
||||||
|
"""Load IP-Adapter FaceID components (lazy loading)."""
|
||||||
|
if self.ip_adapter_loaded:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Initializing IP-Adapter FaceID...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import IP-Adapter components
|
||||||
|
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
|
||||||
|
from insightface.app import FaceAnalysis
|
||||||
|
|
||||||
|
# Initialize InsightFace for face detection/embedding
|
||||||
|
self.face_app = FaceAnalysis(
|
||||||
|
name="buffalo_l",
|
||||||
|
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||||||
|
)
|
||||||
|
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
|
||||||
|
logger.info("InsightFace initialized")
|
||||||
|
|
||||||
|
# Load IP-Adapter FaceID model
|
||||||
|
ip_ckpt = str(config.IP_ADAPTER_PATH)
|
||||||
|
self.ip_model = IPAdapterFaceIDXL(
|
||||||
|
self.pipeline,
|
||||||
|
ip_ckpt,
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ip_adapter_loaded = True
|
||||||
|
logger.info("IP-Adapter FaceID loaded successfully!")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"IP-Adapter dependencies not available: {e}")
|
||||||
|
logger.warning("Face-locked generation will not be available")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load IP-Adapter FaceID: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def extract_face_embedding(self, image: Union[str, Path, Image.Image, np.ndarray]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Extract face embedding from an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Path to image, PIL Image, or numpy array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Face embedding tensor
|
||||||
|
"""
|
||||||
|
if self.face_app is None:
|
||||||
|
raise RuntimeError("InsightFace not initialized. Call initialize_ip_adapter() first.")
|
||||||
|
|
||||||
|
# Convert to numpy array if needed
|
||||||
|
if isinstance(image, (str, Path)):
|
||||||
|
img_cv = cv2.imread(str(image))
|
||||||
|
elif isinstance(image, Image.Image):
|
||||||
|
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||||
|
else:
|
||||||
|
img_cv = image
|
||||||
|
|
||||||
|
# Detect faces and extract embedding
|
||||||
|
faces = self.face_app.get(img_cv)
|
||||||
|
|
||||||
|
if len(faces) == 0:
|
||||||
|
raise ValueError("No face detected in image")
|
||||||
|
|
||||||
|
# Use first detected face
|
||||||
|
face_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
||||||
|
logger.info(f"Face embedding extracted: shape {face_embed.shape}")
|
||||||
|
|
||||||
|
return face_embed
|
||||||
|
|
||||||
|
def precompute_face_embeddings(self, face_images: List[Union[str, Path]]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Precompute and average face embeddings from multiple reference images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
face_images: List of paths to face reference images
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Averaged face embedding tensor
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
for img_path in face_images:
|
||||||
|
try:
|
||||||
|
embed = self.extract_face_embedding(img_path)
|
||||||
|
embeddings.append(embed)
|
||||||
|
logger.info(f"Extracted embedding from {img_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to extract face from {img_path}: {e}")
|
||||||
|
|
||||||
|
if len(embeddings) == 0:
|
||||||
|
raise ValueError("No faces could be extracted from any reference images")
|
||||||
|
|
||||||
|
# Average the embeddings for better consistency
|
||||||
|
avg_embedding = torch.mean(torch.stack(embeddings), dim=0)
|
||||||
|
logger.info(f"Averaged {len(embeddings)} face embeddings")
|
||||||
|
|
||||||
|
return avg_embedding
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
width: int = config.DEFAULT_WIDTH,
|
||||||
|
height: int = config.DEFAULT_HEIGHT,
|
||||||
|
num_inference_steps: int = config.DEFAULT_STEPS,
|
||||||
|
guidance_scale: float = config.DEFAULT_GUIDANCE_SCALE,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
progress_callback = None,
|
||||||
|
face_image: Optional[Union[str, Path, List[str]]] = None,
|
||||||
|
face_strength: float = 0.6,
|
||||||
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Generate an image from a text prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt
|
||||||
|
negative_prompt: Negative prompt
|
||||||
|
width: Image width
|
||||||
|
height: Image height
|
||||||
|
num_inference_steps: Number of diffusion steps
|
||||||
|
guidance_scale: Guidance scale
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
progress_callback: Optional async callback(step, total) for progress updates
|
||||||
|
face_image: Optional path(s) to face reference image(s) for face locking
|
||||||
|
face_strength: Strength of face conditioning (0.0-1.0, default 0.6)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model not loaded
|
||||||
|
"""
|
||||||
|
if not self.model_loaded:
|
||||||
|
raise RuntimeError("Model not loaded. Call initialize() first.")
|
||||||
|
|
||||||
|
logger.info(f"Generating image: '{prompt[:50]}...'")
|
||||||
|
logger.info(f"Parameters: {width}x{height}, steps={num_inference_steps}, guidance={guidance_scale}")
|
||||||
|
|
||||||
|
# Check if face-locked generation requested
|
||||||
|
use_face_id = face_image is not None
|
||||||
|
|
||||||
|
if use_face_id:
|
||||||
|
# Initialize IP-Adapter if needed
|
||||||
|
await self.initialize_ip_adapter()
|
||||||
|
|
||||||
|
if not self.ip_adapter_loaded:
|
||||||
|
logger.warning("IP-Adapter not available, generating without face lock")
|
||||||
|
use_face_id = False
|
||||||
|
else:
|
||||||
|
logger.info(f"Face-locked generation enabled, strength={face_strength}")
|
||||||
|
|
||||||
|
# Set random seed if provided
|
||||||
|
generator = None
|
||||||
|
if seed is not None:
|
||||||
|
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||||||
|
logger.info(f"Using seed: {seed}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
if use_face_id:
|
||||||
|
# Extract face embedding(s)
|
||||||
|
if isinstance(face_image, list):
|
||||||
|
face_embed = self.precompute_face_embeddings(face_image)
|
||||||
|
else:
|
||||||
|
face_embed = self.extract_face_embedding(face_image)
|
||||||
|
|
||||||
|
# Generate with IP-Adapter FaceID
|
||||||
|
image = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.ip_model.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
faceid_embeds=face_embed,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
num_samples=1,
|
||||||
|
seed=seed,
|
||||||
|
s_scale=face_strength,
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Progress callback wrapper (only for standard pipeline)
|
||||||
|
def callback_wrapper(step: int, timestep: int, latents: torch.FloatTensor):
|
||||||
|
if progress_callback:
|
||||||
|
progress = int((step / num_inference_steps) * 100)
|
||||||
|
try:
|
||||||
|
asyncio.create_task(progress_callback(progress))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Standard generation without face lock
|
||||||
|
image = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.pipeline(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
guidance_scale=guidance_scale,
|
||||||
|
generator=generator,
|
||||||
|
callback=callback_wrapper,
|
||||||
|
callback_steps=1
|
||||||
|
).images[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Image generated successfully")
|
||||||
|
return image
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating image: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_model_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get information about the loaded model."""
|
||||||
|
return {
|
||||||
|
"model_id": config.SDXL_MODEL_ID,
|
||||||
|
"device": self.device,
|
||||||
|
"fp16": config.USE_FP16,
|
||||||
|
"attention_slicing": config.ENABLE_ATTENTION_SLICING,
|
||||||
|
"vae_slicing": config.ENABLE_VAE_SLICING,
|
||||||
|
"cpu_offload": config.ENABLE_CPU_OFFLOAD,
|
||||||
|
"loaded": self.model_loaded,
|
||||||
|
"ip_adapter_loaded": self.ip_adapter_loaded,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global generator instance
|
||||||
|
generator = SDXLGenerator()
|
||||||
163
worker/queue_manager.py
Executable file
163
worker/queue_manager.py
Executable file
@@ -0,0 +1,163 @@
|
|||||||
|
"""
|
||||||
|
Job Queue Manager
|
||||||
|
|
||||||
|
In-memory job queue for managing image generation requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Optional, Literal
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Job:
|
||||||
|
"""Represents a single generation job."""
|
||||||
|
job_id: str
|
||||||
|
prompt: str
|
||||||
|
negative_prompt: Optional[str]
|
||||||
|
params: Dict
|
||||||
|
status: Literal["queued", "processing", "completed", "failed"]
|
||||||
|
progress: int = 0
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
result_path: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class QueueManager:
|
||||||
|
"""Manages the job queue and job lifecycle."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.jobs: Dict[str, Job] = {}
|
||||||
|
self.queue: asyncio.Queue = asyncio.Queue(maxsize=config.MAX_QUEUE_SIZE)
|
||||||
|
self.active_jobs: int = 0
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def submit_job(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str],
|
||||||
|
params: Dict
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Submit a new generation job.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt
|
||||||
|
negative_prompt: Negative prompt
|
||||||
|
params: Generation parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
job_id: Unique job identifier
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
asyncio.QueueFull: If queue is at capacity
|
||||||
|
"""
|
||||||
|
job_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
job = Job(
|
||||||
|
job_id=job_id,
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
params=params,
|
||||||
|
status="queued"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
self.jobs[job_id] = job
|
||||||
|
|
||||||
|
# Add to queue (raises QueueFull if at capacity)
|
||||||
|
await self.queue.put(job_id)
|
||||||
|
|
||||||
|
logger.info(f"Job {job_id} submitted: '{prompt[:50]}...'")
|
||||||
|
return job_id
|
||||||
|
|
||||||
|
async def get_next_job(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the next job from the queue (blocks until available).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
job_id or None if queue is empty
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
job_id = await self.queue.get()
|
||||||
|
return job_id
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def start_job(self, job_id: str):
|
||||||
|
"""Mark a job as started."""
|
||||||
|
async with self._lock:
|
||||||
|
if job_id in self.jobs:
|
||||||
|
self.jobs[job_id].status = "processing"
|
||||||
|
self.jobs[job_id].started_at = datetime.utcnow()
|
||||||
|
self.active_jobs += 1
|
||||||
|
logger.info(f"Job {job_id} started processing")
|
||||||
|
|
||||||
|
async def update_progress(self, job_id: str, progress: int):
|
||||||
|
"""Update job progress (0-100)."""
|
||||||
|
async with self._lock:
|
||||||
|
if job_id in self.jobs:
|
||||||
|
self.jobs[job_id].progress = min(100, max(0, progress))
|
||||||
|
|
||||||
|
async def complete_job(self, job_id: str, result_path: str):
|
||||||
|
"""Mark a job as completed successfully."""
|
||||||
|
async with self._lock:
|
||||||
|
if job_id in self.jobs:
|
||||||
|
self.jobs[job_id].status = "completed"
|
||||||
|
self.jobs[job_id].completed_at = datetime.utcnow()
|
||||||
|
self.jobs[job_id].progress = 100
|
||||||
|
self.jobs[job_id].result_path = result_path
|
||||||
|
self.active_jobs = max(0, self.active_jobs - 1)
|
||||||
|
logger.info(f"Job {job_id} completed successfully")
|
||||||
|
|
||||||
|
async def fail_job(self, job_id: str, error: str):
|
||||||
|
"""Mark a job as failed."""
|
||||||
|
async with self._lock:
|
||||||
|
if job_id in self.jobs:
|
||||||
|
self.jobs[job_id].status = "failed"
|
||||||
|
self.jobs[job_id].completed_at = datetime.utcnow()
|
||||||
|
self.jobs[job_id].error = error
|
||||||
|
self.active_jobs = max(0, self.active_jobs - 1)
|
||||||
|
logger.error(f"Job {job_id} failed: {error}")
|
||||||
|
|
||||||
|
def get_job(self, job_id: str) -> Optional[Job]:
|
||||||
|
"""Get job by ID."""
|
||||||
|
return self.jobs.get(job_id)
|
||||||
|
|
||||||
|
def get_queue_size(self) -> int:
|
||||||
|
"""Get current queue size."""
|
||||||
|
return self.queue.qsize()
|
||||||
|
|
||||||
|
def get_active_jobs(self) -> int:
|
||||||
|
"""Get number of currently processing jobs."""
|
||||||
|
return self.active_jobs
|
||||||
|
|
||||||
|
async def cleanup_old_jobs(self, max_age_hours: int = 24):
|
||||||
|
"""Remove old completed/failed jobs from memory."""
|
||||||
|
cutoff = datetime.utcnow().timestamp() - (max_age_hours * 3600)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
to_remove = []
|
||||||
|
for job_id, job in self.jobs.items():
|
||||||
|
if job.status in ["completed", "failed"] and job.completed_at:
|
||||||
|
if job.completed_at.timestamp() < cutoff:
|
||||||
|
to_remove.append(job_id)
|
||||||
|
|
||||||
|
for job_id in to_remove:
|
||||||
|
del self.jobs[job_id]
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
logger.info(f"Cleaned up {len(to_remove)} old jobs from memory")
|
||||||
|
|
||||||
|
|
||||||
|
# Global queue manager instance
|
||||||
|
queue_manager = QueueManager()
|
||||||
Reference in New Issue
Block a user