Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from scipy.ndimage import center_of_mass | |
| import math | |
| # 1. Chargement du modèle | |
| model = tf.keras.models.load_model('mnist_cnn_v1.keras') | |
| def get_best_shift(img): | |
| """Calcule le décalage optimal pour centrer l'image par centre de masse.""" | |
| cy, cx = center_of_mass(img) | |
| rows, cols = img.shape | |
| shiftx = np.round(cols/2.0-cx).astype(int) | |
| shifty = np.round(rows/2.0-cy).astype(int) | |
| return shiftx, shifty | |
| def shift(img, sx, sy): | |
| """Applique le décalage géométrique.""" | |
| rows, cols = img.shape | |
| M = np.float32([[1, 0, sx], [0, 1, sy]]) | |
| shifted = cv2.warpAffine(img, M, (cols, rows)) | |
| return shifted | |
| def preprocess_image(input_data): | |
| """Pipeline robuste : Resize -> Gray -> Invert -> Center -> Normalize""" | |
| if input_data is None: | |
| return None | |
| # Gestion format Gradio 4 (Dict ou Array) | |
| img = input_data["composite"] if isinstance(input_data, dict) else input_data | |
| # 1. Resize initial + Grayscale | |
| img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA) | |
| if len(img.shape) == 3: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) | |
| # 2. Inversion (Noir sur Blanc -> Blanc sur Noir) | |
| # MNIST est blanc sur noir. Si l'utilisateur dessine en noir, on inverse. | |
| if np.mean(img) > 127: | |
| img = 255 - img | |
| # 3. Nettoyage du bruit (Thresholding) | |
| (_, img) = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) | |
| # 4. Centrage par Centre de Masse (CRITIQUE pour la précision) | |
| # On ajoute une marge noire temporaire pour éviter de couper le chiffre lors du shift | |
| while np.sum(img[0]) == 0: img = img[1:] | |
| while np.sum(img[:,0]) == 0: img = img[:,1:] | |
| while np.sum(img[-1]) == 0: img = img[:-1] | |
| while np.sum(img[:,-1]) == 0: img = img[:,:-1] | |
| rows, cols = img.shape | |
| if rows > cols: | |
| factor = 20.0/rows | |
| rows = 20 | |
| cols = int(round(cols*factor)) | |
| img = cv2.resize(img, (cols, rows)) | |
| else: | |
| factor = 20.0/cols | |
| cols = 20 | |
| rows = int(round(rows*factor)) | |
| img = cv2.resize(img, (cols, rows)) | |
| colsPadding = (int(math.ceil((28-cols)/2.0)), int(math.floor((28-cols)/2.0))) | |
| rowsPadding = (int(math.ceil((28-rows)/2.0)), int(math.floor((28-rows)/2.0))) | |
| img = np.lib.pad(img, (rowsPadding, colsPadding), 'constant') | |
| shiftx, shifty = get_best_shift(img) | |
| shifted = shift(img, shiftx, shifty) | |
| img = shifted | |
| # 5. Normalisation et Reshape final | |
| img = img / 255.0 | |
| img = img.reshape(1, 28, 28, 1) | |
| return img | |
| def predict(image): | |
| if image is None: return None | |
| processed_img = preprocess_image(image) | |
| prediction = model.predict(processed_img, verbose=0)[0] | |
| # Retourne un dictionnaire {Label: Confiance} pour Gradio | |
| return {str(i): float(prediction[i]) for i in range(10)} | |
| # --- UI Moderne avec Gradio Blocks --- | |
| css = """ | |
| .container {max-width: 800px; margin: auto; padding-top: 20px} | |
| #title {text-align: center; font-size: 2em; font-weight: bold; margin-bottom: 20px} | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| gr.Markdown("# 🧠 MNIST Classifier: From Sketch to Prediction", elem_id="title") | |
| gr.Markdown("Dessinez un chiffre (0-9). L'IA le centrera automatiquement avant l'analyse.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_sketch = gr.Sketchpad(label="Dessinez ici", type="numpy", crop_size=(200, 200)) | |
| predict_btn = gr.Button("Analyser", variant="primary") | |
| with gr.Column(): | |
| label_output = gr.Label(num_top_classes=3, label="Prédictions & Probabilités") | |
| predict_btn.click(fn=predict, inputs=input_sketch, outputs=label_output) | |
| if __name__ == "__main__": | |
| demo.launch() |