Skip to content

Commit ec6aa6b

Browse files
committed
Updated schema of NumToTensor clean pattern
1 parent 1410ca3 commit ec6aa6b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

core/lowering/passes/device_casting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph) {
4141
// 0D Tensors are initialized on cpu, and need to be casted to CUDA
4242
// to avoid device mismatch issues
4343
std::string num_to_tensor_clean_pattern = R"IR(
44-
graph(%1: int):
44+
graph(%1: Scalar):
4545
%2: Tensor = prim::NumToTensor(%1)
4646
%device: Device = prim::Constant[value="cuda"]()
4747
%dtype: NoneType = prim::Constant()

0 commit comments

Comments
 (0)