Skip to content

Commit d1242e6

Browse files
authored
Merge pull request #804 from sl-sergei:multi_output_gather
Add testdata for Gather with multiple outputs
2 parents 77b41e6 + b2c4f02 commit d1242e6

File tree

4 files changed

+17
-0
lines changed

4 files changed

+17
-0
lines changed
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

+17
Original file line numberDiff line numberDiff line change
@@ -1023,3 +1023,20 @@ def forward(self, x, kernel, bias):
10231023
bias = Variable(torch.randn(4))
10241024
model = Conv1dBias()
10251025
save_data_and_model_multy_inputs("conv1d_variable_wb", model, x, kernel, bias)
1026+
1027+
class GatherMultiOutput(nn.Module):
1028+
def __init__(self, in_dim = 2):
1029+
super(GatherMultiOutput, self).__init__()
1030+
self.in_dim = in_dim
1031+
self.lin_inp = nn.Linear(in_dim, 2, bias=False)
1032+
def forward(self, x):
1033+
x_projected = self.lin_inp(x).long()
1034+
x_gather = x_projected[:,0,:]
1035+
x_float1 = x_gather.float()
1036+
x_float2 = x_gather.float()
1037+
x_float3 = x_gather.float()
1038+
return x_float1+x_float2+x_float3
1039+
1040+
x = Variable(torch.zeros([1, 2, 2]))
1041+
model = GatherMultiOutput()
1042+
save_data_and_model("gather_multi_output", x, model)
432 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)