@@ -331,8 +331,14 @@ def gelu(x):
331
331
# it must be explicitly set as it can conflict with some operations which do not
332
332
# benefit from Tensor core computations.
333
333
334
+ ## Tensor computation can be enabled "manually" modifying the matrix multiplication precision
335
+ ## The default precision is "highest" which will perform the operation according to the dtype
334
336
335
- torch .backends .cuda .matmul .allow_tf32
337
+ # precision "high" and "medium" can be hardware accelerated via tensor cores
338
+ # and will set torch.backends.cuda.matmul.allow_tf32 = True if available
339
+
340
+ # Carefully consider the tradeoff between speed and precision at the moment of evaluating your models!
341
+ torch .set_float32_matmul_precision ("high" )
336
342
337
343
###############################################################################
338
344
# Use CUDA Graphs
@@ -341,8 +347,13 @@ def gelu(x):
341
347
# on some cases the context switch between CPU and GPU can lead to bad resourse
342
348
# utilization. CUDA graphs are a way to keep computation within the GPU without
343
349
# paying the extra cost of kernel launches and host synchronization.
344
- #
345
- # It can be enabled using `torch.compile <https://pytorch.org/docs/stable/generated/torch.compile.html>`_ "reduce-overhead" and "max-autotune" modes.
350
+
351
+ # It can be enabled using
352
+ torch .compile (m , "reduce-overhead" )
353
+ # or
354
+ torch .compile (m , "max-autotune" )
355
+
356
+ ###############################################################################
346
357
# Special care must be present when using cuda graphs as it can lead to increased memory consumption and some models might not compile.
347
358
348
359
###############################################################################
0 commit comments