Skip to content

Commit 1c91c86

Browse files
conradogarciaberrotaransydney-runkledmontagu
authored
Add Secret base type (#8519)
Co-authored-by: sydney-runkle <[email protected]> Co-authored-by: Sydney Runkle <[email protected]> Co-authored-by: David Montague <[email protected]>
1 parent 5463055 commit 1c91c86

File tree

5 files changed

+433
-42
lines changed

5 files changed

+433
-42
lines changed

docs/api/types.md

+3
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
::: pydantic.types
2+
options:
3+
show_root_heading: true
4+
merge_init_into_class: false

docs/examples/secrets.md

+100
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,103 @@ print(model.model_dump())
3737
print(model.model_dump_json())
3838
#> {"password":"IAmSensitive","password_bytes":"IAmSensitiveBytes"}
3939
```
40+
41+
## Create your own Secret field
42+
43+
Pydantic provides the generic `Secret` class as a mechanism for creating custom secret types.
44+
45+
??? api "API Documentation"
46+
[`pydantic.types.Secret`][pydantic.types.Secret]<br>
47+
48+
Pydantic provides the generic `Secret` class as a mechanism for creating custom secret types.
49+
You can either directly parametrize `Secret`, or subclass from a parametrized `Secret` to customize the `str()` and `repr()` of a secret type.
50+
51+
```py
52+
from datetime import date
53+
54+
from pydantic import BaseModel, Secret
55+
56+
# Using the default representation
57+
SecretDate = Secret[date]
58+
59+
60+
# Overwriting the representation
61+
class SecretSalary(Secret[float]):
62+
def _display(self) -> str:
63+
return '$****.**'
64+
65+
66+
class Employee(BaseModel):
67+
date_of_birth: SecretDate
68+
salary: SecretSalary
69+
70+
71+
employee = Employee(date_of_birth='1990-01-01', salary=42)
72+
73+
print(employee)
74+
#> date_of_birth=Secret('**********') salary=SecretSalary('$****.**')
75+
76+
print(employee.salary)
77+
#> $****.**
78+
79+
print(employee.salary.get_secret_value())
80+
#> 42.0
81+
82+
print(employee.date_of_birth)
83+
#> **********
84+
85+
print(employee.date_of_birth.get_secret_value())
86+
#> 1990-01-01
87+
```
88+
89+
You can enforce constraints on the underlying type through annotations:
90+
For example:
91+
92+
```py
93+
from typing_extensions import Annotated
94+
95+
from pydantic import BaseModel, Field, Secret, ValidationError
96+
97+
SecretPosInt = Secret[Annotated[int, Field(gt=0, strict=True)]]
98+
99+
100+
class Model(BaseModel):
101+
sensitive_int: SecretPosInt
102+
103+
104+
m = Model(sensitive_int=42)
105+
print(m.model_dump())
106+
#> {'sensitive_int': Secret('**********')}
107+
108+
try:
109+
m = Model(sensitive_int=-42) # (1)!
110+
except ValidationError as exc_info:
111+
print(exc_info.errors(include_url=False, include_input=False))
112+
"""
113+
[
114+
{
115+
'type': 'greater_than',
116+
'loc': ('sensitive_int',),
117+
'msg': 'Input should be greater than 0',
118+
'ctx': {'gt': 0},
119+
}
120+
]
121+
"""
122+
123+
try:
124+
m = Model(sensitive_int='42') # (2)!
125+
except ValidationError as exc_info:
126+
print(exc_info.errors(include_url=False, include_input=False))
127+
"""
128+
[
129+
{
130+
'type': 'int_type',
131+
'loc': ('sensitive_int',),
132+
'msg': 'Input should be a valid integer',
133+
}
134+
]
135+
"""
136+
```
137+
138+
1. The input value is not greater than 0, so it raises a validation error.
139+
2. The input value is not an integer, so it raises a validation error because the `SecretPosInt` type has strict mode enabled.

pydantic/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
'DirectoryPath',
163163
'NewPath',
164164
'Json',
165+
'Secret',
165166
'SecretStr',
166167
'SecretBytes',
167168
'StrictBool',
@@ -310,6 +311,7 @@
310311
'DirectoryPath': (__package__, '.types'),
311312
'NewPath': (__package__, '.types'),
312313
'Json': (__package__, '.types'),
314+
'Secret': (__package__, '.types'),
313315
'SecretStr': (__package__, '.types'),
314316
'SecretBytes': (__package__, '.types'),
315317
'StrictBool': (__package__, '.types'),

pydantic/types.py

+152-39
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
TypeVar,
2525
Union,
2626
cast,
27+
get_args,
28+
get_origin,
2729
)
2830
from uuid import UUID
2931

@@ -75,6 +77,7 @@
7577
'DirectoryPath',
7678
'NewPath',
7779
'Json',
80+
'Secret',
7881
'SecretStr',
7982
'SecretBytes',
8083
'StrictBool',
@@ -1332,7 +1335,8 @@ class Model(BaseModel):
13321335
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13331336

13341337
if TYPE_CHECKING:
1335-
Json = Annotated[AnyType, ...] # Json[list[str]] will be recognized by type checkers as list[str]
1338+
# Json[list[str]] will be recognized by type checkers as list[str]
1339+
Json = Annotated[AnyType, ...]
13361340

13371341
else:
13381342

@@ -1439,10 +1443,10 @@ def __eq__(self, other: Any) -> bool:
14391443

14401444
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14411445

1442-
SecretType = TypeVar('SecretType', str, bytes)
1446+
SecretType = TypeVar('SecretType')
14431447

14441448

1445-
class _SecretField(Generic[SecretType]):
1449+
class _SecretBase(Generic[SecretType]):
14461450
def __init__(self, secret_value: SecretType) -> None:
14471451
self._secret_value: SecretType = secret_value
14481452

@@ -1460,29 +1464,124 @@ def __eq__(self, other: Any) -> bool:
14601464
def __hash__(self) -> int:
14611465
return hash(self.get_secret_value())
14621466

1463-
def __len__(self) -> int:
1464-
return len(self._secret_value)
1465-
14661467
def __str__(self) -> str:
14671468
return str(self._display())
14681469

14691470
def __repr__(self) -> str:
14701471
return f'{self.__class__.__name__}({self._display()!r})'
14711472

1472-
def _display(self) -> SecretType:
1473+
def _display(self) -> str | bytes:
14731474
raise NotImplementedError
14741475

1476+
1477+
class Secret(_SecretBase[SecretType]):
1478+
"""A generic base class used for defining a field with sensitive information that you do not want to be visible in logging or tracebacks.
1479+
1480+
You may either directly parametrize `Secret` with a type, or subclass from `Secret` with a parametrized type. The benefit of subclassing
1481+
is that you can define a custom `_display` method, which will be used for `repr()` and `str()` methods. The examples below demonstrate both
1482+
ways of using `Secret` to create a new secret type.
1483+
1484+
1. Directly parametrizing `Secret` with a type:
1485+
1486+
```py
1487+
from pydantic import BaseModel, Secret
1488+
1489+
SecretBool = Secret[bool]
1490+
1491+
class Model(BaseModel):
1492+
secret_bool: SecretBool
1493+
1494+
m = Model(secret_bool=True)
1495+
print(m.model_dump())
1496+
#> {'secret_bool': Secret('**********')}
1497+
1498+
print(m.model_dump_json())
1499+
#> {"secret_bool":"**********"}
1500+
1501+
print(m.secret_bool.get_secret_value())
1502+
#> True
1503+
```
1504+
1505+
2. Subclassing from parametrized `Secret`:
1506+
1507+
```py
1508+
from datetime import date
1509+
1510+
from pydantic import BaseModel, Secret
1511+
1512+
class SecretDate(Secret[date]):
1513+
def _display(self) -> str:
1514+
return '****/**/**'
1515+
1516+
class Model(BaseModel):
1517+
secret_date: SecretDate
1518+
1519+
m = Model(secret_date=date(2022, 1, 1))
1520+
print(m.model_dump())
1521+
#> {'secret_date': SecretDate('****/**/**')}
1522+
1523+
print(m.model_dump_json())
1524+
#> {"secret_date":"****/**/**"}
1525+
1526+
print(m.secret_date.get_secret_value())
1527+
#> 2022-01-01
1528+
```
1529+
1530+
The value returned by the `_display` method will be used for `repr()` and `str()`.
1531+
"""
1532+
1533+
def _display(self) -> str | bytes:
1534+
return '**********' if self.get_secret_value() else ''
1535+
14751536
@classmethod
14761537
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
1477-
if issubclass(source, SecretStr):
1478-
field_type = str
1479-
inner_schema = core_schema.str_schema()
1538+
inner_type = None
1539+
# if origin_type is Secret, then cls is a GenericAlias, and we can extract the inner type directly
1540+
origin_type = get_origin(source)
1541+
if origin_type is not None:
1542+
inner_type = get_args(source)[0]
1543+
# otherwise, we need to get the inner type from the base class
14801544
else:
1481-
assert issubclass(source, SecretBytes)
1482-
field_type = bytes
1483-
inner_schema = core_schema.bytes_schema()
1484-
error_kind = 'string_type' if field_type is str else 'bytes_type'
1545+
bases = getattr(cls, '__orig_bases__', getattr(cls, '__bases__', []))
1546+
for base in bases:
1547+
if get_origin(base) is Secret:
1548+
inner_type = get_args(base)[0]
1549+
if bases == [] or inner_type is None:
1550+
raise TypeError(
1551+
f"Can't get secret type from {cls.__name__}. "
1552+
'Please use Secret[<type>], or subclass from Secret[<type>] instead.'
1553+
)
1554+
1555+
inner_schema = handler.generate_schema(inner_type) # type: ignore
1556+
1557+
def validate_secret_value(value, handler) -> Secret[SecretType]:
1558+
if isinstance(value, Secret):
1559+
value = value.get_secret_value()
1560+
validated_inner = handler(value)
1561+
return cls(validated_inner)
1562+
1563+
return core_schema.json_or_python_schema(
1564+
python_schema=core_schema.no_info_wrap_validator_function(
1565+
validate_secret_value,
1566+
inner_schema,
1567+
serialization=core_schema.plain_serializer_function_ser_schema(lambda x: x),
1568+
),
1569+
json_schema=core_schema.no_info_after_validator_function(
1570+
lambda x: cls(x), inner_schema, serialization=core_schema.to_string_ser_schema(when_used='json')
1571+
),
1572+
)
1573+
14851574

1575+
def _secret_display(value: SecretType) -> str: # type: ignore
1576+
return '**********' if value else ''
1577+
1578+
1579+
class _SecretField(_SecretBase[SecretType]):
1580+
_inner_schema: ClassVar[CoreSchema]
1581+
_error_kind: ClassVar[str]
1582+
1583+
@classmethod
1584+
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
14861585
def serialize(
14871586
value: _SecretField[SecretType], info: core_schema.SerializationInfo
14881587
) -> str | _SecretField[SecretType]:
@@ -1494,7 +1593,7 @@ def serialize(
14941593
return value
14951594

14961595
def get_json_schema(_core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
1497-
json_schema = handler(inner_schema)
1596+
json_schema = handler(cls._inner_schema)
14981597
_utils.update_not_none(
14991598
json_schema,
15001599
type='string',
@@ -1505,31 +1604,33 @@ def get_json_schema(_core_schema: core_schema.CoreSchema, handler: GetJsonSchema
15051604

15061605
json_schema = core_schema.no_info_after_validator_function(
15071606
source, # construct the type
1508-
inner_schema,
1607+
cls._inner_schema,
15091608
)
1510-
s = core_schema.json_or_python_schema(
1511-
python_schema=core_schema.union_schema(
1512-
[
1513-
core_schema.is_instance_schema(source),
1514-
json_schema,
1515-
],
1516-
strict=True,
1517-
custom_error_type=error_kind,
1518-
),
1519-
json_schema=json_schema,
1520-
serialization=core_schema.plain_serializer_function_ser_schema(
1521-
serialize,
1522-
info_arg=True,
1523-
return_schema=core_schema.str_schema(),
1524-
when_used='json',
1525-
),
1526-
)
1527-
s.setdefault('metadata', {}).setdefault('pydantic_js_functions', []).append(get_json_schema)
1528-
return s
15291609

1610+
def get_secret_schema(strict: bool) -> CoreSchema:
1611+
return core_schema.json_or_python_schema(
1612+
python_schema=core_schema.union_schema(
1613+
[
1614+
core_schema.is_instance_schema(source),
1615+
json_schema,
1616+
],
1617+
custom_error_type=cls._error_kind,
1618+
strict=strict,
1619+
),
1620+
json_schema=json_schema,
1621+
serialization=core_schema.plain_serializer_function_ser_schema(
1622+
serialize,
1623+
info_arg=True,
1624+
return_schema=core_schema.str_schema(),
1625+
when_used='json',
1626+
),
1627+
)
15301628

1531-
def _secret_display(value: str | bytes) -> str:
1532-
return '**********' if value else ''
1629+
return core_schema.lax_or_strict_schema(
1630+
lax_schema=get_secret_schema(strict=False),
1631+
strict_schema=get_secret_schema(strict=True),
1632+
metadata={'pydantic_js_functions': [get_json_schema]},
1633+
)
15331634

15341635

15351636
class SecretStr(_SecretField[str]):
@@ -1556,8 +1657,14 @@ class User(BaseModel):
15561657
```
15571658
"""
15581659

1660+
_inner_schema: ClassVar[CoreSchema] = core_schema.str_schema()
1661+
_error_kind: ClassVar[str] = 'string_type'
1662+
1663+
def __len__(self) -> int:
1664+
return len(self._secret_value)
1665+
15591666
def _display(self) -> str:
1560-
return _secret_display(self.get_secret_value())
1667+
return _secret_display(self._secret_value)
15611668

15621669

15631670
class SecretBytes(_SecretField[bytes]):
@@ -1583,8 +1690,14 @@ class User(BaseModel):
15831690
```
15841691
"""
15851692

1693+
_inner_schema: ClassVar[CoreSchema] = core_schema.bytes_schema()
1694+
_error_kind: ClassVar[str] = 'bytes_type'
1695+
1696+
def __len__(self) -> int:
1697+
return len(self._secret_value)
1698+
15861699
def _display(self) -> bytes:
1587-
return _secret_display(self.get_secret_value()).encode()
1700+
return _secret_display(self._secret_value).encode()
15881701

15891702

15901703
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)