-
-
Notifications
You must be signed in to change notification settings - Fork 648
[WIP] added PyTorch Profiler #2315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 11 commits
9f928c8
333057e
58c2d18
dafb4d5
f99a2ba
6121183
e738d69
588b7fe
74647fb
79ce0c0
bbfbf8f
bf753bc
229c7ef
27dc96f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# coding: utf-8 | ||
import os | ||
from datetime import datetime | ||
from typing import Any, Callable, Union | ||
|
||
import torch | ||
|
||
import ignite.distributed as idist | ||
from ignite.engine import Engine, Events | ||
|
||
|
||
class PyTorchProfiler: | ||
"""PyTorch Profiler for performance debugging. | ||
|
||
The PyTorch profiler is a tool that collects both GPU hardware and PyTorch related | ||
information, correlates them, performs automatic detection of bottlenecks in the model, | ||
and generates recommendations on how to resolve these bottlenecks. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
from ignite.handlers import PyTorchProfiler | ||
|
||
trainer = ... | ||
model = ... | ||
optimizer = ... | ||
|
||
pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path="logs/train") | ||
pt_profiler.attach(trainer) | ||
|
||
# Get profiler results of time | ||
pt_profiler.print_results() | ||
|
||
# Save profiler result to CSV file (requires pandas) | ||
pt_profiler.write_results() | ||
|
||
Both these methods can also be used as the on_trace_ready function which gets called after trace is ready. | ||
|
||
pt_profiler = PyTorchProfiler(on_trace_ready=profiler.write_to_file(10), output_path="logs/train") | ||
|
||
.. versionadded:: 0.4.8 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
cuda_activity: bool = False, | ||
on_trace_ready: Union[Callable[..., Any], str] = "tensorboard", | ||
record_shapes: bool = False, | ||
profile_memory: bool = False, | ||
with_stack: bool = False, | ||
with_flops: bool = False, | ||
with_modules: bool = False, | ||
output_path: str = None, | ||
wait: int = 2, | ||
warmup: int = 2, | ||
active: int = 6, | ||
repeat: int = 1, | ||
) -> None: | ||
|
||
self.activities = [torch.profiler.ProfilerActivity.CPU] | ||
if cuda_activity and torch.cuda.is_available(): | ||
self.activities.append(torch.profiler.ProfilerActivity.GPU) | ||
|
||
self.output_path = output_path | ||
|
||
self.schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat) | ||
|
||
self.trace_handler = ( | ||
torch.profiler.tensorboard_trace_handler(self.output_path) | ||
if on_trace_ready == "tensorboard" | ||
else on_trace_ready | ||
) | ||
|
||
self.record_shapes = record_shapes | ||
self.profile_memory = profile_memory | ||
self.with_stack = with_stack | ||
self.with_flops = with_flops | ||
self.with_modules = with_modules | ||
|
||
self.SORT_KEYS = { | ||
"cpu_time", | ||
"cuda_time", | ||
"cpu_time_total", | ||
"cuda_time_total", | ||
"cpu_memory_usage", | ||
"cuda_memory_usage", | ||
"self_cpu_memory_usage", | ||
"self_cuda_memory_usage", | ||
"count", | ||
} | ||
|
||
def _profiler_create(self): | ||
self._profiler = torch.profiler.profile( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should check the PyTorch version and provide a clear error message if version < 1.8 ? And this check would be associated to a specific test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't get how I should do this. In case the PyTorch version is <1.8 then I want all the tests to not run right? |
||
activities=self.activities, | ||
schedule=self.schedule, | ||
on_trace_ready=self.trace_handler, | ||
record_shapes=self.record_shapes, | ||
profile_memory=self.profile_memory, | ||
with_stack=self.with_stack, | ||
with_flops=self.with_flops, | ||
) | ||
self._profiler.__enter__() | ||
|
||
def _exit_profiler(self): | ||
self._profiler.__exit__(0, 0, 0) | ||
|
||
def _profiler_step(self): | ||
self._profiler.step() | ||
|
||
def attach( | ||
self, | ||
engine: Engine, | ||
) -> None: | ||
"""Attach the profiler to the engine. | ||
|
||
Args: | ||
engine: engine object. | ||
""" | ||
engine.add_event_handler(Events.EPOCH_STARTED, self._profiler_create) | ||
engine.add_event_handler(Events.GET_BATCH_COMPLETED, self._profiler_step) | ||
engine.add_event_handler(Events.EPOCH_COMPLETED, self._exit_profiler) | ||
|
||
def get_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): | ||
if sort_key not in self.SORT_KEYS: | ||
raise ValueError( | ||
f" The sort_key {sort_key} is not accepted. Please choose a sort key from {self.SORT_KEYS}" | ||
) | ||
|
||
return self._profiler.key_averages().table( | ||
sort_by=sort_key, row_limit=n, top_level_events_only=top_level_events_only | ||
) | ||
|
||
def write_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): | ||
now = datetime.now().strftime("%Y%m%d-%H%M%S") | ||
file_name = f"{idist.backend()}_{now}.txt" | ||
|
||
with open(os.path.join(self.output_path, file_name), "w") as f: | ||
f.write(self.get_results(n, sort_key, top_level_events_only)) | ||
|
||
def print_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): | ||
print(self.get_results(n, sort_key, top_level_events_only)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import glob | ||
import os | ||
|
||
import pytest | ||
import torch | ||
|
||
import ignite.distributed as idist | ||
from ignite.engine import Engine | ||
from ignite.handlers import PyTorchProfiler | ||
|
||
|
||
def update_fn(engine, batch): | ||
a = torch.empty((2, 3), dtype=torch.int32) | ||
b = torch.empty((3, 3), dtype=torch.int32) | ||
|
||
return a + torch.mm(a, b) | ||
|
||
|
||
def get_engine(): | ||
dummy_trainer = Engine(update_fn) | ||
return dummy_trainer | ||
|
||
|
||
def test_get_results(tmp_path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should test firstly when the profiler is not attached to an engine. Secondly, you should test the presence and the absence of the expected keys. |
||
trainer = get_engine() | ||
pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path) | ||
pt_profiler.attach(trainer) | ||
trainer.run(range(10), max_epochs=1) | ||
|
||
with pytest.raises(ValueError, match=r" The sort_key cpu_times is not accepted. Please choose a sort key from"): | ||
pt_profiler.get_results(sort_key="cpu_times") | ||
|
||
|
||
def test_write_results(tmp_path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should test the files generated on more than one epoch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added this |
||
n = 5 | ||
|
||
trainer = get_engine() | ||
pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path) | ||
pt_profiler.attach(trainer) | ||
trainer.run(range(10), max_epochs=1) | ||
pt_profiler.write_results(n=n) | ||
|
||
fp = glob.glob(os.path.join(tmp_path, f"{idist.backend()}_*"))[0 - 1] | ||
assert os.path.isfile(fp) | ||
|
||
file_length = 0 | ||
with open(fp, "r") as fp: | ||
for _ in fp: | ||
file_length += 1 | ||
|
||
assert file_length == n + 5 |
Uh oh!
There was an error while loading. Please reload this page.