import tensorflow as tf import gradio as gr import numpy as np import cv2 from scipy.ndimage import center_of_mass import math # 1. Chargement du modèle model = tf.keras.models.load_model('mnist_cnn_v1.keras') def get_best_shift(img): """Calcule le décalage optimal pour centrer l'image par centre de masse.""" cy, cx = center_of_mass(img) rows, cols = img.shape shiftx = np.round(cols/2.0-cx).astype(int) shifty = np.round(rows/2.0-cy).astype(int) return shiftx, shifty def shift(img, sx, sy): """Applique le décalage géométrique.""" rows, cols = img.shape M = np.float32([[1, 0, sx], [0, 1, sy]]) shifted = cv2.warpAffine(img, M, (cols, rows)) return shifted def preprocess_image(input_data): """Pipeline robuste : Resize -> Gray -> Invert -> Center -> Normalize""" if input_data is None: return None # Gestion format Gradio 4 (Dict ou Array) img = input_data["composite"] if isinstance(input_data, dict) else input_data # 1. Resize initial + Grayscale img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA) if len(img.shape) == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # 2. Inversion (Noir sur Blanc -> Blanc sur Noir) # MNIST est blanc sur noir. Si l'utilisateur dessine en noir, on inverse. if np.mean(img) > 127: img = 255 - img # 3. Nettoyage du bruit (Thresholding) (_, img) = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) # 4. Centrage par Centre de Masse (CRITIQUE pour la précision) # On ajoute une marge noire temporaire pour éviter de couper le chiffre lors du shift while np.sum(img[0]) == 0: img = img[1:] while np.sum(img[:,0]) == 0: img = img[:,1:] while np.sum(img[-1]) == 0: img = img[:-1] while np.sum(img[:,-1]) == 0: img = img[:,:-1] rows, cols = img.shape if rows > cols: factor = 20.0/rows rows = 20 cols = int(round(cols*factor)) img = cv2.resize(img, (cols, rows)) else: factor = 20.0/cols cols = 20 rows = int(round(rows*factor)) img = cv2.resize(img, (cols, rows)) colsPadding = (int(math.ceil((28-cols)/2.0)), int(math.floor((28-cols)/2.0))) rowsPadding = (int(math.ceil((28-rows)/2.0)), int(math.floor((28-rows)/2.0))) img = np.lib.pad(img, (rowsPadding, colsPadding), 'constant') shiftx, shifty = get_best_shift(img) shifted = shift(img, shiftx, shifty) img = shifted # 5. Normalisation et Reshape final img = img / 255.0 img = img.reshape(1, 28, 28, 1) return img def predict(image): if image is None: return None processed_img = preprocess_image(image) prediction = model.predict(processed_img, verbose=0)[0] # Retourne un dictionnaire {Label: Confiance} pour Gradio return {str(i): float(prediction[i]) for i in range(10)} # --- UI Moderne avec Gradio Blocks --- css = """ .container {max-width: 800px; margin: auto; padding-top: 20px} #title {text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 20px} """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: gr.Markdown("# 🧠 MNIST Classifier: From Sketch to Prediction", elem_id="title") gr.Markdown("Dessinez un chiffre (0-9). L'IA le centrera automatiquement avant l'analyse.") with gr.Row(): with gr.Column(): input_sketch = gr.Sketchpad(label="Dessinez ici", type="numpy", crop_size=(200, 200)) predict_btn = gr.Button("Analyser", variant="primary") with gr.Column(): label_output = gr.Label(num_top_classes=3, label="Prédictions & Probabilités") predict_btn.click(fn=predict, inputs=input_sketch, outputs=label_output) if __name__ == "__main__": demo.launch()