|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | # pyre-strict
|
| 8 | +import copy |
8 | 9 | import os
|
9 | 10 | import tempfile
|
10 | 11 | import unittest
|
@@ -1639,3 +1640,36 @@ def forward(self, x):
|
1639 | 1640 | FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run(
|
1640 | 1641 | gm.code
|
1641 | 1642 | )
|
| 1643 | + |
| 1644 | + def test_constant_prop_pass_for_no_grad(self) -> None: |
| 1645 | + class LSTM(torch.nn.Module): |
| 1646 | + def __init__(self, input_size, hidden_size, num_layers): |
| 1647 | + super(LSTM, self).__init__() |
| 1648 | + self.hidden_size = hidden_size |
| 1649 | + self.num_layers = num_layers |
| 1650 | + self.lstm = torch.nn.LSTM( |
| 1651 | + input_size, hidden_size, num_layers, batch_first=True |
| 1652 | + ) |
| 1653 | + |
| 1654 | + def forward(self, text_tokens): |
| 1655 | + # input: (seq_len, batch, input_size) |
| 1656 | + lstm_out, (new_hidden_state, new_cell_state) = self.lstm( |
| 1657 | + input=text_tokens, hx=None |
| 1658 | + ) |
| 1659 | + return lstm_out |
| 1660 | + |
| 1661 | + lstm = LSTM(input_size=200, hidden_size=203, num_layers=2) |
| 1662 | + example_input = (torch.rand(2, 10, 200),) |
| 1663 | + |
| 1664 | + aten = torch.export.export(lstm, example_input, strict=False) |
| 1665 | + _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( |
| 1666 | + _check_ir_validity=True, |
| 1667 | + _skip_dim_order=True, # TODO(T189114319): Reuse dim order op after solving the ios oss issue |
| 1668 | + ) |
| 1669 | + |
| 1670 | + edge_manager: EdgeProgramManager = to_edge( |
| 1671 | + aten, |
| 1672 | + compile_config=_EDGE_COMPILE_CONFIG, |
| 1673 | + ) |
| 1674 | + new_ep = constant_prop_pass(edge_manager._edge_programs["forward"]) |
| 1675 | + _ = copy.deepcopy(new_ep.module_call_graph) |
0 commit comments