In a more general setting, let us say that we want to compute a derivative of the form
\nabla_{\phi}\mathbb{E}_{z\sim q_{\phi}}[f(z)].
\tag{1} The derivative with respect to \phi is problematic because we need to sample from q_{\phi}. Backpropagation will fail because we cannot backpropagate through stochastic nodes. Intuitively, we can think of sampling from q_{\phi} as a process that makes the dependency on \phi disappear.
The reparametrization trick is a solution to this problem. The idea is to rewrite z\sim q_{\phi} with the help of a deterministic differentiable function g_{\phi} and a random variable \epsilon\sim p such that z=g_{\phi}(\epsilon). p is a probability distribution from which it is easy to sample and, most importantly, does not depend on \phi. Using the definition of the integral and the change of variable formula:
With q_\phi = \mathcal N(\mu,\sigma^2), with \phi = (\mu,\sigma^2), we can use: g_\phi: \epsilon\mapsto\mu+\sigma\epsilon and p=\mathcal N(0,1).
With q_\phi = \mathcal E(\phi), we can use: g_\phi = -\frac1\phi\log\epsilon with p=\mathcal U(]0,1[).
With q_\phi = \operatorname{Weibull}(\lambda,k), with \phi=(\lambda, k), we can use: g_\phi: \epsilon\mapsto\lambda{\left(-\log\epsilon\right)}^{1/k} and p=\mathcal U(]0,1[)
If q_\phi is abs-cont-w.r.t. Lebesgue measure, inverse-cdf F^{-1}_\phi is known (and easily computable and differentiable), we can use g_\phi = F^{-1}_\phi and p=\mathcal U(]0,1[).
Reparametrization trick with respect to a continuous random variable
Reparametrization trick for Variational Autoencoders
The reparametrization has become popular because of its usage in a very popular generative model: the Variation Autoencoder (VAE). A VAE is trained by maximizing a loss function called the evidential lower bound (ELBO), as it is a lower bound of the model log-likelihood p_{\theta}(x), \begin{equation*}
\mathcal{L}_{\theta,\phi}(x)=\mathbb{E}_{z\sim q_{\phi}}[\log p_{\theta}(x|z)]
-D_{KL}(q_{\phi}||p_{\theta}),
\end{equation*} where z are latent random variables and p_{\theta} is the prior distribution. During the parameter estimation procedure, one want to differentiate \mathcal{L}_{\theta,\phi}(x) with respect to \phi (and \theta, but this is straightforward) while drawing samples z under the variational distribution, q_{\phi}, parametrized by this same \phi. It is thus clear that the computation of the first term of the ELBO is a particular case of Equation 1, with f(z)=\log p_{\theta}(x|z). In the setting of classical VAEs, z is drawn from a Gaussian distribution, which is reparametrizable and f is differentiable with respect to z, therefore the reparametrization trick can be used. Kingma’s original paper on VAEs can be consulted for details about the VAE model and the associated reparametrization trick.
Experiments with Pytorch
Let us switch to a minimal example for the reparametrization trick in which we want to compute: \begin{equation*}
\frac{\partial}{\partial \mu}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right)
\text{ and }
\frac{\partial}{\partial \sigma}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right)
\end{equation*}
0. Theoretical result
In this toy example, using the identity \mathbb{E}[z^2]=\mathrm{Var}(z)+(\mathbb{E}[z])^2, we can compute by hand that: \begin{equation*}
\frac{\partial}{\partial \mu}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right)
= \frac{\partial}{\partial \mu}\left(\sigma^2 + \mu^2\right) = 2\mu
\text{ and }
\frac{\partial}{\partial
\sigma}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z]\right)
= \frac{\partial}{\partial \sigma}\left(\sigma^2 + \mu^2\right) = 2\sigma.
\end{equation*}
1. Direct computation with backpropagation
Here we will use Pytorch backpropagation without reparametrization trick by using the Monte-Carlo approximations: \begin{equation*}
\frac{\partial}{\partial \mu}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right)
= \frac{\partial}{\partial \mu}\left(\frac{1}{N}\sum_{i=1}^Nz_i^2\right)
= \frac{1}{N}\sum_{i=1}^N\frac{\partial}{\partial \mu}z_i^2
\end{equation*} and \begin{equation*}
\frac{\partial}{\partial
\sigma}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right)
= \frac{\partial}{\partial \sigma}\left(\frac{1}{N}\sum_{i=1}^Nz_i^2\right)
= \frac{1}{N}\sum_{i=1}^N\frac{\partial}{\partial \sigma}z_i^2,
\end{equation*} where, each time, the z_i are N i.i.d. samples from \mathcal{N}(\mu,\sigma^2).
We now implement the Monte-Carlo estimators we have derived above.
import torchimport numpy as npN =1000mu_grads = []std_grads = []mu = torch.tensor([12.], requires_grad=True) # set mu = 12 and store gradientstd = torch.tensor([42.], requires_grad=True) # set std = 42for i inrange(N): mu.grad =None# reset the gradient value std.grad =None z = torch.normal(mu, std) # the random sampling happens here z2 = z **2 z2.backward() mu_grads.append(mu.grad.detach().cpu().numpy()) std_grads.append(std.grad.detach().cpu().numpy())print(f"Estimated dE[z^2]/dmu={np.mean(mu_grads):.2f}")print(f"Estimated dE[z^2]/dstd={np.mean(std_grads):.2f}")
❌ We do not get the expected values: the differentiations with respect to \mu and \sigma have failed.
Important
Zero-ing gradients in Pytorch is an important step that is not automatically done (e.g. between two .backward() calls). It is achieved with a call to the optimizer method optimizer.zero_grad() or by manually resetting the grad attribute from the tensor.
Important
requires_grad means that we should compute gradients for this tensor and that the gradient should be stored in the grad attribute. This concerns leaf tensors. A leaf tensor refers to a leaf in the computational graph, it is equivalent to tensor which does not have a grad_fn function. If we are interested by keeping in memory the gradient values in a non-leaf tensor after a .backward() call one would need to call tensor.retain_grad(). More on this in a following section.
Important
retain_graph=True is useful in the backward() call in order to perform several calls to the backpropagation with the same computed graph from different tensors.
2. Computation with backpropagation and reparametrization trick
We will now use the reparametrization z=g_{\mu,\sigma}(\epsilon) with g_{\mu,\sigma}:\epsilon\mapsto\mu+\sigma*\epsilon where \epsilon\sim\mathcal{N}(0,1).
Note
Mathematically, we can check the procedure for a better understanding. Using the chain rule formula: \begin{align*}
\frac{\partial}{\partial \mu}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right)
&=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\frac{\partial}{\partial z}z^2
\frac{\partial}{\partial\mu}g_{\mu,\sigma}(\epsilon)]\\
&=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[2z * 1]\\
&=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[2(\mu+\sigma\epsilon) * 1]\\
&=2\mu + 2\sigma\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\epsilon]\\
&=2\mu,
\end{align*}\begin{align*}
\frac{\partial}{\partial \sigma}\left(\mathbb{E}_{z\sim\mathcal{N}(\mu,\sigma^2)}[z^2]\right)
&=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\frac{\partial}{\partial z}z^2
\frac{\partial}{\partial\sigma}g_{\mu,\sigma}(\epsilon)]\\
&=\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[2(\mu+\sigma\epsilon)\epsilon]\\
&=2\mu\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\epsilon]+
2\sigma\mathbb{E}_{\epsilon\sim\mathcal{N}(0, 1)}[\epsilon^2]\\
&=2\sigma.
\end{align*} Hopefully, sampling the z_i with reparametrization is mathematically equivalent to the direct derivation we have done above.
Now let us define a function which implements the reparametrization g_{\phi}.
✔️ Provided that \mu=12 and \sigma=42, we now have a plausible estimation!
Note
Traditionally in the ELBO computation of VAEs, only one Monte-Carlo sample is used.
A word on backpropagation
Chain rule formula
Let us consider the one-dimensional case. In the reparametrization trick, after reparametrization, onehass to compute \frac{\mathrm{d}y}{\mathrm{d}\phi}, with y=f(z), with z=g_{\phi}(\epsilon).
In standard notation, the chain rule formula can be written:
\frac{\mathrm{d}y}{\mathrm{d}\phi}=
\frac{\mathrm{d}y}{\mathrm{d}z}
\frac{\mathrm{d}z}{\mathrm{d}\phi}.
The backpropagation algorithm relies on a forward and a backward pass to compute the derivative, by following the chain rule formula, with a fixed \epsilon:
The forward pass computes values from leaf (\phi) to root (y):
compute z=g_\phi(\epsilon). Store z.
compute y=f(z). Store y.
The backward pass uses the previous computations and follows the computational graph from root (y) to leaf (\phi), applying the successive derivations rule.
Compute \frac{\mathrm dy}{\mathrm dz} at the value of stored y and z. Here this is f'(z).
Compute \frac{\mathrm dz}{\mathrm d\phi} at the value of stored z and \phi. Here this is \frac{\mathrm dg_\phi(\epsilon)}{\mathrm d\phi}. Deduce \frac{\mathrm dy}{\mathrm d\phi}=\frac{\mathrm dy}{\mathrm dz}\frac{\mathrm dz}{\mathrm d\phi}.
Goodfellow’s Deep Learning book contains a precise description of the backpropagation algorithm.
Visualizing backpropagation graph
A simple way to visualize this graph is to crawl along the grad_fn attributes of tensors.
print("Backward operation at the root tensor:", z2.grad_fn)
Backward operation at the root tensor: <PowBackward0 object at 0x7fa714662be0>
print("Backward operation(s) one step before the root tensor (i.e. ""the next function to execute the the backpropagation order)", z2.grad_fn.next_functions)
Backward operation(s) one step before the root tensor (i.e. the next function to execute the the backpropagation order) ((<AddBackward0 object at 0x7fa6ecf4fd60>, 0),)
<PowBackward0 object at 0x7fa724269c70>
├── <AddBackward0 object at 0x7fa724269be0>
│ ├── <MulBackward0 object at 0x7fa7242690a0>
│ │ ├── <AccumulateGrad object at 0x7fa724269ca0>
│ │ Tensor: tensor([42.], requires_grad=True)
│ │ Gradient: tensor([257.3424])
│ ├── <AccumulateGrad object at 0x7fa724269bb0>
│ Tensor: tensor([12.], requires_grad=True)
│ Gradient: tensor([159.5153])
Reparametrization trick with respect to a discrete random variable
The reparametrization trick is based on two assumptions:
f is differentiable in z.
q_{\phi} can be reparametrized as a differentiable function of \phi.
In the case of discrete random variable, if we consider the reparametrization q_\phi, the second hypothesis does not hold and things become a little bit more involved. We present several solutions used in the machine learning literature in order to be able to backpropagate through discrete random variables.
Straight-Through Estimator (STE)
This solution is not in itself a reparametrization, but it is widely used in order to circumvent non-differentiable operations. Let us continue with another minimal example. One wishes to compute
\begin{equation*}
\frac{\partial}{\partial p}\left(\mathbb{E}_{y\sim\mathcal{B}(p)}[y^2]\right),
\end{equation*} which, by hand, gives 1 for all p.
1. Direct computation with backpropagation
Let us first observe again a backpropagation failure, in the naive case where we draw Monte-Carlo samples to compute the previous expectation.
N =1000p_grads = []p = torch.tensor([0.25], requires_grad=True)for i inrange(N): p.grad =None y = torch.bernoulli(p) y2 = y **2 y2.backward() p_grads.append(p.grad.detach().cpu().numpy())print(f"Estimated dE[y^2]/dp={np.mean(p_grads):.2f}")
Estimated dE[y^2]/dp=0.00
❌ The backpropagation failed!
Warning
Creating p with p = 0.5 * torch.ones((1), requires_grad=True) is possible but requires to add the line p.retain_grad() since the former makes p a non-leaf tensor!
2. Computation with backpropagation and simple STE
We now implement the STE under its simplest (and tricky) form. It consists in a copy of the gradient but not of the value: such a simple copy solves the problem of having undefined gradients during the backward pass. But it introduces a bias in the estimation !
N =1000p_grads = []p = torch.tensor([0.25], requires_grad=True)for i inrange(N): p.grad =None y = torch.bernoulli(p) y = p + (y - p).detach() # Straight-Through Estimator y2 = y **2 y2.backward() p_grads.append(p.grad.detach().cpu().numpy())print(f"Estimated dE[y^2]/dp={np.mean(p_grads):.2f}")
Estimated dE[y^2]/dp=0.54
⁉️ Our estimation is highly biased! More on that in the next section…
3. Computation with backpropagation and complete STE module
Note that a more complex but more complete and generalizable way to implement the STE estimator would be to create a nn.Module whose computations implement their own forward() and backward() methods. To do so we create an autograd function by subclassing from torch.autograd.Function and implement a custom backward operation for this function.
The custom backward function:
returns the gradient at the input of the module.
has the output gradient(s) as argument(s).
returns the input gradient(s).
class BernoulliSTE_FB(torch.autograd.Function):''' Implements the STE estimator '''@staticmethoddef forward(ctx, input):''' input : p the Bernoulli parameter '''return torch.bernoulli(input)@staticmethoddef backward(ctx, grad_output):''' - This function returns the gradient at the input of the module - Argument grad_output is the gradient at the output of the module - The return statement then sets the gradient at the input equals to the gradient at the ouput (STE estimator) '''return grad_output
We then use this custom function in a module. Note that apply(x) is the way we call this custom autograd function.
class BernoulliSTE(torch.nn.Module):def__init__(self):super(BernoulliSTE, self).__init__()def forward(self, x): x = BernoulliSTE_FB.apply(x) # apply is the way we call a# torch.autograd functionreturn x
Warning
Recall that in the backward pass, all the operations are taken as the reverse of the operations in the forward pass. Thus, if the forward() function of the BernoulliSTE_FB class have two inputs, this means that the backward() funtion of this same class would have two outputs! This in order to specify gradients with respect to both arguments.
We now use this module which replaces the tricky line from the previous example.
N =1000p_grads = []b_ste = BernoulliSTE()p = torch.tensor([0.25], requires_grad=True)for i inrange(N): p.grad =None y = b_ste(p) # Straight Through Estimator y2 = y **2 y2.backward() p_grads.append(p.grad.detach().cpu().numpy())print(f"Estimated dE[y^2]/dp={np.mean(p_grads):.2f}")
Estimated dE[y^2]/dp=0.56
⁉️ Our estimation does not look better than the previous one!
Note
Let us investigate the bias in the gradient estimation of this STE. The line y = p + (y - p).detach() or its equivalent STE module replaces, in terms of derivative, the gradient of y by p, the expected value of y. More precisely, one wants to compute
\begin{align*}
\frac{\partial}{\partial p}\left(\frac 1 N \sum_{i=1}^N y_i^2\right),
\end{align*} which is replaced by \begin{align*}
\frac{\partial}{\partial p}\left(\left(\frac 1 N \sum_{i=1}^N
p\right)^2\right)=
\frac{\partial}{\partial p}(p^2)=2p.
\end{align*}
More generally, for an arbitrary function \xi (in our previous example we had \xi\colon x\rightarrow x^2), we can see that the STE replaces the estimation of \begin{align*}
\frac{\partial}{\partial p}\mathbb{E}\xi(Y)
\end{align*} by \begin{align*}
\frac{\partial}{\partial p}\xi(\mathbb{E}Y),
\end{align*} which is biased in general since \mathbb{E}\xi(Y)\neq\xi(\mathbb{E}Y). Note that the STE is exact for linear \xi.
Gumbel-Softmax reparametrization trick
In a first part, we present a reparametrization for a categorical random variable. We then provide a method to make it differentiable.
1. Gumbel-Max trick
With \pi=(\pi_1,\dots,\pi_k)\in\mathbb R_+^{*k} such that \sum_i \pi_i=1.
Let g i.i.d. vector with g_i\sim\mathrm{Gumbel}(0,1), i.e., g_i\sim-\log(-\log\mathcal{U}([0,1])).
The Gumbel-Max trick introduces: \begin{equation*}
y = \operatorname{one\_hot}(\argmax_i(g_i+\log\pi_i)),
\end{equation*}
then y is the one-hot encoded vector of a categorial variable with probabilities \pi.
2. Gumbel-Softmax reparametrization trick
In order to make the latter reparametrization differentiable, the idea of using a \operatorname{softmax} operation instead of the \operatorname{one\_hot}\circ\argmax has been introduced. We talk about Gumbel-Softmax relaxation: \hat{y} is a continuous random variable, sampled according to \begin{equation*}
\hat{y}=\operatorname{softmax}\left(\frac{g+\log \pi}{\tau}\right),
\end{equation*}
where g is a vector of \mathrm{Gumbel}(0,1). See the Gumbel distribution on Wikipedia. \hat{y} will approximate a sample from the categorical distribution \pi as the temperature parameter \tau (\tau > 0) will decrease to 0.
We first create a function which performs the reparametrization.
def reparametrize_gumbel(p, temp=0.1):# Note, following Jangs's remark in their article (appendix B.1), we have# invertibility issue because softmax operator removes one degree of# freedom. This can cause problems in backpropagation to get our correct# estimation. Thus, we overparametrize the softmax here to preserve the# invertibility. g =-torch.log(-torch.log(torch.rand((1)))) g_ =-torch.log(-torch.log(torch.rand((1)))) logits = torch.log(p) logits_ = torch.log1p(-p) e = torch.exp(((logits_ + g_) - (logits + g)) / temp)return1/ (1+ e)
N =20000p_grads = []for i inrange(N): p = torch.tensor([0.25], requires_grad=True) y = reparametrize_gumbel(p) # Gumbel Softmax reparametrization trick y = y **2 y.backward() p_grads.append(p.grad.detach().cpu().numpy())print(f"Estimated dE[y^2]/dp={np.nanmean(p_grads):.2f}")
Estimated dE[y^2]/dp=0.94
✔️ This yields a correct estimation.
Note
Pytorch provides a gumbel_softmax() function that can directly be used.
An alternative: the log-derivative trick
The log-derivative trick is also known as the REINFORCE trick or score function estimator. It is a direct approach which consists in using the chain rule identity \nabla_{\phi}\log q_{\phi}(z) = \frac{1}{q_{\phi}(z)}\nabla_{\phi}q_{\phi}(z), i.e.:
\begin{align*}
\nabla_{\phi}\mathbb{E}_{z\sim q_{\phi}}[f(z)] &=
\sum_{z}\nabla_{\phi}q_{\phi}(z)f(z)=
\sum_{z}\frac{q_{\phi}(z)}{q_{\phi}(z)}\nabla_{\phi}q_{\phi}(z)f(z),\\
&=\sum_z q_{\phi}(z)\nabla_{\phi}\log q_{\phi}(z)f(z)=
\mathbb{E}_{z\sim q_{\phi}}[\nabla_{\phi}\log q_{\phi}(z)f(z)].
\end{align*} Technically, the differential operator is now inside the expectation and one can approximate the latter. This approach is applicable for continuous and discrete distributions. However, this approach suffers from high variance and it must be combined with variance reduction techniques. See this report to learn more about variance reduction in this precise scenario.
References
Auto-Encoding Variational Bayes, Diederik P. Kingma et al., link