What digits are represented below?

MNIST samples

Easy peasy! But can you write a program to classify digits like these?

In this challenge, your task is to build a K-Nearest Neighbors model to classify handwritten digits (0 - 9) using the famous MNIST dataset.


You can download the MNIST image files from the Kaggle dataset MNIST as PNG. The unzipped file hierarchy looks like this 👇


Every image file is a 28x28 grayscale image of a handwritten digit (0 - 9).

Your job is to predict the digit represented by each file in test/

Need help getting started?

Here's some code to help you load the data into NumPy array format.

Starter Code
import glob
import numpy as np
from PIL import Image
# Define constants
TRAINPATH = "/Users/bgorman/Datasets/mnist-png/train"
TESTPATH = "/Users/bgorman/Datasets/mnist-png/test"
def load_images(path):
    Given a path, recursively identify all PNG files below the path and load
    them into a list of PIL Images
        path (str): Directory path like 'pics/whales/'
        tuple(list[Image], list[str]): Returns a 2-element tuple. The first element
        is a list of Image instances. The second is a list of corresponding filenames
    images = []
    filepaths = []
    for filename in glob.glob(f"{path}/**/*.png"):
        im =
    return images, filepaths
# Load the data
trainX, trainFiles = load_images(TRAINPATH)
testX, testFiles = load_images(TESTPATH)
# Determine the training data labels
trainY = np.array([file.split("/")[-2] for file in trainFiles])
testY = np.array([file.split("/")[-2] for file in testFiles])
# Convert list of Images to arrays with float64 vals
trainX = np.array(trainX, dtype="float64")
testX = np.array(testX, dtype="float64")
# Standardize the data
trainX = trainX / 255
testX = testX / 255

Choosing K

If you want to keep things simple, let K = 5 and call it a day. If you want brownie points, devise and implement a strategy for finding the best value for K.

Optimizing hyperparameters is important, but it's not the focus of this challenge.


Obviously, there are many ways to code a solution. Here's mine (discussed below) 👇

import glob
import numpy as np
from PIL import Image
# Define constants
TRAINPATH = "/Users/bgorman/Datasets/mnist-png/train"
TESTPATH = "/Users/bgorman/Datasets/mnist-png/test"
# Make a numpy random number generator
RNG = np.random.default_rng()
# Helper functions
def get_mode(x):
    Get the mode of a numpy array
        x (array): The array. If 1-D, the mode is returned as a scalar. Otherwise, the mode
        is calculate for each index of the first axis, and a 1-D array of modes is returned
        whose length matches the length of x
    if x.ndim == 1:
        vals, counts = np.unique(x, return_counts=True)
        # Randomly shuffle to break ties without bias
        shuffle = RNG.choice(len(vals), len(vals), replace=False)
        vals = vals[shuffle]
        counts = counts[shuffle]
        # Get the index of the mode
        indexOfMode = np.argmax(counts)
        # Get the mode
        return vals[indexOfMode]
        uniques = map(lambda y: np.unique(y, return_counts=True), x)
        modes = []
        for uniq in uniques:
            vals, counts = uniq
            # If there's only one element, return it
            if len(vals) == 1:
            # Randomly shuffle to break ties without bias
            shuffle = RNG.choice(len(vals), len(vals), replace=False)
            vals = vals[shuffle]
            counts = counts[shuffle]
            # Get the index of the mode
            indexOfMode = np.argmax(counts)
            # Get the mode
        return np.array(modes)
def load_images(path):
    Given a path, recursively identify all PNG files below the path and load
    them into a list of PIL Images
        path (str): Directory path like 'pics/whales/'
        tuple(list[Image], list[str]): Returns a 2-element tuple. The first element
        is a list of Image instances. The second is a list of corresponding filenames
    images = []
    filepaths = []
    for filename in glob.glob(f"{path}/**/*.png"):
        im =
    return images, filepaths
# Load the data
trainX, trainFiles = load_images(TRAINPATH)
testX, testFiles = load_images(TESTPATH)
# Determine the training data labels
trainY = np.array([file.split("/")[-2] for file in trainFiles])
testY = np.array([file.split("/")[-2] for file in testFiles])
# Convert list of Images to arrays with float64 vals
trainX = np.array(trainX, dtype="float64")
testX = np.array(testX, dtype="float64")
# Standardize the data
trainX = trainX / 255
testX = testX / 255
# Define the model
class KNN:
    def __init__(self, X, y, k=5):
        Initialize and train this KNN
            X (array): 3-D array (i,j,k) where i is the ith training image in array format
            y (array): Corresponding training labels
            k (int): How many neighbors to use for prediction
        self.X = X
        self.y = y
        self.k = k
    def predict(self, X):
        Predict the label
            X (array): The array to make predictions on
            any: The predicted label
        if X.ndim == 2:
            # Calculate the L2 distance between this image and all train images
            distances = np.linalg.norm(self.X - X, axis=(1, 2), keepdims=False)
            # Pick out the top 5
            sortedIdxs = np.argsort(distances)
            topKClasses = self.y[sortedIdxs[: self.k]]
        elif X.ndim == 3:
            # Calculate the L2 distance between every (test, train) pair
            distances = np.linalg.norm(
                x=self.X[np.newaxis, :, :] - X[:, np.newaxis, :, :],
                axis=(2, 3),
            # Pick out the top 5
            sortedIdxs = np.argsort(distances, axis=1)
            topKClasses = self.y[sortedIdxs[:, : self.k]]
        # Tally the votes and pick the class with most votes
        # Break ties randomly
        return get_mode(topKClasses)
# Train the model
knn = KNN(trainX, trainY)
# Make predictions in batches
batches = np.split(testX, 100)
preds = []
for i, batch in enumerate(batches):
    print(f"Predicting batch: {i} of {len(batches)}")
    p = knn.predict(batch)
# Concatenate the list of predictions
preds = np.concatenate(preds)
# Measure accuracy rate
np.mean(preds == testY) # XXX

I decided to wrap my model into a class called KNN. It's instantiated by giving it

  • X training features in array form
  • y training labels
  • k number of neighbors to use

If you look at the __init__(...) you won't see anything fancy. It merely "memorizes" the training data so it can be used in the predict() method.

The predict() method operates on an array, X, which can be 2-D or 3-D. If X is 2-D, KNN assumes it's a single image. If X is 3-D, KNN assumes it's a collection of images organized such that the first axis of the array is represents the sample dimension. (In other words, X[i] returns a 2-D array that is the ith image.)

The predict() method uses L2 for the distance metric, implemented with the help of np.linalg.norm(). From there, it's just some fancy array manipulation to grab the top k closest training samples and get their vote for the predicted label.

Struggling to understand the NumPy array logic?

  1. Simplify the arrays into tiny, toy examples that are easy to follow. Then step through the code line by line.
  2. Check out my course on NumPy.

The final task is to make predictions on the test data. Since the training data has 60,000 images and the test data has 10,000 images, prediction is fairly slow and memory intensive. For that reason, I ended up making predictions on the test data in batches and then combining the results.


This model scores 96.9% accuracy on the test data. Not bad!

Grouping by label, we see that the model did a great job predicting 1s, but not so much predicting 8s

import pandas as pd
predsDF = pd.DataFrame({ 'label': testY, 'pred': preds })
predsDF.groupby(['label']).apply(lambda x: np.mean(x.label == x.pred).round(2))
# label
# 0    0.99
# 1    1.00
# 2    0.96
# 3    0.96
# 4    0.96
# 5    0.97
# 6    0.99
# 7    0.96
# 8    0.94
# 9    0.96
# dtype: float64

Breaking the results into a confusion matrix, we see that a lot of 2s were incorrectly predicted to be 7s. That makes sense.

predsDF.groupby(['label', 'pred']).size().unstack(fill_value=0)
# pred     0     1    2    3    4    5    6    7    8    9
# label                                                   
# 0      973     1    1    0    0    1    3    1    0    0
# 1        0  1132    2    0    0    0    1    0    0    0
# 2       10     7  989    3    1    0    1   17    4    0
# 3        0     3    3  973    1   13    1    6    4    6
# 4        3     6    0    0  945    0    4    2    0   22
# 5        4     0    0    9    2  861    6    1    3    6
# 6        5     3    0    0    3    1  946    0    0    0
# 7        0    21    4    0    2    0    0  990    0   11
# 8        7     3    5   13    4   11    4    5  916    6
# 9        5     5    3    8    7    3    1   10    2  965