Discrete Stochastic Optimization

5 minute read

Published:

This post will cover stochastic optimization with discrete latent random variables. Unlike continuous random variables, discrete random variables encode data in a few bits. This allows us to capture relevant information effectively. However, differentiating through discrete variables is challenging. We will look at the challenges posed in discrete stochastic optimization and reparameterization methods which overcome these challenges.

The Stochastic Optimization Problem

Stochastic optimization is the problem of optimizing over latent variables which are sampled from a base distribution. Differentiating through distributions is challenging for modern autograd frameworks. To see this, lets look at a loss function $\mathcal{L}(\theta, z)$ which depends on the parameters $\theta$ and latent variable $z$ sampled from the latent distribution $z \sim q_{\theta}(z)$. Computing the gradient of this objective yields us the following,

\begin{gather} \nabla_{\theta} \mathbb{E} [\mathcal{L} (\theta,z)] = \nabla_{\theta} \int \mathcal{L}(\theta,z) q_{\theta}(z) dz = \int \nabla_{\theta} \mathcal{L}(\theta,z) q_{\theta}(z) dz
\end{gather}

Moving the gradient inside the integral and using the product rule gives us the following two terms,

\begin{gather} \int [\nabla_{\theta} \mathcal{L}(\theta,z)] q_{\theta}(z) dz + \int [\nabla_{\theta} q_{\theta}(z)] \mathcal{L}(\theta,z) dz \end{gather}

While the gradient in the first term is easy to compute, the gradient in the second term is not well defined. Since we sample variables $z \sim q_{\theta}(z)$, computing the gradient of this operation with the chain rule is not exactly clear. In the case of continuous random variables, we overcome this limitation by reparameterizing the distribution and computing pathwise derivatives. However, it is unclear how to reparameterize the distribution with discrete variables.

Gumbel Max Trick

We now look at reparameterization methods which help us approximate $\nabla_{\theta} q_{\theta}(z)$. One such method is the gumbel max trick. Suppose we have a discrete one-hot vector $d \sim$ Cat$(d | \alpha)$ with $k$ entries $d_{k} \in { 0,1 }$ where $\alpha$ are the probability logits, the gumbel max trick reparameterizes by providing a continuous relaxation to discrete variables. We sample gumbel noise, add it to the log probability logits and compute the argmax over samples. Continuous variables are then one-hot encoded to obtain the discrete sample.

\begin{gather} u_{k} \sim \text{Unif}(0,1) \newline \epsilon_{k} \sim -\log (-\log (u_{k})) \newline d = \texttt{onehot}(\text{argmax} (\epsilon_{k} + \log (\alpha_{k}))) \newline \end{gather}

This is straightforward to implement with an autograd library such as JAX,

eps = 1e-20
key = jax.random.PRNGKey(42)
logits = jnp.random.randint((256,32))

u = jax.random.uniform(key, logits.shape, minval = 0, maxval = 1)
epsilon = -jnp.log(-jnp.log(u + eps) + eps)
y = logits + epsilon
d = jnp.equal(y, jnp.max(y, 1, keepdims = True))

Gumbel Softmax Trick

While the gumbel max trick provides a tractable continuous relaxation, it is difficult to backpropagate through the sampling procedure. This is because the gradient of argmax is only defined on the boundary of transitions. We thus need a continuous relaxation function with a well defined gradient.

Instead of using the hard argmax, we can utilize the softmax operator. Softmax defines a continuous probability distribution with a temperature $\tau$. In the limit of $\tau \rightarrow 0$, softmax distribution approaches to the true categorical argmax distribution. With this change, we get a reparameterization which is continuously differentiable.

\begin{gather} u_{k} \sim \text{Unif}(0,1) \newline \epsilon_{k} \sim -\log (-\log (u_{k})) \newline d = \texttt{onehot}(\text{softmax} (\epsilon_{k} + \log (\alpha_{k}) / \tau)) \newline \end{gather}

Implementing this in JAX yields the following,

eps = 1e-20
key = jax.random.PRNGKey(42)
logits = jnp.random.randint((256,32))
temperature = 2

u = jax.random.uniform(key, logits.shape, minval = 0, maxval = 1)
epsilon = -jnp.log(-jnp.log(u + eps) + eps)
y = logits + epsilon
y = jax.nn.softmax(y / temperature)

Straight-Through Gumbel Softmax Trick

While the gumbel softmax trick is a better continuous relaxation with a well-defined gradient, it is still challenging to propagate the gradient of a continuous variable from a discrete sample. We obtain the sample $d$ but we want to backpropagate through the variable $y$. This problem is challenging since the sample $d$, being a one-hot encoding, has no gradient.

We circumvent the above problem by utilizing a straight-through estimator. The estimator creates a differentiable pathway in the computation graph which backpropagates through the variable $y$, leaving the sample $d$ unchanged. Specifically, we bias the sample $d$ such that its value remains unchanged.

\begin{gather} d = d + y - \text{sg}(y) \end{gather}

In the above equation, sg represents the stop_gradient operation. During inference, we obtain the sample $d$ as is. During training, the gradient flows only through the second term $y$. This is because the first term $d$ has no gradient and last term has the gradient stopped. This operation results in a biased estimate of the gradient which keeps the stochastic computation graph intact.

The above approach of biasing the gradient using a straight-through estimator is called the straight-through gumbel softmax trick which is implemented as follows,

\begin{gather} u_{k} \sim \text{Unif}(0,1) \newline \epsilon_{k} \sim -\log (-\log (u_{k})) \newline y = \text{softmax} (\epsilon_{k} + \log (\alpha_{k}) / \tau) \newline d = \texttt{onehot}(d) \newline d = d + y - \text{sg}(y) \end{gather}

When implemented in JAX, it yields the following,

eps = 1e-20
key = jax.random.PRNGKey(42)
logits = jnp.random.randint((256,32))
temperature = 2

u = jax.random.uniform(key, logits.shape, minval = 0, maxval = 1)
epsilon = -jnp.log(-jnp.log(u + eps) + eps)
y = logits + epsilon
y = jax.nn.softmax(y / temperature)
d = jnp.equal(y, jnp.max(y, 1, keepdims = True))
d = d + y - jax.lax.stop_gradient(y)