Skip to content

Commit cf31345

Browse files
committed
Fixed stubgen parsing generics from C extensions
1 parent 08cd1d6 commit cf31345

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

mypy/stubgenc.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,16 @@ def strip_or_import(typ: str, module: ModuleType, imports: List[str]) -> str:
201201
imports: list of import statements (may be modified during the call)
202202
"""
203203
stripped_type = typ
204-
if module and typ.startswith(module.__name__ + '.'):
204+
if any(c in typ for c in '[,'):
205+
for subtyp in re.split(r'[\[,\]]', typ):
206+
strip_or_import(subtyp.strip(), module, imports)
207+
if module:
208+
stripped_type = re.sub(
209+
r'(^|[\[, ]+)' + re.escape(module.__name__ + '.'),
210+
r'\1',
211+
typ,
212+
)
213+
elif module and typ.startswith(module.__name__ + '.'):
205214
stripped_type = typ[len(module.__name__) + 1:]
206215
elif '.' in typ:
207216
arg_module = typ[:typ.rindex('.')]

mypy/test/teststubgen.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,81 @@ def test(arg0: str) -> None:
778778
assert_equal(output, ['def test(arg0: str) -> Action: ...'])
779779
assert_equal(imports, [])
780780

781+
def test_generate_c_type_with_single_arg_generic(self) -> None:
782+
class TestClass:
783+
def test(self, arg0: str) -> None:
784+
"""
785+
test(self: TestClass, arg0: List[int])
786+
"""
787+
pass
788+
output = [] # type: List[str]
789+
imports = [] # type: List[str]
790+
mod = ModuleType(TestClass.__module__, '')
791+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
792+
self_var='self', class_name='TestClass')
793+
assert_equal(output, ['def test(self, arg0: List[int]) -> Any: ...'])
794+
assert_equal(imports, [])
795+
796+
def test_generate_c_type_with_double_arg_generic(self) -> None:
797+
class TestClass:
798+
def test(self, arg0: str) -> None:
799+
"""
800+
test(self: TestClass, arg0: Dict[str, int])
801+
"""
802+
pass
803+
output = [] # type: List[str]
804+
imports = [] # type: List[str]
805+
mod = ModuleType(TestClass.__module__, '')
806+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
807+
self_var='self', class_name='TestClass')
808+
assert_equal(output, ['def test(self, arg0: Dict[str,int]) -> Any: ...'])
809+
assert_equal(imports, [])
810+
811+
def test_generate_c_type_with_nested_generic(self) -> None:
812+
class TestClass:
813+
def test(self, arg0: str) -> None:
814+
"""
815+
test(self: TestClass, arg0: Dict[str, List[int]])
816+
"""
817+
pass
818+
output = [] # type: List[str]
819+
imports = [] # type: List[str]
820+
mod = ModuleType(TestClass.__module__, '')
821+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
822+
self_var='self', class_name='TestClass')
823+
assert_equal(output, ['def test(self, arg0: Dict[str,List[int]]) -> Any: ...'])
824+
assert_equal(imports, [])
825+
826+
def test_generate_c_type_with_generic_using_other_module_first(self) -> None:
827+
class TestClass:
828+
def test(self, arg0: str) -> None:
829+
"""
830+
test(self: TestClass, arg0: Dict[argparse.Action, int])
831+
"""
832+
pass
833+
output = [] # type: List[str]
834+
imports = [] # type: List[str]
835+
mod = ModuleType(TestClass.__module__, '')
836+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
837+
self_var='self', class_name='TestClass')
838+
assert_equal(output, ['def test(self, arg0: Dict[argparse.Action,int]) -> Any: ...'])
839+
assert_equal(imports, ['import argparse'])
840+
841+
def test_generate_c_type_with_generic_using_other_module_last(self) -> None:
842+
class TestClass:
843+
def test(self, arg0: str) -> None:
844+
"""
845+
test(self: TestClass, arg0: Dict[str, argparse.Action])
846+
"""
847+
pass
848+
output = [] # type: List[str]
849+
imports = [] # type: List[str]
850+
mod = ModuleType(TestClass.__module__, '')
851+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
852+
self_var='self', class_name='TestClass')
853+
assert_equal(output, ['def test(self, arg0: Dict[str,argparse.Action]) -> Any: ...'])
854+
assert_equal(imports, ['import argparse'])
855+
781856
def test_generate_c_type_with_overload_pybind11(self) -> None:
782857
class TestClass:
783858
def __init__(self, arg0: str) -> None:

0 commit comments

Comments
 (0)