|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import warnings |
|
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
from sklearn.cluster import KMeans |
|
|
from sklearn.decomposition import PCA |
|
|
from transformers import ( |
|
|
AutoImageProcessor, |
|
|
ViTModel, |
|
|
ViTForImageClassification, |
|
|
AutoConfig, |
|
|
) |
|
|
import plotly.express as px |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
MODEL_NAME = "google/vit-base-patch16-224" |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
BASE_MODEL = None |
|
|
CLF_MODEL = None |
|
|
PROCESSOR = None |
|
|
|
|
|
|
|
|
|
|
|
def load_models(): |
|
|
global BASE_MODEL, CLF_MODEL, PROCESSOR |
|
|
if BASE_MODEL is not None and CLF_MODEL is not None and PROCESSOR is not None: |
|
|
return BASE_MODEL, CLF_MODEL, PROCESSOR |
|
|
|
|
|
PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_NAME) |
|
|
|
|
|
|
|
|
cfg = AutoConfig.from_pretrained(MODEL_NAME) |
|
|
cfg.attn_implementation = "eager" |
|
|
cfg.output_attentions = True |
|
|
cfg.output_hidden_states = True |
|
|
|
|
|
|
|
|
BASE_MODEL = ViTModel.from_pretrained(MODEL_NAME, config=cfg) |
|
|
BASE_MODEL.to(DEVICE).eval() |
|
|
|
|
|
|
|
|
CLF_MODEL = ViTForImageClassification.from_pretrained(MODEL_NAME) |
|
|
CLF_MODEL.to(DEVICE).eval() |
|
|
|
|
|
return BASE_MODEL, CLF_MODEL, PROCESSOR |
|
|
|
|
|
|
|
|
|
|
|
def patch_grid_info(image_size: int = 224, patch_size: int = 16): |
|
|
grid_size = image_size // patch_size |
|
|
positions = [] |
|
|
for i in range(grid_size): |
|
|
for j in range(grid_size): |
|
|
|
|
|
cx = int((j + 0.5) * patch_size) |
|
|
cy = int((i + 0.5) * patch_size) |
|
|
positions.append((cx, cy)) |
|
|
return grid_size, positions |
|
|
|
|
|
|
|
|
|
|
|
def draw_patch_grid(img: Image.Image, patch_size: int = 16, outline=(0, 180, 0)) -> Image.Image: |
|
|
img = img.convert("RGB").resize((224, 224)) |
|
|
draw = ImageDraw.Draw(img) |
|
|
w, h = img.size |
|
|
for x in range(0, w, patch_size): |
|
|
draw.line([(x, 0), (x, h)], fill=outline, width=1) |
|
|
for y in range(0, h, patch_size): |
|
|
draw.line([(0, y), (w, y)], fill=outline, width=1) |
|
|
return img |
|
|
|
|
|
|
|
|
def draw_cluster_blocks(img: Image.Image, labels: np.ndarray, n_clusters: int = 4, patch_size: int = 16): |
|
|
""" |
|
|
labels: (n_patches,) cluster labels assigned to each patch index (left→right, top→bottom) |
|
|
""" |
|
|
img = img.convert("RGB").resize((224, 224)) |
|
|
draw = ImageDraw.Draw(img, "RGBA") |
|
|
grid_size, positions = patch_grid_info() |
|
|
colors = [ |
|
|
(255, 99, 71, 140), |
|
|
(60, 179, 113, 140), |
|
|
(65, 105, 225, 140), |
|
|
(255, 215, 0, 140), |
|
|
(199, 21, 133, 140), |
|
|
(0, 206, 209, 140), |
|
|
] |
|
|
for idx, lab in enumerate(labels): |
|
|
i = idx // grid_size |
|
|
j = idx % grid_size |
|
|
x0 = j * patch_size |
|
|
y0 = i * patch_size |
|
|
x1 = x0 + patch_size |
|
|
y1 = y0 + patch_size |
|
|
col = colors[int(lab) % len(colors)] |
|
|
draw.rectangle([x0, y0, x1, y1], fill=col) |
|
|
return img |
|
|
|
|
|
|
|
|
def draw_attention_arrows(img: Image.Image, att_matrix: np.ndarray, top_k: int = 3, query_idx: Optional[int] = None): |
|
|
""" |
|
|
att_matrix: (n_patches, n_patches) attention from query->keys (already preprocessed) |
|
|
If query_idx is None -> use CLS (not plotted as patch), else 0..n_patches-1 |
|
|
We'll draw arrows from query patch centers to top-k key patch centers. |
|
|
""" |
|
|
img = img.convert("RGB").resize((224, 224)) |
|
|
draw = ImageDraw.Draw(img, "RGBA") |
|
|
grid_size, positions = patch_grid_info() |
|
|
|
|
|
if query_idx is None: |
|
|
query_idx = (grid_size * grid_size) // 2 |
|
|
qpos = positions[query_idx] |
|
|
|
|
|
vec = att_matrix[query_idx] |
|
|
top_idx = vec.argsort()[-top_k:][::-1] |
|
|
for t in top_idx: |
|
|
kpos = positions[t] |
|
|
|
|
|
draw.line([qpos, kpos], fill=(255, 0, 0, 200), width=3) |
|
|
|
|
|
dx = kpos[0] - qpos[0] |
|
|
dy = kpos[1] - qpos[1] |
|
|
ang = math.atan2(dy, dx) |
|
|
|
|
|
ah = 8 |
|
|
p1 = (kpos[0] - ah * math.cos(ang - 0.3), kpos[1] - ah * math.sin(ang - 0.3)) |
|
|
p2 = (kpos[0] - ah * math.cos(ang + 0.3), kpos[1] - ah * math.sin(ang + 0.3)) |
|
|
draw.polygon([kpos, p1, p2], fill=(255, 0, 0, 200)) |
|
|
|
|
|
r = 10 |
|
|
draw.ellipse([qpos[0] - r, qpos[1] - r, qpos[0] + r, qpos[1] + r], outline=(0, 0, 255, 220), width=2) |
|
|
return img |
|
|
|
|
|
|
|
|
def make_focus_overlay(img: Image.Image, heat_grid: np.ndarray, alpha: float = 0.6): |
|
|
""" |
|
|
heat_grid: (G,G) float map |
|
|
overlay colored transparency on image where heat is high |
|
|
""" |
|
|
img = img.convert("RGB").resize((224, 224)) |
|
|
g = np.array(heat_grid, dtype=np.float32) |
|
|
if np.any(g): |
|
|
g = g - g.min() |
|
|
if g.max() > 0: |
|
|
g = g / g.max() |
|
|
else: |
|
|
g = np.zeros_like(g) |
|
|
heat_img = Image.fromarray((g * 255).astype("uint8"), mode="L").resize((224, 224), Image.BILINEAR) |
|
|
heat = np.array(heat_img).astype(np.float32) / 255.0 |
|
|
draw = ImageDraw.Draw(img, "RGBA") |
|
|
|
|
|
H, W = heat.shape |
|
|
for y in range(H): |
|
|
for x in range(W): |
|
|
v = heat[y, x] |
|
|
if v > 0.05: |
|
|
|
|
|
r = int(255 * v) |
|
|
gcol = int(200 * (1 - v)) |
|
|
draw.point((x, y), fill=(r, gcol, 40, int(255 * alpha * v))) |
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
def compute_attention_rollout(all_attentions: List[torch.Tensor]) -> np.ndarray: |
|
|
avg_mats = [] |
|
|
for a in all_attentions: |
|
|
mat = a[0].mean(dim=0).detach().cpu().numpy() |
|
|
avg_mats.append(mat) |
|
|
seq = avg_mats[0].shape[0] |
|
|
aug = [] |
|
|
for A in avg_mats: |
|
|
A_hat = A + np.eye(seq) |
|
|
row_sums = A_hat.sum(axis=-1, keepdims=True) |
|
|
row_sums[row_sums == 0] = 1.0 |
|
|
A_hat = A_hat / row_sums |
|
|
aug.append(A_hat) |
|
|
R = aug[0] |
|
|
for A in aug[1:]: |
|
|
R = A @ R |
|
|
return R |
|
|
|
|
|
|
|
|
|
|
|
def pca_plot_from_hidden(hidden_states: List[torch.Tensor], layers: List[int]): |
|
|
pts_all = [] |
|
|
labels = [] |
|
|
for li in layers: |
|
|
hs = hidden_states[li][0].detach().cpu().numpy() |
|
|
patches = hs[1:, :] |
|
|
pca = PCA(n_components=2) |
|
|
pts = pca.fit_transform(patches) |
|
|
pts_all.append(pts) |
|
|
labels.append(np.array([li] * pts.shape[0])) |
|
|
coords = np.vstack(pts_all) |
|
|
layer_labels = np.concatenate(labels) |
|
|
df = {"x": coords[:, 0], "y": coords[:, 1], "layer": layer_labels.astype(str)} |
|
|
fig = px.scatter(df, x="x", y="y", color="layer", title="Patch embeddings across layers (PCA)") |
|
|
fig.update_traces(marker=dict(size=6)) |
|
|
fig.update_layout(height=480) |
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
def analyze_all(img: Optional[Image.Image], mode_simple: bool): |
|
|
if img is None: |
|
|
|
|
|
empty = None |
|
|
return empty, empty, empty, empty, "", empty, empty, empty |
|
|
|
|
|
base, clf, processor = load_models() |
|
|
|
|
|
|
|
|
img224 = img.convert("RGB").resize((224, 224)) |
|
|
inputs = processor(images=img224, return_tensors="pt").to(DEVICE) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = base(**inputs) |
|
|
|
|
|
attentions = outputs.attentions |
|
|
hidden_states = outputs.hidden_states |
|
|
|
|
|
|
|
|
grid_size, positions = patch_grid_info() |
|
|
seq_len = attentions[0].shape[-1] |
|
|
n_patches = seq_len - 1 |
|
|
|
|
|
|
|
|
patch_grid_img = draw_patch_grid(img224.copy()) |
|
|
|
|
|
|
|
|
last_hidden = hidden_states[-1][0].detach().cpu().numpy() |
|
|
patch_embeddings = last_hidden[1:, :] |
|
|
|
|
|
n_clusters = 4 |
|
|
try: |
|
|
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(patch_embeddings) |
|
|
cluster_labels = kmeans.labels_ |
|
|
except Exception: |
|
|
|
|
|
cluster_labels = np.zeros(n_patches, dtype=int) |
|
|
|
|
|
cluster_img = draw_cluster_blocks(img224.copy(), cluster_labels, n_clusters=n_clusters) |
|
|
|
|
|
|
|
|
last_att = attentions[-1][0].mean(dim=0).cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
if last_att.shape[0] >= n_patches + 1: |
|
|
patch_to_patch = last_att[1:, 1:] |
|
|
else: |
|
|
|
|
|
patch_to_patch = np.zeros((n_patches, n_patches)) |
|
|
|
|
|
arrow_img = draw_attention_arrows(img224.copy(), patch_to_patch, top_k=4, query_idx=(n_patches // 2)) |
|
|
|
|
|
|
|
|
rollout = compute_attention_rollout(attentions) |
|
|
|
|
|
rollout_cls = rollout[0, 1:] |
|
|
if rollout_cls.shape[0] != grid_size * grid_size: |
|
|
tmp = np.zeros(grid_size * grid_size, dtype=float) |
|
|
nmin = min(len(rollout_cls), tmp.shape[0]) |
|
|
tmp[:nmin] = rollout_cls[:nmin] |
|
|
rollout_cls = tmp |
|
|
rollout_grid = rollout_cls.reshape(grid_size, grid_size) |
|
|
focus_img = make_focus_overlay(img224.copy(), rollout_grid, alpha=0.6) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = clf(**inputs).logits[0].cpu().numpy() |
|
|
probs = np.exp(logits - logits.max()) |
|
|
probs = probs / probs.sum() |
|
|
top5 = probs.argsort()[-5:][::-1] |
|
|
labels = clf.config.id2label |
|
|
preds_text = "\n".join([f"{labels[i]} — {probs[i]*100:.2f}%" for i in top5]) |
|
|
|
|
|
|
|
|
pca_fig = pca_plot_from_hidden(hidden_states, [0, max(0, len(hidden_states) // 2), len(hidden_states) - 1]) |
|
|
|
|
|
|
|
|
att_np = attentions[-1][0].cpu().numpy() |
|
|
|
|
|
cls_to_patches = att_np.mean(axis=0)[0, 1:] |
|
|
if cls_to_patches.shape[0] != grid_size * grid_size: |
|
|
tmp = np.zeros(grid_size * grid_size, dtype=float) |
|
|
nmin = min(len(cls_to_patches), tmp.shape[0]) |
|
|
tmp[:nmin] = cls_to_patches[:nmin] |
|
|
cls_to_patches = tmp |
|
|
cls_grid = cls_to_patches.reshape(grid_size, grid_size) |
|
|
|
|
|
from PIL import Image |
|
|
focus_overlay_default = make_focus_overlay(img224.copy(), cls_grid, alpha=0.5) |
|
|
|
|
|
|
|
|
state = { |
|
|
"attentions": [a.cpu() for a in attentions], |
|
|
"hidden_states": [h.cpu() for h in hidden_states], |
|
|
"grid_size": grid_size, |
|
|
"num_layers": len(attentions), |
|
|
"num_heads": attentions[0].shape[1], |
|
|
"base_image": img, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
simple_explain = """ |
|
|
**How ViT Sees — Simple Steps** |
|
|
|
|
|
1) **Chop** — The image is chopped into small square tiles (patches) like LEGO pieces. |
|
|
2) **Understand** — Each piece gets a code that describes colors/edges. Pieces that look similar are grouped. |
|
|
3) **Talk** — Pieces tell each other what they see (we draw arrows to show that). |
|
|
4) **Focus & Guess** — The model merges clues and focuses on important areas, then guesses what the image shows. |
|
|
""" |
|
|
|
|
|
advanced_explain = """ |
|
|
**Advanced View:** Explore attention per layer/head, the PCA of patch embeddings, and the model's internal focus. |
|
|
Use sliders to change layer/head and see how ViT's attention evolves. |
|
|
""" |
|
|
|
|
|
return ( |
|
|
patch_grid_img, |
|
|
cluster_img, |
|
|
arrow_img, |
|
|
focus_img, |
|
|
preds_text, |
|
|
simple_explain, |
|
|
focus_overlay_default, |
|
|
pca_fig, |
|
|
preds_text, |
|
|
advanced_explain, |
|
|
state, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def advanced_update_attention(state: Dict[str, Any], layer_idx: int, head_idx: int): |
|
|
if not state: |
|
|
return None |
|
|
l = max(0, min(int(layer_idx), state["num_layers"] - 1)) |
|
|
h = max(0, min(int(head_idx), state["num_heads"] - 1)) |
|
|
att_tensor = state["attentions"][l] |
|
|
if att_tensor.ndim == 4: |
|
|
att_tensor = att_tensor[0] |
|
|
att_np = att_tensor.numpy() |
|
|
|
|
|
vec = att_np[h, 0, 1:] |
|
|
grid = state["grid_size"] |
|
|
if vec.shape[0] != grid * grid: |
|
|
tmp = np.zeros(grid * grid, dtype=float) |
|
|
nmin = min(vec.shape[0], tmp.shape[0]) |
|
|
tmp[:nmin] = vec[:nmin] |
|
|
vec = tmp |
|
|
grid_map = vec.reshape(grid, grid) |
|
|
return make_focus_overlay(state["base_image"].convert("RGB"), grid_map, alpha=0.55) |
|
|
|
|
|
|
|
|
def advanced_update_rollout(state: Dict[str, Any]): |
|
|
if not state: |
|
|
return None |
|
|
mats = [a.unsqueeze(0) if a.ndim == 3 else a for a in state["attentions"]] |
|
|
R = compute_attention_rollout(mats) |
|
|
grid = state["grid_size"] |
|
|
rollout_cls = R[0, 1:] |
|
|
if rollout_cls.shape[0] != grid * grid: |
|
|
tmp = np.zeros(grid * grid, dtype=float) |
|
|
nmin = min(len(rollout_cls), tmp.shape[0]) |
|
|
tmp[:nmin] = rollout_cls[:nmin] |
|
|
rollout_cls = tmp |
|
|
rollout_grid = rollout_cls.reshape(grid, grid) |
|
|
return make_focus_overlay(state["base_image"].convert("RGB"), rollout_grid, alpha=0.6) |
|
|
|
|
|
|
|
|
def advanced_update_pca(state: Dict[str, Any], txt: str): |
|
|
if not state: |
|
|
return None |
|
|
try: |
|
|
layers = [int(x.strip()) for x in txt.split(",") if x.strip() != ""] |
|
|
except Exception: |
|
|
layers = [0, max(0, state["num_layers"] - 1)] |
|
|
return pca_plot_from_hidden(state["hidden_states"], layers) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="ViT Visualizer — Simple + Advanced") as demo: |
|
|
gr.Markdown("# 👀 How Vision Transformers (ViT) See Images\n" |
|
|
"Simple mode (story-style) + Advanced mode (inspect internals). Model: **google/vit-base-patch16-224**") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Simple (for everyone)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
img_input = gr.Image(label="Upload an image (photo / object)", type="pil") |
|
|
run_btn = gr.Button("🔎 Explain simply") |
|
|
gr.Markdown("Tip: use clear images of objects, animals, scenes for best examples.") |
|
|
with gr.Column(scale=1): |
|
|
pass |
|
|
|
|
|
gr.Markdown("### Step 1 — Chopped into patches") |
|
|
step1 = gr.Image(label="Patch Grid (ViT chops image into 16×16 patches)") |
|
|
|
|
|
gr.Markdown("### Step 2 — The model groups similar patches") |
|
|
step2 = gr.Image(label="Clustered patches (colored blocks)") |
|
|
|
|
|
gr.Markdown("### Step 3 — Patches talk to each other (simplified)") |
|
|
step3 = gr.Image(label="Patch-to-Patch arrows") |
|
|
|
|
|
gr.Markdown("### Step 4 — Model focus map and guess") |
|
|
with gr.Row(): |
|
|
step4 = gr.Image(label="Focus map (where model looked most)") |
|
|
preds_simple = gr.Textbox(label="Model guesses (Top-5)", lines=4) |
|
|
|
|
|
explanation_simple = gr.Markdown() |
|
|
|
|
|
run_btn.click( |
|
|
fn=analyze_all, |
|
|
inputs=[img_input, gr.State(True)], |
|
|
outputs=[step1, step2, step3, step4, preds_simple, explanation_simple, |
|
|
gr.State(), gr.Plot(), gr.Textbox(), gr.Markdown(), gr.State()], |
|
|
) |
|
|
|
|
|
with gr.TabItem("Advanced (inspect internals)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
img_adv = gr.Image(label="Upload image for advanced view", type="pil") |
|
|
run_adv = gr.Button("Analyze (advanced)") |
|
|
gr.Markdown("Use the sliders to explore attention per layer and head.") |
|
|
layer_slider = gr.Slider(0, 11, value=11, step=1, label="Layer (0=shallow)") |
|
|
head_slider = gr.Slider(0, 11, value=0, step=1, label="Head index") |
|
|
rollout_btn = gr.Button("Refresh Rollout Overlay") |
|
|
pca_txt = gr.Textbox(label="PCA layers (comma separated)", value="0,6,11") |
|
|
pca_btn = gr.Button("Update PCA") |
|
|
with gr.Column(scale=1): |
|
|
adv_attn = gr.Image(label="Attention overlay (layer/head CLS->patch)") |
|
|
adv_rollout = gr.Image(label="Attention rollout overlay (aggregated)") |
|
|
adv_pca = gr.Plot(label="PCA of patch embeddings") |
|
|
adv_preds = gr.Textbox(label="Top-5 predictions", lines=5) |
|
|
adv_explain = gr.Markdown() |
|
|
|
|
|
state_box = gr.State() |
|
|
|
|
|
|
|
|
run_adv.click( |
|
|
fn=analyze_all, |
|
|
inputs=[img_adv, gr.State(False)], |
|
|
outputs=[gr.Image(), gr.Image(), gr.Image(), gr.Image(), adv_preds, gr.Markdown(), |
|
|
adv_attn, adv_pca, adv_preds, adv_explain, state_box], |
|
|
) |
|
|
|
|
|
|
|
|
layer_slider.change( |
|
|
fn=advanced_update_attention, |
|
|
inputs=[state_box, layer_slider, head_slider], |
|
|
outputs=[adv_attn], |
|
|
) |
|
|
head_slider.change( |
|
|
fn=advanced_update_attention, |
|
|
inputs=[state_box, layer_slider, head_slider], |
|
|
outputs=[adv_attn], |
|
|
) |
|
|
|
|
|
rollout_btn.click( |
|
|
fn=advanced_update_rollout, |
|
|
inputs=[state_box], |
|
|
outputs=[adv_rollout], |
|
|
) |
|
|
|
|
|
pca_btn.click( |
|
|
fn=advanced_update_pca, |
|
|
inputs=[state_box, pca_txt], |
|
|
outputs=[adv_pca], |
|
|
) |
|
|
|
|
|
demo.launch() |