Skip to content

Commit b412ddc

Browse files
authored
Update dim_order type
Differential Revision: D68041866 Pull Request resolved: #7610
1 parent 83c0da5 commit b412ddc

File tree

3 files changed

+3
-5
lines changed

3 files changed

+3
-5
lines changed

exir/schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class Tensor:
6464
scalar_type: ScalarType
6565
storage_offset: int
6666
sizes: List[int]
67-
dim_order: List[bytes]
67+
dim_order: List[int]
6868
requires_grad: bool
6969
layout: int
7070
data_buffer_idx: int

exir/tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
7676
return tuple(typing.cast(Tuple[bytes], sorted_dims))
7777

7878

79-
def stride_from_dim_order(sizes: List[int], dim_order: List[bytes]) -> List[int]:
79+
def stride_from_dim_order(sizes: List[int], dim_order: List[int]) -> List[int]:
8080
"""
8181
Converts dim order to stride using sizes
8282
e.g. if sizes = (2, 3, 4) and dim_order = (0, 1, 2) then strides = (12, 4, 1)

exir/tests/common.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
import typing
9-
from typing import List
108

119
import torch
1210

@@ -46,7 +44,7 @@ def get_test_program() -> Program:
4644
scalar_type=ScalarType.FLOAT,
4745
storage_offset=0,
4846
sizes=[2, 2],
49-
dim_order=typing.cast(List[bytes], [0, 1]),
47+
dim_order=[0, 1],
5048
requires_grad=False,
5149
layout=0,
5250
data_buffer_idx=0,

0 commit comments

Comments
 (0)