@@ -8,68 +8,86 @@ namespace core {
8
8
namespace lowering {
9
9
namespace passes {
10
10
11
- void UnpackAndCastMaskedFill (std::shared_ptr<torch::jit::Graph>& graph) {
11
+ void UnpackAndCastMaskedFill (std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name ) {
12
12
std::string masked_fill_pattern = R"IR(
13
13
graph(%self, %mask, %value):
14
14
%out: Tensor = aten::masked_fill_(%self, %mask, %value)
15
15
return (%out))IR" ;
16
16
17
17
// Calls to masked_fill_ often utilize CPU tensors, and as such
18
- // should be casted to CUDA to avoid device mismatch errors
19
- std::string unpacked_pattern = R"IR(
18
+ // should be moved to gpu to avoid device mismatch errors
19
+
20
+ // Separate string into portions to insert device name
21
+ std::string clean_pattern_part_1 = R"IR(
20
22
graph(%self, %mask, %value):
21
- %device: Device = prim::Constant[value="cuda"]()
23
+ %device: Device = prim::Constant[value=")IR" ;
24
+
25
+ std::string clean_pattern_part_2 = R"IR( "]()
22
26
%dtype: NoneType = prim::Constant()
23
27
%false: bool = prim::Constant[value=0]()
24
28
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
25
29
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
26
- %out: Tensor = aten::masked_fill_ (%self_cuda, %mask_cuda, %value)
30
+ %out: Tensor = aten::masked_fill (%self_cuda, %mask_cuda, %value)
27
31
return (%out))IR" ;
28
32
33
+ auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
34
+
29
35
torch::jit::SubgraphRewriter masked_fill_rewriter;
30
36
masked_fill_rewriter.RegisterRewritePattern (masked_fill_pattern, unpacked_pattern);
31
37
masked_fill_rewriter.runOnGraph (graph);
32
38
LOG_GRAPH (" After unpack and cast masked_fill_: " << *graph);
33
39
}
34
40
35
- void UnpackAndCastNumToTensor (std::shared_ptr<torch::jit::Graph>& graph) {
41
+ void UnpackAndCastNumToTensor (std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name ) {
36
42
std::string num_to_tensor_cast_pattern = R"IR(
37
43
graph(%1: Scalar):
38
44
%2: Tensor = prim::NumToTensor(%1)
39
45
return (%2))IR" ;
40
46
41
- // 0D Tensors are initialized on cpu, and need to be casted to CUDA
47
+ // 0D Tensors are initialized on cpu, and need to be moved to gpu
42
48
// to avoid device mismatch issues
43
- std::string num_to_tensor_clean_pattern = R"IR(
49
+
50
+ // Separate string into portions to insert device name
51
+ std::string clean_pattern_part_1 = R"IR(
44
52
graph(%1: Scalar):
45
53
%2: Tensor = prim::NumToTensor(%1)
46
- %device: Device = prim::Constant[value="cuda"]()
54
+ %device: Device = prim::Constant[value=")IR" ;
55
+
56
+ std::string clean_pattern_part_2 = R"IR( "]()
47
57
%dtype: NoneType = prim::Constant()
48
58
%false: bool = prim::Constant[value=0]()
49
59
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
50
60
return (%3))IR" ;
51
61
62
+ auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
63
+
52
64
torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
53
65
num_to_tensor_cast_rewriter.RegisterRewritePattern (num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
54
66
num_to_tensor_cast_rewriter.runOnGraph (graph);
55
67
56
68
LOG_GRAPH (" After unpack and cast NumToTensor: " << *graph);
57
69
}
58
70
59
- void UnpackAndCastFull (std::shared_ptr<torch::jit::Graph>& graph) {
71
+ void UnpackAndCastFull (std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name ) {
60
72
std::string full_cast_pattern = R"IR(
61
73
graph(%1, %2, %3, %4, %5, %6):
62
74
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
63
75
return (%out))IR" ;
64
76
65
- // Tensors created via aten::full are initialized on cpu, and need to be casted to CUDA
77
+ // Tensors created via aten::full are initialized on cpu, and need to be casted to gpu
66
78
// to avoid device mismatch issues
67
- std::string full_clean_pattern = R"IR(
79
+
80
+ // Separate string into portions to insert device name
81
+ std::string clean_pattern_part_1 = R"IR(
68
82
graph(%1, %2, %3, %4, %5, %6):
69
- %cuda: Device = prim::Constant[value="cuda"]()
70
- %out: Tensor = aten::full(%1, %2, %3, %4, %cuda, %6)
83
+ %device: Device = prim::Constant[value=")IR" ;
84
+
85
+ std::string clean_pattern_part_2 = R"IR( "]()
86
+ %out: Tensor = aten::full(%1, %2, %3, %4, %device, %6)
71
87
return (%out))IR" ;
72
88
89
+ auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
90
+
73
91
torch::jit::SubgraphRewriter full_cast_rewriter;
74
92
full_cast_rewriter.RegisterRewritePattern (full_cast_pattern, full_clean_pattern);
75
93
full_cast_rewriter.runOnGraph (graph);
0 commit comments