Files
dreamtail/api/routes.py
Vixy e4294b57e6 DreamTail v1.0.0 with IP-Adapter FaceID support
- SDXL image generation using RealVisXL_V4.0
- IP-Adapter FaceID integration for consistent face generation
- Simplified API (removed client_id requirement)
- New params: face_image, face_strength
- 'vixy' shortcut for face-locked generation
- Queue-based async job processing
- FastAPI with proper error handling

Co-authored-by: Alex <alex@k4zka.online>
2026-01-01 19:54:59 -06:00

167 lines
4.4 KiB
Python
Executable File

"""
API Routes
FastAPI routes for image generation service.
"""
import logging
from fastapi import APIRouter, HTTPException, Response
from fastapi.responses import FileResponse
import asyncio
from api.models import (
GenerateRequest, JobResponse, JobStatus,
HealthResponse, ModelsResponse
)
from worker.queue_manager import queue_manager
from worker.generator import generator
from dreamtail_storage.file_manager import file_manager
import config
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/generate", response_model=JobResponse, status_code=202)
async def generate_image(request: GenerateRequest):
"""
Submit an image generation job.
Returns immediately with a job_id. Use /status/{job_id} to check progress.
"""
try:
# Submit job to queue
job_id = await queue_manager.submit_job(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
params=request.params.dict() if request.params else {}
)
job = queue_manager.get_job(job_id)
return JobResponse(
job_id=job_id,
status=job.status,
created_at=job.created_at,
message=f"Job queued. Queue position: {queue_manager.get_queue_size()}"
)
except asyncio.QueueFull:
raise HTTPException(
status_code=503,
detail=f"Queue is full (max: {config.MAX_QUEUE_SIZE}). Please try again later."
)
except Exception as e:
logger.error(f"Error submitting job: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/status/{job_id}", response_model=JobStatus)
async def get_job_status(job_id: str):
"""
Get the status of a generation job.
Returns job progress, status, and timestamps.
"""
job = queue_manager.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return JobStatus(
job_id=job.job_id,
status=job.status,
progress=job.progress,
created_at=job.created_at,
started_at=job.started_at,
completed_at=job.completed_at,
error=job.error,
prompt=job.prompt
)
@router.get("/result/{job_id}")
async def get_result(job_id: str):
"""
Download the generated image for a completed job.
Returns the image file as PNG.
"""
job = queue_manager.get_job(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status != "completed":
raise HTTPException(
status_code=400,
detail=f"Job is {job.status}, not completed. Check /status/{job_id}"
)
# Check if image file exists
image_path = file_manager.get_image_path(job_id)
if not image_path:
raise HTTPException(
status_code=404,
detail="Image file not found (may have been cleaned up)"
)
# Return image file
return FileResponse(
path=image_path,
media_type="image/png",
filename=f"{job_id}.png"
)
@router.get("/health", response_model=HealthResponse)
async def health_check():
"""
Health check endpoint.
Returns service status and basic statistics.
"""
import time
from main import start_time
return HealthResponse(
status="healthy" if generator.model_loaded else "unhealthy",
version=config.APP_VERSION,
model_loaded=generator.model_loaded,
queue_size=queue_manager.get_queue_size(),
active_jobs=queue_manager.get_active_jobs(),
uptime_seconds=time.time() - start_time
)
@router.get("/models", response_model=ModelsResponse)
async def get_models_info():
"""
Get information about loaded models and configuration.
"""
model_info = generator.get_model_info()
return ModelsResponse(
base_model=config.SDXL_MODEL_ID,
refiner_model=config.SDXL_REFINER_ID if config.USE_REFINER else None,
refiner_enabled=config.USE_REFINER,
device=model_info.get("device", "unknown"),
fp16_enabled=config.USE_FP16
)
@router.get("/storage")
async def get_storage_info():
"""
Get storage statistics (admin endpoint).
"""
stats = file_manager.get_storage_stats()
return {
"total_images": stats["total_images"],
"total_size_mb": round(stats["total_size_mb"], 2),
"storage_path": stats["storage_path"],
"retention_days": config.IMAGE_RETENTION_DAYS
}