Skip to content

Commit c957609

Browse files
committed
Add mypy plugin for NewType
1 parent 520698f commit c957609

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Auto-generated from tests
2+
test-output
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

marshmallow_dataclass/mypy.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import inspect
2+
from typing import Any, Callable, Optional, Type, Optional
3+
4+
from mypy import nodes
5+
from mypy.plugin import DynamicClassDefContext, Plugin
6+
7+
import marshmallow_dataclass
8+
9+
_NEW_TYPE_SIG = inspect.signature(marshmallow_dataclass.NewType)
10+
11+
12+
def plugin(version: str) -> Type[Plugin]:
13+
return MarshmallowDataclassPlugin
14+
15+
16+
class MarshmallowDataclassPlugin(Plugin):
17+
def get_dynamic_class_hook(
18+
self, fullname: str
19+
) -> Optional[Callable[[DynamicClassDefContext], None]]:
20+
if fullname == "marshmallow_dataclass.NewType":
21+
return new_type_hook
22+
return None
23+
24+
25+
def new_type_hook(ctx: DynamicClassDefContext) -> None:
26+
"""
27+
Dynamic class hook for :func:`marshmallow_dataclass.NewType`.
28+
29+
Uses the type of the ``typ`` argument.
30+
"""
31+
typ = _get_arg_by_name(ctx.call, "typ", _NEW_TYPE_SIG)
32+
if not isinstance(typ, nodes.RefExpr):
33+
return
34+
info = typ.node
35+
if not isinstance(info, nodes.TypeInfo):
36+
return
37+
ctx.api.add_symbol_table_node(ctx.name, nodes.SymbolTableNode(nodes.GDEF, info))
38+
39+
40+
def _get_arg_by_name(
41+
call: nodes.CallExpr, name: str, sig: inspect.Signature
42+
) -> Optional[nodes.Expression]:
43+
"""
44+
Get value of argument from a call.
45+
46+
:return: The argument value, or ``None`` if it cannot be found.
47+
48+
.. warning::
49+
This probably doesn't yet work for calls with ``*args`` and/or ``*kwargs``.
50+
"""
51+
args = []
52+
kwargs = {}
53+
for arg_name, arg_value in zip(call.arg_names, call.args):
54+
if arg_name is None:
55+
args.append(arg_value)
56+
else:
57+
kwargs[arg_name] = arg_value
58+
try:
59+
bound_args = sig.bind(*args, **kwargs)
60+
except TypeError:
61+
return None
62+
try:
63+
return bound_args.arguments[name]
64+
except KeyError:
65+
return None

tests/test_mypy.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import shutil
2+
import textwrap
3+
import os
4+
import unittest
5+
6+
import mypy.api
7+
8+
HERE = os.path.dirname(__file__)
9+
TEST_OUTPUT_DIR = os.path.join(HERE, "test-output")
10+
MYPY_INI = """\
11+
[mypy]
12+
follow_imports = silent
13+
plugins = marshmallow_dataclass.mypy
14+
"""
15+
16+
17+
class TestMypyPlugin(unittest.TestCase):
18+
maxDiff = None
19+
20+
def setUp(self):
21+
"""
22+
Prepare a clean test directory at tests/test-output/test_mypy-{testname}.
23+
Also cd into it for the duration of the test to get simple filenames in mypy output.
24+
"""
25+
testname = self.id().split(".")[-1]
26+
self.testdir = os.path.join(TEST_OUTPUT_DIR, f"test_mypy-{testname}")
27+
if os.path.exists(self.testdir):
28+
shutil.rmtree(self.testdir)
29+
os.makedirs(self.testdir)
30+
self.old_cwd = os.getcwd()
31+
os.chdir(self.testdir)
32+
33+
def tearDown(self):
34+
os.chdir(self.old_cwd)
35+
36+
def mypy_test(self, contents: str, expected: str):
37+
"""
38+
Run mypy and assert output matches ``expected``.
39+
40+
The file with ``contents`` is always named ``main.py``.
41+
"""
42+
config_path = os.path.join(self.testdir, "mypy.ini")
43+
script_path = os.path.join(self.testdir, "main.py")
44+
with open(config_path, "w") as f:
45+
f.write(MYPY_INI)
46+
with open(script_path, "w") as f:
47+
f.write(textwrap.dedent(contents).strip())
48+
command = [script_path, "--config-file", config_path, "--no-error-summary"]
49+
out, err, returncode = mypy.api.run(command)
50+
err_msg = "\n".join(["", f"returncode: {returncode}", "stdout:", out, "stderr:", err])
51+
self.assertEqual(out.strip(), textwrap.dedent(expected).strip(), err_msg)
52+
53+
def test_basic(self):
54+
self.mypy_test(
55+
"""
56+
from dataclasses import dataclass
57+
import marshmallow as ma
58+
from marshmallow_dataclass import NewType
59+
60+
Email = NewType("Email", str, validate=ma.validate.Email)
61+
UserID = NewType("UserID", validate=ma.validate.Length(equal=32), typ=str)
62+
63+
@dataclass
64+
class User:
65+
id: UserID
66+
email: Email
67+
68+
user = User(id="a"*32, email="[email protected]")
69+
reveal_type(user.id)
70+
reveal_type(user.email)
71+
72+
User(id=42, email="[email protected]")
73+
User(id="a"*32, email=["not", "a", "string"])
74+
""",
75+
"""
76+
main.py:14: note: Revealed type is 'builtins.str'
77+
main.py:15: note: Revealed type is 'builtins.str'
78+
main.py:17: error: Argument "id" to "User" has incompatible type "int"; expected "str"
79+
main.py:18: error: Argument "email" to "User" has incompatible type "List[str]"; expected "str"
80+
""",
81+
)
82+
83+
84+
if __name__ == "__main__":
85+
unittest.main()

0 commit comments

Comments
 (0)