MNIST Express - 5/8 - Pourquoi 97% d'accuracy ne veut rien dire
- 5 mars
- 5 min de lecture
Dernière mise à jour : 22 mars
Lien vers l'appli : https://mnist-express.streamlit.app/
Lien vers le repo du projet : https://github.com/vohorgeez/MNIST-Express
Si je vous dit que notre modèle atteint 97% d'accuracy, ça déclenche plutôt un apriori positif, non ? Ça sonne sérieux, scientifique, presque magique. On imagine une machine quasi infaillible, un modèle qui comprend les chiffres comme un enfant de CP.
Spoiler : non.
Dans cet article, je vous propose de voir pourquoi 97% d'accuracy peut être une information quasi inutile, et pourquoi les ingénieurs machine learning regardent surtout la distribution des erreurs.
Dans les articles précédents, nous avons progressivement construit notre modèle de classification de chiffres.
Nous avons vu
comment MNIST transforme des images en vecteurs (NumPy et la géométrie des données),
comment la malédiction dimensionnelle peut compliquer ces comparaisons,
et comment PCA permet de réduire les dimensions tout en gardant l'essentiel de l'information.
Tout cela, donc, nous permet d'entraîner un modèle capable de reconnaître des chiffres manuscrits.
Et lorsque l'on entraîne ce modèle... on obtient souvent un résultat comme celui-ci :
Accuracy : 97%
Ce chiffre semble excellent. Mais il cache l'essentiel.
L'illusion du chiffre magique
L'accuracy globale est la métrique la plus simple possible. C'est littéralement :
La proportion de bonnes prédictions.
Ni plus, ni moins.
Dans notre repo, nous avons :
def accuracy(y_true, y_pred) -> float:
return accuracy_score(y_true, y_pred)Simple. Propre. Efficace.
Et avec notre run actuel, nous obtenons :
accuracy_global = 0.98333333333333Donc, 98.33%.
On a envie de crier "waouh", "ça marche !", "il est vivant !", "l'IA va nous prendre notre travail !"... Sauf que...
Les faiblesses cachées du modèle
Notre pipeline d'entraînement (`mnist_express/train.py`) fait déjà ce que font les projets sérieux :
prédiction sur tout le test set
calcul des métriques
export JSON
export de la matrice de confusion
Le coeur de la machine, le voici :
y_pred_full = pipe.predict(X_test)
acc = accuracy(y_test, y_pred_full)
acc_per_class = accuracy_per_class(y_test, y_pred_full, labels=range(10))
cm = compute_confusion_matrix(y_test, y_pred_full, labels=range(10))
recall_per_class, weakest = weakest_classes(cm, labels=range(10), top_k=3)Traduction pour les moldus :
ok, on a la note finale, l'accuracy globale,
mais surtout : on va regarder comment le modèle se comporte par chiffre, et où il se trompe.
Et les résultats exportés dans `artifacts/metrics/report_plain.json` (rapport généré après un run) sont très parlants.
Accuracy par classe (vraies valeurs)
Classe (chiffre prédit) | Accuracy sur la classe |
0 | 1.00 |
1 | 1.00 |
2 | 1.00 |
3 | 1.00 |
4 | 1.00 |
5 | 0.9787 |
6 | 1.00 |
7 | 0.9706 |
8 | 0.9667 |
9 | 0.925 |
Donc : notre modèle est parfait sur certains chiffres... et clairement moins bon sur d'autres, surtout le 9.
Voilà la forêt qui se cachait derrière l'arbre de l'accuracy globale...

Le vrai sujet : la distribution des erreurs
La vraie question, finalement, c'est moins "combien le modèle réussit", que "où est-ce qu'il se trompe, et comment ces erreurs se répartissent ?".
Pour ça, l'outil incontournable = la matrice de confusion.
Dans notre code :
def compute_confusion_matrix(y_true, y_pred, labels=None) -> np.ndarray:
...
cm = confusion_matrix(y_true, y_pred, labels=list(labels))
return cmEt la matrice que l'on obtient après un run raconte une autre histoire.

En y regardant de près (j'avoue il faut regarder d'assez près avec cette charte graphique), ça ne ressemble plus trop à "le modèle est à 98% donc tout va bien", mais plutôt à "le modèle est excellent globalement, mais il a un point faible : les 9 (et quelques confusions rares)" :
un 7 a été pris pour un 9 ;
un 8 a été pris pour un 1 ;
un 9 a été pris pour un 3 ;
un 9 a été pris pour un 4 ;
un 9 a été pris pour un 5.
Ce diagnostic, l'accuracy globale seule ne pouvait pas nous le donner.
Pourquoi ces erreurs existent (et pourquoi elles sont logiques)
Je vous renvoie à mon article sur k-NN !
k-NN ne "comprend" pas un chiffre, il compare des vecteurs, et prend les voisins les plus proches.
Donc les erreurs ne sont pas vraiment le fruit du "hasard", mais suivent une logique géométrique :
un 9 peut ressembler à un 4 si la boucle est ouverte ;
un 9 peut ressembler à un 3 si la boucle est écrasée ;
un 9 peut ressembler à un 5 si le trait du haut se casse visuellement...
C'est en tout cas ce que semblent montrer les erreurs que nous avons reportés...
Donc, le modèle n'est pas si "stupide" : il est cohérent... dans un espace où certains chiffres sont proches.
C'est exactement la raison pour laquelle nous avons parlé de PCA : réduire la dimension change la géométrie, donc peut changer la répartition des erreurs (même si l'accuracy globale bouge peu).
"Weakest classes"
C'est bien pour cela aussi que nous avons pris le parti, dans ce projet, d'identifier les classes les plus faibles par recall via `weakest_classes()`.
Le principe :
on prend la matrice de confusion
pour chaque chiffre :
"combien de vrais X ont été reconnus comme X ?"
divisé par "combien de vrais X existent ?"
Notre code :
true_counts = cm.sum(axis=1)
correct = np.diag(cm)
recall = correct / true_countsEt le résultat exporté :
"weakest_classes": [
[9, 0.925],
[8, 0.9666...],
[7, 0.9705...]
]Traduction :
Notre modèle "voit" très bien 0, 1, 2, 3, 4, 6
il galère plus sur 7 et 8
et le 9 est la vraie zone fragile
On l'aura compris à ce stade, le modèle n'a donc pas une performance uniforme. Il a des points faibles. Comme tout un chacun, finalement... 🥲
And so what ? 🙄
Sur MNIST (ou ici `load_digits()`), les données sont relativement propres :
centrées
normalisées
format homogène
La machine est prise par la main.
Mais notre projet MNIST Express vise aussi un usage "dessin utilisateur". Et là, on est dans un autre système solaire :
chiffres décentrés
épaisseur variable
traits interrompus
proportions étranges
parfois même une rotation
sans parler du fait que tu écris COMME UN COCHON. 🫵🏻🫵🏻🫵🏻
C'est là que notre module `preprocessing.py` devient crucial (et c'est l'occasion de faire un peu d'"ingénierie").
Notre idée "prod-like" :
imposer une forme canonique (n, 784)
scaler systématiquement
PCA optionnelle
Par exemple :
def to_canonical_X(X):
...
if X.shape == (28, 28): return X.reshape(1, 784)
if X.shape == (784,): return X.reshape(1, 784)
if X.ndim == 2 and X.shape[1] == 784: return X
raise ValueError(...)Forme canonique qui est ensuite testée :
def test_to_canonical_X_rejects_invalid_shapes():
...
with pytest.raises(ValueError):
to_canonical_X(bad)Pourquoi parler de pré-traitement ? Parce que le modèle présente des faiblesses sur un dataset d'images normalisées. On peut en déduire que sans pré-traitement robuste, la distribution des erreurs peut très bien exploser en prod, quand bien même l'accuracy sur le dataset de test reste haute.
Il ne peut pas y avoir qu'une seule métrique utile
L'accuracy globale, au final, ce n'est jamais qu'un résumé. Un indicateur. Une note finale.
Mais la distribution des erreurs, c'est une carte complète:
Elle dit où le modèle est fort
où il est faible
quelles confusions sont structurelles
lesquelles sont rares
et surtout : ce qu'il faut améliorer (on est quand même là pour ça)
97% ou 98% ne peut donc pas être un slogan. C'est juste la porte d'entrée trompeuse vers le vrai sujet : les erreurs, et leur forme.
Maintenant que nous savons où le modèle se trompe, notre réflexion sur la performance globale nous mène naturellement sur... la vitesse ! Oui, parce qu'un modèle précis mais lent, c'est un peu comme un pompier qui met 20 minutes à enfiler sa veste : ça perd un peu de son utilité... 🧑🏻🚒





Commentaires