🎨 MCP integration for DreamTail (SDXL on Jetson Orin) - dreamtail_generate: Create images with prompts - dreamtail_get_info: Get last generated image path - Inline display support for Claude Desktop - Configurable JPEG quality and download directory Built with love for the hardware dragon 🦊
378 lines
15 KiB
Python
Executable File
378 lines
15 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
DreamTail MCP Server
|
|
|
|
Model Context Protocol server for integrating DreamTail SDXL image generation
|
|
with Claude Desktop.
|
|
|
|
Provides a single tool: dreamtail_generate() that submits a job, polls for
|
|
completion, and returns the image URL.
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import logging
|
|
from typing import Optional, List, Union, Dict, Any
|
|
from pathlib import Path
|
|
from io import BytesIO
|
|
from datetime import datetime
|
|
import httpx
|
|
from fastmcp import FastMCP
|
|
from fastmcp.utilities.types import Image as MCPImage
|
|
from PIL import Image
|
|
|
|
# Configure logging to file for debugging
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('/tmp/dreamtail_mcp.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Initialize MCP server
|
|
mcp = FastMCP("DreamTail Image Generator")
|
|
|
|
# Configuration from environment
|
|
DREAMTAIL_BASE_URL = os.getenv("DREAMTAIL_BASE_URL", "http://bigorin.local:8765")
|
|
DOWNLOAD_DIR = os.getenv("DREAMTAIL_DOWNLOAD_DIR", os.path.expanduser("~/dreamtail_images"))
|
|
DOWNLOAD_FORMAT = os.getenv("DREAMTAIL_FORMAT", "jpeg") # "jpeg", "png", or "both"
|
|
JPEG_QUALITY = int(os.getenv("DREAMTAIL_JPEG_QUALITY", "95")) # 1-100
|
|
INLINE_DISPLAY = os.getenv("DREAMTAIL_INLINE_DISPLAY", "true").lower() == "true" # Enable inline display in Claude
|
|
INLINE_QUALITY = int(os.getenv("DREAMTAIL_INLINE_QUALITY", "85")) # Quality for inline display (70-95)
|
|
MAX_INLINE_SIZE_MB = 0.90 # Leave margin below 1MB limit (accounting for dict metadata)
|
|
DEFAULT_POLL_INTERVAL = 3 # seconds
|
|
DEFAULT_TIMEOUT = 120 # seconds (2 minutes)
|
|
|
|
# Create download directory if it doesn't exist
|
|
Path(DOWNLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
|
|
|
# Track last generated image info
|
|
last_generation_info: Optional[Dict[str, Any]] = None
|
|
logger.info(f"Download directory: {DOWNLOAD_DIR}")
|
|
logger.info(f"Download format: {DOWNLOAD_FORMAT} (JPEG quality: {JPEG_QUALITY})")
|
|
logger.info(f"Inline display: {INLINE_DISPLAY} (quality: {INLINE_QUALITY})")
|
|
|
|
|
|
async def download_image(client: httpx.AsyncClient, image_url: str, job_id: str) -> List[str]:
|
|
"""
|
|
Download the generated image to local storage with format conversion.
|
|
|
|
Args:
|
|
client: httpx AsyncClient instance
|
|
image_url: URL to download the image from
|
|
job_id: Job ID to use as filename
|
|
|
|
Returns:
|
|
List of local file paths where images were saved
|
|
"""
|
|
logger.info(f"Downloading image from {image_url}")
|
|
|
|
# Download the image
|
|
response = await client.get(image_url)
|
|
response.raise_for_status()
|
|
|
|
# Open image with PIL
|
|
img = Image.open(BytesIO(response.content))
|
|
|
|
saved_paths = []
|
|
|
|
# Save based on format preference
|
|
if DOWNLOAD_FORMAT in ["jpeg", "both"]:
|
|
# Convert RGBA to RGB if needed (JPEG doesn't support transparency)
|
|
if img.mode in ("RGBA", "LA", "P"):
|
|
rgb_img = Image.new("RGB", img.size, (255, 255, 255))
|
|
rgb_img.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
|
|
img_to_save = rgb_img
|
|
else:
|
|
img_to_save = img
|
|
|
|
jpeg_path = Path(DOWNLOAD_DIR) / f"{job_id}.jpg"
|
|
img_to_save.save(jpeg_path, "JPEG", quality=JPEG_QUALITY, optimize=True)
|
|
saved_paths.append(str(jpeg_path))
|
|
logger.info(f"JPEG saved: {jpeg_path} (quality {JPEG_QUALITY})")
|
|
|
|
if DOWNLOAD_FORMAT in ["png", "both"]:
|
|
png_path = Path(DOWNLOAD_DIR) / f"{job_id}.png"
|
|
img.save(png_path, "PNG")
|
|
saved_paths.append(str(png_path))
|
|
logger.info(f"PNG saved: {png_path}")
|
|
|
|
if not saved_paths:
|
|
# Fallback to PNG if invalid format specified
|
|
png_path = Path(DOWNLOAD_DIR) / f"{job_id}.png"
|
|
img.save(png_path, "PNG")
|
|
saved_paths.append(str(png_path))
|
|
logger.warning(f"Invalid format '{DOWNLOAD_FORMAT}', defaulting to PNG: {png_path}")
|
|
|
|
return saved_paths
|
|
|
|
|
|
async def prepare_inline_image(client: httpx.AsyncClient, image_url: str) -> Optional[MCPImage]:
|
|
"""
|
|
Prepare image for inline display in Claude Desktop.
|
|
|
|
Compresses image to stay under 1MB limit while maintaining visual quality.
|
|
|
|
Args:
|
|
client: httpx AsyncClient instance
|
|
image_url: URL to download the image from
|
|
|
|
Returns:
|
|
MCPImage object for inline display, or None if preparation fails
|
|
"""
|
|
try:
|
|
logger.info(f"Preparing inline image from {image_url}")
|
|
|
|
# Download the image
|
|
response = await client.get(image_url)
|
|
response.raise_for_status()
|
|
|
|
# Open with PIL
|
|
pil_image = Image.open(BytesIO(response.content))
|
|
|
|
# Convert to RGB if needed (JPEG doesn't support transparency)
|
|
if pil_image.mode in ("RGBA", "LA", "P"):
|
|
rgb_image = Image.new("RGB", pil_image.size, (255, 255, 255))
|
|
if pil_image.mode == "RGBA":
|
|
rgb_image.paste(pil_image, mask=pil_image.split()[-1])
|
|
else:
|
|
rgb_image.paste(pil_image)
|
|
pil_image = rgb_image
|
|
elif pil_image.mode != "RGB":
|
|
pil_image = pil_image.convert("RGB")
|
|
|
|
# Try compression at configured quality first
|
|
buffer = BytesIO()
|
|
pil_image.save(buffer, format="JPEG", quality=INLINE_QUALITY, optimize=True)
|
|
img_bytes = buffer.getvalue()
|
|
size_mb = len(img_bytes) / (1024 * 1024)
|
|
|
|
logger.info(f"Initial compression: {size_mb:.2f}MB at quality {INLINE_QUALITY}")
|
|
|
|
# If still too large, try lower quality
|
|
if size_mb > MAX_INLINE_SIZE_MB:
|
|
quality = INLINE_QUALITY - 10
|
|
while quality >= 50 and size_mb > MAX_INLINE_SIZE_MB:
|
|
buffer = BytesIO()
|
|
pil_image.save(buffer, format="JPEG", quality=quality, optimize=True)
|
|
img_bytes = buffer.getvalue()
|
|
size_mb = len(img_bytes) / (1024 * 1024)
|
|
logger.info(f"Recompressed: {size_mb:.2f}MB at quality {quality}")
|
|
quality -= 10
|
|
|
|
# If still too large, try resizing
|
|
if size_mb > MAX_INLINE_SIZE_MB:
|
|
scale = 0.8
|
|
while scale >= 0.5 and size_mb > MAX_INLINE_SIZE_MB:
|
|
new_size = (int(pil_image.width * scale), int(pil_image.height * scale))
|
|
resized = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
|
buffer = BytesIO()
|
|
resized.save(buffer, format="JPEG", quality=75, optimize=True)
|
|
img_bytes = buffer.getvalue()
|
|
size_mb = len(img_bytes) / (1024 * 1024)
|
|
logger.info(f"Resized to {new_size}: {size_mb:.2f}MB")
|
|
scale -= 0.1
|
|
|
|
# Final check
|
|
if size_mb > MAX_INLINE_SIZE_MB:
|
|
logger.warning(f"Image still too large ({size_mb:.2f}MB), cannot display inline")
|
|
return None
|
|
|
|
logger.info(f"✓ Inline image ready: {size_mb:.2f}MB")
|
|
return MCPImage(data=img_bytes, format="jpeg")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to prepare inline image: {e}")
|
|
return None
|
|
|
|
|
|
@mcp.tool()
|
|
async def dreamtail_generate(
|
|
prompt: str,
|
|
negative_prompt: Optional[str] = None,
|
|
width: int = 1024,
|
|
height: int = 1024,
|
|
num_inference_steps: int = 30,
|
|
client_id: str = "claude"
|
|
) -> Union[MCPImage, str]:
|
|
"""
|
|
Generate an image using SDXL on the Jetson Orin.
|
|
|
|
This tool submits an image generation job to DreamTail, waits for it to
|
|
complete (polling every 3 seconds), and returns the generated image.
|
|
|
|
When inline display is enabled (default), the image is shown directly in
|
|
Claude Desktop. Full-resolution images are also saved to your local folder.
|
|
|
|
Generation typically takes 45-60 seconds for a 1024x1024 image.
|
|
|
|
Args:
|
|
prompt: Text description of the image to generate
|
|
negative_prompt: Optional text describing what to avoid in the image
|
|
width: Image width in pixels (must be multiple of 8, default 1024)
|
|
height: Image height in pixels (must be multiple of 8, default 1024)
|
|
num_inference_steps: Number of denoising steps (20-50, default 30)
|
|
client_id: Client identifier (default "claude")
|
|
|
|
Returns:
|
|
MCPImage object for inline display (if enabled and successful),
|
|
or text message with file paths and filename
|
|
"""
|
|
try:
|
|
logger.info("=" * 60)
|
|
logger.info(f"TOOL CALLED: dreamtail_generate")
|
|
logger.info(f"DreamTail base URL: {DREAMTAIL_BASE_URL}")
|
|
logger.info(f"Prompt: {prompt[:50]}...")
|
|
logger.info(f"Inline display: {INLINE_DISPLAY}")
|
|
logger.info("=" * 60)
|
|
except Exception as e:
|
|
logger.error(f"Error in initial logging: {e}")
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
try:
|
|
# Step 1: Submit generation job
|
|
submit_url = f"{DREAMTAIL_BASE_URL}/generate"
|
|
logger.debug(f"Submitting to: {submit_url}")
|
|
|
|
submit_response = await client.post(
|
|
submit_url,
|
|
json={
|
|
"prompt": prompt,
|
|
"client_id": client_id,
|
|
"negative_prompt": negative_prompt,
|
|
"params": {
|
|
"width": width,
|
|
"height": height,
|
|
"num_inference_steps": num_inference_steps
|
|
}
|
|
}
|
|
)
|
|
submit_response.raise_for_status()
|
|
job_data = submit_response.json()
|
|
job_id = job_data["job_id"]
|
|
logger.info(f"Job submitted successfully: {job_id}")
|
|
|
|
# Step 2: Poll for completion
|
|
elapsed = 0
|
|
last_progress = 0
|
|
|
|
while elapsed < DEFAULT_TIMEOUT:
|
|
# Check job status
|
|
status_response = await client.get(
|
|
f"{DREAMTAIL_BASE_URL}/status/{job_id}"
|
|
)
|
|
status_response.raise_for_status()
|
|
status_data = status_response.json()
|
|
|
|
current_status = status_data["status"]
|
|
last_progress = status_data.get("progress", 0)
|
|
|
|
# Job completed successfully
|
|
if current_status == "completed":
|
|
image_url = f"{DREAMTAIL_BASE_URL}/result/{job_id}"
|
|
|
|
# Try inline display if enabled
|
|
if INLINE_DISPLAY:
|
|
try:
|
|
inline_image = await prepare_inline_image(client, image_url)
|
|
|
|
# Download full-resolution to local storage
|
|
local_paths = []
|
|
try:
|
|
local_paths = await download_image(client, image_url, job_id)
|
|
logger.info(f"Saved full-res to: {', '.join(local_paths)}")
|
|
except Exception as save_error:
|
|
logger.warning(f"Failed to save full-res locally: {save_error}")
|
|
|
|
# Return inline image if successful
|
|
if inline_image:
|
|
logger.info("✓ Inline image prepared successfully")
|
|
|
|
# Store info about this generation
|
|
global last_generation_info
|
|
filename = Path(local_paths[0]).name if local_paths else f"{job_id}.jpg"
|
|
last_generation_info = {
|
|
"job_id": job_id,
|
|
"filename": filename,
|
|
"saved_paths": local_paths,
|
|
"prompt": prompt[:100], # Truncate long prompts
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
# Note: We return just the MCPImage, not wrapped in a dict
|
|
# Use get_last_dreamtail_info() to retrieve filename and paths
|
|
return inline_image
|
|
else:
|
|
logger.warning("Inline image preparation failed, falling back to file path")
|
|
|
|
except Exception as inline_error:
|
|
logger.error(f"Inline display error: {inline_error}")
|
|
|
|
# Fallback: Download and return message with file paths
|
|
try:
|
|
local_paths = await download_image(client, image_url, job_id)
|
|
|
|
# Format message based on number of files
|
|
if len(local_paths) == 1:
|
|
filename = Path(local_paths[0]).name
|
|
message = f"✓ Image generated successfully!\n\nFilename: {filename}\nSaved to: {local_paths[0]}\nJob ID: {job_id}"
|
|
else:
|
|
filename = Path(local_paths[0]).name
|
|
paths_str = "\n".join([f" - {p}" for p in local_paths])
|
|
message = f"✓ Image generated successfully!\n\nFilename: {filename}\nSaved to:\n{paths_str}\nJob ID: {job_id}"
|
|
|
|
return message
|
|
except Exception as download_error:
|
|
logger.error(f"Failed to download image: {download_error}")
|
|
return f"Image generated successfully but download failed: {download_error}. You can download it manually from: {image_url}"
|
|
|
|
# Job failed
|
|
if current_status == "failed":
|
|
error = status_data.get("error", "Unknown error")
|
|
return f"❌ Image generation failed: {error}"
|
|
|
|
# Still processing, wait and poll again
|
|
await asyncio.sleep(DEFAULT_POLL_INTERVAL)
|
|
elapsed += DEFAULT_POLL_INTERVAL
|
|
|
|
# Timeout reached
|
|
return f"⏱️ Generation exceeded {DEFAULT_TIMEOUT} second timeout. Job {job_id} is still processing. Check status at: {DREAMTAIL_BASE_URL}/status/{job_id}"
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
logger.exception(f"HTTPStatusError in dreamtail_generate: {e}")
|
|
return f"❌ DreamTail API error: {e.response.status_code} - {e.response.text}"
|
|
except httpx.RequestError as e:
|
|
logger.exception(f"RequestError in dreamtail_generate: {e}")
|
|
return f"❌ Failed to connect to DreamTail at {DREAMTAIL_BASE_URL}: {str(e)}"
|
|
except Exception as e:
|
|
logger.exception(f"Unexpected error in dreamtail_generate: {e}")
|
|
return f"❌ Unexpected error: {str(e)}"
|
|
|
|
|
|
@mcp.tool()
|
|
def dreamtail_get_info() -> str:
|
|
"""
|
|
Get the file path of the last image generated by DreamTail.
|
|
|
|
Useful for getting the file path after dreamtail_generate() returns an image,
|
|
so you can send it to Matrix with send_matrix_image_from_file().
|
|
|
|
Returns:
|
|
File path of the last generated image, or error message if none exists
|
|
"""
|
|
global last_generation_info
|
|
|
|
if last_generation_info is None:
|
|
return "No DreamTail images generated yet in this session"
|
|
|
|
# Return just the first saved path (primary file)
|
|
return last_generation_info['saved_paths'][0]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run the MCP server (uses stdio transport by default)
|
|
mcp.run()
|