"""Sprite sheet generation driver for ComfyUI API."""
import json
import time
import urllib.request
import urllib.error
import uuid
import os
import sys
import shutil

SERVER = "127.0.0.1:8188"
OUT_DIR = "/home/uri/pa/display"


def post_prompt(workflow):
    data = json.dumps({"prompt": workflow, "client_id": str(uuid.uuid4())}).encode()
    req = urllib.request.Request(
        f"http://{SERVER}/prompt",
        data=data,
        headers={"Content-Type": "application/json"},
    )
    with urllib.request.urlopen(req) as resp:
        return json.loads(resp.read())


def wait_for(prompt_id, timeout=600):
    start = time.time()
    while time.time() - start < timeout:
        try:
            with urllib.request.urlopen(f"http://{SERVER}/history/{prompt_id}") as r:
                h = json.loads(r.read())
            if prompt_id in h:
                return h[prompt_id]
        except Exception:
            pass
        time.sleep(2)
    raise TimeoutError(f"prompt {prompt_id} timed out")


def fetch_output(result, dest):
    outputs = result.get("outputs", {})
    for node_id, node_out in outputs.items():
        for img in node_out.get("images", []):
            fn = img["filename"]
            sub = img.get("subfolder", "")
            typ = img.get("type", "output")
            url = f"http://{SERVER}/view?filename={fn}&subfolder={sub}&type={typ}"
            with urllib.request.urlopen(url) as r:
                with open(dest, "wb") as f:
                    shutil.copyfileobj(r, f)
            return dest
    return None


def sdxl_workflow(prompt, negative, ckpt, seed, w=1024, h=1024, steps=28, cfg=7.0, lora=None, lora_strength=0.8):
    nodes = {
        "1": {"class_type": "CheckpointLoaderSimple",
              "inputs": {"ckpt_name": ckpt}},
        "4": {"class_type": "CLIPTextEncode",
              "inputs": {"text": prompt, "clip": ["1", 1]}},
        "5": {"class_type": "CLIPTextEncode",
              "inputs": {"text": negative, "clip": ["1", 1]}},
        "6": {"class_type": "EmptyLatentImage",
              "inputs": {"width": w, "height": h, "batch_size": 1}},
        "7": {"class_type": "KSampler",
              "inputs": {"model": ["1", 0], "positive": ["4", 0], "negative": ["5", 0],
                         "latent_image": ["6", 0], "seed": seed, "steps": steps,
                         "cfg": cfg, "sampler_name": "dpmpp_2m_sde_gpu",
                         "scheduler": "karras", "denoise": 1.0}},
        "8": {"class_type": "VAEDecode",
              "inputs": {"samples": ["7", 0], "vae": ["1", 2]}},
        "9": {"class_type": "SaveImage",
              "inputs": {"images": ["8", 0], "filename_prefix": "sdxl_out"}},
    }
    if lora:
        nodes["2"] = {"class_type": "LoraLoader",
                      "inputs": {"model": ["1", 0], "clip": ["1", 1],
                                 "lora_name": lora, "strength_model": lora_strength,
                                 "strength_clip": lora_strength}}
        nodes["4"]["inputs"]["clip"] = ["2", 1]
        nodes["5"]["inputs"]["clip"] = ["2", 1]
        nodes["7"]["inputs"]["model"] = ["2", 0]
    return nodes


def run_attempt(name, workflow):
    print(f"[gen] {name} submitting...")
    r = post_prompt(workflow)
    pid = r["prompt_id"]
    print(f"[gen] {name} prompt_id={pid}, waiting...")
    result = wait_for(pid)
    out = os.path.join(OUT_DIR, name)
    fetch_output(result, out)
    print(f"[gen] {name} saved -> {out}")
    return out


if __name__ == "__main__":
    print("loaded")
