266 lines
9.6 KiB
Python
266 lines
9.6 KiB
Python
"""FastAPI server for the sprite generation review GUI.
|
|
|
|
Serves the REST API for browsing sprites, reviewing variants, and approving/rejecting.
|
|
Also serves raw/variant images and the static GUI build.
|
|
|
|
Usage:
|
|
uvicorn server:app --port 5801 --reload
|
|
# or via CLI:
|
|
python3 cli.py review --port 5801
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
|
|
from fastapi import FastAPI, HTTPException, Query
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
|
|
from engine.registry import SpriteRegistry
|
|
|
|
TOOL_DIR = Path(__file__).resolve().parent
|
|
DB_PATH = TOOL_DIR / "sprites.db"
|
|
RAW_DIR = TOOL_DIR / "raw"
|
|
VARIANTS_DIR = TOOL_DIR / "variants"
|
|
GUI_DIST = TOOL_DIR / "gui" / "dist"
|
|
|
|
|
|
class ApproveRequest(BaseModel):
|
|
variant_id: int
|
|
dimension_id: int | None = None
|
|
|
|
|
|
class RejectRequest(BaseModel):
|
|
dimension_id: int | None = None
|
|
|
|
|
|
class RegenerateRequest(BaseModel):
|
|
prompt: str | None = None
|
|
dimension_id: int | None = None
|
|
variants: int = 8
|
|
|
|
|
|
class PromptUpdate(BaseModel):
|
|
prompt: str
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
category: str | None = None
|
|
sprite_id: str | None = None
|
|
variants: int = 8
|
|
priority: str = "normal"
|
|
max_sprites: int | None = None
|
|
|
|
|
|
def create_app(
|
|
registry: SpriteRegistry | None = None,
|
|
raw_dir: Path = RAW_DIR,
|
|
variants_dir: Path = VARIANTS_DIR,
|
|
) -> FastAPI:
|
|
if registry is None:
|
|
registry = SpriteRegistry(DB_PATH)
|
|
|
|
app = FastAPI(title="Magic Civilization Sprite Generator", version="1.0.0")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# ── Sprites ───────────────────────────────────────────────────────────
|
|
|
|
@app.get("/api/sprites")
|
|
def list_sprites(
|
|
category: Annotated[str | None, Query()] = None,
|
|
status: Annotated[str | None, Query()] = None,
|
|
search: Annotated[str | None, Query()] = None,
|
|
limit: Annotated[int, Query(ge=1, le=10000)] = 200,
|
|
offset: Annotated[int, Query(ge=0)] = 0,
|
|
) -> list[dict]:
|
|
return registry.get_sprites(
|
|
category=category, status=status, search=search,
|
|
limit=limit, offset=offset,
|
|
)
|
|
|
|
@app.get("/api/sprites/{sprite_id:path}")
|
|
def get_sprite(sprite_id: str) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
return sprite
|
|
|
|
@app.get("/api/sprites/{sprite_id:path}/variants")
|
|
def get_variants(
|
|
sprite_id: str,
|
|
dimension_id: Annotated[int | None, Query()] = None,
|
|
) -> list[dict]:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
return registry.get_variants(sprite_id, dimension_id=dimension_id)
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/approve")
|
|
def approve_sprite(sprite_id: str, body: ApproveRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.approve_variant(body.variant_id)
|
|
return {"status": "approved", "variant_id": body.variant_id}
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/reject")
|
|
def reject_sprite(sprite_id: str, body: RejectRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.reject_sprite(sprite_id, dimension_id=body.dimension_id)
|
|
return {"status": "rejected"}
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/skip")
|
|
def skip_sprite(sprite_id: str) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.update_sprite_status(sprite_id, "skip")
|
|
return {"status": "skip"}
|
|
|
|
@app.post("/api/sprites/{sprite_id:path}/regenerate")
|
|
def regenerate_sprite(sprite_id: str, body: RegenerateRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
if body.prompt:
|
|
registry.conn.execute(
|
|
"UPDATE sprites SET prompt = ?, updated_at = datetime('now') WHERE id = ?",
|
|
(body.prompt, sprite_id),
|
|
)
|
|
registry.conn.commit()
|
|
registry.reject_sprite(sprite_id, dimension_id=body.dimension_id)
|
|
return {"status": "needed", "message": "Ready for regeneration"}
|
|
|
|
@app.put("/api/sprites/{sprite_id:path}/prompt")
|
|
def update_prompt(sprite_id: str, body: PromptUpdate) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.conn.execute(
|
|
"UPDATE sprites SET prompt = ?, updated_at = datetime('now') WHERE id = ?",
|
|
(body.prompt, sprite_id),
|
|
)
|
|
registry.conn.commit()
|
|
return {"status": "updated", "prompt": body.prompt}
|
|
|
|
# ── Progress & Queue ─────────────────────────────────────────────────
|
|
|
|
@app.get("/api/stats")
|
|
def get_stats() -> dict:
|
|
return registry.get_stats()
|
|
|
|
@app.get("/api/progress")
|
|
def get_progress() -> dict:
|
|
return registry.get_progress()
|
|
|
|
@app.get("/api/queue")
|
|
def get_review_queue(
|
|
limit: Annotated[int, Query(ge=1, le=200)] = 50,
|
|
) -> list[dict]:
|
|
return registry.get_review_queue(limit=limit)
|
|
|
|
@app.post("/api/queue/{sprite_id:path}/approve")
|
|
def approve_from_queue(sprite_id: str, body: ApproveRequest) -> dict:
|
|
sprite = registry.get_sprite(sprite_id)
|
|
if not sprite:
|
|
raise HTTPException(404, f"Sprite not found: {sprite_id}")
|
|
registry.approve_variant(body.variant_id)
|
|
return {"status": "approved", "variant_id": body.variant_id}
|
|
|
|
@app.get("/api/variants/recent")
|
|
def get_recent_variants(
|
|
limit: Annotated[int, Query(ge=1, le=100)] = 30,
|
|
) -> list[dict]:
|
|
return registry.get_recent_variants(limit=limit)
|
|
|
|
@app.get("/api/stream/variants")
|
|
async def stream_variants() -> StreamingResponse:
|
|
async def event_generator():
|
|
last_check = datetime.now(timezone.utc).isoformat()
|
|
while True:
|
|
await asyncio.sleep(3)
|
|
new_variants = registry.get_recent_variants(limit=10, since=last_check)
|
|
if new_variants:
|
|
last_check = new_variants[0]["created_at"]
|
|
yield f"data: {json.dumps(new_variants)}\n\n"
|
|
else:
|
|
yield ": keepalive\n\n"
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
)
|
|
|
|
@app.get("/api/runs")
|
|
def get_runs() -> list[dict]:
|
|
return registry.get_runs()
|
|
|
|
# ── Generation trigger ────────────────────────────────────────────────
|
|
|
|
@app.post("/api/generate")
|
|
def trigger_generate(body: GenerateRequest) -> dict:
|
|
from engine.generator import SpriteGenerator
|
|
|
|
config = json.loads((TOOL_DIR / "sprite-config.json").read_text())
|
|
gen = SpriteGenerator(config=config, registry=registry, raw_dir=raw_dir)
|
|
|
|
sprites = registry.get_sprites(
|
|
category=body.category,
|
|
status="needed",
|
|
limit=body.max_sprites or 10000,
|
|
)
|
|
if body.sprite_id:
|
|
sprites = [s for s in sprites if s["id"] == body.sprite_id]
|
|
|
|
if not sprites:
|
|
return {"submitted": 0, "message": "No sprites in 'needed' status"}
|
|
|
|
sprite_ids = [s["id"] for s in sprites]
|
|
submitted = gen.generate_sprites(
|
|
sprite_ids=sprite_ids,
|
|
variants_per=body.variants,
|
|
priority=body.priority,
|
|
)
|
|
return {"submitted": submitted, "sprites": len(sprite_ids)}
|
|
|
|
# ── Image serving ─────────────────────────────────────────────────────
|
|
|
|
@app.get("/images/raw/{file_path:path}")
|
|
def serve_raw(file_path: str) -> FileResponse:
|
|
full = raw_dir / file_path
|
|
if not full.exists():
|
|
raise HTTPException(404, f"Image not found: {file_path}")
|
|
return FileResponse(full, media_type="image/png")
|
|
|
|
@app.get("/images/variants/{file_path:path}")
|
|
def serve_variant(file_path: str) -> FileResponse:
|
|
full = variants_dir / file_path
|
|
if not full.exists():
|
|
raise HTTPException(404, f"Image not found: {file_path}")
|
|
return FileResponse(full, media_type="image/png")
|
|
|
|
# ── Static GUI ────────────────────────────────────────────────────────
|
|
|
|
if GUI_DIST.exists():
|
|
app.mount("/", StaticFiles(directory=str(GUI_DIST), html=True))
|
|
|
|
return app
|
|
|
|
|
|
# Direct run support
|
|
app = create_app()
|