@@ -80,10 +80,10 @@ int main() {
80
80
// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}}
81
81
// Round a, b to tf32
82
82
for (auto i = 0 ; i < 4 ; ++i)
83
- sub_a.data [i] = float_to_tf32 (sub_a.data [i]);
83
+ sub_a.data [i] = round_to_tf32 (sub_a.data [i]);
84
84
85
85
for (auto i = 0 ; i < 4 ; ++i)
86
- sub_b.data [i] = float_to_tf32 (sub_b.data [i]);
86
+ sub_b.data [i] = round_to_tf32 (sub_b.data [i]);
87
87
88
88
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 %{{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}}
89
89
sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
@@ -125,10 +125,10 @@ int main() {
125
125
// CHECK: tail call i32 @llvm.nvvm.f2tf32.rna(float {{.*}}
126
126
// Round a, b to tf32
127
127
for (auto i = 0 ; i < 4 ; ++i)
128
- sub_a.data [i] = float_to_tf32 (sub_a.data [i]);
128
+ sub_a.data [i] = round_to_tf32 (sub_a.data [i]);
129
129
130
130
for (auto i = 0 ; i < 4 ; ++i)
131
- sub_b.data [i] = float_to_tf32 (sub_b.data [i]);
131
+ sub_b.data [i] = round_to_tf32 (sub_b.data [i]);
132
132
133
133
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}) #{{.*}}
134
134
sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
0 commit comments