Skip to content

🐛 [Bug] Compiling BERT model from transformers fails #1401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
vishwanath-gowda opened this issue Oct 13, 2022 · 7 comments
Closed

🐛 [Bug] Compiling BERT model from transformers fails #1401

vishwanath-gowda opened this issue Oct 13, 2022 · 7 comments
Assignees

Comments

@vishwanath-gowda
Copy link

Bug Description

While compiling a Pytorch model following this tutorial. Compilation fails with below error

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select

To Reproduce

Here is the code used to compile

import torch_tensorrt
import torch
import sys
from transformers import BertConfig, BertModel
model = BertModel.from_pretrained("bert-base-uncased")
inp = torch.rand(1, 256).long()
model = torch.jit.trace(model, inp)

print(next(model.parameters()).is_cuda)

inputs = [
    torch_tensorrt.Input(
        min_shape=[1, 256],
        opt_shape=[1, 256],
        max_shape=[1, 256],
        dtype=torch.int32,
    )
]
enabled_precisions = {torch.float32}  # Run with fp16
trt_ts_module = torch_tensorrt.ts.compile(
    model,inputs=inputs, enabled_precisions=enabled_precisions)

#input_data = input_data.to("cuda").half()
#result = trt_ts_module(input_data)
#print(result)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")

Steps to reproduce the behavior:

  1. Copy above code to /home/compile.py
  2. sudo docker run --gpus all -v /home:/home -it --rm nvcr.io/nvidia/pytorch:22.09-py3
  3. pip install transformers==2.3.0
  4. python /home/compile.py

Environment

AWS g4dn.xlarge with DLAMI

Build information about Torch-TensorRT can be found by turning on debug messages

root@abc:/home/ec2-user# pip list | grep tensor
jupyter-tensorboard           0.2.0
tensorboard                   2.10.0
tensorboard-data-server       0.6.1
tensorboard-plugin-wit        1.8.1
tensorrt                      8.5.0.12
torch-tensorrt                1.2.0a0
functorch                     0.3.0a0
pytorch-quantization          2.1.2
torch                         1.13.0a0+d0d6b1f
torch-tensorrt                1.2.0a0
torchtext                     0.11.0a0
torchvision                   0.14.0a0
root@abc:/home/ec2-user# python --version
Python 3.8.13
[ec2-user@abc]$ uname -ra
Linux ip-172-31-36-232.us-west-2.compute.internal 4.14.291-218.527.amzn2.x86_64 #1 SMP Fri Aug 26 09:54:31 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
  • GPU models and configuration: T4

Stack trace


False
WARNING: [Torch-TensorRT] - For input input_ids, found user specified input dtype as Int32, however when inspecting the graph, the input type expected was inferred to be Float
The compiler is going to use the user setting Int32
This conflict may cause an error at runtime due to partial compilation being enabled and therefore
compatibility with PyTorch's data type convention is required.
If you do indeed see errors at runtime either:
- Remove the dtype spec for input_ids
- Disable partial compilation by setting require_full_compilation to True
Traceback (most recent call last):
  File "tensorrt_compile_1.py", line 20, in <module>
    trt_ts_module = torch_tensorrt.ts.compile(
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 134, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py(2206): embedding
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/sparse.py(160): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1173): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1185): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/modeling_bert.py(186): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1173): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1185): _call_impl
/opt/conda/lib/python3.8/site-packages/transformers/modeling_bert.py(735): forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1173): _slow_forward
/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py(1185): _call_impl
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(967): trace_module
/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py(750): trace
tensorrt_compile_1.py(7): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

@vishwanath-gowda vishwanath-gowda added the bug Something isn't working label Oct 13, 2022
@Mansterteddy
Copy link

Looks like you need to move model to cuda device: model.eval().cuda().

@vishwanath-gowda
Copy link
Author

Have tried this already but no help.

model = torch.jit.trace(model, inp).eval().to('cuda')
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

@vishwanath-gowda
Copy link
Author

I have also tried

model = BertModel.from_pretrained("bert-base-uncased").eval().to('cuda')

Same error

@narendasan
Copy link
Collaborator

Have tried this already but no help.

model = torch.jit.trace(model, inp).eval().to('cuda')
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

For JIT trace are your example tensors on the GPU as well?

@gs-olive
Copy link
Collaborator

gs-olive commented Oct 28, 2022

I have reproduced the error with the described environment, and was able to resolve the issue in my environment with these changes:

  • Put the model on GPU
    • Add device = "cuda"
    • Add model = model.to(device)
  • Initialize inp as a torch.int32 type (instead of torch.int64) and put inp on GPU
    • inp = torch.rand(1, 256).long() $\Longrightarrow$ inp = torch.rand(1, 256).int().to(device)
  • Enable the truncate_long_and_double option in compilation
    • torch_tensorrt.ts.compile( ... truncate_long_and_double=True, ...)

Let me know if this resolves the issue on your end as well.

@gs-olive gs-olive added the bug: triaged [verified] We can replicate the bug label Oct 28, 2022
@narendasan
Copy link
Collaborator

  1. User should be responsible for putting model on GPU and in Eval mode, we can work on improving detection here to help the user catch this issue but not a bug.
  2. Currently TRT does not support int64 inputs, FR in Automatic int32 <=> int64 datatype conversion in fallback #1387 to make a compatibility mode
  3. The truncate_long_and_double feature will tell you if you need this

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants