Understanding your data is key in any data science project.
Visualization is useful for this, but can be challenging when the data has a high dimensionality.
This includes complex data such as text, images, and sensor measurements.
You could create a scatter plot matrix, but this can only show how any two features interact and fails to capture structure across many dimensions.
There’s an entire subfield within machine learning concerned with exactly this challenge: dimensionality reduction.
Dimensionality reduction algorithms can help you gain insight into your high-dimensional data and reveal whether there’s any structure.
The current state-of-the-art dimensionality reduction algorithm is UMAP, which stands for Uniform Manifold Approximation and Projection for Dimension Reduction.
In this post we’re going to apply UMAP to the MNIST dataset, a collection of 70,000 handwritten digits. Each digit is an image of 28 by 28 grayscale pixels, making the dimensionality of the dataset 784. UMAP will embed this dataset into two dimensions, allowing us to understand its structure.
To be honest, the MNIST dataset is already used way too often in machine learning demos, including those about UMAP.
What’s done less often, is look at how UMAP creates an embedding, by visualizing the intermediate results at every step (also known as an epoch).
Result
The result is an animation where each frame is an epoch.
On the right are random samples of each digit, to give us an idea of how all the variations look like and to serve as a legend.
There are a couple of interesting things to note:
UMAP has a decent embedding already after a couple of epochs.
The long shape of cluster “1” is probably because the way a “1” is written, depends mostly on the angle. Other digits have more degrees of freedom.
The clusters “3”, “5”, “8” are close together. Probably because they’re often confused.
The same can be said for clusters “4”, “7”, and “9”.
Towards the end of the optimization the noisiness is reduced.
Applying UMAP
The notebook prepare.ipynb applies UMAP and prepares the data necessary to generate the visualizations and animations.
It creates two files: data/digits.parquet and data/epochs.parquet.
Because these two files are already present in the repository, you don’t necessarily need to run this notebook.
from plotnine import* import polars as pl
from util import combine_plots
pl.Config.set_tbl_cols(10);
Understanding handwritten digits as high-dimensional vectors
Machine learning algorithms, such as UMAP, assume that each row is a data point.
That means that one hand-written digit is a row of 784 values, in other words, a high-dimensional vector.
The values, by the way, vary from 0 to 255 to represent the amount of ink.
Here’s a portion of the very wide DataFrame df_digits.
Our goal is to reduce the dimensionality from 784 to 2, so that we can create a scatter plot.
Plot digits
To better understanding what the handwritten digits actually look like, we can visualize them using Plotnine.
For this, we need to wrangle a wide row into a long DataFrame of x and y values.
defget_pixels(df_, seed=None): if seed isnotNone: pl.set_random_seed(seed)
If we make this long DataFrame square, and squint our eyes, we can actually see a “7”:
with pl.Config(tbl_rows=30, tbl_cols=30): display(df_pixels.filter(pl.col("digit")=="7").pivot("x", values="value"))
shape: (28, 30)
digit
y
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
cat
u16
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
i64
"7"
27
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
26
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
25
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
24
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
23
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
22
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
21
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
20
0
0
0
0
0
0
0
0
0
0
0
0
0
32
66
160
159
159
187
254
130
0
0
0
0
0
0
0
"7"
19
0
0
0
0
0
0
0
0
40
184
225
225
225
239
253
254
253
253
253
253
185
0
0
0
0
0
0
0
"7"
18
0
0
0
0
0
0
0
0
45
240
251
253
253
245
243
150
101
67
253
253
102
0
0
0
0
0
0
0
"7"
17
0
0
0
0
0
0
0
0
0
32
69
84
84
13
0
0
0
131
253
239
33
0
0
0
0
0
0
0
"7"
16
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
39
228
253
188
0
0
0
0
0
0
0
0
"7"
15
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
138
253
247
69
0
0
0
0
0
0
0
0
"7"
14
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
23
234
253
146
0
0
0
0
0
0
0
0
0
"7"
13
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
122
253
226
34
0
0
0
0
0
0
0
0
0
"7"
12
0
0
0
0
0
0
0
0
0
0
0
0
0
0
11
240
253
139
0
0
0
0
0
0
0
0
0
0
"7"
11
0
0
0
0
0
0
0
0
0
0
0
0
0
0
164
254
239
35
0
0
0
0
0
0
0
0
0
0
"7"
10
0
0
0
0
0
0
0
0
0
0
0
0
0
113
254
255
98
0
0
0
0
0
0
0
0
0
0
0
"7"
9
0
0
0
0
0
0
0
0
0
0
0
0
9
221
253
181
3
0
0
0
0
0
0
0
0
0
0
0
"7"
8
0
0
0
0
0
0
0
0
0
0
0
0
203
253
249
56
0
0
0
0
0
0
0
0
0
0
0
0
"7"
7
0
0
0
0
0
0
0
0
0
0
0
0
209
253
159
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
6
0
0
0
0
0
0
0
0
0
0
0
48
249
247
60
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
5
0
0
0
0
0
0
0
0
0
0
0
127
253
182
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
4
0
0
0
0
0
0
0
0
0
0
7
203
253
130
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
3
0
0
0
0
0
0
0
0
0
0
165
253
240
32
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
2
0
0
0
0
0
0
0
0
0
23
229
253
117
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
1
0
0
0
0
0
0
0
0
0
13
148
221
30
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
"7"
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
However, we have Plotnine at our disposal, so let’s use that instead.
The following function uses the geom_tile() and facet_wrap() functions.
We’re going to use it to create a legend of all 10 digits.
Note that we’re using an unsupervised algorithm here, meaning that UMAP doesn’t know which high-dimensional vector belongs to which digit. The coloring is only used for visualization purposes.
Plot embedding
To plot an embedding, we’re going to read the Parquet file data/epochs.parquet into a DataFrame called df_epochs.
This is a very long DataFrame containing the x and y position of each digit at every epoch.
From the 60,000 available digits, we use 30,000 of them.
We let UMAP optimize for 200 epochs.
This is to keep the final animation below 10MB so that GitHub can include it in the README.
The following function creates a scatter plot using the geom_point() function.
The max_epoch argument is used to add a progress bar (using the annotate() function) underneath the title.
The majority of the code is responsible for styling the plot.
The legend argument is set to False when we combine the scatter plot with the visualizations of the digits show above.
The dpi argument is only used to create the small image at the beginning of this notebook.
# Create a small image to add to the beginning of this notebook plot_embedding(df_epochs, num_epochs-1, num_epochs, legend=True, dpi=50).save("images/intro.png", verbose=False)
Initial embedding
UMAP begins by constructing a weighted k-nearest neighbor graph from the high-dimensional data. It then performs spectral embedding, which involves computing the eigenvectors of the graph Laplacian. This step provides an initial low-dimensional representation of the data.
plot_embedding(df_epochs,0, num_epochs)
Intermediate embeddings
After the spectral embedding, UMAP refines the embedding using a non-linear optimization technique. It minimizes a cross-entropy loss function that aligns the high-dimensional data structure with the low-dimensional representation. This optimization is done using stochastic gradient descent.
As the optimization progresses, the algorithm approaches a local minimum of the loss function. The gradients of the loss function become smaller, leading to smaller updates to the embeddings. This is a natural part of gradient-based optimization processes, where convergence typically slows as the solution nears the optimum.
To create the animation, we first generate all the frames as individual PNG files.
for i inrange(num_epochs): combine_plots([ plot_embedding(df_epochs, i, num_epochs, legend=False), plot_digits(get_pixels(df_digits), height=8) ],f"frames/combined-{i:06}.png", orientation="horizontal")
We then use ffmpeg to stitch the 200 PNG files into an MP4 movie.
(Plotnine offers a PlotnineAnimation class to animate ggplot objects, but at the time of writing it had some issues.) Note that Plotnine uses matplotlib under the hood, which, in turn, uses ffmpeg to create animations.
The most important argument is -i, which specifies the PNG files. The arguments -pix_fmt, -vcodec, and -crf define the video encoding and compression. The argument -y causes ffmpeg to override existing files.
There are a couple of interesting things to note about this animation:
UMAP has a decent embedding already after a couple of epochs.
The long shape of cluster “1” is probably because the way a “1” is written, depends mostly on the angle. Other digits have more degrees of freedom.
The clusters “3”, “5”, “8” are close together. Probably because they’re often confused.
The same can be said for clusters “4”, “7”, and “9”.
Towards the end of the optimization the noisiness is reduced.
This animation suggests that visualizing and animating intermediate results of otherwise complicated algorithms can help us understand them. They can be complementary to the math and the implementation of the algorithms.
Would you like to receive an email whenever I have a new blog post, organize an event, or have an important announcement to make? Sign up to my newsletter: