diff --git a/core/lowering/passes/unpack_var.cpp b/core/lowering/passes/unpack_var.cpp index bbb25374be..4a697833a0 100644 --- a/core/lowering/passes/unpack_var.cpp +++ b/core/lowering/passes/unpack_var.cpp @@ -26,11 +26,18 @@ void UnpackVar(std::shared_ptr& graph) { %var: Tensor = aten::sub(%sqrdmean, %meansqrd, %1) %varout : Tensor = prim::If(%unbiased) block0(): - %shape: int[] = aten::size(%input) - %shapet: Tensor = aten::tensor(%shape, %f32_dtype, %none, %false) - %dim: int = prim::ListUnpack(%dims) - %reduceddims: Tensor = aten::select(%shapet, %0, %dim) - %numel: Tensor = aten::prod(%reduceddims, %dim, %keepdim, %none) + # Compute number of elements in original input tensor + %originalshape: int[] = aten::size(%input) + %originalshapet: Tensor = aten::tensor(%originalshape, %f32_dtype, %none, %false) + %originalnumel: Tensor = aten::prod(%originalshapet, %0, %false, %none) + # Compute number of elements in resulting output tensor + %resultingshape: int[] = aten::size(%var) + %resultingshapet: Tensor = aten::tensor(%resultingshape, %f32_dtype, %none, %false) + %resultingnumel: Tensor = aten::prod(%resultingshapet, %0, %false, %none) + # Quotient of original number of elements and resulting number of elements + # is equal to the number of elements used per variance calculation + %numel: Tensor = aten::div(%originalnumel, %resultingnumel) + # Perform Bessel's correction on computed variance %mul: Tensor = aten::mul(%var, %numel) %sub: Tensor = aten::sub(%numel, %1, %1) %v: Tensor = aten::div(%mul, %sub) diff --git a/tests/core/conversion/converters/test_reduce.cpp b/tests/core/conversion/converters/test_reduce.cpp index bf5234dc19..e3e1e6d252 100644 --- a/tests/core/conversion/converters/test_reduce.cpp +++ b/tests/core/conversion/converters/test_reduce.cpp @@ -465,7 +465,8 @@ TEST(Converters, UnpackStdUnbiasedKeepDimsLowersCorrectly) { %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) + %one : int = prim::Constant[value=1]() + %6 : int[] = prim::ListConstruct(%3, %one) %7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26 return (%7))IR"; diff --git a/tests/core/lowering/test_unpack_reduce_ops.cpp b/tests/core/lowering/test_unpack_reduce_ops.cpp index 146e49891a..8ae6be8e8c 100644 --- a/tests/core/lowering/test_unpack_reduce_ops.cpp +++ b/tests/core/lowering/test_unpack_reduce_ops.cpp @@ -134,7 +134,8 @@ TEST(LoweringPasses, UnpackStdKeepDimsLowersCorrectly) { %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) + %one : int = prim::Constant[value=1]() + %6 : int[] = prim::ListConstruct(%3, %one) %7 : Tensor = aten::std(%x.1, %6, %5, %5) # test_zeros.py:10:26 return (%7))IR"; @@ -184,7 +185,8 @@ TEST(LoweringPasses, UnpackStdUnbiasedKeepDimsLowersCorrectly) { %5 : bool = prim::Constant[value=0]() # test_zeros.py:10:65 %4 : bool = prim::Constant[value=1]() # test_zeros.py:10:50 %3 : int = prim::Constant[value=0]() # test_zeros.py:10:39 - %6 : int[] = prim::ListConstruct(%3) + %one : int = prim::Constant[value=1]() + %6 : int[] = prim::ListConstruct(%3, %one) %7 : Tensor = aten::std(%x.1, %6, %4, %5) # test_zeros.py:10:26 return (%7))IR";