We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1410ca3 commit ec6aa6bCopy full SHA for ec6aa6b
core/lowering/passes/device_casting.cpp
@@ -41,7 +41,7 @@ void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph) {
41
// 0D Tensors are initialized on cpu, and need to be casted to CUDA
42
// to avoid device mismatch issues
43
std::string num_to_tensor_clean_pattern = R"IR(
44
- graph(%1: int):
+ graph(%1: Scalar):
45
%2: Tensor = prim::NumToTensor(%1)
46
%device: Device = prim::Constant[value="cuda"]()
47
%dtype: NoneType = prim::Constant()
0 commit comments