DreamTail v1.0.0 with IP-Adapter FaceID support

- SDXL image generation using RealVisXL_V4.0
- IP-Adapter FaceID integration for consistent face generation
- Simplified API (removed client_id requirement)
- New params: face_image, face_strength
- 'vixy' shortcut for face-locked generation
- Queue-based async job processing
- FastAPI with proper error handling

Co-authored-by: Alex <alex@k4zka.online>
This commit is contained in:
2026-01-01 19:54:59 -06:00
commit e4294b57e6
18 changed files with 1895 additions and 0 deletions

1
api/__init__.py Executable file
View File

@@ -0,0 +1 @@
"""API modules for DreamTail."""

73
api/models.py Executable file
View File

@@ -0,0 +1,73 @@
"""
Pydantic models for API requests and responses.
"""
from typing import Optional, Dict, Any, Literal
from pydantic import BaseModel, Field, validator
from datetime import datetime
import config
class GenerationParams(BaseModel):
"""Optional generation parameters."""
width: int = Field(default=config.DEFAULT_WIDTH, ge=512, le=2048)
height: int = Field(default=config.DEFAULT_HEIGHT, ge=512, le=2048)
num_inference_steps: int = Field(default=config.DEFAULT_STEPS, ge=config.MIN_STEPS, le=config.MAX_STEPS)
guidance_scale: float = Field(default=config.DEFAULT_GUIDANCE_SCALE, ge=config.MIN_GUIDANCE, le=config.MAX_GUIDANCE)
seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
face_image: Optional[str] = Field(default=None, description="Face reference image name (from faces directory) or 'vixy' for default")
face_strength: float = Field(default=config.DEFAULT_FACE_STRENGTH, ge=0.0, le=1.0, description="Face conditioning strength (0.0-1.0)")
@validator('width', 'height')
def must_be_multiple_of_8(cls, v):
if v % 8 != 0:
raise ValueError('Width and height must be multiples of 8')
return v
class GenerateRequest(BaseModel):
"""Request to generate an image."""
prompt: str = Field(..., min_length=1, max_length=2000, description="Text prompt for image generation")
negative_prompt: Optional[str] = Field(default=None, max_length=2000, description="Negative prompt to avoid certain features")
params: Optional[GenerationParams] = Field(default_factory=GenerationParams)
class JobResponse(BaseModel):
"""Response when submitting a generation job."""
job_id: str = Field(..., description="Unique job identifier")
status: Literal["queued", "processing", "completed", "failed"] = Field(..., description="Current job status")
created_at: datetime = Field(..., description="Job creation timestamp")
message: Optional[str] = Field(default=None, description="Optional message")
class JobStatus(BaseModel):
"""Detailed job status information."""
job_id: str
status: Literal["queued", "processing", "completed", "failed"]
progress: int = Field(..., ge=0, le=100, description="Progress percentage (0-100)")
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
error: Optional[str] = None
prompt: str
class HealthResponse(BaseModel):
"""Health check response."""
model_config = {"protected_namespaces": ()} # Allow "model_" prefix
status: Literal["healthy", "unhealthy"]
version: str
model_loaded: bool
queue_size: int
active_jobs: int
uptime_seconds: float
class ModelsResponse(BaseModel):
"""Available models information."""
base_model: str
refiner_model: Optional[str] = None
refiner_enabled: bool
device: str
fp16_enabled: bool

166
api/routes.py Executable file
View File

@@ -0,0 +1,166 @@
"""
API Routes
FastAPI routes for image generation service.
"""
import logging
from fastapi import APIRouter, HTTPException, Response
from fastapi.responses import FileResponse
import asyncio
from api.models import (
GenerateRequest, JobResponse, JobStatus,
HealthResponse, ModelsResponse
)
from worker.queue_manager import queue_manager
from worker.generator import generator
from dreamtail_storage.file_manager import file_manager
import config
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/generate", response_model=JobResponse, status_code=202)
async def generate_image(request: GenerateRequest):
"""
Submit an image generation job.
Returns immediately with a job_id. Use /status/{job_id} to check progress.
"""
try:
# Submit job to queue
job_id = await queue_manager.submit_job(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
params=request.params.dict() if request.params else {}
)
job = queue_manager.get_job(job_id)
return JobResponse(
job_id=job_id,
status=job.status,
created_at=job.created_at,
message=f"Job queued. Queue position: {queue_manager.get_queue_size()}"
)
except asyncio.QueueFull:
raise HTTPException(
status_code=503,
detail=f"Queue is full (max: {config.MAX_QUEUE_SIZE}). Please try again later."
)
except Exception as e:
logger.error(f"Error submitting job: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/status/{job_id}", response_model=JobStatus)
async def get_job_status(job_id: str):
"""
Get the status of a generation job.
Returns job progress, status, and timestamps.
"""
job = queue_manager.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return JobStatus(
job_id=job.job_id,
status=job.status,
progress=job.progress,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
error=job.error,
prompt=job.prompt
)
@router.get("/result/{job_id}")
async def get_result(job_id: str):
"""
Download the generated image for a completed job.
Returns the image file as PNG.
"""
job = queue_manager.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status != "completed":
raise HTTPException(
status_code=400,
detail=f"Job is {job.status}, not completed. Check /status/{job_id}"
)
# Check if image file exists
image_path = file_manager.get_image_path(job_id)
if not image_path:
raise HTTPException(
status_code=404,
detail="Image file not found (may have been cleaned up)"
)
# Return image file
return FileResponse(
path=image_path,
media_type="image/png",
filename=f"{job_id}.png"
)
@router.get("/health", response_model=HealthResponse)
async def health_check():
"""
Health check endpoint.
Returns service status and basic statistics.
"""
import time
from main import start_time
return HealthResponse(
status="healthy" if generator.model_loaded else "unhealthy",
version=config.APP_VERSION,
model_loaded=generator.model_loaded,
queue_size=queue_manager.get_queue_size(),
active_jobs=queue_manager.get_active_jobs(),
uptime_seconds=time.time() - start_time
)
@router.get("/models", response_model=ModelsResponse)
async def get_models_info():
"""
Get information about loaded models and configuration.
"""
model_info = generator.get_model_info()
return ModelsResponse(
base_model=config.SDXL_MODEL_ID,
refiner_model=config.SDXL_REFINER_ID if config.USE_REFINER else None,
refiner_enabled=config.USE_REFINER,
device=model_info.get("device", "unknown"),
fp16_enabled=config.USE_FP16
)
@router.get("/storage")
async def get_storage_info():
"""
Get storage statistics (admin endpoint).
"""
stats = file_manager.get_storage_stats()
return {
"total_images": stats["total_images"],
"total_size_mb": round(stats["total_size_mb"], 2),
"storage_path": stats["storage_path"],
"retention_days": config.IMAGE_RETENTION_DAYS
}