|
| 1 | +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s |
| 2 | + |
| 3 | +// CHECK-LABEL: func.func @torch.aten.view$view_like( |
| 4 | +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { |
| 5 | +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32> |
| 6 | +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 |
| 7 | +// CHECK: %[[INT224:.*]] = torch.constant.int 224 |
| 8 | +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]224 : (!torch.int, !torch.int) -> !torch.list<int> |
| 9 | +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1 |
| 10 | +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]] |
| 11 | +// CHECK: %[[T4:.*]] = arith.trunci %[[T2]] : i64 to i32 |
| 12 | +// CHECK: %[[T5:.*]] = arith.trunci %[[T3]] : i64 to i32 |
| 13 | +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T4]], %[[T5]] : tensor<2xi32> |
| 14 | +// CHECK: %[[T7:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T6]]) : (tensor<?x?x?x?xf32>, tensor<2xi32>) -> tensor<?x224xf32> |
| 15 | +// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x224xf32> -> !torch.vtensor<[?,224],f32> |
| 16 | +// CHECK: return %[[T8]] : !torch.vtensor<[?,224],f32> |
| 17 | +func.func @torch.aten.view$view_like(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { |
| 18 | + %int-1 = torch.constant.int -1 |
| 19 | + %int224 = torch.constant.int 224 |
| 20 | + %0 = torch.prim.ListConstruct %int-1, %int224 : (!torch.int, !torch.int) -> !torch.list<int> |
| 21 | + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,224],f32> |
| 22 | + return %1 : !torch.vtensor<[?,224],f32> |
| 23 | +} |
| 24 | + |
| 25 | +// ----- |
| 26 | +// CHECK-LABEL: func.func @torch.aten.reshape$view_like( |
| 27 | +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { |
| 28 | +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?,?],f32> -> tensor<?x?x?x?x?xf32> |
| 29 | +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 |
| 30 | +// CHECK: %[[INT120:.*]] = torch.constant.int 120 |
| 31 | +// CHECK: %[[INT4:.*]] = torch.constant.int 4 |
| 32 | +// CHECK: %[[INT64:.*]] = torch.constant.int 64 |
| 33 | +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT]]-1, %[[INT]]120, %[[INT]]4, %[[INT]]64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> |
| 34 | +// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT]]-1 |
| 35 | +// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT120]] |
| 36 | +// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[INT4]] |
| 37 | +// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]] |
| 38 | +// CHECK: %[[T6:.*]] = arith.trunci %[[T2]] : i64 to i32 |
| 39 | +// CHECK: %[[T7:.*]] = arith.trunci %[[T3]] : i64 to i32 |
| 40 | +// CHECK: %[[T8:.*]] = arith.trunci %[[T4]] : i64 to i32 |
| 41 | +// CHECK: %[[T9:.*]] = arith.trunci %[[T5]] : i64 to i32 |
| 42 | +// CHECK: %[[T10:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]], %[[T9]] : tensor<4xi32> |
| 43 | +// CHECK: %[[T11:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor<?x?x?x?x?xf32>, tensor<4xi32>) -> tensor<?x120x4x64xf32> |
| 44 | +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x120x4x64xf32> -> !torch.vtensor<[?,120,4,64],f32> |
| 45 | +// CHECK: return %[[T12]] : !torch.vtensor<[?,120,4,64],f32> |
| 46 | +func.func @torch.aten.reshape$view_like(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { |
| 47 | + %int-1 = torch.constant.int -1 |
| 48 | + %int120 = torch.constant.int 120 |
| 49 | + %int4 = torch.constant.int 4 |
| 50 | + %int64 = torch.constant.int 64 |
| 51 | + %0 = torch.prim.ListConstruct %int-1, %int120, %int4, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> |
| 52 | + %1 = torch.aten.reshape %arg0, %0 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,120,4,64],f32> |
| 53 | + return %1 : !torch.vtensor<[?,120,4,64],f32> |
| 54 | +} |
| 55 | + |
| 56 | +// ----- |
| 57 | +// CHECK-LABEL: func.func @torch.aten.view.minus1$view_like( |
| 58 | +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { |
| 59 | +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32> |
| 60 | +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 |
| 61 | +// CHECK: %[[INT1:.*]] = torch.constant.int 1 |
| 62 | +// CHECK: %[[INT0:.*]] = torch.constant.int 0 |
| 63 | +// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int |
| 64 | +// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int |
| 65 | +// CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INT]]-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> |
| 66 | +// CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]] |
| 67 | +// CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]] |
| 68 | +// CHECK: %[[T6:.*]] = torch_c.to_i64 %[[INT]]-1 |
| 69 | +// CHECK: %[[T7:.*]] = arith.trunci %[[T4]] : i64 to i32 |
| 70 | +// CHECK: %[[T8:.*]] = arith.trunci %[[T5]] : i64 to i32 |
| 71 | +// CHECK: %[[T9:.*]] = arith.trunci %[[T6]] : i64 to i32 |
| 72 | +// CHECK: %[[T10:.*]] = tensor.from_elements %[[T7]], %[[T8]], %[[T9]] : tensor<3xi32> |
| 73 | +// CHECK: %[[T11:.*]] = "chlo.dynamic_reshape"(%[[T0]], %[[T10]]) : (tensor<2x3x?x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32> |
| 74 | +// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<2x3x?xf32> -> !torch.vtensor<[2,3,?],f32> |
| 75 | +// CHECK: return %[[T12]] : !torch.vtensor<[2,3,?],f32> |
| 76 | +func.func @torch.aten.view.minus1$view_like(%arg0: !torch.vtensor<[2,3,?,?],f32>) -> !torch.vtensor<[2,3,?],f32> { |
| 77 | + %int-1 = torch.constant.int -1 |
| 78 | + %int1 = torch.constant.int 1 |
| 79 | + %int0 = torch.constant.int 0 |
| 80 | + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int |
| 81 | + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int |
| 82 | + %2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> |
| 83 | + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[2,3,?,?],f32>, !torch.list<int> -> !torch.vtensor<[2,3,?],f32> |
| 84 | + return %3 : !torch.vtensor<[2,3,?],f32> |
| 85 | +} |
| 86 | + |
| 87 | +// ----- |
| 88 | +// CHECK-LABEL: func.func @torch.aten.view.to_rank1$view_like( |
| 89 | +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { |
| 90 | +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor<f32> |
| 91 | +// CHECK: %[[INT1:.*]] = torch.constant.int 1 |
| 92 | +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list<int> |
| 93 | +// CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor<f32>) -> tensor<1xf32> |
| 94 | +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32> |
| 95 | +// CHECK: return %[[T3]] : !torch.vtensor<[1],f32> |
| 96 | +func.func @torch.aten.view.to_rank1$view_like(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { |
| 97 | + %int1 = torch.constant.int 1 |
| 98 | + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int> |
| 99 | + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[],f32>, !torch.list<int> -> !torch.vtensor<[1],f32> |
| 100 | + return %1 : !torch.vtensor<[1],f32> |
| 101 | +} |
| 102 | + |
| 103 | +// ----- |
| 104 | +// CHECK-LABEL: func.func @torch.aten.view.to_rank0$view_like( |
| 105 | +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { |
| 106 | +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32> |
| 107 | +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list<int> |
| 108 | +// CHECK: %[[T2:.*]] = "mhlo.reshape"(%[[T0]]) : (tensor<1xf32>) -> tensor<f32> |
| 109 | +// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<f32> -> !torch.vtensor<[],f32> |
| 110 | +// CHECK: return %[[T3]] : !torch.vtensor<[],f32> |
| 111 | +func.func @torch.aten.view.to_rank0$view_like(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { |
| 112 | + %0 = torch.prim.ListConstruct : () -> !torch.list<int> |
| 113 | + %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1],f32>, !torch.list<int> -> !torch.vtensor<[],f32> |
| 114 | + return %1 : !torch.vtensor<[],f32> |
| 115 | +} |
0 commit comments