1
+ import os
2
+
3
+ # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
4
+ # os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '0'
5
+ os .environ ['HIP_VISIBLE_DEVICES' ] = '2'
6
+ os .environ ["XLA_FLAGS" ] = (
7
+ "--xla_dump_to=./generated "
8
+ "--xla_dump_hlo_as_dot "
9
+ "--xla_dump_hlo_as_text "
10
+ "--xla_dump_hlo_as_html "
11
+ "--xla_dump_hlo_pass_re=.*"
12
+ # "--xla_gpu_enable_triton_gemm=false"
13
+ # "--xla_gpu_gemm_rewrite_size_threshold=0"
14
+ )
15
+
16
+ import tensorflow as tf
17
+ import numpy as np
18
+
19
+ @tf .function (jit_compile = True )
20
+ def fp8_matmul (x_fp8 , y_fp8 , scale_x , scale_y ):
21
+ # 1. dequantize x_fp8 and y_fp8 to fp32
22
+ x_fp32_unscaled = tf .cast (x_fp8 , tf .float32 ) * scale_x
23
+ y_fp32_unscaled = tf .cast (y_fp8 , tf .float32 ) * scale_y
24
+ # 2. perform matmul in fp32
25
+ z_fp32_unscaled = tf .matmul (x_fp32_unscaled , y_fp32_unscaled )
26
+ return z_fp32_unscaled
27
+
28
+ if __name__ == "__main__" :
29
+ x_fp32 = tf .random .uniform ([16 , 32 ], dtype = tf .float32 )
30
+ y_fp32 = tf .random .uniform ([32 , 16 ], dtype = tf .float32 )
31
+ # initialize x_fp8 and y_fp8
32
+ x_fp8 = tf .cast (x_fp32 , tf .dtypes .experimental .float8_e4m3fn )
33
+ y_fp8 = tf .cast (y_fp32 , tf .dtypes .experimental .float8_e4m3fn )
34
+ # initialize scale_x and scale_y
35
+ scale_x = tf .constant (2.0 , dtype = tf .float32 )
36
+ scale_y = tf .constant (2.0 , dtype = tf .float32 )
37
+ # do fp8 matmul
38
+ z_fp32 = fp8_matmul (x_fp8 , y_fp8 , scale_x , scale_y )
39
+ print (z_fp32 )
0 commit comments