Skip to content

Commit a954763

Browse files
tengyifeipgmoka
authored andcommitted
Introduce a GRU module implemented with scan (#8777)
1 parent 76bdd9d commit a954763

File tree

4 files changed

+363
-0
lines changed

4 files changed

+363
-0
lines changed

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ function run_xla_op_tests2 {
197197
run_test "$CDIR/scan/test_scan.py"
198198
run_test "$CDIR/scan/test_scan_spmd.py"
199199
run_test "$CDIR/scan/test_scan_layers.py"
200+
run_test "$CDIR/test_gru.py"
200201
run_test "$CDIR/test_as_stride_use_slice.py"
201202
run_xla_hlo_debug run_test "$CDIR/scan/test_scan_debug.py"
202203
run_test "$CDIR/test_autocast.py"

test/test_gru.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
import torch_xla
5+
from torch_xla.experimental.gru import GRU
6+
7+
from absl.testing import absltest, parameterized
8+
9+
10+
class TestGRU(parameterized.TestCase):
11+
12+
def setUp(self):
13+
super().setUp()
14+
torch.manual_seed(0)
15+
torch_xla.manual_seed(0)
16+
17+
def build_models(self, input_size, hidden_size, num_layers, bias):
18+
gru = nn.GRU(
19+
input_size,
20+
hidden_size,
21+
num_layers=num_layers,
22+
bias=bias,
23+
batch_first=False,
24+
dropout=0.0,
25+
bidirectional=False)
26+
scan_gru = GRU(
27+
input_size, hidden_size, num_layers=num_layers, bias=bias, dropout=0.0)
28+
29+
# Copy parameters from the upstream GRU to our scan-based GRU.
30+
for layer in range(num_layers):
31+
scan_gru.weight_ih[layer].data.copy_(
32+
getattr(gru, f'weight_ih_l{layer}').data)
33+
scan_gru.weight_hh[layer].data.copy_(
34+
getattr(gru, f'weight_hh_l{layer}').data)
35+
if gru.bias:
36+
scan_gru.bias_ih[layer].data.copy_(
37+
getattr(gru, f'bias_ih_l{layer}').data)
38+
scan_gru.bias_hh[layer].data.copy_(
39+
getattr(gru, f'bias_hh_l{layer}').data)
40+
41+
return gru, scan_gru
42+
43+
def check_gradients(self,
44+
inp1,
45+
hx1,
46+
inp2,
47+
hx2,
48+
num_layers,
49+
gru,
50+
scan_gru,
51+
atol=None,
52+
rtol=None):
53+
# Compare gradients for input and initial hidden state.
54+
assert inp1.grad is not None
55+
assert hx1.grad is not None
56+
assert inp2.grad is not None
57+
assert hx2.grad is not None
58+
torch.testing.assert_close(
59+
inp1.grad,
60+
inp2.grad,
61+
msg=lambda msg: f"Input gradient mismatch. {msg}",
62+
check_device=False,
63+
atol=atol,
64+
rtol=rtol)
65+
torch.testing.assert_close(
66+
hx1.grad,
67+
hx2.grad,
68+
msg=lambda msg: f"Hidden state gradient mismatch. {msg}",
69+
check_device=False,
70+
atol=atol,
71+
rtol=rtol)
72+
73+
# Compare gradients for all parameters.
74+
params_to_check = ['weight_ih', 'weight_hh']
75+
assert scan_gru.bias == gru.bias
76+
if scan_gru.bias:
77+
params_to_check += ['bias_ih', 'bias_hh']
78+
for layer in range(num_layers):
79+
for name in params_to_check:
80+
param1 = getattr(gru, f'{name}_l{layer}')
81+
param2 = getattr(scan_gru, name)[layer]
82+
torch.testing.assert_close(
83+
param1.grad,
84+
param2.grad,
85+
msg=lambda msg:
86+
f"Gradient mismatch in {name} at layer {layer}. {msg}",
87+
check_device=False,
88+
atol=atol,
89+
rtol=rtol)
90+
91+
@parameterized.parameters(True, False)
92+
def test_scan_gru_vs_pytorch_xla_for_loop(self, bias):
93+
"""
94+
Compare scan-based GRU against upstream GRU both compiled with XLA.
95+
"""
96+
seq_len, batch_size, input_size, hidden_size, num_layers = 16, 4, 16, 32, 2
97+
gru, scan_gru = self.build_models(input_size, hidden_size, num_layers, bias)
98+
gru, scan_gru = gru.to('xla'), scan_gru.to('xla')
99+
torch_xla.sync()
100+
101+
# Prepare input and initial hidden states.
102+
inp1 = torch.randn(seq_len, batch_size,
103+
input_size).to('xla').requires_grad_(True)
104+
inp2 = inp1.clone().detach().requires_grad_(True)
105+
hx1 = torch.randn(num_layers, batch_size,
106+
hidden_size).to('xla').requires_grad_(True)
107+
hx2 = hx1.clone().detach().requires_grad_(True)
108+
torch_xla.sync()
109+
110+
# Forward passes.
111+
out1, h1 = gru(inp1, hx1)
112+
torch_xla.sync()
113+
114+
out2, h2 = scan_gru(inp2, hx2)
115+
torch_xla.sync()
116+
117+
# Compare the numerical outputs.
118+
torch.testing.assert_close(out1, out2, check_device=False)
119+
torch.testing.assert_close(h1, h2, check_device=False)
120+
121+
# Compute losses.
122+
loss1 = out1.sum() + h1.sum()
123+
loss2 = out2.sum() + h2.sum()
124+
125+
# Backward passes.
126+
loss1.backward()
127+
loss2.backward()
128+
torch_xla.sync()
129+
130+
self.check_gradients(inp1, hx1, inp2, hx2, num_layers, gru, scan_gru)
131+
132+
@parameterized.parameters(True, False)
133+
def test_scan_gru_vs_pytorch_native_cpu(self, bias):
134+
"""
135+
Compare scan-based GRU compiled with XLA against upstream GRU run with PyTorch eager.
136+
"""
137+
seq_len, batch_size, input_size, hidden_size, num_layers = 2048, 4, 16, 32, 5
138+
gru, scan_gru = self.build_models(input_size, hidden_size, num_layers, bias)
139+
gru = gru.cpu()
140+
scan_gru = scan_gru.to('xla')
141+
torch_xla.sync()
142+
143+
# Prepare input and initial hidden states.
144+
inp1 = torch.randn(seq_len, batch_size, input_size).requires_grad_(True)
145+
inp2 = inp1.to('xla').clone().detach().requires_grad_(True)
146+
hx1 = torch.randn(num_layers, batch_size, hidden_size).requires_grad_(True)
147+
hx2 = hx1.to('xla').clone().detach().requires_grad_(True)
148+
torch_xla.sync()
149+
150+
# Forward passes.
151+
out1, h1 = gru(inp1, hx1)
152+
torch_xla.sync()
153+
154+
out2, h2 = scan_gru(inp2, hx2)
155+
torch_xla.sync()
156+
157+
# Compare the numerical outputs.
158+
torch.testing.assert_close(
159+
out1, out2, check_device=False, atol=1e-3, rtol=1e-3)
160+
torch.testing.assert_close(h1, h2, check_device=False, atol=1e-3, rtol=1e-3)
161+
162+
# Compute losses.
163+
loss1 = out1.sum() + h1.sum()
164+
loss2 = out2.sum() + h2.sum()
165+
166+
# Backward passes.
167+
loss1.backward()
168+
loss2.backward()
169+
torch_xla.sync()
170+
171+
# Gradient thresholds are relaxed because numerical differences between TPU
172+
# and CPU adds up to a non-trivial impact over 2048 steps.
173+
self.check_gradients(
174+
inp1, hx1, inp2, hx2, num_layers, gru, scan_gru, atol=0.05, rtol=0.05)
175+
176+
177+
if __name__ == "__main__":
178+
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
179+
absltest.main()

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ python3 "$TEST_CDIR/scan/test_scan.py"
3434
python3 "$TEST_CDIR/scan/test_scan_spmd.py"
3535
python3 "$TEST_CDIR/scan/test_scan_pallas.py"
3636
python3 "$TEST_CDIR/scan/test_scan_layers.py"
37+
python3 "$TEST_CDIR/test_gru.py"
3738
python3 "$TEST_CDIR/test_as_stride_use_slice.py"
3839
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
3940
python3 "$TEST_CDIR/test_pallas.py" -v

torch_xla/experimental/gru.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from torch_xla.experimental.scan import scan
6+
7+
8+
class GRU(nn.Module):
9+
r"""
10+
PyTorch/XLA GRU implemented using scan.
11+
12+
Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
13+
For each element in the input sequence, each layer computes the following
14+
function:
15+
16+
.. math::
17+
\begin{array}{ll}
18+
r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
19+
z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
20+
n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\
21+
h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
22+
\end{array}
23+
24+
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
25+
at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
26+
at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
27+
:math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
28+
:math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
29+
30+
In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
31+
(:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
32+
dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
33+
variable which is :math:`0` with probability :attr:`dropout`.
34+
35+
Args:
36+
input_size: The number of expected features in the input `x`
37+
hidden_size: The number of features in the hidden state `h`
38+
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
39+
would mean stacking two GRUs together to form a `stacked GRU`,
40+
with the second GRU taking in outputs of the first GRU and
41+
computing the final results. Default: 1
42+
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
43+
Default: ``True``
44+
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
45+
GRU layer except the last layer, with dropout probability equal to
46+
:attr:`dropout`. Default: 0
47+
48+
This implementation has the following differences from the GRU module in PyTorch upstream:
49+
50+
- Only supports unidirectional GRU.
51+
- Only supports inputs in the `(seq, batch, feature)` format (i.e. `batch_first = False`).
52+
53+
"""
54+
55+
def __init__(self,
56+
input_size,
57+
hidden_size,
58+
num_layers=1,
59+
bias=True,
60+
dropout=0.0):
61+
super().__init__()
62+
63+
self.input_size = input_size
64+
self.hidden_size = hidden_size
65+
self.num_layers = num_layers
66+
self.bias = bias
67+
self.dropout = dropout
68+
69+
# Create parameters for each layer.
70+
# For layer 0, the input dimension is `input_size`, otherwise it's `hidden_size`.
71+
self.weight_ih = nn.ParameterList()
72+
self.weight_hh = nn.ParameterList()
73+
if bias:
74+
self.bias_ih = nn.ParameterList()
75+
self.bias_hh = nn.ParameterList()
76+
77+
for layer in range(num_layers):
78+
layer_input_size = input_size if layer == 0 else hidden_size
79+
# weight_ih: combines weights for reset, update, and new gates.
80+
w_ih = nn.Parameter(torch.Tensor(3 * hidden_size, layer_input_size))
81+
w_hh = nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size))
82+
self.weight_ih.append(w_ih)
83+
self.weight_hh.append(w_hh)
84+
if bias:
85+
b_ih = nn.Parameter(torch.Tensor(3 * hidden_size))
86+
b_hh = nn.Parameter(torch.Tensor(3 * hidden_size))
87+
self.bias_ih.append(b_ih)
88+
self.bias_hh.append(b_hh)
89+
self.reset_parameters()
90+
91+
def reset_parameters(self):
92+
# Initialize parameters uniformly as in the upstream PyTorch GRU.
93+
stdv = 1.0 / (self.hidden_size**0.5)
94+
for weight in self.parameters():
95+
weight.data.uniform_(-stdv, stdv)
96+
97+
def forward(self, input, hx=None):
98+
"""
99+
Args:
100+
input: Tensor of shape (seq_len, batch, input_size)
101+
hx: Optional initial hidden state of shape (num_layers, batch, hidden_size).
102+
If not provided, defaults to zeros.
103+
Returns:
104+
output: Tensor of shape (seq_len, batch, hidden_size) from the last GRU layer.
105+
hidden: Tensor of shape (num_layers, batch, hidden_size) containing the final hidden state per layer.
106+
"""
107+
seq_len, batch_size, _ = input.size()
108+
if hx is None:
109+
hx = input.new_zeros(self.num_layers, batch_size, self.hidden_size)
110+
else:
111+
assert hx.size(0) == self.num_layers, \
112+
"Mismatch in number of layers for hidden state."
113+
114+
# The output of one layer is the input to the next.
115+
output = input
116+
hidden_states = []
117+
118+
# Loop over each layer.
119+
for layer in range(self.num_layers):
120+
init = {
121+
'h': hx[layer],
122+
'w_ih': self.weight_ih[layer],
123+
'w_hh': self.weight_hh[layer]
124+
}
125+
if self.bias:
126+
init['b_ih'] = self.bias_ih[layer]
127+
init['b_hh'] = self.bias_hh[layer]
128+
129+
# Define the step function for scanning over time.
130+
# x_t: (batch, current_input_size)
131+
# h: (batch, hidden_size)
132+
# carry: dictionary containing h and weights/biases.
133+
def step_fn(carry, x_t):
134+
h = carry['h']
135+
w_ih = carry['w_ih']
136+
w_hh = carry['w_hh']
137+
b_ih = carry.get('b_ih')
138+
b_hh = carry.get('b_hh')
139+
140+
# Get input projections
141+
x_linear = F.linear(x_t, w_ih, b_ih)
142+
x_r, x_z, x_n = x_linear.chunk(3, dim=1)
143+
144+
# Get hidden projections
145+
h_linear = F.linear(h, w_hh, b_hh)
146+
h_r, h_z, h_n = h_linear.chunk(3, dim=1)
147+
148+
# Compute reset and update gates
149+
r = torch.sigmoid(x_r + h_r)
150+
z = torch.sigmoid(x_z + h_z)
151+
152+
# Compute the new gate with proper reset gate application
153+
n = torch.tanh(x_n + r * h_n)
154+
155+
# Update hidden state
156+
h_new = (1 - z) * n + z * h
157+
158+
carry_new = {
159+
'h': h_new,
160+
'w_ih': w_ih,
161+
'w_hh': w_hh,
162+
}
163+
if b_ih is not None:
164+
carry_new['b_ih'] = b_ih
165+
if b_hh is not None:
166+
carry_new['b_hh'] = b_hh
167+
return carry_new, h_new
168+
169+
# Use scan to iterate over the time dimension.
170+
# Here, scan(fn, init, xs) applies step_fn to each time slice of `output`.
171+
final_carry, layer_output = scan(fn=step_fn, init=init, xs=output)
172+
hidden_states.append(final_carry['h'])
173+
# Apply dropout on the output of the current layer (if not the final layer).
174+
if layer < self.num_layers - 1 and self.dropout > 0:
175+
layer_output = F.dropout(
176+
layer_output, p=self.dropout, training=self.training)
177+
output = layer_output
178+
assert output.size(0) == seq_len
179+
180+
# Stack the final hidden states for each layer.
181+
hidden = torch.stack(hidden_states, dim=0)
182+
return output, hidden

0 commit comments

Comments
 (0)