diff --git a/flashinfer/jit/__init__.py b/flashinfer/jit/__init__.py index 4281c37e1..c4c561960 100644 --- a/flashinfer/jit/__init__.py +++ b/flashinfer/jit/__init__.py @@ -46,10 +46,16 @@ from .utils import parallel_load_modules as parallel_load_modules try: - from .. import _kernels, _kernels_sm90 # noqa: F401 + from .. import flashinfer_kernels, flashinfer_kernels_sm90 # noqa: F401 from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri has_prebuilt_ops = True -except ImportError: +except ImportError as e: + if "undefined symbol" in str(e): + raise ImportError("Loading prebuilt ops failed.") from e + + from .core import logger + + logger.info("Prebuilt kernels not found, using JIT backend") prebuilt_ops_uri = {} has_prebuilt_ops = False