You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, the XLA profiler can capture the TPU device trace when running on a v3-8 TPU VM, but cannot capture the device trace in trace_viewer when running on a TPU pod (e.g. a v3-128 pod). This is unlike the TensorFlow profiler which is able to capture TPU device traces when running on a pod.
As we have frequently observed a bigger overhead from pod training (compared to training on a v3-8 under the same per-TPU batch size, e.g. #3441), it would be great if the TPU device trace can also be captured in the pod case to help understand the performance bottlenecks.
(Not sure whether this should be a bug report or a feature request. Since many practical TPU use cases involve training in pods, it would be great if the XLA profiler could also work for the pod case.)
To Reproduce
While the XLA profiler can capture CPU host traces when doing TPU pod training (e.g. v3-128), it cannot capture the TPU device trace. In particular, no device trace shows up on the trace_viewer page in the profiler, as shown in the screenshot below (only CPU traces are captured).
Note that on the other hand, the profiler successfully captures TPU device traces on a v3-8 (as in the screenshot below).
To reproduce the failure case above on v3-128
Allocate a v3-128 TPU VM pod (e.g. with name tpu-debug-128) from the tpu-vm-pt-1.10 environment
Clone the PyTorch XLA repo the TPU VM to download the profiler script (do this on all the nodes in the TPU VM, e.g. throughgcloud alpha compute tpus tpu-vm ssh --worker all):
Start training on the TPU pod with the profiler server:
TPU_NAME=tpu-debug-128 # change to your TPU name
cd ${HOME} && python3 -m torch_xla.distributed.xla_dist --tpu=${TPU_NAME} --restart-tpuvm-pod-server -- \
python3 -u /home/ronghanghu/workspace/xla/test/test_profile_mp_mnist.py \
--batch_size 16 --drop_last --num_epochs 200000 --lr 0.0
Forward the tensorboard port 6006 to a local machine and try to capture the profile from localhost:9012 on the profile page on tensorboard. Then check its trace_viewer tool.
Expected behavior
It would be great if TPU device traces can also be captured in pod training.
Environment
Reproducible on XLA backend [CPU/TPU]: v3-128 TPU pod with tpu-vm-pt-1.10 runtime
torch_xla version: 1.10
Additional context
This issue (that the profiler cannot capture the TPU device on a pod) can be reproduced on all torch_xla 1.9, torch_xla 1.10, and the nightly 20220308 versions.
The text was updated successfully, but these errors were encountered:
Thanks @miladm! I've only tried running on the XLA profiler for the pod case on TPU VMs (since it's my practical use case and TPU VMs are generally faster than TPU nodes), and haven't tried the older way of using a compute engine VM + TPU nodes.
I think this issue (that the profiler cannot capture the TPU device on a pod) can be reproduced on both a v3-32 pod (with 4 TPU VM nodes) and a v3-128 pod (with 16 VM nodes) following the steps above.
Following up on this issue: in our internal test on TPU v4, the XLA profiler worked well on v4-8 but failed to capture TPU device traces on v4-32 or v4-128.
Uh oh!
There was an error while loading. Please reload this page.
🐛 Bug / 🚀 Feature request
Currently, the XLA profiler can capture the TPU device trace when running on a v3-8 TPU VM, but cannot capture the device trace in trace_viewer when running on a TPU pod (e.g. a v3-128 pod). This is unlike the TensorFlow profiler which is able to capture TPU device traces when running on a pod.
As we have frequently observed a bigger overhead from pod training (compared to training on a v3-8 under the same per-TPU batch size, e.g. #3441), it would be great if the TPU device trace can also be captured in the pod case to help understand the performance bottlenecks.
(Not sure whether this should be a bug report or a feature request. Since many practical TPU use cases involve training in pods, it would be great if the XLA profiler could also work for the pod case.)
To Reproduce
While the XLA profiler can capture CPU host traces when doing TPU pod training (e.g. v3-128), it cannot capture the TPU device trace. In particular, no device trace shows up on the
trace_viewer
page in the profiler, as shown in the screenshot below (only CPU traces are captured).The tensorboard output (including the profiler results) for this case is also uploaded to https://drive.google.com/file/d/108GRRqndJJyEQEhmICx1u4aaiF1vkQ_F/view?usp=sharing.
Note that on the other hand, the profiler successfully captures TPU device traces on a v3-8 (as in the screenshot below).
To reproduce the failure case above on v3-128
tpu-debug-128
) from thetpu-vm-pt-1.10
environmentgcloud alpha compute tpus tpu-vm ssh --worker all
):localhost:9012
on the profile page on tensorboard. Then check itstrace_viewer
tool.Expected behavior
It would be great if TPU device traces can also be captured in pod training.
Environment
tpu-vm-pt-1.10
runtimeAdditional context
This issue (that the profiler cannot capture the TPU device on a pod) can be reproduced on all torch_xla 1.9, torch_xla 1.10, and the nightly 20220308 versions.
The text was updated successfully, but these errors were encountered: