|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | + |
| 5 | +from torch_xla.experimental.scan import scan |
| 6 | + |
| 7 | + |
| 8 | +class GRU(nn.Module): |
| 9 | + r""" |
| 10 | + PyTorch/XLA GRU implemented using scan. |
| 11 | +
|
| 12 | + Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence. |
| 13 | + For each element in the input sequence, each layer computes the following |
| 14 | + function: |
| 15 | +
|
| 16 | + .. math:: |
| 17 | + \begin{array}{ll} |
| 18 | + r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ |
| 19 | + z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ |
| 20 | + n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ |
| 21 | + h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} |
| 22 | + \end{array} |
| 23 | +
|
| 24 | + where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input |
| 25 | + at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer |
| 26 | + at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, |
| 27 | + :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. |
| 28 | + :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. |
| 29 | +
|
| 30 | + In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer |
| 31 | + (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by |
| 32 | + dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random |
| 33 | + variable which is :math:`0` with probability :attr:`dropout`. |
| 34 | +
|
| 35 | + Args: |
| 36 | + input_size: The number of expected features in the input `x` |
| 37 | + hidden_size: The number of features in the hidden state `h` |
| 38 | + num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` |
| 39 | + would mean stacking two GRUs together to form a `stacked GRU`, |
| 40 | + with the second GRU taking in outputs of the first GRU and |
| 41 | + computing the final results. Default: 1 |
| 42 | + bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. |
| 43 | + Default: ``True`` |
| 44 | + dropout: If non-zero, introduces a `Dropout` layer on the outputs of each |
| 45 | + GRU layer except the last layer, with dropout probability equal to |
| 46 | + :attr:`dropout`. Default: 0 |
| 47 | +
|
| 48 | + This implementation has the following differences from the GRU module in PyTorch upstream: |
| 49 | +
|
| 50 | + - Only supports unidirectional GRU. |
| 51 | + - Only supports inputs in the `(seq, batch, feature)` format (i.e. `batch_first = False`). |
| 52 | +
|
| 53 | + """ |
| 54 | + |
| 55 | + def __init__(self, |
| 56 | + input_size, |
| 57 | + hidden_size, |
| 58 | + num_layers=1, |
| 59 | + bias=True, |
| 60 | + dropout=0.0): |
| 61 | + super().__init__() |
| 62 | + |
| 63 | + self.input_size = input_size |
| 64 | + self.hidden_size = hidden_size |
| 65 | + self.num_layers = num_layers |
| 66 | + self.bias = bias |
| 67 | + self.dropout = dropout |
| 68 | + |
| 69 | + # Create parameters for each layer. |
| 70 | + # For layer 0, the input dimension is `input_size`, otherwise it's `hidden_size`. |
| 71 | + self.weight_ih = nn.ParameterList() |
| 72 | + self.weight_hh = nn.ParameterList() |
| 73 | + if bias: |
| 74 | + self.bias_ih = nn.ParameterList() |
| 75 | + self.bias_hh = nn.ParameterList() |
| 76 | + |
| 77 | + for layer in range(num_layers): |
| 78 | + layer_input_size = input_size if layer == 0 else hidden_size |
| 79 | + # weight_ih: combines weights for reset, update, and new gates. |
| 80 | + w_ih = nn.Parameter(torch.Tensor(3 * hidden_size, layer_input_size)) |
| 81 | + w_hh = nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size)) |
| 82 | + self.weight_ih.append(w_ih) |
| 83 | + self.weight_hh.append(w_hh) |
| 84 | + if bias: |
| 85 | + b_ih = nn.Parameter(torch.Tensor(3 * hidden_size)) |
| 86 | + b_hh = nn.Parameter(torch.Tensor(3 * hidden_size)) |
| 87 | + self.bias_ih.append(b_ih) |
| 88 | + self.bias_hh.append(b_hh) |
| 89 | + self.reset_parameters() |
| 90 | + |
| 91 | + def reset_parameters(self): |
| 92 | + # Initialize parameters uniformly as in the upstream PyTorch GRU. |
| 93 | + stdv = 1.0 / (self.hidden_size**0.5) |
| 94 | + for weight in self.parameters(): |
| 95 | + weight.data.uniform_(-stdv, stdv) |
| 96 | + |
| 97 | + def forward(self, input, hx=None): |
| 98 | + """ |
| 99 | + Args: |
| 100 | + input: Tensor of shape (seq_len, batch, input_size) |
| 101 | + hx: Optional initial hidden state of shape (num_layers, batch, hidden_size). |
| 102 | + If not provided, defaults to zeros. |
| 103 | + Returns: |
| 104 | + output: Tensor of shape (seq_len, batch, hidden_size) from the last GRU layer. |
| 105 | + hidden: Tensor of shape (num_layers, batch, hidden_size) containing the final hidden state per layer. |
| 106 | + """ |
| 107 | + seq_len, batch_size, _ = input.size() |
| 108 | + if hx is None: |
| 109 | + hx = input.new_zeros(self.num_layers, batch_size, self.hidden_size) |
| 110 | + else: |
| 111 | + assert hx.size(0) == self.num_layers, \ |
| 112 | + "Mismatch in number of layers for hidden state." |
| 113 | + |
| 114 | + # The output of one layer is the input to the next. |
| 115 | + output = input |
| 116 | + hidden_states = [] |
| 117 | + |
| 118 | + # Loop over each layer. |
| 119 | + for layer in range(self.num_layers): |
| 120 | + init = { |
| 121 | + 'h': hx[layer], |
| 122 | + 'w_ih': self.weight_ih[layer], |
| 123 | + 'w_hh': self.weight_hh[layer] |
| 124 | + } |
| 125 | + if self.bias: |
| 126 | + init['b_ih'] = self.bias_ih[layer] |
| 127 | + init['b_hh'] = self.bias_hh[layer] |
| 128 | + |
| 129 | + # Define the step function for scanning over time. |
| 130 | + # x_t: (batch, current_input_size) |
| 131 | + # h: (batch, hidden_size) |
| 132 | + # carry: dictionary containing h and weights/biases. |
| 133 | + def step_fn(carry, x_t): |
| 134 | + h = carry['h'] |
| 135 | + w_ih = carry['w_ih'] |
| 136 | + w_hh = carry['w_hh'] |
| 137 | + b_ih = carry.get('b_ih') |
| 138 | + b_hh = carry.get('b_hh') |
| 139 | + |
| 140 | + # Get input projections |
| 141 | + x_linear = F.linear(x_t, w_ih, b_ih) |
| 142 | + x_r, x_z, x_n = x_linear.chunk(3, dim=1) |
| 143 | + |
| 144 | + # Get hidden projections |
| 145 | + h_linear = F.linear(h, w_hh, b_hh) |
| 146 | + h_r, h_z, h_n = h_linear.chunk(3, dim=1) |
| 147 | + |
| 148 | + # Compute reset and update gates |
| 149 | + r = torch.sigmoid(x_r + h_r) |
| 150 | + z = torch.sigmoid(x_z + h_z) |
| 151 | + |
| 152 | + # Compute the new gate with proper reset gate application |
| 153 | + n = torch.tanh(x_n + r * h_n) |
| 154 | + |
| 155 | + # Update hidden state |
| 156 | + h_new = (1 - z) * n + z * h |
| 157 | + |
| 158 | + carry_new = { |
| 159 | + 'h': h_new, |
| 160 | + 'w_ih': w_ih, |
| 161 | + 'w_hh': w_hh, |
| 162 | + } |
| 163 | + if b_ih is not None: |
| 164 | + carry_new['b_ih'] = b_ih |
| 165 | + if b_hh is not None: |
| 166 | + carry_new['b_hh'] = b_hh |
| 167 | + return carry_new, h_new |
| 168 | + |
| 169 | + # Use scan to iterate over the time dimension. |
| 170 | + # Here, scan(fn, init, xs) applies step_fn to each time slice of `output`. |
| 171 | + final_carry, layer_output = scan(fn=step_fn, init=init, xs=output) |
| 172 | + hidden_states.append(final_carry['h']) |
| 173 | + # Apply dropout on the output of the current layer (if not the final layer). |
| 174 | + if layer < self.num_layers - 1 and self.dropout > 0: |
| 175 | + layer_output = F.dropout( |
| 176 | + layer_output, p=self.dropout, training=self.training) |
| 177 | + output = layer_output |
| 178 | + assert output.size(0) == seq_len |
| 179 | + |
| 180 | + # Stack the final hidden states for each layer. |
| 181 | + hidden = torch.stack(hidden_states, dim=0) |
| 182 | + return output, hidden |
0 commit comments