What Is Candidate Sampling: X ,) T X T L
What Is Candidate Sampling: X ,) T X T L
is Candidate Sampling
Say we have a multiclass or multilabel problem where each training example (xi , T i ) consists of
a context xi a small (multi)set of target classes T i out of a large universe L of possible
classes. For example, the problem might be to predicting the next word (or the set of future
words) in a sentence given the previous words.
We wish to learn a compatibility function F (x, y ) which says something about the compatibility of
a class y with a context x . For example the probability of the class given the context.
“Exhaustive” training methods such as softmax and logistic regression require us to compute
F (x, y ) for every class y ∈ L for every training example. When |L| is very large, this can be
prohibitively expensive.
“Candidate Sampling” training methods involve constructing a training task in which for each
training example (xi , T i ) , we only need to evaluate F (x, y ) for a small set of candidate classes
C i ⊂ L . Typically, the set of candidates C i is the union of the target classes with a randomly
chosen sample of (other) classes S i ⊂ L .
Ci = T i ⋃ Si
The random choice of S i may or may not depend on xi and/or T i .
The training algorithm takes the form of a neural network, where the layer representing F (x, y )
is trained by backpropagation from a loss function.
Table of Candidate Sampling Algorithms
Positive training Negative training Input to Training F (x, y ) gets
classes classes Training Loss trained to
associated with associated with Loss approximate:
training example training example
(xi , T i ) : (xi , T i ) : G(x, y ) =
P OS i = N EGi =
Negative
Sampling
Ti Si
F (x, y ) Logistic log ( P (y|x)
Q(y|x) )
Sampled Ti (S i − T i ) F (x, y ) Logistic logodds(y|x) =
Logistic − log(Q(y|x)) (
P (y|x)
log 1−P (y|x) )
Full Logistic T i (L − T i ) F (x, y ) Logistic log(odds(y|x)) =
(
P (y|x)
log 1−P (y|x) )
Full T i = {ti } (L − T i ) F (x, y ) Softmax log(P (y|x)) + K (x)
Softmax
●
logistic training loss = ∑
i ( ∑
y∈P OS i
log(1 + exp(− G(xi , y )) + ∑
y∈N EGi
log(1 + exp(G(xi , y ))
)
● sof tmax training loss = ∑
i( − G(xi , ti ) + log
( ∑
y∈P OS i ⋃N EGi
exp(G(xi , y ))
))
● NCE and Negative Sampling generalize to the case where T i is a multiset. In this
case, P (y|x) denotes the expected count of y in T i . Similarly, NCE, Negative
Sampling, and Sampled Logistic generalize to the case where S i is a multiset. In this
case Q(y|x) denotes the expected count of y in S i .
Sampled Softmax
(A faster way to train a softmax classifier)
Reference:
http://arxiv.org/abs/1412.2007
Assume that we have a singlelabel problem. Each training example (xi , {ti }) consists of a
context and one target class. We write P (y|x) for the probability of that the one target class is
y given that the context is x .
We would like to train a function F (x, y ) to produce softmax logits that is, relative log
probabilities of the class given the context:
F (x, y) ← log(P (y|x)) + K (x)
Where K (x) is an arbitrary function that does not depend on y .
In full softmax training, for every training example (xi , {ti }) , we would need to compute logits
F (xi , y ) for all classes in y ∈ L . This can get expensive if the universe of classes L is very
large.
In “Sampled Softmax”, for each training example (x i , {ti }) , we pick a small set S i ⊂ L of
“sampled” classes according to a chosen sampling function Q(y|x) . Each class y ∈ L is
included in S i independently with probability Q(y|xi ) .
P (S i = S |xi ) = ∏ Q(y|xi ) ∏ (1 − Q(y|xi ))
y∈S y∈(L−S)
We create a set of candidates C i containing the union of the target class and the sampled
classes:
C i = S i ⋃ {ti }
Our training task is to figure out, given this set C i , which of the classes in C i is the target class.
For each class y ∈ C i , we want to compute the posterior probability that
y is the target class
given our knowledge of xi and C i . We call this P (ti = y |x, C i )
Applying Bayes’ rule:
P (ti = y |xi , C i ) = P (ti = y , C i |xi ) / P (C i |xi )
= P (ti = y |xi ) P (C i |ti = y , xi ) / P (C i |xi )
= P (y|xi ) P (C i |ti = y , xi ) / P (C i |xi )
Now to compute P (C i |ti = y , xi ) , we note that in order for this to happen, S i may or may not
contain y , must contain all other elements of C i , and must not contain any classes not in C i .
So:
P (ti = y |xi , C i ) = P (y|xi ) ∏ Q(y ′|xi ) ∏ (1 − Q(y ′|xi )) / P (C i |xi )
y ′∈C i −{y} y ′∈(L−C i )
P (y|xi )
= Q(y|xi ) ∏ Q(y ′|xi ) ∏ (1 − Q(y ′|xi )) / P (C i |xi )
y ′∈C i y ′∈(L−C i )
P (y|xi )
= Q(y|xi ) / K(xi , C i )
where K (xi , C i ) is a function that does not depend on y . So:
log(P (ti = y |xi , C i )) = log(P (y|xi )) − log(Q(y|xi )) + K ′(xi , C i )
These are the relative logits that should feed into a softmax classifier predicting which of the
candidates in C i is the true one.
Since we are trying to train the function F (x, y ) to approximate log(P (y|x)) , we take the layer in
our network representing F (x, y ) , subtract log(Q(y|x)) , and pass the result to a softmax
classifier predicting which candidate is the true one.
T raining Sof tmax Input = F (x, y ) − log(Q(y|x)
Backpropagating the gradients from that classifier trains F to give us what we want.
Noise Contrastive Estimation (NCE)
Reference:
http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf
Each training example (xi , T i ) consists of a context and a small multiset of target classes. In
practice, T x may always be a set or even a single class, but we use a multiset here for
generality.
We use the following as a shorthand for the expected count of a class in the set of target
classes for a context. In the case of sets with no duplicates, this is the probability of the class
given the context:
P (y|x) := E(T (y) | x)
We would like to train a function F (x, y ) to approximate the log expected count of the class
given the context, or in the case of a sets, the log probability of the class given the context.
F (x, y ) ← log (P (y|x))
For each example (xi , T i ) , we pick a multiset of sampled classes S i . In practice, it probably
makes sense to pick a set, but we use a multiset here for generality. Our sampling algorithm
may or may not depend on xi but may not depend on T i . We construct a multiset of
candidates consisting of the sum of the target classes and the sampled classes.
Ci = T i + Si
Our training task is to distinguish the true candidates from the sampled candidates. We have
one positive training metaexample for each element of T i and one negative training
metaexample for each element of S i .
We introduce the shorthand Q(y|x) to denote the expected count, according to our sampling
algorithm, of a particular class in the set of sampled classes. If S never contains duplicates,
then this is a probability.
Q(y|x) := E (S(y) | x))
logodds(y came f rom T vs S | x) = log ( P (y|x)
Q(y|x) )
= log (P (y|x)) − log(Q(y|x))
The first term, log (P (y|x)) , is what we would like to train F (x, y ) to estimate.
We have a layer in our model which represents F (x, y ) . We add to it the second term,
− log(Q(y|x)) , which we compute analytically, and we pass the result to a logistic regression loss
whose “label” indicates whether y came from T as opposed to S .
Logistic Regression Input = F (x, y ) − log(Q(y|x))
The backpropagation signal trains F (x, y ) to approximate what we want it to.
Negative Sampling
Reference:
http://papers.nips.cc/paper/5021distributedrepresentationsofwordsandphrasesandtheirco
mpositionality.pdf
Negative sampling is a simplified variant of Noise Contrastive Estimation where we neglect to
subtract off log(Q(y|x)) during training. As a result, F (x, y ) is trained to approximate
log (E(y|x)) − log(Q(y|x)) .
It is noteworthy that in Negative Sampling, we are optimizing F (x, y ) to approximate something
that depends on the sampling distribution Q . This will make the results highly dependent on
the choice of sampling distribution. This is not true for the other algorithms described here.
Sampled Logistic
Sampled Logistic is a variant on Noise Contrastive Estimation where we discard without
replacement all sampled classes that happen to also be target classes. This requires T i to be
a set, as opposed to a multiset, though S i may be a multiset. As a result we learn an estimator
of the logodds of a class as opposed to the logprobability of a class. The math changes from
the NCE math as follows:
(
P (y|x)
)
P (y|x)
( )
logodds(y came f rom T vs (S − T ) | x) = log Q(y|x)(1−P (y|x) = log 1−P (y|x) − log(Q(y|x))
The first term, log ( P (y|x)
1−P (y|x) ) , is what we would like to train F (x, y) to estimate.
We have a layer in our model, which represents F (x, y ) . We add to it the second term,
− log(Q(y|x)) , which we compute analytically, and we pass the result to a logistic regression loss
predicting whether y came from T i vs (S i − T i ) .
Logistic Regression Input = F (x, y ) − log(Q(y|x)
The backpropagation signal trains the F (x, y ) layer to approximate what we want it to.
F (x, y ) ← log ( P (y|x)
1−P (y|x) )
Context‐Specific vs. Generic Sampling
In the methods discussed, the sampling algorithm is allowed to depend on the context. It is
possible that for some models, contextspecific sampling will be very useful, in that we can
generate contextdependent hard negatives and provide a more useful training signal. The
authors have to this point focused on generic sampling algorithms such as uniform sampling
and unigram sampling, which do not make use of the context. The reason is described in the
next section.
Batchwise Sampling
We have focused on models which use the same set S of sampled classes across a whole
batch of training examples. This seems counterintuitive shouldn’t convergence be faster if we
use different sampled classes for each training example? The reason for using the same
sampled classes across a batch is computational.
In many of our models, F (x, y ) is computed as the dot product of a feature vector for the context
(the top hidden layer of a neural network), and an embedding vector for the class. Computing
the dot products of many feature vectors with many embedding vectors is a matrix multiplication,
which is highly efficient on modern hardware, especially on GPUs. Batching like this often
allows us to use hundreds or thousands of sampled classes without noticeable slowdown.
Another way to see it is that the overhead of fetching a class embedding across devices is
greater than the time it takes to compute its dot products with hundreds or even thousands of
feature vectors. So if we are going to use a sampled class with one context, it is virtually free to
use it with all of the other contexts in the batch as well.