Back PropagandaJekyll2016-08-10T09:46:53+00:00https://jramapuram.github.io/Jason Ramapuramhttps://jramapuram.github.io/jason.ramapuram@gmail.comhttps://jramapuram.github.io/ramblings/rnn-backrpop2016-06-06T00:00:00+00:002016-06-06T00:00:00+00:00Jason Ramapuramhttps://jramapuram.github.iojason.ramapuram@gmail.com
<p>Vanilla RNN’s are overviewed in detail in quite a few works of machine learning literature. However, I find that some of the intricate details to be a lacking. Particularly with what happens with layers surrounding the RNN layer. Furthermore when I first started I had questions like :</p>
<ul>
<li><code class="highlighter-rouge">Do we need to iterate all layers backprop_interval times?</code></li>
<li><code class="highlighter-rouge">Does every layer need to hold a list of inputs / outputs for each backprop_interval?</code></li>
<li><code class="highlighter-rouge">Do we need backprop_interval number of weights for the RNN layer?</code></li>
<li><code class="highlighter-rouge">Do we need to cache all the state derivatives of the RNN layer?</code></li>
</ul>
<p>For this reason we will work through the backprop-through-time equations for an RNN with a prior and post ANN layer and see how we can cache/reuse certain parameters.</p>
<p><img src="https://jramapuram.github.io/images/rnn_diag.png" alt="RNN_Diag" /></p>
<h1 id="forward-propagation">Forward Propagation</h1>
<p>It’s always best to start off defining what each variable means and assume a sample sizing.<br />
This ensures that we get our dimensions right along the way.</p>
<table>
<thead>
<tr>
<th>Variable Definition</th>
<th>Assumed Sizing</th>
</tr>
</thead>
<tbody>
<tr>
<td><script type="math/tex">x_t</script>: input at time t</td>
<td>[128 x 10]</td>
</tr>
<tr>
<td><script type="math/tex">a_t^0</script>: <script type="math/tex">ANN_0</script> activation at time t</td>
<td>[128 x 20]</td>
</tr>
<tr>
<td><script type="math/tex">h_t</script>: rnn output at time t</td>
<td>[128 x 5]</td>
</tr>
<tr>
<td><script type="math/tex">W_{x_t}^T, W_{h_t}^T, W_{y_t}^T</script>: Weights for <script type="math/tex">ANN_0</script>, RNN and <script type="math/tex">ANN_1</script></td>
<td>[10 x 20] , [20 x 5], [5 x 100]</td>
</tr>
<tr>
<td><script type="math/tex">U_{h_t}^T</script>: Recurrent weights at time t</td>
<td>[5 x 5]</td>
</tr>
<tr>
<td><script type="math/tex">b_{x_t}, b_{h_t}, b_{y_t}</script>: Biases for <script type="math/tex">ANN_0</script>, RNN and <script type="math/tex">ANN_1</script></td>
<td>[20], [5], [100]</td>
</tr>
<tr>
<td><script type="math/tex">\sigma_{0,1,2}</script>: Activation of first, second and third layers</td>
<td>[128 x 20], [128 x 5], [128 x 100]</td>
</tr>
<tr>
<td><script type="math/tex">\sigma_{0,1,2}^\prime</script>: Derivative of the activation of first, second and third layers</td>
<td>[128 x 20], [128 x 5], [128 x 100]</td>
</tr>
<tr>
<td><script type="math/tex">\hat{y_t}</script>: prediction at time t</td>
<td>[128, 100]</td>
</tr>
</tbody>
</table>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
a_t^0 &= \sigma_0(x_tW_{x_t}^T + b_{x_t}) \\
h_t &= \sigma_1(a_t^0W_{h_t}^T + h_{t-1}U_{h_t}^T + b_{h_t}) \\
\hat{y}_t &= \sigma_2(h_tW_{y_t}^T + b_{y_t}) \\
\mathcal{L} &= f(\hat{y}_t, y_t)
\end{aligned} %]]></script>
<p>We leave the loss <script type="math/tex">\mathcal{L}</script> to be arbitrary for generalization purposes.
An example loss could be an L2 loss for regression or perhaps a cross-entropy loss for classification.
We leave the sizing in <strong>transpose-weight</strong> notation because it keeps logic consistent with data being in the shape of [batch_size, feature]</p>
<h1 id="backpropagation-through-time">Backpropagation Through Time</h1>
<p><strong>The chain rule for the final ANN [i.e. the emitter of <script type="math/tex">\hat{y}_t</script>]:</strong></p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\frac{\delta\mathcal{L}}{\delta W_{y_t}} &= \frac{\delta \mathcal{L}}{\delta \hat{y}_t} \frac{\delta\hat{y}_t}{\delta W_{y_t}} = [\delta_t^{Loss}\odot\sigma_2^\prime(h_tW_{y_t}^T + b_{y_t})]^Th_t \ \ = [\delta_t^{Loss}\odot\sigma_2^\prime(z_{y_t})]^T h_t \\
\frac{\delta\mathcal{L}}{\delta b_{y_t}} &= \frac{\delta\mathcal{L}}{\delta\hat{y}_t} \frac{\delta\hat{y}_t} {\delta b_{y_t}} = \sum_{batch}[\delta_t^{Loss}\odot\sigma_2^\prime(h_tW_{y_t}^T + b_{y_t})] = \sum_{batch}[\delta_t^{Loss}\odot\sigma_2^\prime(z_{y_t})]
\end{aligned} %]]></script>
<p>Note that <script type="math/tex">\delta_t^{Loss}</script> is merely the loss derivative. For an L2 loss this is just <script type="math/tex">(\hat{y}_t - y)</script>.<br />
We then pass the following <script type="math/tex">\delta_t^{L-1}</script> back to the previous RNN layer:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\delta_t^{L-1} &= \frac{\delta\mathcal{L}}{\delta\hat{y}_t} \frac{\delta\hat{y}_t}{\delta h_t} = [\delta_t^{Loss}\odot\sigma_2^\prime(z_{y_t})]W_{y_t}
\end{aligned} %]]></script>
<p><strong>The chain rule for the RNN:</strong></p>
<p>A key point that makes the RNN different from a standard ANN is that the derivative of the hidden state is sent backwards through time and compounded through a simple addition operation. More formally this means:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\frac{\delta h_t}{\delta h_{t-1}} &= \sigma_1^\prime(a_t^0W_{h_t}^T + h_{t-1}U_{h_t}^T + b_{h_t}) U_{h_t} \ = \sigma_1^\prime(z_{h_t}) U_{h_t}
\end{aligned} %]]></script>
<p>This value is sent backwards through time from the final BPTT time step (<script type="math/tex">T_f</script> below) until till time t=0.
In computation terms this means that <script type="math/tex">\frac{\delta {h_t}}{\delta h_{t-1}}</script> needs to be cached and added to itself.
Listed below is an example of two time steps of accumulation of the hidden state:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\Delta h_{T_f} &= \frac{\delta h_{T_f}}{\delta h_{T_{f-1}}} \\
\Delta h_{T_{f-1}} &= \Delta h_{T_f} + \frac{\delta h_{T_{f-1}}}{\delta h_{T_{f-2}}}
\end{aligned} %]]></script>
<p>This process is repeated up till the last unroll after which the hidden state’s derivative accumulator is zero’d out (you can also do this when you zero out the gradients in the optimizer).</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\frac{\delta\mathcal{L}}{\delta W_{h_t}} &= \frac{\delta \mathcal{L}}{\delta \hat{y}_t} \frac{\delta\hat{y}_t}{\delta h_t} \frac{\delta h_t}{\delta W_{h_t}} = [(\delta_t^{L-1} + \Delta h_t)\odot\sigma_1^\prime(a_t^0W_{h_t}^T + h_{t-1}U_{h_t}^T + b_{h_t})]^Ta_t^0 \ \ \ = [(\delta_t^{L-1} + \Delta h_t)\odot\sigma_1^\prime(z_{h_{t}})]^Ta_t^0 \\
\frac{\delta\mathcal{L}}{\delta U_{h_t}} &= \frac{\delta \mathcal{L}}{\delta \hat{y}_t} \frac{\delta\hat{y}_t}{\delta h_t} \frac{\delta h_t}{\delta U_{h_t}} = [(\delta_t^{L-1} + \Delta h_t)\odot\sigma_1^\prime(a_t^0W_{h_t}^T + h_{t-1}U_{h_t}^T + b_{h_t})]^Th_{t-1} = [(\delta_t^{L-1} + \Delta h_t)\odot\sigma_1^\prime(z_{h_{t}})]^Th_{t-1} \\
\frac{\delta\mathcal{L}}{\delta b_{h_t}} &= \frac{\delta \mathcal{L}}{\delta \hat{y}_t} \frac{\delta\hat{y}_t}{\delta h_t} \frac{\delta h_t}{\delta b_{h_t}} = \sum_{batch}[(\delta_t^{L-1} + \Delta h_t)\odot\sigma_1^\prime(a_t^0W_{h_t}^T + h_{t-1}U_{h_t}^T + b_{h_t})] \ \ \ = \sum_{batch}[(\delta_t^{L-1} + \Delta h_t)\odot\sigma_1^\prime(z_{h_{t}})]
\end{aligned} %]]></script>
<p>As shown above we have that all the parameters of the RNN depend on <script type="math/tex">h_{t-1}</script> and that is defined as:</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
h_{t-1} &= \sigma_1(a_{t-1}^0W_{h_{t-1}}^T + h_{t-2}U_{h_{t-1}}^T + b_{h_{t-1}}) \\
\end{aligned} %]]></script>
<p>An insight here is that the parameters are shared across time; we don’t in fact use different parameters for the RNN, but merely share and jointly update them using BPTT. So with this little tidbit of knowledge we can rewrite our equation from above as such :</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
h_{t-1} &= \sigma_1(a_{t-1}^0W_{h_t}^T + h_{t-2}U_{h_t}^T + b_{h_{t}}) \\
\end{aligned} %]]></script>
<p>This leaves us with the following unsatisfied requirement: <script type="math/tex">a_{t-1}^0</script> which is the activated response from the first ANN. Before we sort out the logistics of the entire BPTT algorithm let’s just write out the equations for the first layer.</p>
<p>The RNN layer will emit <script type="math/tex">\delta_t^{L-2}</script> which will be used to update the parameters of the first ANN.</p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\delta_t^{L-2} &= \frac{\delta\mathcal{L}}{\delta\hat{y}_t} \frac{\delta\hat{y}_t}{\delta h_t} \frac{\delta h_t}{\delta a_t^0} = [(\delta_t^{L-1} + \Delta h_t)\odot\sigma_1^\prime(z_{h_t})]W_{h_t} \\
\end{aligned} %]]></script>
<p><strong>The chain rule for the first ANN [i.e. the emitter of <script type="math/tex">a_t^0</script>]:</strong></p>
<script type="math/tex; mode=display">% <![CDATA[
\begin{aligned}
\frac{\delta\mathcal{L}}{\delta W_{x_t}} &= \frac{\delta \mathcal{L}}{\delta \hat{y}_t} \frac{\delta\hat{y}_t}{\delta h_t} \frac{\delta h_t}{\delta a_t^0} \frac{\delta a_t^0}{\delta W_{x_t}} = [\delta_t^{L-2}\odot\sigma_0^\prime(x_t W_{x_t}^T + b_{x_t})]^T x_t = [\delta_t^{L-2} \odot \sigma_0^\prime(z_{x_t})]^T x_t \\
\frac{\delta\mathcal{L}}{\delta b_{x_t}} &= \frac{\delta \mathcal{L}}{\delta \hat{y}_t} \frac{\delta\hat{y}_t}{\delta h_t} \frac{\delta h_t}{\delta a_t^0} \frac{\delta a_t^0}{\delta b_{x_t}} = \sum_{batch}[\delta_t^{L-2}\odot\sigma_0^\prime(x_t W_{x_t}^T + b_{x_t})] = \sum_{batch}[\delta_t^{L-2} \odot \sigma_0^\prime(z_{x_t})]
\end{aligned} %]]></script>
<p><strong>Sizing Verification:</strong></p>
<p>Before we move on let’s just verify that the size of the following are the same. <br />
This is left as an exercise to the reader.</p>
<table>
<thead>
<tr>
<th>Variable Name</th>
<th>Desired Size</th>
</tr>
</thead>
<tbody>
<tr>
<td><script type="math/tex">\frac{\delta\mathcal{L}}{\delta W_{y_t}^T}</script> : derivative of <script type="math/tex">ANN_1</script>’s weight matrix</td>
<td>[5 x 100]</td>
</tr>
<tr>
<td><script type="math/tex">\frac{\delta\mathcal{L}}{\delta b_{y_t}}</script> : derivative of <script type="math/tex">ANN_1</script>’s bias vector</td>
<td>[100]</td>
</tr>
<tr>
<td><script type="math/tex">\frac{\delta\mathcal{L}}{\delta W_{h_t}^T}</script> : derivative of RNN input-to-hidden weight matrix</td>
<td>[20 x 5]</td>
</tr>
<tr>
<td><script type="math/tex">\frac{\delta\mathcal{L}}{\delta U_{h_t}^T}</script> : derivative of the RNN’s recurrent weight matrix</td>
<td>[5 x 5]</td>
</tr>
<tr>
<td><script type="math/tex">\frac{\delta\mathcal{L}}{\delta b_{h_t}}</script> : derivative of the RNN’s bias vector</td>
<td>[5]</td>
</tr>
<tr>
<td><script type="math/tex">\frac{\delta\mathcal{L}}{\delta W_{x_t}^T}</script> : derivative of <script type="math/tex">ANN_0</script>’s weight matrix</td>
<td>[10 x 20]</td>
</tr>
<tr>
<td><script type="math/tex">\frac{\delta\mathcal{L}}{\delta b_{x_t}}</script> : derivative of <script type="math/tex">ANN_0</script>’s bias vector</td>
<td>[20]</td>
</tr>
<tr>
<td><script type="math/tex">\delta_t^{Loss}</script> : loss delta</td>
<td>[128 x 100]</td>
</tr>
<tr>
<td><script type="math/tex">\delta_t^{L-1}</script> : delta from <script type="math/tex">ANN_1</script> to RNN</td>
<td>[128 x 5]</td>
</tr>
<tr>
<td><script type="math/tex">\delta_t^{L-2}</script> : delta from RNN to <script type="math/tex">ANN_0</script></td>
<td>[128 x 20]</td>
</tr>
</tbody>
</table>
<p><strong>Step-by-step Forward / Backward Procedure:</strong></p>
<p>Let’s think about this from another way: let’s start by writing the above equations from <script type="math/tex">t=0</script>.<br />
In the example below we will walk through three time steps [i.e our BPTT interval is 3]:</p>
<p><img src="https://jramapuram.github.io/images/rnn_forward.gif" alt="RNN_Forward" /></p>
<ol>
<li>We start off by initializing the first hidden layer to be zero since we have no data up to this point.</li>
<li>We can then use <script type="math/tex">h_{-1}</script> and the new <script type="math/tex">x_0</script> to calculate <script type="math/tex">a_0</script>, <script type="math/tex">h_0</script>, <script type="math/tex">\hat{y}_0</script> and finally <script type="math/tex">L_0</script></li>
<li>We repeat this (i.e. we pass the h from the current time to the next time step) and do this for the BPTT interval.</li>
</ol>
<p>It is good to note here that each of these are vectors. While the loss in the end is reduced via a sum / mean to a single number, for the purposes of backpropagation it is still a vector (eg: <script type="math/tex">[\hat{y} - y]</script>).</p>
<p><img src="https://jramapuram.github.io/images/rnn_backward.gif" alt="RNN_Backward" /></p>
<p>The backward pass is shown above. The key here is that you just update one time slice at a time in the same way you would in a classical neural network.
So to revisit the questions posed earlier:</p>
<ul>
<li><code class="highlighter-rouge">Do we need to iterate all layers backprop_interval times?</code> Yes, you need to get all the loss vectors for the corresponding input minibatches (at each timestep).</li>
<li><code class="highlighter-rouge">Does every layer need to hold a list of inputs / outputs for each backprop_interval?</code> Yes. SGD updates are only done after backprop_interval forward pass operations. Since this is the case each layer will need to hold the input / output mapping for each one of those forward passes. This is what makes an RNN trained with backprop-through-time memory intensive as the memory scales with the length of the BPTT interval.</li>
<li><code class="highlighter-rouge">Do we need backprop_interval number of weights for the RNN layer?</code> No! They are shared weights/biases between all the recurrent layers and they are updated in one fail swoop!</li>
<li><code class="highlighter-rouge">Do we need to cache all the state derivatives of the RNN layer?</code> No! Since they are compounded via a simple addition operator you merely need to keep one value and recursively add to it.</li>
</ul>
<h2 id="issues">Issues</h2>
<p>If you find any errors with any of the math or logic here please leave a comment below.</p>
<p><a href="https://jramapuram.github.io/ramblings/rnn-backrpop/">RNN Backprop Through Time Equations</a> was originally published by Jason Ramapuram at <a href="https://jramapuram.github.io">Back Propaganda</a> on June 06, 2016.</p>