Skip to content

Commit a172ad8

Browse files
author
Shyam Dwaraknath
authored
feat: Add dataclass and pydantic support
References: #9, #27
1 parent a74dccf commit a172ad8

File tree

5 files changed

+109
-2
lines changed

5 files changed

+109
-2
lines changed

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ mkdocs = "^1.1"
3333
mkdocstrings = "^0.10.3"
3434
mkdocs-material = "^4.6.3"
3535
mypy = "^0.770"
36-
pydantic = "^1.4"
36+
pydantic = "^1.5.1"
3737
pylint = { git = "https://github.com/PyCQA/pylint.git" }
3838
pytest = "~5.3.5"
3939
pytest-cov = "^2.8.1"

Diff for: src/pytkdocs/loader.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def get_module_documentation(self, node: ObjectNode, members=None) -> Module:
271271
source = None
272272

273273
root_object = Module(
274-
name=name, path=path, file_path=node.file_path, docstring=inspect.getdoc(module) or "", source=source,
274+
name=name, path=path, file_path=node.file_path, docstring=inspect.getdoc(module) or "", source=source
275275
)
276276

277277
if members is False:
@@ -338,6 +338,22 @@ def get_class_documentation(self, node: ObjectNode, members=None) -> Class:
338338
elif child_node.is_property():
339339
root_object.add_child(self.get_property_documentation(child_node))
340340

341+
# First check if this is pdyantic compataible
342+
if "__fields__" in class_.__dict__:
343+
root_object.properties = ["pydantic"]
344+
for field_name, model_field in class_.__dict__.get("__fields__", {}).items():
345+
if self.select(field_name, members): # type: ignore
346+
child_node = ObjectNode(obj=model_field, name=field_name, parent=node)
347+
root_object.add_child(self.get_pydantic_field_documentation(child_node))
348+
349+
# Handle dataclasses
350+
elif "__dataclass_fields__" in class_.__dict__:
351+
root_object.properties = ["dataclass"]
352+
for field_name, annotation in class_.__dict__.get("__annotations__", {}).items():
353+
if self.select(field_name, members): # type: ignore
354+
child_node = ObjectNode(obj=annotation, name=field_name, parent=node)
355+
root_object.add_child(self.get_annotated_dataclass_field(child_node))
356+
341357
return root_object
342358

343359
def get_function_documentation(self, node: ObjectNode) -> Function:
@@ -415,6 +431,49 @@ def get_property_documentation(self, node: ObjectNode) -> Attribute:
415431
source=source,
416432
)
417433

434+
def get_pydantic_field_documentation(self, node: ObjectNode) -> Attribute:
435+
"""
436+
Get the documentation for a PyDantic Field
437+
438+
Arguments:
439+
node: The node representing the Field and its parents.
440+
441+
Return:
442+
The documented attribute object.
443+
"""
444+
prop = node.obj
445+
path = node.dotted_path
446+
properties = ["field", "pydantic"]
447+
if prop.required:
448+
properties.append("required")
449+
450+
return Attribute(
451+
name=node.name,
452+
path=path,
453+
file_path=node.file_path,
454+
docstring=prop.field_info.description,
455+
attr_type=prop.type_,
456+
properties=properties,
457+
)
458+
459+
def get_annotated_dataclass_field(self, node: ObjectNode) -> Attribute:
460+
"""
461+
Get the documentation for an dataclass annotation.
462+
463+
Arguments:
464+
node: The node representing the annotation and its parents.
465+
466+
Return:
467+
The documented attribute object.
468+
"""
469+
annotation: type = node.obj
470+
path = node.dotted_path
471+
properties = ["field"]
472+
473+
return Attribute(
474+
name=node.name, path=path, file_path=node.file_path, attr_type=annotation, properties=properties
475+
)
476+
418477
def get_classmethod_documentation(self, node: ObjectNode) -> Method:
419478
"""
420479
Get the documentation for a class-method.

Diff for: tests/fixtures/dataclass.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class Person:
6+
"""Simple dataclass for a person's information"""
7+
8+
name: str
9+
age: int

Diff for: tests/fixtures/pydantic.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pydantic import BaseModel, Field
2+
3+
4+
class Person(BaseModel):
5+
"""Simple Pydantic Model for a person's information"""
6+
7+
name: str = Field("PersonA", description="The person's name")
8+
age: int = Field(18, description="The person's age which must be at minimum 18")

Diff for: tests/test_loader.py

+31
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,37 @@ def test_loading_class():
107107
assert obj.docstring == "The class docstring."
108108

109109

110+
def test_loading_dataclass():
111+
loader = Loader()
112+
obj = loader.get_object_documentation("tests.fixtures.dataclass.Person")
113+
assert obj.docstring == "Simple dataclass for a person's information"
114+
assert len(obj.attributes) == 2
115+
name_attr = next(attr for attr in obj.attributes if attr.name == "name")
116+
assert name_attr.type == str
117+
age_attr = next(attr for attr in obj.attributes if attr.name == "age")
118+
assert age_attr.type == int
119+
assert "dataclass" in obj.properties
120+
121+
not_dataclass = loader.get_object_documentation("tests.fixtures.the_package.the_module.TheClass.TheNestedClass")
122+
assert "dataclass" not in not_dataclass.properties
123+
124+
125+
def test_loading_pydantic_model():
126+
loader = Loader()
127+
obj = loader.get_object_documentation("tests.fixtures.pydantic.Person")
128+
assert obj.docstring == "Simple Pydantic Model for a person's information"
129+
assert "pydantic" in obj.properties
130+
assert len(obj.attributes) == 2
131+
name_attr = next(attr for attr in obj.attributes if attr.name == "name")
132+
assert name_attr.type == str
133+
assert name_attr.docstring == "The person's name"
134+
assert "pydantic" in name_attr.properties
135+
age_attr = next(attr for attr in obj.attributes if attr.name == "age")
136+
assert age_attr.type == int
137+
assert age_attr.docstring == "The person's age which must be at minimum 18"
138+
assert "pydantic" in age_attr.properties
139+
140+
110141
def test_loading_nested_class():
111142
loader = Loader()
112143
obj = loader.get_object_documentation("tests.fixtures.the_package.the_module.TheClass.TheNestedClass")

0 commit comments

Comments
 (0)