Skip to content

Commit 770126f

Browse files
authored
bugfix: Another bugfix for torch.library (#828)
Followup of #823 , we should import `from .. import flashinfer_kernels, flashinfer_kernels_sm90` instead of `from .. import _kernels, _kernels_sm90`, otherwise we will be using JIT compilation all the code. Also add some logic to catch "undefined symbol" errors in case the AOT wheel compilation is successful but failed to be loaded.
1 parent 2076f72 commit 770126f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

flashinfer/jit/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,16 @@
4646
from .utils import parallel_load_modules as parallel_load_modules
4747

4848
try:
49-
from .. import _kernels, _kernels_sm90 # noqa: F401
49+
from .. import flashinfer_kernels, flashinfer_kernels_sm90 # noqa: F401
5050
from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri
5151

5252
has_prebuilt_ops = True
53-
except ImportError:
53+
except ImportError as e:
54+
if "undefined symbol" in str(e):
55+
raise ImportError("Loading prebuilt ops failed.") from e
56+
57+
from .core import logger
58+
59+
logger.info("Prebuilt kernels not found, using JIT backend")
5460
prebuilt_ops_uri = {}
5561
has_prebuilt_ops = False

0 commit comments

Comments
 (0)