|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
from diffusers import StableDiffusionPipeline, DDIMScheduler |
|
|
from sklearn.decomposition import PCA |
|
|
import plotly.graph_objects as go |
|
|
from PIL import Image |
|
|
import time |
|
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
|
torch.backends.mkldnn.enabled = False |
|
|
|
|
|
MODEL_ID = "CompVis/stable-diffusion-v1-4" |
|
|
|
|
|
PIPE_CACHE = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_pipe(): |
|
|
""" |
|
|
Load and cache the Stable Diffusion v1-4 pipeline on CPU, |
|
|
with safety checker DISABLED correctly. |
|
|
""" |
|
|
global PIPE_CACHE |
|
|
if PIPE_CACHE is not None: |
|
|
return PIPE_CACHE |
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float32, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False |
|
|
) |
|
|
|
|
|
|
|
|
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
|
pipe.to(DEVICE) |
|
|
|
|
|
PIPE_CACHE = pipe |
|
|
return PIPE_CACHE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_pca(latents): |
|
|
""" |
|
|
latents: list of (C,H,W) numpy arrays. |
|
|
Returns Nx2 array of PCA coords (one point per step). |
|
|
""" |
|
|
if not latents: |
|
|
return np.zeros((0, 2)) |
|
|
flat = [x.flatten() for x in latents] |
|
|
X = np.stack(flat) |
|
|
if X.shape[0] < 2: |
|
|
return np.zeros((X.shape[0], 2)) |
|
|
try: |
|
|
pca = PCA(n_components=2) |
|
|
pts = pca.fit_transform(X) |
|
|
return pts |
|
|
except Exception: |
|
|
return np.zeros((X.shape[0], 2)) |
|
|
|
|
|
|
|
|
def compute_norm(latents): |
|
|
""" |
|
|
L2 norm of each latent over all dims. |
|
|
""" |
|
|
if not latents: |
|
|
return [] |
|
|
return [float(np.linalg.norm(x.flatten())) for x in latents] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_latent(pipe, latent_np): |
|
|
""" |
|
|
Decode a single latent (C,H,W) numpy array into a 256x256 RGB PIL image. |
|
|
""" |
|
|
latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE) |
|
|
scale = pipe.vae.config.scaling_factor |
|
|
with torch.no_grad(): |
|
|
image = pipe.vae.decode(latent / scale).sample |
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
np_img = (image[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8") |
|
|
return Image.fromarray(np_img) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_diffusion(prompt, steps, guidance, seed, simple): |
|
|
""" |
|
|
Run SD v1-4 at 256x256, capturing latents at EVERY step via callback. |
|
|
Returns: |
|
|
- final image |
|
|
- explanation text |
|
|
- step slider config |
|
|
- image at current step |
|
|
- PCA plot |
|
|
- norm plot |
|
|
- state dict (for slider updates) |
|
|
""" |
|
|
|
|
|
if not prompt or not prompt.strip(): |
|
|
return ( |
|
|
None, |
|
|
"⚠️ Please enter a prompt.", |
|
|
gr.update(maximum=0, value=0), |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
{} |
|
|
) |
|
|
|
|
|
pipe = get_pipe() |
|
|
|
|
|
steps = int(steps) |
|
|
guidance = float(guidance) |
|
|
|
|
|
if seed is None or seed < 0: |
|
|
seed_val = int(time.time()) |
|
|
else: |
|
|
seed_val = int(seed) |
|
|
|
|
|
generator = torch.Generator(device=DEVICE).manual_seed(seed_val) |
|
|
|
|
|
latents_list = [] |
|
|
timesteps = [] |
|
|
|
|
|
def callback(step: int, timestep: int, latents: torch.FloatTensor): |
|
|
|
|
|
latents_list.append(latents.detach().cpu().numpy()[0]) |
|
|
timesteps.append(int(timestep)) |
|
|
|
|
|
t0 = time.time() |
|
|
try: |
|
|
result = pipe( |
|
|
prompt, |
|
|
height=256, |
|
|
width=256, |
|
|
num_inference_steps=steps, |
|
|
guidance_scale=guidance, |
|
|
generator=generator, |
|
|
callback=callback, |
|
|
callback_steps=1, |
|
|
) |
|
|
except Exception as e: |
|
|
return ( |
|
|
None, |
|
|
f"❌ Diffusion error: {e}", |
|
|
gr.update(maximum=0, value=0), |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
{"error": str(e)} |
|
|
) |
|
|
|
|
|
total = time.time() - t0 |
|
|
|
|
|
if not latents_list: |
|
|
return ( |
|
|
None, |
|
|
"❌ No latents collected. Something went wrong inside the pipeline.", |
|
|
gr.update(maximum=0, value=0), |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
{"error": "no_latents"} |
|
|
) |
|
|
|
|
|
final_image = result.images[0] |
|
|
|
|
|
|
|
|
pca_pts = compute_pca(latents_list) |
|
|
norms = compute_norm(latents_list) |
|
|
|
|
|
current_idx = len(latents_list) - 1 |
|
|
|
|
|
|
|
|
try: |
|
|
step_image = decode_latent(pipe, latents_list[current_idx]) |
|
|
except Exception: |
|
|
step_image = None |
|
|
|
|
|
|
|
|
if simple: |
|
|
explanation = ( |
|
|
"🧒 **Simple explanation of what you see:**\n\n" |
|
|
"1. The model starts from pure noise.\n" |
|
|
"2. At each step, it removes some noise and makes the picture clearer.\n" |
|
|
"3. Your text prompt tells it what kind of picture to create.\n" |
|
|
"4. You can move the slider to see the image at different steps.\n" |
|
|
) |
|
|
else: |
|
|
explanation = ( |
|
|
"🔬 **Technical explanation:**\n\n" |
|
|
"- We run a DDIM diffusion process over the latent space.\n" |
|
|
"- At each timestep `t`, the UNet predicts noise εₜ and the scheduler updates `zₜ → zₜ₋₁`.\n" |
|
|
"- We record `zₜ` at every step and decode it with the VAE.\n" |
|
|
"- PCA over flattened latents gives a 2D trajectory of the diffusion path.\n" |
|
|
"- The L2 norm plot shows how the latent magnitude evolves per step.\n" |
|
|
) |
|
|
explanation += f"\n⏱ **Runtime:** {total:.2f}s • **Steps:** {len(latents_list)} • Seed: {seed_val}" |
|
|
|
|
|
|
|
|
pca_fig = plot_pca(pca_pts, current_idx) if len(pca_pts) > 0 else None |
|
|
norm_fig = plot_norm(norms, current_idx) if norms else None |
|
|
|
|
|
|
|
|
state = { |
|
|
"latents": latents_list, |
|
|
"pca": pca_pts, |
|
|
"norms": norms |
|
|
} |
|
|
|
|
|
step_slider_update = gr.update(maximum=len(latents_list) - 1, value=current_idx) |
|
|
|
|
|
return ( |
|
|
final_image, |
|
|
explanation, |
|
|
step_slider_update, |
|
|
step_image, |
|
|
pca_fig, |
|
|
norm_fig, |
|
|
state |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_pca(points, idx): |
|
|
""" |
|
|
PCA trajectory plot over steps, highlighting current step. |
|
|
points: (N,2) |
|
|
""" |
|
|
if points.shape[0] == 0: |
|
|
return None |
|
|
|
|
|
steps = list(range(points.shape[0])) |
|
|
fig = go.Figure() |
|
|
fig.add_trace(go.Scatter( |
|
|
x=points[:, 0], |
|
|
y=points[:, 1], |
|
|
mode="lines+markers", |
|
|
name="steps", |
|
|
text=[f"step {i}" for i in steps] |
|
|
)) |
|
|
if 0 <= idx < len(steps): |
|
|
fig.add_trace(go.Scatter( |
|
|
x=[points[idx, 0]], |
|
|
y=[points[idx, 1]], |
|
|
mode="markers+text", |
|
|
text=[f"step {idx}"], |
|
|
textposition="top center", |
|
|
marker=dict(size=12, color="red"), |
|
|
name="current" |
|
|
)) |
|
|
fig.update_layout( |
|
|
title="Latent PCA trajectory", |
|
|
xaxis_title="PC1", |
|
|
yaxis_title="PC2", |
|
|
height=350 |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
def plot_norm(norms, idx): |
|
|
""" |
|
|
Plot latent L2 norm vs step, highlight current step. |
|
|
""" |
|
|
if not norms: |
|
|
return None |
|
|
steps = list(range(len(norms))) |
|
|
fig = go.Figure() |
|
|
fig.add_trace(go.Scatter( |
|
|
x=steps, |
|
|
y=norms, |
|
|
mode="lines+markers", |
|
|
name="‖latent‖₂" |
|
|
)) |
|
|
if 0 <= idx < len(steps): |
|
|
fig.add_trace(go.Scatter( |
|
|
x=[idx], |
|
|
y=[norms[idx]], |
|
|
mode="markers", |
|
|
marker=dict(size=12, color="red"), |
|
|
name="current" |
|
|
)) |
|
|
fig.update_layout( |
|
|
title="Latent L2 norm vs step", |
|
|
xaxis_title="Step index", |
|
|
yaxis_title="‖latent‖₂", |
|
|
height=350 |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_step(state, idx): |
|
|
""" |
|
|
When user moves the slider: |
|
|
- decode latent at that step |
|
|
- update PCA highlight |
|
|
- update norm highlight |
|
|
""" |
|
|
if not state or "latents" not in state: |
|
|
return gr.update(value=None), gr.update(value=None), gr.update(value=None) |
|
|
|
|
|
latents = state["latents"] |
|
|
pca_pts = state["pca"] |
|
|
norms = state["norms"] |
|
|
|
|
|
if not latents: |
|
|
return gr.update(value=None), gr.update(value=None), gr.update(value=None) |
|
|
|
|
|
idx = int(idx) |
|
|
idx = max(0, min(idx, len(latents) - 1)) |
|
|
|
|
|
pipe = get_pipe() |
|
|
|
|
|
try: |
|
|
img = decode_latent(pipe, latents[idx]) |
|
|
except Exception: |
|
|
img = None |
|
|
|
|
|
pca_fig = plot_pca(pca_pts, idx) if pca_pts is not None else None |
|
|
norm_fig = plot_norm(norms, idx) if norms else None |
|
|
|
|
|
return gr.update(value=img), gr.update(value=pca_fig), gr.update(value=norm_fig) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Stable Diffusion v1-4 — CPU Diffusion Visualizer") as demo: |
|
|
|
|
|
gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)") |
|
|
gr.Markdown( |
|
|
"This app shows **how a real Stable Diffusion model** turns noise into an image, step by step.\n" |
|
|
"- Uses `CompVis/stable-diffusion-v1-4` on CPU\n" |
|
|
"- 256×256 resolution for speed\n" |
|
|
"- You can scrub through all diffusion steps\n" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
value="a small cozy cabin in the forest, digital art", |
|
|
lines=3 |
|
|
) |
|
|
steps = gr.Slider(10, 30, value=20, step=1, label="Number of diffusion steps") |
|
|
guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.5, label="Guidance scale") |
|
|
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) |
|
|
simple = gr.Checkbox(label="Simple explanation", value=True) |
|
|
run = gr.Button("Run diffusion", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
final = gr.Image(label="Final generated image") |
|
|
expl = gr.Markdown(label="Explanation") |
|
|
|
|
|
gr.Markdown("### 🔍 Explore the denoising process step-by-step") |
|
|
|
|
|
step_slider = gr.Slider(0, 0, value=0, step=1, label="View step (0 = early noise, max = final)") |
|
|
step_img = gr.Image(label="Image at this diffusion step") |
|
|
pca_plot = gr.Plot(label="Latent PCA trajectory") |
|
|
norm_plot = gr.Plot(label="Latent norm vs step") |
|
|
|
|
|
state = gr.State() |
|
|
|
|
|
run.click( |
|
|
run_diffusion, |
|
|
inputs=[prompt, steps, guidance, seed, simple], |
|
|
outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state] |
|
|
) |
|
|
|
|
|
step_slider.change( |
|
|
update_step, |
|
|
inputs=[state, step_slider], |
|
|
outputs=[step_img, pca_plot, norm_plot] |
|
|
) |
|
|
|
|
|
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, pwa=True) |