The objective of this post is to implement word2vec
from scratch -- starting with the underlying equations and using
bare minimum dependencies and single sentence as input. We do this
initially with determining a probability matrix which maximizes
the likelihood of having the sequence of input words. Then we view
the equations as a 'neural network'.
This is implemented initially using only numpy. With that in place, we the equivalent minimal transitory pytorch code to solve the same problem exactly is written, gradually bringing in new packages. This would help us gain some confidence when using pytorch since it is so easy to have an implementation that runs without error, and even "works", but may not actually do what you may think it does or runs suboptimally.
Given an large input corpus of text in form of coherent English sentences, we are interested in finding association between words. As the name suggests, this is done by associating each vocabulary word with a vector. The popularized illustrative example commonly in word2vec used is as follows: Assume that the large corpus of text included the full english language. Then from the learned word-vector association, we may have a relationship between the vectors resembling , where denotes the learned vector for a given word.
The overall of the approach, which will be detailed below, is as follows.
While there are two common different models that can be considered, we will focus on the Skip-gram model. In doing so, given a corpus of text, we generate a training dataset by running a fixed-width window width across the corpus of text. The window will be centered around the center word, and the neighboring context words will be the targets.
In this post, our corpus of text will be a single sentence:
Needless to say, we consider only a single example solely for simplicity and demonstration.
In this example, we choose a fixed word-window size of 5. The window will be passed through the sentence, incrementing it by one on each step. In the below table, the 'center' of the window is colored and visualized as green the surrounding neighbor words are colored and highlighted in yellow. The Skip-grams are the center words and each of the neighboring words to the centered word:
i | Window | Skip-grams |
---|---|---|
0 |
The quick brown fox jumps over the lazy dog |
the, quick the, brown |
1 |
The quick brown fox jumps over the lazy dog |
quick, the quick, brown quick, fox |
2 |
The quick brown fox jumps over the lazy dog |
brown, the brown, quick brown, fox brown, jumps |
3 |
The quick brown fox jumps over the lazy dog |
fox, quick fox, brown fox, jumps fox, over |
4 |
The quick brown fox jumps over the lazy dog |
jumps, brown jumps, fox jumps, over jumps, the |
5 |
The quick brown fox jumps over the lazy dog |
over, fox over, jumps over, the over, lazy |
6 |
The quick brown fox jumps over the lazy dog |
the, jumps the, over the, lazy the, dog |
7 |
The quick brown fox jumps over the lazy dog |
lazy, over lazy, the lazy, dog |
8 |
The quick brown fox jumps over the lazy dog |
dog, the dog, lazy |
Based on the above table, the resulting dataset has 30
entries. Since there are 9 words 8 of which are unique, the full
dataset will be 30 x 8
. The above text can be encoded
as integers by building a vocabulary.
As in previous posts (see e.g. logistic regression), we consider maximization of a likelihood function. Specifically, given a center word , we wish to maximize the likelihood of having its context words .
In the above,
i
and j
indicate iterating
over the Skip-grams, i.e. iterating over the center words and the
neighboring words of the center words.
M
is the number of centered words
i
-th centered word.
Instead of working with the likelihood directly, as has been the pattern and see before (see e.g. part 1, part 2, and part 3 ), we consider the maximization of the log-likelihood. Denoting , the log-likelihood objective function becomes
Note that is unknown; it will be approximated through the optimization process. Also note that , or equivalently .
If we let denote the total number of centered words and their skipgrams, as the total number of vocabulary, then the relation between the center word and the context words for a given Skip-grams can be expressed through sparse matrices as follows. Each row of matrix marks the center word with a 1. Similarly, each corresponding row of marks the corresponding context word of the centered word (see example in the next section). can be readily constructed for given Skip-grams. Then the objective function for a given dataset can be written as
where here denotes the component-wise logarithm of an input matrix, is a matrix whose rows contain (unknown) probabilities for each centered word, and '' denotes the matrix inner product. Note that effectively only "picks" the selected rows of specified in , i.e. rows of contain duplicates and are made up of rows of . The columns of represent the prediction of the likelihood of the a given vocabulary word.
We write as the softmax of two unknown matrices:
where
Neural networks can be thought of as function approximators. In neural networks and deep learning, it is not unusual to construct a model by stacking a large number of "layers" on top of each other (sometimes blindly) and solving for the unknown coefficients in the layers by optimizing the objective function ("loss" function). Here, however, the network is constructed explicitly based on the equations above.
Recall the objective function is given by
We can viewed this as two distinct steps:
Noting that , the neural network layers are
The last layer (L3) can be viewed as two separate operations (layers): row-wise softmax, followed by matrix inner product. But here we use a single layer because it corresponds to the likelihood evaluation described in the preceding section, and as will be seen later, it can correspond to a single function call in pytorch.
In order to minimize this objective function, we need at least the first gradient of the objective function with respect to the model parameters. While gradients are typically computed via backward passes within deep learning frameworks, here we compute them directly.
Before going further, it is useful to note two equalities due to the structure of and :
The objective function can thus be rewritten as
Then represents a matrix whose rows (some duplicate) are made up of vector-embedded center words. And besides representing (un-normalized) probabilities, can be viewed as the projection of vector-embedded cented words onto vector-embedded neighbor words.
In order to minimize this objective function, we need at least the first gradient of the objective function with respect to the model parameters. First denote
Now note that the objective function can be expanded as
Then,
And so,
With these, the first gradient of the objective function with respect to the weights are
The above two equations are easily determined by reverting to index notation and converting back.
To summarize, the process will be as follows.
As an example, consider the previous single sentence:
We will break up the string and drop all punctuation and get a tokens of text.
tokens = [
'the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog'
]
Create a Word-to-int w2i
map to encode the tokens
into integers.
w2i = {
"the": 0,
"quick": 1,
"brown": 2,
"fox": 3,
"jumps": 4,
"over": 5,
"lazy": 6,
"dog": 7
}
Our vocabulary here has a size of 8. There is an analogous reverse
int-to-word i2w
map to map integers back to tokens.
encoded_tokens = [
0, 1, 2, 3, 4, 5, 0, 5, 6
]
Using a window size of 5, we found the 30 different skip-grams in the table above. We will denote to contain the encoded center words, and to contain the target context words:
The horizontal lines are visual guides to indicate the change in center words.
The gradients of the objective function with respect to the weights were determined in the preceding section. Utilizing those results, a bare gradient decent algorithm for solving the optimization problem becomes simple:
Here 'convergence' can be a predetermined value of '', or a fixed number of iterations.
We will start out with a numpy-only implementation, then migrate to pytorch. Before going into numpy implementaions, let's quickly look at where we want to be. I would like the overall approach resemble a workflow that's seen in pytorch. If you have never worked with pytorch before, the implementation may look somewhat odd, unusual, and/or counter-intuitive (see the intro to pytorch article). That's by design.
While pytorch does not places any restrictions per-se, the
objective function ("loss function") is decoupled from the forward
pass and computed separately via a predefined loss function in the
nn
module. The gradients are computed by calling
backward
on the final error, which is a "scalar". I
am putting scalar in quotes because the returned error is actually
a "tensor" which contains the full network graph.
Here's the overall target pytorch-ish pseudo-code that our implementation is going to resemble
class my_model(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(...)
model = my_model()
#
# 1. Forward pass
#
prediction = model(dataset)
#
# 2. Evaluate error.
#
# the objective function is generally
# separate from the forward pass
#
error = objective_function(prediction, target)
#
# 3. Evaluate gradients on model parameters.
# This call changes `model.weights`
#
error.backward()
#
# 4. Update weights using gradients
#
# The gradient of `error` with respect to 'weight'
# is now populated. Use them to update next weights for
# the next iteration.
#
model.weight -= gamma * model.weight.grad
This is what I mean by counter-intuitive: calling
error.backward()
changes the state of other
variables. To make it more confusing, consider the last line,
namely:
model.weight -= gamma * model.weight.grad
Intuitively model.weight.grad
is the gradient of
model.weights
. Actually it's not. And that would not even
make sense: the gradient of weights clearly does not have the same
dimension as the weights. It turns out, as you may have recognized
the last line as an update of the weights in a gradient descent,
model.weight.grad
is actually the gradient of the
error
(objective function) with respect to
model.weight
. Very confusing. Alas, our goal is for
our implementation to somewhat resemble the above.
numpy
implementationFollowing the steps listed in the preceding example.
def tokenize_text(text):
"""Split and lowercase words in input string, removing punctuations
Args:
text (str): input string
Return:
list[str]: list of words
"""
return text
.translate(str.maketrans('', '', string.punctuation))
.lower()
.split()
from collections import OrderedDict
def tokens_encode(tokens):
"""Generate int-to-word and word-to-int maps
Args:
tokens (list): list of string tokens
Return:
(dict): word-to-int mapping of tokens
(dict): int-to-word mapping of tokens
"""
w2i, i2w = {}, {}
# Remove duplicate entries while preserving order
tokens = list(OrderedDict.fromkeys(tokens))
for i, token in enumerate(tokens):
w2i[token] = i
i2w[i] = token
return w2i, i2w
The vocabulary is the keys of the w2i
map, and its length
is the size of the vocabulary.
We need to define a couple of helper functions for convenience:
N
whose entries are all zero except at a given index:
def window_range(center, wing, max_index):
"""Return indices of a window centered around index
Example 2-width width centered around 4 of a vector of length 11
[0 1 2 3 4 5 6 7 8 9 10]
| | |
[+----center---+]
[+-wing-+------+]
Args:
center (int): center index
wing (int): half the window width
max_size (int): maximum index the window can slide over
Returns:
(range) range centered around center bounded by max_index
Index list for a window
"""
return range(
max(0, center-wing),
min(max_index, center+wing+1)
)
def bin_vec(size, idx):
"""Generate a vector whose entries are either 0 or zero
Args:
size (int): length of the vector
idx (int): index where the vector value is one
Returns:
(list) zero vector whose 'idx' entry is equal to 1
"""
ret = [0]*size
ret[idx] = 1
return ret
With these defined, we can generate the dataset by iterating thorugh the input tokens
def gen_dataset(tokens, w2i, window):
"""Generate a training dataset given input tokens
For text `foo bar baz qux quux corge grault garply`, a window of
length 2 centered around `qux` would result in training targets of
['bar', 'baz', 'quux', 'corge']
Args:
tokens (list[str]): list of words
w2i (dict[str:int]): string to index mapping of input words
window (int): half the window width centered around each word,
see above
Returns:
(numpy.ndarray) input dataset of dimension (token_size, vocab_size)
(numpy.ndarray) target dataset of dimension (token_size, vocab_size)
"""
X, y = [], []
token_count = len(tokens)
vocab_size = len(w2i)
for token_id in range(token_count):
for window_id in window_range(token_id, window, token_count):
if token_id == window_id:
continue
X.append(bin_vec(vocab_size, w2i[tokens[token_id]]))
y.append(bin_vec(vocab_size, w2i[tokens[window_id]]))
return np.array(X), np.array(y)
In order to implement our objective function, We need two more helper functions:
def softmax(X):
"""Return softmax on rows of a 2d array
Args:
X (ndarray) 2 dimensional matrix
Returns
(ndarray) 2 dimensional matrix with normalized rows
"""
res = []
for x in X:
# subtract max val to avoid overflow
exp = np.exp(x - np.max(x))
res.append(exp/exp.sum())
return np.array(res)
def cross_entropy(z, y):
"""Cross entropy between two 2d arrays
Args:
z (ndarray) 2 dimensional matrix
y (ndarray) 2 dimensional target matrix
Returns
(float) cross entropy between z and y
"""
# component-wise multiplication
return - np.sum(np.log(z) * y)
Finally, the model can be defined:
class w2v_model:
def __init__(self, vocab_size, embed_size):
self.w1 = np.random.randn(vocab_size, embed_size)
self.w2 = np.random.randn(embed_size, vocab_size)
def predict(self, X):
cache = {}
A1 = X @ self.w1
A2 = A1 @ self.w2
z = softmax(A2)
cache = {
'a1': A1,
'a2': A2,
'z': z
}
return z, cache
def forward(self, X):
"""Foward pass of input through the network
Args:
X (np.array): Input data
y (np.array): Expected output
Returns:
(np.array): network output
(float): objective function error
(dict): cache for running backward
"""
return self.predict(X)
def __call__(self, X):
return self.forward(X)
def backward(self, cache, X, y):
"""Perform backward pass _and_ update model parameters
Args:
cache (dict): cache returned via forward call
X (np.array): input dataset
y (np.array): expected output
lr (float): step size, in (0, 1]
Returns:
None
"""
da2 = cache['z'] - y
dw2 = cache['a1'].T @ da2
da1 = da2 @ self.w2.T
dw1 = X.T @ da1
# to resemble some form of consistency with pytorch
#
# gradients of the `error` with respect to w1 and w2
self.w1_grad = dw1
self.w2_grad = dw2
def error(self, output, target):
return cross_entropy(output, target)
Now can run a simple gradient descent algorithm:
LR = 0.05
WINDOW = 2
ITER_COUNT = 20
text = 'The quick brown fox jumps over the lazy dog.'
tokens = tokenize_text(text)
w2i, i2w = tokens_encode(tokens)
X, y = gen_dataset(tokens, w2i, window)
model = w2v_model(len(w2i), embed_size=10)
#
# Run gradient descent
#
err_hist = []
for i in range(ITER_COUNT):
# Model prediction
pred, cache = model(X)
# objective function
err = model.error(pred, y)
#
# Update model weight gradients
#
model.backward(cache, X, y)
#
# Update model weights
#
model.w1 -= LR * model.w1_grad
model.w2 -= LR * model.w2_grad
err_hist.append(err)
print(i, err)
numpy
sanity checkIt turns out, especially in pytorch, it's quiet easy to have an implementation that runs, and even "works", but what it does may not correspond to what you may think or it runs suboptimally. A good example of this is due to operator overloading. For example, visually one may expect that
C = A * B
corresponds to the matrix multiplication of A
and
B
. But if these variables are numpy.ndarray
s,
the above code may be performing component-wise multiplication
with (conventionally) "incompatible" dimensions via
broadcasting.
Going back to our implementation. If we have optimized the coefficients sufficiently well enough ("trained" well), for this trivial example we expect to see the neighboring elements having the highest probabilities of occuring given the center words.
Consider for example the second center word:
Since this was our training data, we then expect that, given quick, The , brown , and fox have the highest probabilities of occuring.
# second window
cent_vec = bin_vec(len(w2i), w2i['quick'])
probabilities, _ = model([cent_vec])
probabilities = probabilities[0]
# sorted probability indices in descending order of probability
pred_next_indices = np.argsort(probabilities)[::-1]
neighbors = ['the', 'brown', 'fox']
pred_next_indices_window = pred_next_indices[:len(neighbors)]
for neighbor in neighbors:
print('assert ', neighbor)
assert w2i[neighbor] in pred_next_indices_window
pytorch
implementationThere are different ways to implement this in pytorch. But in our implementation, we want it to be a gradual transition to pytorch without pulling in a lot of black boxes.
This implementation uses bare parameters.
class w2v_torch_model(nn.Module):
def __init__(self, vocab_size, embed_size):
super().__init__()
#
# Only for consistency with the numpy model, initialize the
# parameters with numpy randn with preset seed
#
self.w1 = nn.Parameter(
torch.tensor(
np.random.randn(vocab_size, embed_size).astype(torch.float32)
),
requires_grad=True
)
# We could have used torch.nn.Linear instead for the weights
self.w2 = nn.Parameter(
torch.tensor(np.random.randn(embed_size, vocab_size).astype(torch.float32)),
requires_grad=True
)
self.smax = torch.nn.Softmax(dim=1)
def predict(self, X):
A1 = X @ self.w1
A2 = A1 @ self.w2
z = self.smax(A2)
return z
def forward(self, X):
return self.predict(X)
@staticmethod
def obj_func(X, y):
# component-wise multiplication of X and y y is binary --- the
# multiplication will pick only the nth component of X (the
# target)
return -torch.sum(torch.log(X) * y)
The solver is going to look almost identical to the numpy implementation:
LR = 0.05
EMBED_DIM = 10
WINDOW = 2
ITER_COUNT = 20
text = 'The quick brown fox jumps over the lazy dog.'
tokens = tokenize_text(text)
w2i, i2w = tokens_encode(tokens)
X, y = gen_dataset(tokens, w2i, window)
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)
model = w2v_torch_model(len(w2i), embed_size=EMBED_DIM)
err_hist = []
for i in range(ITER_COUNT):
pred = model(X)
err = model.obj_func(pred, y)
model.zero_grad()
err.backward()
with torch.no_grad():
model.w1 -= LR*model.w1.grad
model.w2 -= LR*model.w2.grad
err_hist.append(err.data.item())
print(f'{i}/{ITER_COUNT}: {err.data.item()}')
In the next implementation, we will offload the parameters to
built-in "layers", and use a built-in function
(CrossEntropy
to evaluate the likelihood. In our
model, we only have simple matrix multiplication, so we can
use nn.Linear
layers with no bias
.
import numpy as np
from torch import nn
np.random.seed(0)
class w2v_torch_model_2(nn.Module):
def __init__(self, vocab_size, embed_size):
super().__init__()
# Pytorch implicitly initializes model weights
L1 = nn.Linear(vocab_size, embed_size, bias=False)
L2 = nn.Linear(vocab_size, embed_size, bias=False)
# To get the same results per iteration as the numpy impl,
# explicitely overwrite the weights with the same exact
# weights used in the numpy implementation. This is *not*
# required in other scenarios
#
w1_0 = np.random.randn(vocab_size, embed_size)
w2_0 = np.random.randn(vocab_size, embed_size)
L1.weight = torch.from_numpy(w1_0.T)
L2.weight = torch.from_numpy(w2_0.T)
# Stack the layers
self.layers = nn.Sequential(L1, L2)
def forward(self, x):
return self.layers(x)
@staticmethod
def obj_func(X, y):
f = nn.CrossEntropyLoss(reduction='sum')
return f(X, y)
Then the driver would be similar as before
LR = 0.05
EMBED_DIM = 10
WINDOW = 2
ITER_COUNT = 20
text = 'The quick brown fox jumps over the lazy dog.'
tokens = tokenize_text(text)
w2i, i2w = tokens_encode(tokens)
X, y = gen_dataset(tokens, w2i, window)
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)
model = w2v_torch_model_2(len(w2i), embed_size=EMBED_DIM)
err_hist = []
for i in range(ITER_COUNT):
# forward: predict (evaluate probabilities)
pred = model(X)
# evaluate objective function
likelihood = model.obj_func(pred, y)
# Backward pass: compute gradients
likelihood.backward()
# Update parameters for this iteration of gradient decent
with torch.no_grad():
for layer in model.layers:
# Gradient decent iteration update
layer.weight -= LR*layer.weight.grad
# Clear the accumulated gradients for the next
# iteration.
model.zero_grad()
err_hist.append(err.data.item())
print(f'{i}/{ITER_COUNT}: {err.data.item()}')
For the final iteration, we will use pytorch's built-in optimizers
for updating model weights in each iteration. Since we were using
gradient decent as the algorithm, we can use pytorch's built-in
SGD
optimizer:
LR = 0.05
EMBED_DIM = 10
WINDOW = 2
ITER_COUNT = 20
text = 'The quick brown fox jumps over the lazy dog.'
tokens = tokenize_text(text)
w2i, i2w = tokens_encode(tokens)
X, y = gen_dataset(tokens, w2i, window)
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)
model = w2v_torch_model_2(len(w2i), embed_size=EMBED_DIM)
optim = torch.optim.SGD(model.parameters(), lr=LR)
err_hist = []
for i in range(ITER_COUNT):
# forward: predict (evaluate probabilities)
pred = model(X)
# evaluate objective function
likelihood = model.obj_func(pred, y)
# Backward pass: compute gradients
likelihood.backward()
optim.step()
# Clear the accumulated gradients for the next
# iteration.
model.zero_grad()
err_hist.append(err.data.item())
print(f'{i}/{ITER_COUNT}: {err.data.item()}')
From above, note another oddity of pytorch: calling
optim.step()
modifies model weights under the
hood. Bizarre.
Since we are using the same data types, the results, per iteration, should be identical to the numpy implementation.
Looking at the output for the single sentence is not going to be very revealing. Instead, consider now a (admittedly heavily manufactured) example which is meant to be purely conceptual. A very small article is constructed from various articles regarding the sun, daylight, sunrise, and earth. For illustration purposes, and to avoid introducing additional topics (dimensionality reduction), the dimension of the embedding vector is taken to be 3, and the window size is taken to be 5.
Figure shows the resulting embedding vectors and the corresponding words. For illustration the resulting vectors in are shown in the following figure.