The Gumbel-Max Trick: Explained

Leonard Tang
The Startup
Published in
4 min readJan 28, 2021

--

Softmax’s slicker sibling.

Motivation

I’ve recently been playing around with a few nature-inspired metaheuristic algorithms (think genetic algorithms, simulated annealing, etc.) In such settings, an algorithm iterates through a candidate solution search space with the objective of converging to an optimal solution (for some sense of the word optimal, i.e. characterized by the algorithm’s loss function). If executed with well-chosen hyperparameters (especially temporal ones), such algorithms are able to initially avoid locally optimal solutions, while gradually honing in on globally optimal solutions by making more and more precise — and careful — decisions.

Concretely, an algorithm in this class usually iterates through the candidate space by probabilistically choosing the next candidate(s) and performing some set of action(s) on them. Speaking in abstract, for a genetic algorithm, this manifests in the form of chromosome sampling and crossover, as well as mutation. As for stimulated annealing, a single next candidate s’ is sampled, and the action is simply whether or not to transition to this next candidate from the current candidate s.

The probability of sampling the next candidates is derived from the algorithm-specific loss function. In my case, I’ve come across a specific instance of a genetic algorithm with a loss function that scores solutions by log(p) — that is, by the log of the probability that the solution is chosen.

Given this choice of metric, the question becomes how to best sample from a distribution parametrized by log-probabilities.

Good Old Fashioned Exp-Norm, i.e. Softmax

So we have a distribution of log-probabilities xₖ. The standard way to sample from this distribution is to use the softmax function (the same one you might have heard of for classification using neural networks). A more descriptive synonym for softmax is the normalized exponential function. As the name implies, we will exponentiate and transform the log-probabilities xₖ to exp(xₖ). Recall that by basic probability axioms, the probabilities all events need to sum to 1. We guarantee this by normalizing each transformed probability. Ultimately, we have the following transformation allowing us to sample from the distribution:

That is, the probability of sampling category k from N categories is:

Critically, the xₖ are unconstrained in , but the πₖ lie on the probability simplex (i.e. ∀ k, πₖ ≥ 0, and ∑ πₖ = 1), as desired.

The Gumbel-Max Trick

Interestingly, the following formulation is equivalent to the softmax function:

There are multiple benefits to using the Gumbel-Max Trick. Most saliently:

  • It operates primarily in log-space, thereby avoiding potentially nasty numerical over/under-flow errors and unexpected/incorrect sampling behavior.
  • It entirely bypasses the need for marginalization (i.e. exp-sum), which can be expensive for a large number of categories.

Definitely pretty slick! Let’s derive this result next.

Derivation

First, recall that for a random variable X ∼ Gumbel(0,1), we have the following PDF and CDF:

Now, define the following:

Definition of the Gumbel-Max Trick
We use uₖ for notational convenience

Critically, observe that the randomness in uₖ comes from zₖ, whereas the log(πₖ) = xₖ are known. We then have that:

Thus, using Gumbel-Max Trick is indeed equivalent to using the softmax function.

--

--