# ========================================================== # Stable Diffusion v1-4 — CPU Diffusion Visualizer (256x256) # - Runs on HF CPU # - Real images (not blurry) # - Step-by-step latents # - PCA trajectory + latent norm plots # ========================================================== import gradio as gr import torch import numpy as np from diffusers import StableDiffusionPipeline, DDIMScheduler from sklearn.decomposition import PCA import plotly.graph_objects as go from PIL import Image import time import warnings warnings.filterwarnings("ignore") # ------------------- CPU SETTINGS ------------------- DEVICE = "cpu" # Sometimes MKLDNN causes weird matmul errors with SD on some CPUs, disable to be safe. torch.backends.mkldnn.enabled = False MODEL_ID = "CompVis/stable-diffusion-v1-4" PIPE_CACHE = None # ------------------- LOAD SD MODEL ------------------- def get_pipe(): """ Load and cache the Stable Diffusion v1-4 pipeline on CPU, with safety checker DISABLED correctly. """ global PIPE_CACHE if PIPE_CACHE is not None: return PIPE_CACHE pipe = StableDiffusionPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.float32, safety_checker=None, # <--- disable safety checker properly requires_safety_checker=False ) # Use DDIM so we have clear, predictable timesteps for visualization pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.to(DEVICE) PIPE_CACHE = pipe return PIPE_CACHE # ------------------- PCA + NORM ------------------- def compute_pca(latents): """ latents: list of (C,H,W) numpy arrays. Returns Nx2 array of PCA coords (one point per step). """ if not latents: return np.zeros((0, 2)) flat = [x.flatten() for x in latents] X = np.stack(flat) if X.shape[0] < 2: return np.zeros((X.shape[0], 2)) try: pca = PCA(n_components=2) pts = pca.fit_transform(X) return pts except Exception: return np.zeros((X.shape[0], 2)) def compute_norm(latents): """ L2 norm of each latent over all dims. """ if not latents: return [] return [float(np.linalg.norm(x.flatten())) for x in latents] # ------------------- LATENT DECODER ------------------- def decode_latent(pipe, latent_np): """ Decode a single latent (C,H,W) numpy array into a 256x256 RGB PIL image. """ latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE) scale = pipe.vae.config.scaling_factor with torch.no_grad(): image = pipe.vae.decode(latent / scale).sample image = (image / 2 + 0.5).clamp(0, 1) np_img = (image[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8") return Image.fromarray(np_img) # ------------------- MAIN DIFFUSION RUN ------------------- def run_diffusion(prompt, steps, guidance, seed, simple): """ Run SD v1-4 at 256x256, capturing latents at EVERY step via callback. Returns: - final image - explanation text - step slider config - image at current step - PCA plot - norm plot - state dict (for slider updates) """ if not prompt or not prompt.strip(): return ( None, "⚠️ Please enter a prompt.", gr.update(maximum=0, value=0), None, None, None, {} ) pipe = get_pipe() steps = int(steps) guidance = float(guidance) if seed is None or seed < 0: seed_val = int(time.time()) else: seed_val = int(seed) generator = torch.Generator(device=DEVICE).manual_seed(seed_val) latents_list = [] timesteps = [] def callback(step: int, timestep: int, latents: torch.FloatTensor): # latents shape: (batch, C, H, W) latents_list.append(latents.detach().cpu().numpy()[0]) timesteps.append(int(timestep)) t0 = time.time() try: result = pipe( prompt, height=256, width=256, num_inference_steps=steps, guidance_scale=guidance, generator=generator, callback=callback, callback_steps=1, ) except Exception as e: return ( None, f"❌ Diffusion error: {e}", gr.update(maximum=0, value=0), None, None, None, {"error": str(e)} ) total = time.time() - t0 if not latents_list: return ( None, "❌ No latents collected. Something went wrong inside the pipeline.", gr.update(maximum=0, value=0), None, None, None, {"error": "no_latents"} ) final_image = result.images[0] # PIL # Compute PCA trajectory and norms pca_pts = compute_pca(latents_list) norms = compute_norm(latents_list) current_idx = len(latents_list) - 1 # final step # Decode image at current step try: step_image = decode_latent(pipe, latents_list[current_idx]) except Exception: step_image = None # Explanation text if simple: explanation = ( "🧒 **Simple explanation of what you see:**\n\n" "1. The model starts from pure noise.\n" "2. At each step, it removes some noise and makes the picture clearer.\n" "3. Your text prompt tells it what kind of picture to create.\n" "4. You can move the slider to see the image at different steps.\n" ) else: explanation = ( "🔬 **Technical explanation:**\n\n" "- We run a DDIM diffusion process over the latent space.\n" "- At each timestep `t`, the UNet predicts noise εₜ and the scheduler updates `zₜ → zₜ₋₁`.\n" "- We record `zₜ` at every step and decode it with the VAE.\n" "- PCA over flattened latents gives a 2D trajectory of the diffusion path.\n" "- The L2 norm plot shows how the latent magnitude evolves per step.\n" ) explanation += f"\n⏱ **Runtime:** {total:.2f}s • **Steps:** {len(latents_list)} • Seed: {seed_val}" # Build plots pca_fig = plot_pca(pca_pts, current_idx) if len(pca_pts) > 0 else None norm_fig = plot_norm(norms, current_idx) if norms else None # State for slider updates state = { "latents": latents_list, "pca": pca_pts, "norms": norms } step_slider_update = gr.update(maximum=len(latents_list) - 1, value=current_idx) return ( final_image, explanation, step_slider_update, step_image, pca_fig, norm_fig, state ) # ------------------- PLOT HELPERS ------------------- def plot_pca(points, idx): """ PCA trajectory plot over steps, highlighting current step. points: (N,2) """ if points.shape[0] == 0: return None steps = list(range(points.shape[0])) fig = go.Figure() fig.add_trace(go.Scatter( x=points[:, 0], y=points[:, 1], mode="lines+markers", name="steps", text=[f"step {i}" for i in steps] )) if 0 <= idx < len(steps): fig.add_trace(go.Scatter( x=[points[idx, 0]], y=[points[idx, 1]], mode="markers+text", text=[f"step {idx}"], textposition="top center", marker=dict(size=12, color="red"), name="current" )) fig.update_layout( title="Latent PCA trajectory", xaxis_title="PC1", yaxis_title="PC2", height=350 ) return fig def plot_norm(norms, idx): """ Plot latent L2 norm vs step, highlight current step. """ if not norms: return None steps = list(range(len(norms))) fig = go.Figure() fig.add_trace(go.Scatter( x=steps, y=norms, mode="lines+markers", name="‖latent‖₂" )) if 0 <= idx < len(steps): fig.add_trace(go.Scatter( x=[idx], y=[norms[idx]], mode="markers", marker=dict(size=12, color="red"), name="current" )) fig.update_layout( title="Latent L2 norm vs step", xaxis_title="Step index", yaxis_title="‖latent‖₂", height=350 ) return fig # ------------------- SLIDER UPDATE ------------------- def update_step(state, idx): """ When user moves the slider: - decode latent at that step - update PCA highlight - update norm highlight """ if not state or "latents" not in state: return gr.update(value=None), gr.update(value=None), gr.update(value=None) latents = state["latents"] pca_pts = state["pca"] norms = state["norms"] if not latents: return gr.update(value=None), gr.update(value=None), gr.update(value=None) idx = int(idx) idx = max(0, min(idx, len(latents) - 1)) pipe = get_pipe() try: img = decode_latent(pipe, latents[idx]) except Exception: img = None pca_fig = plot_pca(pca_pts, idx) if pca_pts is not None else None norm_fig = plot_norm(norms, idx) if norms else None return gr.update(value=img), gr.update(value=pca_fig), gr.update(value=norm_fig) # ------------------- GRADIO UI ------------------- with gr.Blocks(title="Stable Diffusion v1-4 — CPU Diffusion Visualizer") as demo: gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)") gr.Markdown( "This app shows **how a real Stable Diffusion model** turns noise into an image, step by step.\n" "- Uses `CompVis/stable-diffusion-v1-4` on CPU\n" "- 256×256 resolution for speed\n" "- You can scrub through all diffusion steps\n" ) with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", value="a small cozy cabin in the forest, digital art", lines=3 ) steps = gr.Slider(10, 30, value=20, step=1, label="Number of diffusion steps") guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.5, label="Guidance scale") seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) simple = gr.Checkbox(label="Simple explanation", value=True) run = gr.Button("Run diffusion", variant="primary") with gr.Column(): final = gr.Image(label="Final generated image") expl = gr.Markdown(label="Explanation") gr.Markdown("### 🔍 Explore the denoising process step-by-step") step_slider = gr.Slider(0, 0, value=0, step=1, label="View step (0 = early noise, max = final)") step_img = gr.Image(label="Image at this diffusion step") pca_plot = gr.Plot(label="Latent PCA trajectory") norm_plot = gr.Plot(label="Latent norm vs step") state = gr.State() run.click( run_diffusion, inputs=[prompt, steps, guidance, seed, simple], outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state] ) step_slider.change( update_step, inputs=[state, step_slider], outputs=[step_img, pca_plot, norm_plot] ) demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, pwa=True)