Reparametrization trick

Author

Hugo Gangloff

Published

April 14, 2023

General idea

In a more general setting, let us say that we want to compute a derivative of the form \nabla_{\phi}\mathbb{E}_{z\sim q_{\phi}}[f(z)]. \tag{1} The derivative with respect to \phi is problematic because we need to sample from q_{\phi}. Backpropagation will fail because we cannot backpropagate through stochastic nodes. Intuitively, we can think of sampling from q_{\phi} as a process that makes the dependency on \phi disappear.

The reparametrization trick is a solution to this problem. The idea is to rewrite z\sim q_{\phi} with the help of a deterministic differentiable function g_{\phi} and a random variable \epsilon\sim p such that z=g_{\phi}(\epsilon). p is a probability distribution from which it is easy to sample and, most importantly, does not depend on \phi. Using the definition of the integral and the change of variable formula:

\mathbb{E}_{z\sim q_{\phi}}[f(z)] = \int q_{\phi}(z)f(z) \mathrm{d}z = \int p(\epsilon) f(g_{\phi}(\epsilon)) \mathrm{d}\epsilon= \mathbb{E}_{\epsilon\sim p}[f(g_{\phi}(\epsilon))]. Then,

\begin{align*} \nabla_{\phi}\mathbb{E}_{z\sim q_{\phi}}[f(z)] &= \nabla_{\phi}\mathbb{E}_{\epsilon\sim p}[f(g_{\phi}(\epsilon))],\\ &= \mathbb{E}_{\epsilon\sim p}[\nabla_{\phi}f(g_{\phi}(\epsilon))],\\ &= \mathbb{E}_{\epsilon\sim p}[\nabla_{z}f(z)\nabla_{\phi}g_{\phi}(\epsilon)] \text{ (chain rule formula)}.\\ \end{align*}

Reparametrization examples

  • With q_\phi = \mathcal N(\mu,\sigma^2), with \phi = (\mu,\sigma^2), we can use: g_\phi: \epsilon\mapsto\mu+\sigma\epsilon and p=\mathcal N(0,1).

  • With q_\phi = \mathcal E(\phi), we can use: g_\phi = -\frac1\phi\log\epsilon with p=\mathcal U(]0,1[).

  • With q_\phi = \operatorname{Weibull}(\lambda,k), with \phi=(\lambda, k), we can use: g_\phi: \epsilon\mapsto\lambda{\left(-\log\epsilon\right)}^{1/k} and p=\mathcal U(]0,1[)

  • If q_\phi is abs-cont-w.r.t. Lebesgue measure, inverse-cdf F^{-1}_\phi is known (and easily computable and differentiable), we can use g_\phi = F^{-1}_\phi and p=\mathcal U(]0,1[).

Reparametrization trick with respect to a continuous random variable

Reparametrization trick for Variational Autoencoders

The reparametrization has become popular because of its usage in a very popular generative model: the Variation Autoencoder (VAE). A VAE is trained by maximizing a loss function called the evidential lower bound (ELBO), as it is a lower bound of the model log-likelihood p_{\theta}(x), \begin{equation*} \mathcal{L}_{\theta,\phi}(x)=\mathbb{E}_{z\sim q_{\phi}}[\log p_{\theta}(x|z)] -D_{KL}(q_{\phi}||p_{\theta}), \end{equation*} where z are latent random variables and p_{\theta} is the prior distribution. During the parameter estimation procedure, one want to differentiate \mathcal{L}_{\theta,\phi}(x) with respect to \phi (and \theta, but this is straightforward) while drawing samples z under the variational distribution, q_{\phi}, parametrized by this same \phi. It is thus clear that the computation of the first term of the ELBO is a particular case of Equation 1, with f(z)=\log p_{\theta}(x|z). In the setting of classical VAEs, z is drawn from a Gaussian distribution, which is reparametrizable and f is differentiable with respect to z, therefore the reparametrization trick can be used. Kingma’s original paper on VAEs can be consulted for details about the VAE model and the associated reparametrization trick.

Experiments with Pytorch

Let us switch to a minimal example for the reparametrization trick in which we want to compute: \begin{equation*} \frac{\partial}{\partial \mu}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right) \text{ and } \frac{\partial}{\partial \sigma}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right) \end{equation*}

0. Theoretical result

In this toy example, using the identity \mathbb{E}[z^2]=\mathrm{Var}(z)+(\mathbb{E}[z])^2, we can compute by hand that: \begin{equation*} \frac{\partial}{\partial \mu}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right) = \frac{\partial}{\partial \mu}\left(\sigma^2 + \mu^2\right) = 2\mu \text{ and } \frac{\partial}{\partial \sigma}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z]\right) = \frac{\partial}{\partial \sigma}\left(\sigma^2 + \mu^2\right) = 2\sigma. \end{equation*}

1. Direct computation with backpropagation

Here we will use Pytorch backpropagation without reparametrization trick by using the Monte-Carlo approximations: \begin{equation*} \frac{\partial}{\partial \mu}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right) = \frac{\partial}{\partial \mu}\left(\frac{1}{N}\sum_{i=1}^Nz_i^2\right) = \frac{1}{N}\sum_{i=1}^N\frac{\partial}{\partial \mu}z_i^2 \end{equation*} and \begin{equation*} \frac{\partial}{\partial \sigma}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right) = \frac{\partial}{\partial \sigma}\left(\frac{1}{N}\sum_{i=1}^Nz_i^2\right) = \frac{1}{N}\sum_{i=1}^N\frac{\partial}{\partial \sigma}z_i^2, \end{equation*} where, each time, the z_i are N i.i.d. samples from \mathcal{N}(\mu,\sigma^2).

We now implement the Monte-Carlo estimators we have derived above.

import torch
import numpy as np

N = 1000
mu_grads = []
std_grads = []

mu = torch.tensor([12.], requires_grad=True) # set mu = 12 and store gradient
std = torch.tensor([42.], requires_grad=True) # set std = 42

for i in range(N):
    mu.grad = None # reset the gradient value
    std.grad = None

    z = torch.normal(mu, std) # the random sampling happens here
    z2 = z ** 2
    z2.backward()

    mu_grads.append(mu.grad.detach().cpu().numpy())
    std_grads.append(std.grad.detach().cpu().numpy())

print(f"Estimated dE[z^2]/dmu={np.mean(mu_grads):.2f}")
print(f"Estimated dE[z^2]/dstd={np.mean(std_grads):.2f}")
Estimated dE[z^2]/dmu=0.00
Estimated dE[z^2]/dstd=0.00

We do not get the expected values: the differentiations with respect to \mu and \sigma have failed.

Important

Zero-ing gradients in Pytorch is an important step that is not automatically done (e.g. between two .backward() calls). It is achieved with a call to the optimizer method optimizer.zero_grad() or by manually resetting the grad attribute from the tensor.

Important

requires_grad means that we should compute gradients for this tensor and that the gradient should be stored in the grad attribute. This concerns leaf tensors. A leaf tensor refers to a leaf in the computational graph, it is equivalent to tensor which does not have a grad_fn function. If we are interested by keeping in memory the gradient values in a non-leaf tensor after a .backward() call one would need to call tensor.retain_grad(). More on this in a following section.

Important

retain_graph=True is useful in the backward() call in order to perform several calls to the backpropagation with the same computed graph from different tensors.

2. Computation with backpropagation and reparametrization trick

We will now use the reparametrization z=g_{\mu,\sigma}(\epsilon) with g_{\mu,\sigma}:\epsilon\mapsto\mu+\sigma*\epsilon where \epsilon\sim\mathcal{N}(0,1).

Note

Mathematically, we can check the procedure for a better understanding. Using the chain rule formula: \begin{align*} \frac{\partial}{\partial \mu}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right) &=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\frac{\partial}{\partial z}z^2 \frac{\partial}{\partial\mu}g_{\mu,\sigma}(\epsilon)]\\ &=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[2z * 1]\\ &=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[2(\mu+\sigma\epsilon) * 1]\\ &=2\mu + 2\sigma\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\epsilon]\\ &=2\mu, \end{align*} \begin{align*} \frac{\partial}{\partial \sigma}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right) &=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\frac{\partial}{\partial z}z^2 \frac{\partial}{\partial\sigma}g_{\mu,\sigma}(\epsilon)]\\ &=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[2(\mu+\sigma\epsilon)\epsilon]\\ &=2\mu\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\epsilon]+ 2\sigma\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\epsilon^2]\\ &=2\sigma. \end{align*} Hopefully, sampling the z_i with reparametrization is mathematically equivalent to the direct derivation we have done above.

Now let us define a function which implements the reparametrization g_{\phi}.

def reparametrize(mu, std):
    eps = torch.randn_like(std)
    return eps * std + mu

Therefore, we adapt the previous code using Monte-Carlo approximation but now the z_i are sampled using the reparametrization trick.

mu_grads = []
std_grads = []

mu = torch.tensor([12.], requires_grad=True)
std = torch.tensor([42.], requires_grad=True)

for i in range(N):
    mu.grad = None
    std.grad = None

    z = reparametrize(mu, std)
    z2 = z ** 2
    z2.backward(retain_graph=True)

    mu_grads.append(mu.grad.detach().cpu().numpy())
    std_grads.append(std.grad.detach().cpu().numpy())

print(f"Estimated dE[z^2]/dmu={np.mean(mu_grads):.2f}")
print(f"Estimated dE[z^2]/dstd={np.mean(std_grads):.2f}")
Estimated dE[z^2]/dmu=26.72
Estimated dE[z^2]/dstd=79.62

✔️ Provided that \mu=12 and \sigma=42, we now have a plausible estimation!

Note

Traditionally in the ELBO computation of VAEs, only one Monte-Carlo sample is used.

A word on backpropagation

Chain rule formula

Let us consider the one-dimensional case. In the reparametrization trick, after reparametrization, onehass to compute \frac{\mathrm{d}y}{\mathrm{d}\phi}, with y=f(z), with z=g_{\phi}(\epsilon).

In standard notation, the chain rule formula can be written: \frac{\mathrm{d}y}{\mathrm{d}\phi}= \frac{\mathrm{d}y}{\mathrm{d}z} \frac{\mathrm{d}z}{\mathrm{d}\phi}. The backpropagation algorithm relies on a forward and a backward pass to compute the derivative, by following the chain rule formula, with a fixed \epsilon:

  • The forward pass computes values from leaf (\phi) to root (y):
    • compute z=g_\phi(\epsilon). Store z.
    • compute y=f(z). Store y.
  • The backward pass uses the previous computations and follows the computational graph from root (y) to leaf (\phi), applying the successive derivations rule.
    • Compute \frac{\mathrm dy}{\mathrm dz} at the value of stored y and z. Here this is f'(z).
    • Compute \frac{\mathrm dz}{\mathrm d\phi} at the value of stored z and \phi. Here this is \frac{\mathrm dg_\phi(\epsilon)}{\mathrm d\phi}. Deduce \frac{\mathrm dy}{\mathrm d\phi}=\frac{\mathrm dy}{\mathrm dz}\frac{\mathrm dz}{\mathrm d\phi}.

Goodfellow’s Deep Learning book contains a precise description of the backpropagation algorithm.

Visualizing backpropagation graph

A simple way to visualize this graph is to crawl along the grad_fn attributes of tensors.

print("Backward operation at the root tensor:", z2.grad_fn)
Backward operation at the root tensor: <PowBackward0 object at 0x7fa714662be0>
print("Backward operation(s) one step before the root tensor (i.e. "
    "the next function to execute the the backpropagation order)",
    z2.grad_fn.next_functions)
Backward operation(s) one step before the root tensor (i.e. the next function to execute the the backpropagation order) ((<AddBackward0 object at 0x7fa6ecf4fd60>, 0),)

A complete function could be:

def get_graph(fn, level=0):
    prefix = ('│    ' * (level-1)) + '├── ' if level else ''
    print(prefix+str(fn))
    tensor = getattr(fn, 'variable', None)
    if tensor is not None:
        runningprefix = prefix.replace('├── ','    ')
        print(runningprefix+f"Tensor: {tensor}")
        print(runningprefix+f"Gradient: {tensor.grad}")
    else:
        for sub in fn.next_functions:
            if sub[0]:
                get_graph(sub[0], level+1)

get_graph(z2.grad_fn)
<PowBackward0 object at 0x7fa724269c70>
├── <AddBackward0 object at 0x7fa724269be0>
│    ├── <MulBackward0 object at 0x7fa7242690a0>
│    │    ├── <AccumulateGrad object at 0x7fa724269ca0>
│    │        Tensor: tensor([42.], requires_grad=True)
│    │        Gradient: tensor([257.3424])
│    ├── <AccumulateGrad object at 0x7fa724269bb0>
│        Tensor: tensor([12.], requires_grad=True)
│        Gradient: tensor([159.5153])

Reparametrization trick with respect to a discrete random variable

The reparametrization trick is based on two assumptions:

  1. f is differentiable in z.

  2. q_{\phi} can be reparametrized as a differentiable function of \phi.

In the case of discrete random variable, if we consider the reparametrization q_\phi, the second hypothesis does not hold and things become a little bit more involved. We present several solutions used in the machine learning literature in order to be able to backpropagate through discrete random variables.

Straight-Through Estimator (STE)

This solution is not in itself a reparametrization, but it is widely used in order to circumvent non-differentiable operations. Let us continue with another minimal example. One wishes to compute

\begin{equation*} \frac{\partial}{\partial p}\left(\mathbb{E}_{y\sim\mathcal{B}(p)}[y^2]\right), \end{equation*} which, by hand, gives 1 for all p.

1. Direct computation with backpropagation

Let us first observe again a backpropagation failure, in the naive case where we draw Monte-Carlo samples to compute the previous expectation.

N = 1000
p_grads = []

p = torch.tensor([0.25], requires_grad=True)

for i in range(N):
    p.grad = None

    y = torch.bernoulli(p)
    y2 = y ** 2
    y2.backward()

    p_grads.append(p.grad.detach().cpu().numpy())

print(f"Estimated dE[y^2]/dp={np.mean(p_grads):.2f}")
Estimated dE[y^2]/dp=0.00

The backpropagation failed!

Warning

Creating p with p = 0.5 * torch.ones((1), requires_grad=True) is possible but requires to add the line p.retain_grad() since the former makes p a non-leaf tensor!

2. Computation with backpropagation and simple STE

We now implement the STE under its simplest (and tricky) form. It consists in a copy of the gradient but not of the value: such a simple copy solves the problem of having undefined gradients during the backward pass. But it introduces a bias in the estimation !

N = 1000
p_grads = []

p = torch.tensor([0.25], requires_grad=True)

for i in range(N):
    p.grad = None

    y = torch.bernoulli(p)
    y = p + (y - p).detach() # Straight-Through Estimator
    y2 = y ** 2
    y2.backward()

    p_grads.append(p.grad.detach().cpu().numpy())

print(f"Estimated dE[y^2]/dp={np.mean(p_grads):.2f}")
Estimated dE[y^2]/dp=0.54

⁉️ Our estimation is highly biased! More on that in the next section…

3. Computation with backpropagation and complete STE module

Note that a more complex but more complete and generalizable way to implement the STE estimator would be to create a nn.Module whose computations implement their own forward() and backward() methods. To do so we create an autograd function by subclassing from torch.autograd.Function and implement a custom backward operation for this function.

The custom backward function:

  • returns the gradient at the input of the module.
  • has the output gradient(s) as argument(s).
  • returns the input gradient(s).
class BernoulliSTE_FB(torch.autograd.Function):
    '''
    Implements the STE estimator
    '''
    @staticmethod
    def forward(ctx, input):
        '''
        input : p the Bernoulli parameter
        '''
        return torch.bernoulli(input)

    @staticmethod
    def backward(ctx, grad_output):
        '''
        - This function returns the gradient at the input of the module
        - Argument grad_output is the gradient at the output of the module
        - The return statement then sets the gradient at the input equals to
          the gradient at the ouput (STE estimator)
        '''
        return grad_output

We then use this custom function in a module. Note that apply(x) is the way we call this custom autograd function.

class BernoulliSTE(torch.nn.Module):
    def __init__(self):
        super(BernoulliSTE, self).__init__()

    def forward(self, x):
        x = BernoulliSTE_FB.apply(x) # apply is the way we call a
        # torch.autograd function
        return x
Warning

Recall that in the backward pass, all the operations are taken as the reverse of the operations in the forward pass. Thus, if the forward() function of the BernoulliSTE_FB class have two inputs, this means that the backward() funtion of this same class would have two outputs! This in order to specify gradients with respect to both arguments.

We now use this module which replaces the tricky line from the previous example.

N = 1000
p_grads = []
b_ste = BernoulliSTE()

p = torch.tensor([0.25], requires_grad=True)

for i in range(N):
    p.grad = None

    y = b_ste(p) # Straight Through Estimator
    y2 = y ** 2
    y2.backward()

    p_grads.append(p.grad.detach().cpu().numpy())

print(f"Estimated dE[y^2]/dp={np.mean(p_grads):.2f}")
Estimated dE[y^2]/dp=0.56

⁉️ Our estimation does not look better than the previous one!

Note

Let us investigate the bias in the gradient estimation of this STE. The line y = p + (y - p).detach() or its equivalent STE module replaces, in terms of derivative, the gradient of y by p, the expected value of y. More precisely, one wants to compute

\begin{align*} \frac{\partial}{\partial p}\left(\frac 1 N \sum_{i=1}^N y_i^2\right), \end{align*} which is replaced by \begin{align*} \frac{\partial}{\partial p}\left(\left(\frac 1 N \sum_{i=1}^N p\right)^2\right)= \frac{\partial}{\partial p}(p^2)=2p. \end{align*}

More generally, for an arbitrary function \xi (in our previous example we had \xi\colon x\rightarrow x^2), we can see that the STE replaces the estimation of \begin{align*} \frac{\partial}{\partial p}\mathbb{E}\xi(Y) \end{align*} by \begin{align*} \frac{\partial}{\partial p}\xi(\mathbb{E}Y), \end{align*} which is biased in general since \mathbb{E}\xi(Y)\neq\xi(\mathbb{E}Y). Note that the STE is exact for linear \xi.

Gumbel-Softmax reparametrization trick

In a first part, we present a reparametrization for a categorical random variable. We then provide a method to make it differentiable.

1. Gumbel-Max trick

With \pi=(\pi_1,\dots,\pi_k)\in\mathbb R_+^{*k} such that \sum_i \pi_i=1.

Let g i.i.d. vector with g_i\sim\mathrm{Gumbel}(0,1), i.e., g_i\sim-\log(-\log\mathcal{U}([0,1])).

The Gumbel-Max trick introduces: \begin{equation*} y = \operatorname{one\_hot}(\argmax_i(g_i+\log\pi_i)), \end{equation*}

then:

\begin{equation*} \forall i,\, \mathbb P(y_i=1) = \pi_i \qquad\text{i.e.}\qquad y\sim\mathcal M(1,\pi) \end{equation*}

then y is the one-hot encoded vector of a categorial variable with probabilities \pi.

2. Gumbel-Softmax reparametrization trick

In order to make the latter reparametrization differentiable, the idea of using a \operatorname{softmax} operation instead of the \operatorname{one\_hot}\circ\argmax has been introduced. We talk about Gumbel-Softmax relaxation: \hat{y} is a continuous random variable, sampled according to \begin{equation*} \hat{y}=\operatorname{softmax}\left(\frac{g+\log \pi}{\tau}\right), \end{equation*}

where g is a vector of \mathrm{Gumbel}(0,1). See the Gumbel distribution on Wikipedia. \hat{y} will approximate a sample from the categorical distribution \pi as the temperature parameter \tau (\tau > 0) will decrease to 0.

Note

We can prove:

\forall \pi,g,\qquad \operatorname{softmax}\left(\frac{g+\log \pi}{\tau}\right) \xrightarrow[\tau\to0]{} \operatorname{one\_hot}\left(\argmax\left(g+\log \pi\right)\right)

We first create a function which performs the reparametrization.

def reparametrize_gumbel(p, temp=0.1):
    # Note, following Jangs's remark in their article (appendix B.1), we have
    # invertibility issue because softmax operator removes one degree of
    # freedom. This can cause problems in backpropagation to get our correct
    # estimation. Thus, we overparametrize the softmax here to preserve the
    # invertibility.
    g = -torch.log(-torch.log(torch.rand((1))))
    g_ = -torch.log(-torch.log(torch.rand((1))))
    logits = torch.log(p)
    logits_ = torch.log1p(-p)
    e = torch.exp(((logits_ + g_) - (logits + g)) / temp)
    return 1 / (1 + e)
N = 20000
p_grads = []

for i in range(N):
    p = torch.tensor([0.25], requires_grad=True)
    y = reparametrize_gumbel(p) # Gumbel Softmax reparametrization trick
    y = y ** 2
    y.backward()

    p_grads.append(p.grad.detach().cpu().numpy())

print(f"Estimated dE[y^2]/dp={np.nanmean(p_grads):.2f}")
Estimated dE[y^2]/dp=0.94

✔️ This yields a correct estimation.

Note

Pytorch provides a gumbel_softmax() function that can directly be used.

An alternative: the log-derivative trick

The log-derivative trick is also known as the REINFORCE trick or score function estimator. It is a direct approach which consists in using the chain rule identity \nabla_{\phi}\log q_{\phi}(z) = \frac{1}{q_{\phi}(z)}\nabla_{\phi}q_{\phi}(z), i.e.:

\begin{align*} \nabla_{\phi}\mathbb{E}_{z\sim q_{\phi}}[f(z)] &= \sum_{z}\nabla_{\phi}q_{\phi}(z)f(z)= \sum_{z}\frac{q_{\phi}(z)}{q_{\phi}(z)}\nabla_{\phi}q_{\phi}(z)f(z),\\ &=\sum_z q_{\phi}(z)\nabla_{\phi}\log q_{\phi}(z)f(z)= \mathbb{E}_{z\sim q_{\phi}}[\nabla_{\phi}\log q_{\phi}(z)f(z)]. \end{align*} Technically, the differential operator is now inside the expectation and one can approximate the latter. This approach is applicable for continuous and discrete distributions. However, this approach suffers from high variance and it must be combined with variance reduction techniques. See this report to learn more about variance reduction in this precise scenario.

References

  • Auto-Encoding Variational Bayes, Diederik P. Kingma et al., link
  • Deep Learning, Ian Goodfellow et al., link
  • Discrete Latent Variable Models, lecture slides by Stefano Ermon and Yang Song, link
  • Categorical reparameterization with Gumbel-softmax, Eric Jang et al., link
  • Latent space conditioning for improved classification and anomaly detection, Erik Norlander and Alexandros Sopasakis, link
  • Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation, Yoshua Bengio et al., link
  • Black Box Variational Inference, Miguel Biron Lattes, link