Skip to content

Commit 34b8537

Browse files
committed
update
1 parent b6e6e37 commit 34b8537

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
3+
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
4+
# os.environ['TF_CPP_MAX_VLOG_LEVEL'] = '0'
5+
os.environ['HIP_VISIBLE_DEVICES'] = '2'
6+
os.environ["XLA_FLAGS"] = (
7+
"--xla_dump_to=./generated "
8+
"--xla_dump_hlo_as_dot "
9+
"--xla_dump_hlo_as_text "
10+
"--xla_dump_hlo_as_html "
11+
"--xla_dump_hlo_pass_re=.*"
12+
# "--xla_gpu_enable_triton_gemm=false"
13+
# "--xla_gpu_gemm_rewrite_size_threshold=0"
14+
)
15+
16+
import tensorflow as tf
17+
import numpy as np
18+
19+
@tf.function(jit_compile=True)
20+
def fp8_matmul(x_fp8, y_fp8, scale_x, scale_y):
21+
# 1. dequantize x_fp8 and y_fp8 to fp32
22+
x_fp32_unscaled = tf.cast(x_fp8, tf.float32) * scale_x
23+
y_fp32_unscaled = tf.cast(y_fp8, tf.float32) * scale_y
24+
# 2. perform matmul in fp32
25+
z_fp32_unscaled = tf.matmul(x_fp32_unscaled, y_fp32_unscaled)
26+
return z_fp32_unscaled
27+
28+
if __name__ == "__main__":
29+
x_fp32 = tf.random.uniform([16, 32], dtype=tf.float32)
30+
y_fp32 = tf.random.uniform([32, 16], dtype=tf.float32)
31+
# initialize x_fp8 and y_fp8
32+
x_fp8 = tf.cast(x_fp32, tf.dtypes.experimental.float8_e4m3fn)
33+
y_fp8 = tf.cast(y_fp32, tf.dtypes.experimental.float8_e4m3fn)
34+
# initialize scale_x and scale_y
35+
scale_x = tf.constant(2.0, dtype=tf.float32)
36+
scale_y = tf.constant(2.0, dtype=tf.float32)
37+
# do fp8 matmul
38+
z_fp32 = fp8_matmul(x_fp8, y_fp8, scale_x, scale_y)
39+
print(z_fp32)

0 commit comments

Comments
 (0)