Skip to content

Expermimental mypy plugin #7117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ known_first_party = xarray
exclude = properties|asv_bench|doc
files = .
show_error_codes = True
plugins = xarray.mypy_plugin

# Most of the numerical computing stack doesn't have type annotations yet.
[mypy-affine.*]
Expand Down
48 changes: 35 additions & 13 deletions xarray/core/extensions.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,53 @@
from __future__ import annotations

import warnings

from .dataarray import DataArray
from .dataset import Dataset
from typing import Callable, Generic, TypeVar, overload


class AccessorRegistrationWarning(Warning):
"""Warning for conflicts in accessor registration."""


class _CachedAccessor:
_Accessor = TypeVar("_Accessor")


class _CachedAccessor(Generic[_Accessor]):
"""Custom property-like object (descriptor) for caching accessors."""

def __init__(self, name, accessor):
_name: str
_accessor: type[_Accessor]

def __init__(self, name: str, accessor: type[_Accessor]):
self._name = name
self._accessor = accessor

def __get__(self, obj, cls):
@overload
def __get__(self, obj: None, cls) -> type[_Accessor]:
...

@overload
def __get__(self, obj: object, cls) -> _Accessor:
...

def __get__(self, obj: None | object, cls) -> type[_Accessor] | _Accessor:
if obj is None:
# we're accessing the attribute of the class, i.e., Dataset.geo
return self._accessor

# Use the same dict as @pandas.util.cache_readonly.
# It must be explicitly declared in obj.__slots__.
try:
cache = obj._cache
cache = obj._cache # type: ignore[attr-defined]
except AttributeError:
cache = obj._cache = {}
cache = obj._cache = {} # type: ignore[attr-defined]

try:
return cache[self._name]
except KeyError:
pass

try:
accessor_obj = self._accessor(obj)
accessor_obj = self._accessor(obj) # type: ignore[call-arg]
except AttributeError:
# __getattr__ on data object will swallow any AttributeErrors
# raised when initializing the accessor, so we need to raise as
Expand All @@ -46,8 +58,10 @@ def __get__(self, obj, cls):
return accessor_obj


def _register_accessor(name, cls):
def decorator(accessor):
def _register_accessor(
name: str, cls: type[object]
) -> Callable[[type[_Accessor]], type[_Accessor]]:
def decorator(accessor: type[_Accessor]) -> type[_Accessor]:
if hasattr(cls, name):
warnings.warn(
f"registration of accessor {accessor!r} under name {name!r} for type {cls!r} is "
Expand All @@ -61,7 +75,9 @@ def decorator(accessor):
return decorator


def register_dataarray_accessor(name):
def register_dataarray_accessor(
name: str,
) -> Callable[[type[_Accessor]], type[_Accessor]]:
"""Register a custom accessor on xarray.DataArray objects.

Parameters
Expand All @@ -74,10 +90,14 @@ def register_dataarray_accessor(name):
--------
register_dataset_accessor
"""
from .dataarray import DataArray

return _register_accessor(name, DataArray)


def register_dataset_accessor(name):
def register_dataset_accessor(
name: str,
) -> Callable[[type[_Accessor]], type[_Accessor]]:
"""Register a custom property on xarray.Dataset objects.

Parameters
Expand Down Expand Up @@ -119,4 +139,6 @@ def register_dataset_accessor(name):
--------
register_dataarray_accessor
"""
from .dataset import Dataset

return _register_accessor(name, Dataset)
51 changes: 51 additions & 0 deletions xarray/mypy_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from functools import partial
from typing import Callable

from mypy.nodes import CallExpr, ClassDef, Expression, StrExpr, TypeInfo
from mypy.plugin import ClassDefContext, Plugin
from mypy.plugins.common import add_attribute_to_class
from mypy.types import Instance


def accessor_callback(
ctx: ClassDefContext, accessor_cls: ClassDef, xr_cls: ClassDef
) -> None:
name = _get_name_arg(ctx.reason)
cls_typ = Instance(ctx.cls.info, [])
# TODO: when changing code and running mypy again, this fails?
attr_typ = Instance(accessor_cls.info, [cls_typ])
add_attribute_to_class(api=ctx.api, cls=xr_cls, name=name, typ=attr_typ)


def _get_name_arg(reason: Expression) -> str:
assert isinstance(reason, CallExpr)
assert len(reason.args) == 1 # only a single "name" arg
name_expr = reason.args[0]
assert isinstance(name_expr, StrExpr)
return name_expr.value


class XarrayPlugin(Plugin):
def get_class_decorator_hook(
self, fullname: str
) -> Callable[[ClassDefContext], None] | None:
for x in ("DataArray", "Dataset"):
if fullname == f"xarray.core.extensions.register_{x.lower()}_accessor":
xr_cls = self._get_cls(f"{x.lower()}.{x}")
ac_cls = self._get_cls("extensions._CachedAccessor")
return partial(accessor_callback, accessor_cls=ac_cls, xr_cls=xr_cls)
return None

def _get_cls(self, typename: str) -> ClassDef:
cls = self.lookup_fully_qualified("xarray.core." + typename)
assert cls is not None
node = cls.node
assert isinstance(node, TypeInfo)
return node.defn


def plugin(version: str) -> type[XarrayPlugin]:
"""An entry-point for mypy."""
return XarrayPlugin
12 changes: 6 additions & 6 deletions xarray/tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ExampleAccessor:

def __init__(self, xarray_obj):
self.obj = xarray_obj
self.value = "initial"


class TestAccessor:
Expand Down Expand Up @@ -43,20 +44,19 @@ def foo(self):

# check descriptor
assert ds.demo.__doc__ == "Demo accessor."
# TODO: typing doesn't seem to work with accessors
assert xr.Dataset.demo.__doc__ == "Demo accessor." # type: ignore
assert isinstance(ds.demo, DemoAccessor) # type: ignore
assert xr.Dataset.demo is DemoAccessor # type: ignore
assert xr.Dataset.demo.__doc__ == "Demo accessor."
assert isinstance(ds.demo, DemoAccessor)
assert xr.Dataset.demo is DemoAccessor

# ensure we can remove it
del xr.Dataset.demo # type: ignore
del xr.Dataset.demo
assert not hasattr(xr.Dataset, "demo")

with pytest.warns(Warning, match="overriding a preexisting attribute"):

@xr.register_dataarray_accessor("demo")
class Foo:
pass
foo: str

# it didn't get registered again
assert not hasattr(xr.Dataset, "demo")
Expand Down