import glob
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
# Define constants
TRAINPATH = "/Users/bgorman/Datasets/mnist-png/train"
TESTPATH = "/Users/bgorman/Datasets/mnist-png/test"
# Random number generator
RNG = np.random.default_rng( 2024 )
def load_images (path):
"""
Given a path, recursively identify all PNG files below the path and load
them into a list of PIL Images
Args:
path (str): Directory path like 'pics/whales/'
Returns:
tuple(list[Image], list[str]): Returns a 2-element tuple. The first element
is a list of Image instances. The second is a list of corresponding filenames
"""
images = []
filepaths = []
for filename in glob.glob( f " { path } /**/*.png" ):
im = Image.open(filename)
images.append(im)
filepaths.append(filename)
return images, filepaths
# Load the data
trainX, trainFiles = load_images( TRAINPATH )
testX, testFiles = load_images( TESTPATH )
# Determine the training data labels
trainY = np.array([ file .split( "/" )[ - 2 ] for file in trainFiles])
testY = np.array([ file .split( "/" )[ - 2 ] for file in testFiles])
# Convert list of Images to arrays with float64 vals
trainX = np.array(trainX, dtype = "float64" )
testX = np.array(testX, dtype = "float64" )
# Standardize the data
trainX = trainX / 255
testX = testX / 255
# Declare the model
class Perceptron :
def __init__ (self, w, b):
self .w = w
self .b = b
def predict (self, X):
if X.ndim == 2 :
z = np.sum(X.reshape(( 1 , - 1 )) * self .w) + self .b
elif X.ndim == 3 :
z = np.sum(X.reshape(( len (X), - 1 )) * self .w, axis = 1 ) + self .b
return z >= 0
# Keep track of the best weights and bias
best = { "accuracy" : 0.0 , "w" : None , "b" : None }
# Find optimal weights + bias
for i in range ( 5000 ):
if i % 1000 == 0 :
print ( "Iteration" , i)
# Pick random weights and bias
w = RNG .uniform( - 1 , 1 , size = 784 )
b = RNG .uniform( - 20 , 20 )
# Train the Perceptron
perceptron = Perceptron(w, b)
# Predict on the training data
preds = perceptron.predict(trainX)
# Measure the accuracy rate
accuracy = np.mean(preds == (trainY == "2" ))
# Check if best
if accuracy > best[ "accuracy" ]:
print ( "New best accuracy:" , accuracy)
best[ "accuracy" ] = accuracy
best[ "w" ] = w
best[ "b" ] = b
After 5,000 iterations, the best model classifies 2s with ~91% accuracy on the training data and ~90% accuracy on the test data.
Iteration 0
New best accuracy: 0.5606166666666667
New best accuracy: 0.82925
New best accuracy: 0.9007
New best accuracy: 0.9007166666666667
New best accuracy: 0.9008
New best accuracy: 0.9014666666666666
New best accuracy: 0.90215
New best accuracy: 0.9054166666666666
Iteration 1000
New best accuracy: 0.9103333333333333
Iteration 2000
Iteration 3000
Iteration 4000
Iteration 5000
Iteration 6000
Iteration 7000
Iteration 8000
Iteration 9000
print (w.round( 3 ))
# [-0.887 0.006 -0.014 0.592 0.716 0.063 -0.957 0.387 -0.814 -0.565
# 0.613 -0.09 0.032 -0.334 0.568 -0.168 -0.516 -0.748 0.607 0.626
# -0.615 -0.64 -0.697 -0.767 0.667 0.393 0.564 -0.012 0.804 -0.735
# 0.609 0.53 -0.592 -0.846 0.464 -0.201 -0.327 0.257 0.992 -0.267
# -0.292 0.329 0.073 0.976 0.843 0.685 0.627 -0.899 0.677 0.873
# -0.87 -0.2 -0.411 -0.826 0.557 0.781 0.487 0.299 -0.842 -0.144
# 0.454 0.477 -0.136 -0.54 0.34 0.5 0.692 0.737 -0.929 0.102
# -0.199 -0.063 -0.144 -0.846 -0.619 -0.148 -0.463 0.196 -0.822 0.093
# -0.678 0.397 0.809 -0.04 -0.451 -0.122 0.341 0.6 -0.322 -0.767
# 0.482 0.252 -0.801 -0.82 -0.063 0.136 -0.219 0.537 -0.996 0.968
# 0.243 0.175 0.543 -0.766 0.261 -0.168 0.933 -0.112 0.172 0.089
# 0.983 -0.782 0.876 0.612 0.294 -0.165 -0.443 -0.142 0.064 0.817
# -0.444 0.206 -0.66 -0.641 -0.215 -0.425 -0.612 -0.373 -0.693 -0.635
# 0.896 0.55 -0.348 -0.2 0.931 0.937 -0.087 0.566 -0.302 0.012
# 0.542 -0.296 0.436 -0.347 0.562 -0.051 0.221 -0.976 -0.581 -0.061
# 0.899 0.512 -0.778 0.279 -0.032 0.252 -0.342 0.737 -0.539 -0.499
# 0.667 0.126 -0.624 0.325 -0.188 -0.938 -0.466 -0.476 -0.635 -0.45
# -0.796 -0.301 -0.18 0.964 -0.524 0.419 -0.527 0.638 -0.277 -0.667
# -0.978 -0.753 -0.816 -0.744 -0.713 -0.354 0.246 0.723 0.012 0.868
# 0.776 -0.81 0.909 0.984 -0.168 0.401 0.082 0.576 0.217 0.865
# -0.671 0.34 0.026 0.424 -0.068 -0.046 -0.471 -0.577 -0.37 0.897
# 0.944 -0.3 0.7 -0.556 -0.113 -0.491 0.937 -0.802 -0.582 -0.917
# 0.362 0.711 0.614 0.155 0.874 0.203 0.056 0.332 -0.373 -0.937
# 0.589 -0.646 -0.836 0.477 0.546 0.251 -0.575 0.729 -0.182 -0.516
# -0.086 -0.118 0.421 -0.011 -0.711 0.536 -0.371 -0.184 0.257 -0.094
# 0.331 -0.117 0.095 0.208 0.9 0.635 -0.864 -0.643 0.622 -0.329
# -0.385 -0.64 0.804 0.24 0.573 -0.319 0.243 -0.461 -0.16 -0.373
# -0.915 0.168 0.142 0.035 0.217 0.851 -0.455 -0.173 0.56 -0.135
# -0.116 0.518 0.915 0.602 -0.053 0.7 0.828 -0.614 -0.274 0.31
# 0.486 0.232 -0.623 -0.055 0.243 -0.126 0.341 0.609 -0.235 -0.849
# 0.457 0.263 -0.337 -0.708 -0.599 0.355 -0.079 -0.098 0.223 -0.612
# 0.806 -0.026 -0.42 -0.182 -0.45 -0.923 -0.424 0.722 -0.315 -0.268
# 0.625 -0.61 0.38 -0.876 -0.924 0.263 -0.822 -0.291 -0.774 -0.53
# -0.012 0.444 -0.498 -0.83 -0.407 0.379 0.463 0.213 -0.536 0.997
# -0.618 -0.306 -0.989 0.818 -0.523 -0.478 0.576 0.99 -0.772 0.077
# 0.373 -0.054 -0.553 -0.063 -0.903 0.521 0.069 -0.469 0.786 0.208
# -0.829 0.495 -0.403 0.209 0.955 -0.773 0.476 -0.248 0.999 -0.865
# 0.391 0.426 0.085 0.27 -0.846 0.89 -0.081 -0.974 -0.515 -0.552
# 0.481 -0.959 -0.503 -0.623 0.88 -0.901 -0.398 0.672 -0.579 -0.169
# -0.621 0.92 -0.08 0.024 0.881 -0.015 0.237 -0.272 0.286 -0.188
# 0.278 0.636 0.156 0.27 0.882 0.018 -0.969 0.56 -0.734 0.514
# 0.626 0.297 0.776 -0.95 -0.214 -0.538 0.073 0.026 -0.979 0.142
# -0.129 0.215 -0.735 0.303 -0.636 -0.615 0.951 0.539 -0.878 0.16
# 0.266 0.527 0.218 -0.898 -0.071 -0.706 0.353 -0.057 -0.254 -0.644
# -0.532 0.019 0.998 0.712 0.358 0.006 0.401 -0.079 -0.956 -0.612
# 0.421 0.7 -0.763 0.111 -0.55 0.653 0.072 -0.629 -0.016 -0.984
# 0.176 -0.646 0.445 0.169 0.213 -0.616 -0.048 0.057 -0.497 0.191
# 0.506 -0.109 0.575 -0.72 -0.065 0.837 0.714 -0.267 0.128 0.976
# 0.226 0.966 0.335 0.797 0.127 -0.674 0.012 -0.903 -0.491 0.153
# 0.905 0.845 -0.46 0.706 0.939 0.576 0.907 -0.477 0.245 -0.72
# 0.889 -0.942 0.797 0.352 0.476 -0.583 -0.235 0.22 0.456 -0.939
# -0.435 -0.651 0.391 0.878 0.77 0.775 -0.103 0.181 0.823 -0.961
# 0.081 0.307 -0.429 -0.124 0.217 0.259 -0.519 0.072 0.461 -0.839
# 0.25 0.057 -0.788 -0.805 0.063 -0.072 0.372 -0.215 0.428 -0.852
# -0.393 -0.246 -0.36 0.866 -0.562 0.551 -0.06 0.503 0.823 0.064
# -0.861 -0.253 0.285 -0.757 0.838 0.924 -0.282 0.682 0.38 -0.157
# 0.326 0.195 -0.612 -0.14 -0.534 0.218 0.378 -0.582 0.21 0.511
# 0.72 0.639 0.971 -0.034 0.274 0.93 0.519 -0.919 -0.202 -0.252
# -0.972 -0.258 0.905 -0.649 -0.752 -0.295 0.744 -0.188 0.795 0.524
# -0.384 0.109 -0.974 -0.959 -0.065 -0.948 -0.01 0.797 -0.844 -0.15
# -0.411 0.585 0.106 0.662 0.245 0.15 -0.635 0.412 0.291 -0.162
# 0.924 0.057 -0.985 -0.864 -0.04 -0.953 -0.931 0.908 0.701 0.195
# 0.921 0.08 -0.356 -0.919 0.652 -0.862 -0.221 -0.859 -0.379 0.746
# 0.434 -0.443 -0.903 -0.85 0.236 -0.629 0.787 0.085 -0.062 1.
# -0.659 -0.772 -0.704 -0.828 0.153 0.979 0.002 -0.676 -0.223 0.141
# -0.782 -0.561 -0.337 0.811 -0.888 0.109 0.269 -0.862 0.184 0.508
# 0.219 0.434 0.954 -0.869 0.954 0.643 0.662 -0.933 0.618 0.485
# 0.097 -0.346 -0.625 0.988 -0.816 -0.217 0.369 -0.058 -0.601 0.482
# 0.091 -0.567 0.299 0.243 -0.964 0.139 -0.433 0.265 -0.468 0.935
# -0.59 -0.106 0.023 0.585 0.154 -0.597 0.225 -0.887 0.01 -0.631
# 0.677 0.618 -0.605 -0.099 -0.357 -0.412 -0.12 -0.459 0.29 -0.709
# -0.952 -0.432 0.911 0.309 0.276 -0.666 0.695 0.055 -0.242 0.715
# -0.7 -0.327 0.94 0.688 0.844 -0.447 -0.058 0.79 0.449 0.072
# 0.427 0.152 0.306 -0.011 -0.363 -0.779 0.136 -0.65 -0.82 0.833
# 0.586 0.415 -0.289 -0.19 0.792 0.996 -0.691 -0.622 -0.866 0.148
# -0.695 0.326 -0.852 0.423 -0.146 0.831 -0.918 -0.307 0.476 -0.574
# -0.302 -0.24 -0.925 -0.999 -0.405 -0.45 0.921 0.834 -0.412 0.035
# 0.636 -0.594 -0.735 -0.975 0.112 0.676 0.125 -0.092 0.371 0.442
# 0.752 0.165 0.267 0.927]
print (b)
# 14.605564462995053
Model Review
We can identify out the most important pixels by finding the index of the positive and negative weights with the greatest magnitude.
print (np.argmax(best[ 'w' ])) # 582 ==> (20, 22)
print (np.argmin(best[ 'w' ])) # 214 ==> (7, 18)
In this case, the model implies
lightening pixel (20, 22)
adds the most evidence that the image is a 2
darkening pixel (7, 18)
adds the most evidence that the image is a 2
Here are some example images with these pixels highlighted by green and red arrows.
The large positive weight for pixel (20, 22)
makes intuitive sense because it locates the trailing tail of the number "2" - a pixel that is rarely activated by other numbers. Pixel (7, 18)
appears to have a large negative weight because it's a popular pixel that's activated by many different numbers.