Skip to content

[BUG] Torchscripting ViT results in AttributeError for Python 3.11 #1946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
narendasan opened this issue Sep 8, 2023 · 4 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@narendasan
Copy link

Describe the bug
When using python 3.11, timm 0.9.7 and the latest version of PyTorch (2.0.1), the following error is thrown when trying to script ViT:

Traceback (most recent call last):
  File "/home/narens/Developer/opensource/pytorch_org/tensorrt/experiments/repro.py", line 5, in <module>
    ts_model = torch.jit.script(model)
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_recursive.py", line 477, in create_script_module
    concrete_type = get_module_concrete_type(nn_module, share_types)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_recursive.py", line 428, in get_module_concrete_type
    concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_recursive.py", line 369, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_recursive.py", line 234, in infer_concrete_type_builder
    sub_concrete_type = get_module_concrete_type(item, share_types)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_recursive.py", line 428, in get_module_concrete_type
    concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_recursive.py", line 369, in get_or_create_concrete_type
    concrete_type_builder = infer_concrete_type_builder(nn_module)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_recursive.py", line 333, in infer_concrete_type_builder
    attr_type, inferred = infer_type(name, value)
                          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_recursive.py", line 178, in infer_type
    ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], fake_range())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/annotations.py", line 422, in ann_to_type
    the_type = try_ann_to_type(ann, loc, rcb)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/annotations.py", line 403, in try_ann_to_type
    scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/jit/_script.py", line 1465, in _recursive_compile_class
    rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/_jit_internal.py", line 457, in createResolutionCallbackForClassMethods
    captures.update(get_closure(fn))
                    ^^^^^^^^^^^^^^^
  File "/home/narens/.miniconda3/envs/torch210cu121py311/lib/python3.11/site-packages/torch/_jit_internal.py", line 207, in get_closure
    captures.update(fn.__globals__)
                    ^^^^^^^^^^^^^^
AttributeError: 'wrapper_descriptor' object has no attribute '__globals__'

To Reproduce

import timm
import torch

model = timm.create_model("vit_tiny_patch16_224", pretrained=True)
ts_model = torch.jit.script(model)
torch.save(ts_model, "vit_tiny.ts")

Expected behavior
A clear and concise description of what you expected to happen.

Screenshots
If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

  • OS: Ubuntu 22.04 x86_64
  • This repository version [e.g. pip 0.3.1 or commit ref]: pip 0.9.7
  • PyTorch version w/ CUDA/cuDNN [e.g. from conda list, 1.7.0 py3.8_cuda11.0.221_cudnn8.0.3_0]: PyTorch 2.0.1-cu121

Additional context
Add any other context about the problem here.

@narendasan narendasan added the bug Something isn't working label Sep 8, 2023
@rwightman
Copy link
Collaborator

rwightman commented Sep 8, 2023 via email

@rwightman
Copy link
Collaborator

rwightman commented Sep 9, 2023

Thought there was an issue for this already but couldn't find it, did discuss it with pytorch devs at some point but I guess it got lost in the shuffle: pytorch/pytorch#108933

Problem with using enums + jit in python 3.11 (anythin < 3.11 is fine)

@rwightman
Copy link
Collaborator

@narendasan looks like a fix for this will squeek into 2.1 pytorch/pytorch#109807

@rwightman
Copy link
Collaborator

I tested the latest PyTorch release candidate for 2.1 and the issue appears fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants