PraneshJs's picture
Update app.py
265f384 verified
# ==========================================================
# Stable Diffusion v1-4 — CPU Diffusion Visualizer (256x256)
# - Runs on HF CPU
# - Real images (not blurry)
# - Step-by-step latents
# - PCA trajectory + latent norm plots
# ==========================================================
import gradio as gr
import torch
import numpy as np
from diffusers import StableDiffusionPipeline, DDIMScheduler
from sklearn.decomposition import PCA
import plotly.graph_objects as go
from PIL import Image
import time
import warnings
warnings.filterwarnings("ignore")
# ------------------- CPU SETTINGS -------------------
DEVICE = "cpu"
# Sometimes MKLDNN causes weird matmul errors with SD on some CPUs, disable to be safe.
torch.backends.mkldnn.enabled = False
MODEL_ID = "CompVis/stable-diffusion-v1-4"
PIPE_CACHE = None
# ------------------- LOAD SD MODEL -------------------
def get_pipe():
"""
Load and cache the Stable Diffusion v1-4 pipeline on CPU,
with safety checker DISABLED correctly.
"""
global PIPE_CACHE
if PIPE_CACHE is not None:
return PIPE_CACHE
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
safety_checker=None, # <--- disable safety checker properly
requires_safety_checker=False
)
# Use DDIM so we have clear, predictable timesteps for visualization
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(DEVICE)
PIPE_CACHE = pipe
return PIPE_CACHE
# ------------------- PCA + NORM -------------------
def compute_pca(latents):
"""
latents: list of (C,H,W) numpy arrays.
Returns Nx2 array of PCA coords (one point per step).
"""
if not latents:
return np.zeros((0, 2))
flat = [x.flatten() for x in latents]
X = np.stack(flat)
if X.shape[0] < 2:
return np.zeros((X.shape[0], 2))
try:
pca = PCA(n_components=2)
pts = pca.fit_transform(X)
return pts
except Exception:
return np.zeros((X.shape[0], 2))
def compute_norm(latents):
"""
L2 norm of each latent over all dims.
"""
if not latents:
return []
return [float(np.linalg.norm(x.flatten())) for x in latents]
# ------------------- LATENT DECODER -------------------
def decode_latent(pipe, latent_np):
"""
Decode a single latent (C,H,W) numpy array into a 256x256 RGB PIL image.
"""
latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE)
scale = pipe.vae.config.scaling_factor
with torch.no_grad():
image = pipe.vae.decode(latent / scale).sample
image = (image / 2 + 0.5).clamp(0, 1)
np_img = (image[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")
return Image.fromarray(np_img)
# ------------------- MAIN DIFFUSION RUN -------------------
def run_diffusion(prompt, steps, guidance, seed, simple):
"""
Run SD v1-4 at 256x256, capturing latents at EVERY step via callback.
Returns:
- final image
- explanation text
- step slider config
- image at current step
- PCA plot
- norm plot
- state dict (for slider updates)
"""
if not prompt or not prompt.strip():
return (
None,
"⚠️ Please enter a prompt.",
gr.update(maximum=0, value=0),
None,
None,
None,
{}
)
pipe = get_pipe()
steps = int(steps)
guidance = float(guidance)
if seed is None or seed < 0:
seed_val = int(time.time())
else:
seed_val = int(seed)
generator = torch.Generator(device=DEVICE).manual_seed(seed_val)
latents_list = []
timesteps = []
def callback(step: int, timestep: int, latents: torch.FloatTensor):
# latents shape: (batch, C, H, W)
latents_list.append(latents.detach().cpu().numpy()[0])
timesteps.append(int(timestep))
t0 = time.time()
try:
result = pipe(
prompt,
height=256,
width=256,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
callback=callback,
callback_steps=1,
)
except Exception as e:
return (
None,
f"❌ Diffusion error: {e}",
gr.update(maximum=0, value=0),
None,
None,
None,
{"error": str(e)}
)
total = time.time() - t0
if not latents_list:
return (
None,
"❌ No latents collected. Something went wrong inside the pipeline.",
gr.update(maximum=0, value=0),
None,
None,
None,
{"error": "no_latents"}
)
final_image = result.images[0] # PIL
# Compute PCA trajectory and norms
pca_pts = compute_pca(latents_list)
norms = compute_norm(latents_list)
current_idx = len(latents_list) - 1 # final step
# Decode image at current step
try:
step_image = decode_latent(pipe, latents_list[current_idx])
except Exception:
step_image = None
# Explanation text
if simple:
explanation = (
"🧒 **Simple explanation of what you see:**\n\n"
"1. The model starts from pure noise.\n"
"2. At each step, it removes some noise and makes the picture clearer.\n"
"3. Your text prompt tells it what kind of picture to create.\n"
"4. You can move the slider to see the image at different steps.\n"
)
else:
explanation = (
"🔬 **Technical explanation:**\n\n"
"- We run a DDIM diffusion process over the latent space.\n"
"- At each timestep `t`, the UNet predicts noise εₜ and the scheduler updates `zₜ → zₜ₋₁`.\n"
"- We record `zₜ` at every step and decode it with the VAE.\n"
"- PCA over flattened latents gives a 2D trajectory of the diffusion path.\n"
"- The L2 norm plot shows how the latent magnitude evolves per step.\n"
)
explanation += f"\n⏱ **Runtime:** {total:.2f}s • **Steps:** {len(latents_list)} • Seed: {seed_val}"
# Build plots
pca_fig = plot_pca(pca_pts, current_idx) if len(pca_pts) > 0 else None
norm_fig = plot_norm(norms, current_idx) if norms else None
# State for slider updates
state = {
"latents": latents_list,
"pca": pca_pts,
"norms": norms
}
step_slider_update = gr.update(maximum=len(latents_list) - 1, value=current_idx)
return (
final_image,
explanation,
step_slider_update,
step_image,
pca_fig,
norm_fig,
state
)
# ------------------- PLOT HELPERS -------------------
def plot_pca(points, idx):
"""
PCA trajectory plot over steps, highlighting current step.
points: (N,2)
"""
if points.shape[0] == 0:
return None
steps = list(range(points.shape[0]))
fig = go.Figure()
fig.add_trace(go.Scatter(
x=points[:, 0],
y=points[:, 1],
mode="lines+markers",
name="steps",
text=[f"step {i}" for i in steps]
))
if 0 <= idx < len(steps):
fig.add_trace(go.Scatter(
x=[points[idx, 0]],
y=[points[idx, 1]],
mode="markers+text",
text=[f"step {idx}"],
textposition="top center",
marker=dict(size=12, color="red"),
name="current"
))
fig.update_layout(
title="Latent PCA trajectory",
xaxis_title="PC1",
yaxis_title="PC2",
height=350
)
return fig
def plot_norm(norms, idx):
"""
Plot latent L2 norm vs step, highlight current step.
"""
if not norms:
return None
steps = list(range(len(norms)))
fig = go.Figure()
fig.add_trace(go.Scatter(
x=steps,
y=norms,
mode="lines+markers",
name="‖latent‖₂"
))
if 0 <= idx < len(steps):
fig.add_trace(go.Scatter(
x=[idx],
y=[norms[idx]],
mode="markers",
marker=dict(size=12, color="red"),
name="current"
))
fig.update_layout(
title="Latent L2 norm vs step",
xaxis_title="Step index",
yaxis_title="‖latent‖₂",
height=350
)
return fig
# ------------------- SLIDER UPDATE -------------------
def update_step(state, idx):
"""
When user moves the slider:
- decode latent at that step
- update PCA highlight
- update norm highlight
"""
if not state or "latents" not in state:
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
latents = state["latents"]
pca_pts = state["pca"]
norms = state["norms"]
if not latents:
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
idx = int(idx)
idx = max(0, min(idx, len(latents) - 1))
pipe = get_pipe()
try:
img = decode_latent(pipe, latents[idx])
except Exception:
img = None
pca_fig = plot_pca(pca_pts, idx) if pca_pts is not None else None
norm_fig = plot_norm(norms, idx) if norms else None
return gr.update(value=img), gr.update(value=pca_fig), gr.update(value=norm_fig)
# ------------------- GRADIO UI -------------------
with gr.Blocks(title="Stable Diffusion v1-4 — CPU Diffusion Visualizer") as demo:
gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)")
gr.Markdown(
"This app shows **how a real Stable Diffusion model** turns noise into an image, step by step.\n"
"- Uses `CompVis/stable-diffusion-v1-4` on CPU\n"
"- 256×256 resolution for speed\n"
"- You can scrub through all diffusion steps\n"
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
value="a small cozy cabin in the forest, digital art",
lines=3
)
steps = gr.Slider(10, 30, value=20, step=1, label="Number of diffusion steps")
guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.5, label="Guidance scale")
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
simple = gr.Checkbox(label="Simple explanation", value=True)
run = gr.Button("Run diffusion", variant="primary")
with gr.Column():
final = gr.Image(label="Final generated image")
expl = gr.Markdown(label="Explanation")
gr.Markdown("### 🔍 Explore the denoising process step-by-step")
step_slider = gr.Slider(0, 0, value=0, step=1, label="View step (0 = early noise, max = final)")
step_img = gr.Image(label="Image at this diffusion step")
pca_plot = gr.Plot(label="Latent PCA trajectory")
norm_plot = gr.Plot(label="Latent norm vs step")
state = gr.State()
run.click(
run_diffusion,
inputs=[prompt, steps, guidance, seed, simple],
outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state]
)
step_slider.change(
update_step,
inputs=[state, step_slider],
outputs=[step_img, pca_plot, norm_plot]
)
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, pwa=True)