papasega's picture
Update app.py
36306c8 verified
import tensorflow as tf
import gradio as gr
import numpy as np
import cv2
from scipy.ndimage import center_of_mass
import math
# 1. Chargement du modèle
model = tf.keras.models.load_model('mnist_cnn_v1.keras')
def get_best_shift(img):
"""Calcule le décalage optimal pour centrer l'image par centre de masse."""
cy, cx = center_of_mass(img)
rows, cols = img.shape
shiftx = np.round(cols/2.0-cx).astype(int)
shifty = np.round(rows/2.0-cy).astype(int)
return shiftx, shifty
def shift(img, sx, sy):
"""Applique le décalage géométrique."""
rows, cols = img.shape
M = np.float32([[1, 0, sx], [0, 1, sy]])
shifted = cv2.warpAffine(img, M, (cols, rows))
return shifted
def preprocess_image(input_data):
"""Pipeline robuste : Resize -> Gray -> Invert -> Center -> Normalize"""
if input_data is None:
return None
# Gestion format Gradio 4 (Dict ou Array)
img = input_data["composite"] if isinstance(input_data, dict) else input_data
# 1. Resize initial + Grayscale
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
if len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# 2. Inversion (Noir sur Blanc -> Blanc sur Noir)
# MNIST est blanc sur noir. Si l'utilisateur dessine en noir, on inverse.
if np.mean(img) > 127:
img = 255 - img
# 3. Nettoyage du bruit (Thresholding)
(_, img) = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
# 4. Centrage par Centre de Masse (CRITIQUE pour la précision)
# On ajoute une marge noire temporaire pour éviter de couper le chiffre lors du shift
while np.sum(img[0]) == 0: img = img[1:]
while np.sum(img[:,0]) == 0: img = img[:,1:]
while np.sum(img[-1]) == 0: img = img[:-1]
while np.sum(img[:,-1]) == 0: img = img[:,:-1]
rows, cols = img.shape
if rows > cols:
factor = 20.0/rows
rows = 20
cols = int(round(cols*factor))
img = cv2.resize(img, (cols, rows))
else:
factor = 20.0/cols
cols = 20
rows = int(round(rows*factor))
img = cv2.resize(img, (cols, rows))
colsPadding = (int(math.ceil((28-cols)/2.0)), int(math.floor((28-cols)/2.0)))
rowsPadding = (int(math.ceil((28-rows)/2.0)), int(math.floor((28-rows)/2.0)))
img = np.lib.pad(img, (rowsPadding, colsPadding), 'constant')
shiftx, shifty = get_best_shift(img)
shifted = shift(img, shiftx, shifty)
img = shifted
# 5. Normalisation et Reshape final
img = img / 255.0
img = img.reshape(1, 28, 28, 1)
return img
def predict(image):
if image is None: return None
processed_img = preprocess_image(image)
prediction = model.predict(processed_img, verbose=0)[0]
# Retourne un dictionnaire {Label: Confiance} pour Gradio
return {str(i): float(prediction[i]) for i in range(10)}
# --- UI Moderne avec Gradio Blocks ---
css = """
.container {max-width: 800px; margin: auto; padding-top: 20px}
#title {text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 20px}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
gr.Markdown("# 🧠 MNIST Classifier: From Sketch to Prediction", elem_id="title")
gr.Markdown("Dessinez un chiffre (0-9). L'IA le centrera automatiquement avant l'analyse.")
with gr.Row():
with gr.Column():
input_sketch = gr.Sketchpad(label="Dessinez ici", type="numpy", crop_size=(200, 200))
predict_btn = gr.Button("Analyser", variant="primary")
with gr.Column():
label_output = gr.Label(num_top_classes=3, label="Prédictions & Probabilités")
predict_btn.click(fn=predict, inputs=input_sketch, outputs=label_output)
if __name__ == "__main__":
demo.launch()