diff --git a/docs/source/en/perf_train_gpu_one.mdx b/docs/source/en/perf_train_gpu_one.mdx index 5e825beb7d10..0c130b417223 100644 --- a/docs/source/en/perf_train_gpu_one.mdx +++ b/docs/source/en/perf_train_gpu_one.mdx @@ -11,7 +11,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o # Efficient Training on a Single GPU -This guide focuses on training large models efficiently on a single GPU. These approaches are still valid if you have access to a machine with multiple GPUs but you will also have access to additional methods outlined in the [multi-GPU section](perf_train_gpu_many). +This guide focuses on training large models efficiently on a single GPU. These approaches are still valid if you have access to a machine with multiple GPUs but you will also have access to additional methods outlined in the [multi-GPU section](perf_train_gpu_many). In this section we have a look at a few tricks to reduce the memory footprint and speed up training for large models and how they are integrated in the [`Trainer`] and [🤗 Accelerate](https://huggingface.co/docs/accelerate/). Each method can improve speed or memory usage which is summarized in the table below: @@ -367,7 +367,7 @@ Samples/second: 10.09 GPU memory occupied: 7275 MB. ``` -We can see that with these tweaks we use about half the GPU memory as at the beginning while also being slightly faster. +We can see that with these tweaks we use about half the GPU memory as at the beginning while also being slightly faster. ### BF16 If you have access to a Ampere or newer hardware you can use bf16 for your training and evaluation. While bf16 has a worse precision than fp16, it has a much much bigger dynamic range. Therefore, if in the past you were experiencing overflow issues while training the model, bf16 will prevent this from happening most of the time. Remember that in fp16 the biggest number you can have is `65535` and any number above that will overflow. A bf16 number can be as large as `3.39e+38` (!) which is about the same as fp32 - because both have 8-bits used for the numerical range. @@ -394,7 +394,7 @@ Like all cases with reduced precision this may or may not be satisfactory for yo If you're already using fp16 or bf16 mixed precision it may help with the throughput as well. -You can enable this mode in the 🤗 Trainer with: +You can enable this mode in the 🤗 Trainer with: ```python TrainingArguments(tf32=True) ``` @@ -654,7 +654,7 @@ https://github.com/huggingface/transformers/blob/master/src/transformers/trainer ## Choice of GPU -Sometimes, even when applying all the above tweaks the throughput on a given GPU might still not be good enough. One easy solution is to change the type of GPU. For example switching from let's say a K80 (which you typically get on Google Colab) to a fancier GPU such as the V100 or A100. Although they are more expensive they are usually more cost effective than cheaper GPUs due to their larger memory and faster architecture. +Sometimes, even when applying all the above tweaks the throughput on a given GPU might still not be good enough. One easy solution is to change the type of GPU. For example switching from let's say a K80 (which you typically get on Google Colab) to a fancier GPU such as the V100 or A100. Although they are more expensive they are usually more cost effective than cheaper GPUs due to their larger memory and faster architecture. Now, let's take a step back and discuss what we should optimize for when scaling the training of large models. @@ -718,3 +718,15 @@ For some applications, such as pretraining large language models, applying all t Another use case for training on many GPUs is if the model does not fit on a single GPU with all the mentioned tricks. There are still more methods we can apply although life starts to get a bit more complicated. This usually involves some form of pipeline or tensor parallelism where the model itself is distributed across several GPUs. One can also make use of DeepSpeed which implements some of these parallelism strategies along with some more optimization to reduce the memory footprint such as partitioning the optimizer states. You can read more about this in the ["Multi-GPU training" section](perf_train_gpu_many). +## Inference with torchdynamo +TorchDynamo is a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. After capturing the FX graph, different backends can be deployed to lower the graph to an optimized engine. One solution is using the [TensorRT](https://developer.nvidia.com/tensorrt) or NVFuser as backend. You can choose one option below for performance boost. +``` +TrainingArguments(torchdynamo="eager") #enable eager model GPU. No performance boost +TrainingArguments(torchdynamo="nvfuser") #enable nvfuser +TrainingArguments(torchdynamo="fx2trt") #enable tensorRT fp32 +TrainingArguments(torchdynamo="fx2trt-f16") #enable tensorRT fp16 +``` +This feature involves 3 different libraries. To install them, please follow the instructions below: +- [Torchdynamo installation](https://github.com/pytorch/torchdynamo#requirements-and-setup) +- [Functorch installation](https://github.com/pytorch/functorch#install) +- [Torch-TensorRT(FX) installation](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst#installation) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 1a71e9d8408b..564d6364d259 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -71,6 +71,7 @@ is_torch_available, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, + is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, is_torchaudio_available, @@ -494,6 +495,11 @@ def require_torchdynamo(test_case): return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) +def require_torch_tensorrt_fx(test_case): + """Decorator marking a test that requires Torch-TensorRT FX""" + return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0bebc8626ba6..dcadc02718cf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -141,6 +141,7 @@ is_ipex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, + is_torch_tensorrt_fx_available, is_torch_tpu_available, is_torchdynamo_available, logging, @@ -598,6 +599,35 @@ def __init__( # very last self._memory_tracker.stop_and_update_metrics() + # torchdynamo + if args.torchdynamo: + if not is_torchdynamo_available(): + raise RuntimeError("Torchdynamo is not installed.") + import torchdynamo + from torchdynamo.optimizations import backends + from torchdynamo.optimizations.training import aot_autograd_speedup_strategy + + def get_ctx(): + # Normal + if args.torchdynamo == "eager": + return torchdynamo.optimize("eager") + elif args.torchdynamo == "nvfuser": + return torchdynamo.optimize(aot_autograd_speedup_strategy) + # TensorRT + if args.torchdynamo in ["fx2trt-fp16", "fx2trt"]: + if not is_torch_tensorrt_fx_available(): + raise RuntimeError("Torch-TensorRT FX path is not installed.") + if args.torchdynamo == "fx2trt-fp16": + return torchdynamo.optimize(backends.fx2trt_compiler_fp16) + elif args.torchdynamo == "fx2trt": + return torchdynamo.optimize(backends.fx2trt_compiler) + else: + raise RuntimeError(f"Torchdynamo backend {args.torchdynamo} is not supported.") + + self.ctx_manager_torchdynamo = get_ctx() + else: + self.ctx_manager_torchdynamo = contextlib.nullcontext() + def add_callback(self, callback): """ Add a callback to the current list of [`~transformer.TrainerCallback`]. @@ -2291,16 +2321,7 @@ def torchdynamo_smart_context_manager(self): """ A helper wrapper that creates an appropriate context manager for `torchdynamo`. """ - ctx_manager = contextlib.nullcontext() - if is_torchdynamo_available(): - import torchdynamo - from torchdynamo.optimizations.training import aot_autograd_speedup_strategy - - if self.args.torchdynamo == "eager": - ctx_manager = torchdynamo.optimize("eager") - elif self.args.torchdynamo == "nvfuser": - ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy) - return ctx_manager + return self.ctx_manager_torchdynamo def autocast_smart_context_manager(self): """ diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 603015bf989b..833dc174c375 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -935,7 +935,7 @@ class TrainingArguments: " are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging." " nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models." ), - "choices": ["eager", "nvfuser"], + "choices": ["eager", "nvfuser", "fx2trt", "fx2trt-fp16"], }, ) ray_scope: Optional[str] = field( diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index f96c281c010f..1ee4521514af 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -132,6 +132,7 @@ is_torch_fx_available, is_torch_fx_proxy, is_torch_onnx_dict_inputs_support_available, + is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, is_torchaudio_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 31dbb536ac60..17a73890b078 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -418,6 +418,12 @@ def is_torchdynamo_available(): return importlib.util.find_spec("torchdynamo") is not None +def is_torch_tensorrt_fx_available(): + if importlib.util.find_spec("torch_tensorrt") is None: + return False + return importlib.util.find_spec("torch_tensorrt.fx") is not None + + def is_datasets_available(): return _datasets_available diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 73e7b4eeb120..4393d0220662 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -62,6 +62,7 @@ require_torch_gpu, require_torch_multi_gpu, require_torch_non_multi_gpu, + require_torch_tensorrt_fx, require_torch_tf32, require_torch_up_to_2_gpus, require_torchdynamo, @@ -1799,6 +1800,7 @@ def test_fp16_full_eval(self): @require_torch_non_multi_gpu @require_torchdynamo + @require_torch_tensorrt_fx def test_torchdynamo_full_eval(self): # torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu n_gpus = get_gpu_count() @@ -1827,6 +1829,21 @@ def test_torchdynamo_full_eval(self): metrics = trainer.evaluate() self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) + # 4. TorchDynamo fx2trt + trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt") + metrics = trainer.evaluate() + t1 = metrics["eval_loss"] + t2 = original_eval_loss + self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss) + + # 5. TorchDynamo fx2trt-fp16 + trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="fx2trt-fp16") + metrics = trainer.evaluate() + t1 = metrics["eval_loss"] + t2 = original_eval_loss + # fp16 has accuracy accuracy degradation + self.assertLess(np.max(np.abs(t1 - t2)), 1e-3) + @require_torch_non_multi_gpu @require_torchdynamo def test_torchdynamo_memory(self): @@ -1852,7 +1869,7 @@ def forward(self, x): mod = MyModule() - # 1. Default - without TorchDynamo + # 1. without TorchDynamo (eager baseline) a = torch.ones(1024, 1024, device="cuda", requires_grad=True) a.grad = None trainer = CustomTrainer(model=mod) @@ -1860,16 +1877,15 @@ def forward(self, x): for _ in range(10): orig_loss = trainer.training_step(mod, {"x": a}) + # resets + gc.collect() + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() + orig_loss = trainer.training_step(mod, {"x": a}) orig_peak_mem = torch.cuda.max_memory_allocated() del trainer - # Reset the peak for another measurement - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - # 2. TorchDynamo nvfuser a = torch.ones(1024, 1024, device="cuda", requires_grad=True) a.grad = None @@ -1879,7 +1895,11 @@ def forward(self, x): for _ in range(10): loss = trainer.training_step(mod, {"x": a}) + # resets + gc.collect() + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() + loss = trainer.training_step(mod, {"x": a}) peak_mem = torch.cuda.max_memory_allocated() del trainer