Skip to content

Commit 99e1dad

Browse files
committed
Add mypy plugin for NewType
1 parent 520698f commit 99e1dad

File tree

4 files changed

+157
-0
lines changed

4 files changed

+157
-0
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Auto-generated from tests
2+
test-output
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

marshmallow_dataclass/mypy.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import inspect
2+
from typing import Callable, Optional, Type
3+
4+
from mypy import nodes
5+
from mypy.plugin import DynamicClassDefContext, Plugin
6+
7+
import marshmallow_dataclass
8+
9+
_NEW_TYPE_SIG = inspect.signature(marshmallow_dataclass.NewType)
10+
11+
12+
def plugin(version: str) -> Type[Plugin]:
13+
return MarshmallowDataclassPlugin
14+
15+
16+
class MarshmallowDataclassPlugin(Plugin):
17+
def get_dynamic_class_hook(
18+
self, fullname: str
19+
) -> Optional[Callable[[DynamicClassDefContext], None]]:
20+
if fullname == "marshmallow_dataclass.NewType":
21+
return new_type_hook
22+
return None
23+
24+
25+
def new_type_hook(ctx: DynamicClassDefContext) -> None:
26+
"""
27+
Dynamic class hook for :func:`marshmallow_dataclass.NewType`.
28+
29+
Uses the type of the ``typ`` argument.
30+
"""
31+
typ = _get_arg_by_name(ctx.call, "typ", _NEW_TYPE_SIG)
32+
if not isinstance(typ, nodes.RefExpr):
33+
return
34+
info = typ.node
35+
if not isinstance(info, nodes.TypeInfo):
36+
return
37+
ctx.api.add_symbol_table_node(ctx.name, nodes.SymbolTableNode(nodes.GDEF, info))
38+
39+
40+
def _get_arg_by_name(
41+
call: nodes.CallExpr, name: str, sig: inspect.Signature
42+
) -> Optional[nodes.Expression]:
43+
"""
44+
Get value of argument from a call.
45+
46+
:return: The argument value, or ``None`` if it cannot be found.
47+
48+
.. warning::
49+
This probably doesn't yet work for calls with ``*args`` and/or ``*kwargs``.
50+
"""
51+
args = []
52+
kwargs = {}
53+
for arg_name, arg_value in zip(call.arg_names, call.args):
54+
if arg_name is None:
55+
args.append(arg_value)
56+
else:
57+
kwargs[arg_name] = arg_value
58+
try:
59+
bound_args = sig.bind(*args, **kwargs)
60+
except TypeError:
61+
return None
62+
try:
63+
return bound_args.arguments[name]
64+
except KeyError:
65+
return None

setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
':python_version == "3.6"': ["dataclasses"],
1919
"lint": ["pre-commit~=1.18"],
2020
"docs": ["sphinx"],
21+
"tests": ["mypy"],
2122
}
2223
EXTRAS_REQUIRE["dev"] = (
2324
EXTRAS_REQUIRE["enum"]
2425
+ EXTRAS_REQUIRE["union"]
2526
+ EXTRAS_REQUIRE["lint"]
2627
+ EXTRAS_REQUIRE["docs"]
28+
+ EXTRAS_REQUIRE["tests"]
2729
)
2830

2931
setup(

tests/test_mypy.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import shutil
2+
import textwrap
3+
import os
4+
import unittest
5+
6+
import mypy.api
7+
8+
HERE = os.path.dirname(__file__)
9+
TEST_OUTPUT_DIR = os.path.join(HERE, "test-output")
10+
MYPY_INI = """\
11+
[mypy]
12+
follow_imports = silent
13+
plugins = marshmallow_dataclass.mypy
14+
"""
15+
16+
17+
class TestMypyPlugin(unittest.TestCase):
18+
maxDiff = None
19+
20+
def setUp(self):
21+
"""
22+
Prepare a clean test directory at tests/test-output/test_mypy-{testname}.
23+
Also cd into it for the duration of the test to get simple filenames in mypy output.
24+
"""
25+
testname = self.id().split(".")[-1]
26+
self.testdir = os.path.join(TEST_OUTPUT_DIR, f"test_mypy-{testname}")
27+
if os.path.exists(self.testdir):
28+
shutil.rmtree(self.testdir)
29+
os.makedirs(self.testdir)
30+
self.old_cwd = os.getcwd()
31+
os.chdir(self.testdir)
32+
33+
def tearDown(self):
34+
os.chdir(self.old_cwd)
35+
36+
def mypy_test(self, contents: str, expected: str):
37+
"""
38+
Run mypy and assert output matches ``expected``.
39+
40+
The file with ``contents`` is always named ``main.py``.
41+
"""
42+
config_path = os.path.join(self.testdir, "mypy.ini")
43+
script_path = os.path.join(self.testdir, "main.py")
44+
with open(config_path, "w") as f:
45+
f.write(MYPY_INI)
46+
with open(script_path, "w") as f:
47+
f.write(textwrap.dedent(contents).strip())
48+
command = [script_path, "--config-file", config_path, "--no-error-summary"]
49+
out, err, returncode = mypy.api.run(command)
50+
err_msg = "\n".join(
51+
["", f"returncode: {returncode}", "stdout:", out, "stderr:", err]
52+
)
53+
self.assertEqual(out.strip(), textwrap.dedent(expected).strip(), err_msg)
54+
55+
def test_basic(self):
56+
self.mypy_test(
57+
"""
58+
from dataclasses import dataclass
59+
import marshmallow as ma
60+
from marshmallow_dataclass import NewType
61+
62+
Email = NewType("Email", str, validate=ma.validate.Email)
63+
UserID = NewType("UserID", validate=ma.validate.Length(equal=32), typ=str)
64+
65+
@dataclass
66+
class User:
67+
id: UserID
68+
email: Email
69+
70+
user = User(id="a"*32, email="[email protected]")
71+
reveal_type(user.id)
72+
reveal_type(user.email)
73+
74+
User(id=42, email="[email protected]")
75+
User(id="a"*32, email=["not", "a", "string"])
76+
""",
77+
"""
78+
main.py:14: note: Revealed type is 'builtins.str'
79+
main.py:15: note: Revealed type is 'builtins.str'
80+
main.py:17: error: Argument "id" to "User" has incompatible type "int"; expected "str"
81+
main.py:18: error: Argument "email" to "User" has incompatible type "List[str]"; expected "str"
82+
""",
83+
)
84+
85+
86+
if __name__ == "__main__":
87+
unittest.main()

0 commit comments

Comments
 (0)