24
24
TypeVar ,
25
25
Union ,
26
26
cast ,
27
+ get_args ,
28
+ get_origin ,
27
29
)
28
30
from uuid import UUID
29
31
75
77
'DirectoryPath' ,
76
78
'NewPath' ,
77
79
'Json' ,
80
+ 'Secret' ,
78
81
'SecretStr' ,
79
82
'SecretBytes' ,
80
83
'StrictBool' ,
@@ -1332,7 +1335,8 @@ class Model(BaseModel):
1332
1335
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1333
1336
1334
1337
if TYPE_CHECKING :
1335
- Json = Annotated [AnyType , ...] # Json[list[str]] will be recognized by type checkers as list[str]
1338
+ # Json[list[str]] will be recognized by type checkers as list[str]
1339
+ Json = Annotated [AnyType , ...]
1336
1340
1337
1341
else :
1338
1342
@@ -1439,10 +1443,10 @@ def __eq__(self, other: Any) -> bool:
1439
1443
1440
1444
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1441
1445
1442
- SecretType = TypeVar ('SecretType' , str , bytes )
1446
+ SecretType = TypeVar ('SecretType' )
1443
1447
1444
1448
1445
- class _SecretField (Generic [SecretType ]):
1449
+ class _SecretBase (Generic [SecretType ]):
1446
1450
def __init__ (self , secret_value : SecretType ) -> None :
1447
1451
self ._secret_value : SecretType = secret_value
1448
1452
@@ -1460,29 +1464,124 @@ def __eq__(self, other: Any) -> bool:
1460
1464
def __hash__ (self ) -> int :
1461
1465
return hash (self .get_secret_value ())
1462
1466
1463
- def __len__ (self ) -> int :
1464
- return len (self ._secret_value )
1465
-
1466
1467
def __str__ (self ) -> str :
1467
1468
return str (self ._display ())
1468
1469
1469
1470
def __repr__ (self ) -> str :
1470
1471
return f'{ self .__class__ .__name__ } ({ self ._display ()!r} )'
1471
1472
1472
- def _display (self ) -> SecretType :
1473
+ def _display (self ) -> str | bytes :
1473
1474
raise NotImplementedError
1474
1475
1476
+
1477
+ class Secret (_SecretBase [SecretType ]):
1478
+ """A generic base class used for defining a field with sensitive information that you do not want to be visible in logging or tracebacks.
1479
+
1480
+ You may either directly parametrize `Secret` with a type, or subclass from `Secret` with a parametrized type. The benefit of subclassing
1481
+ is that you can define a custom `_display` method, which will be used for `repr()` and `str()` methods. The examples below demonstrate both
1482
+ ways of using `Secret` to create a new secret type.
1483
+
1484
+ 1. Directly parametrizing `Secret` with a type:
1485
+
1486
+ ```py
1487
+ from pydantic import BaseModel, Secret
1488
+
1489
+ SecretBool = Secret[bool]
1490
+
1491
+ class Model(BaseModel):
1492
+ secret_bool: SecretBool
1493
+
1494
+ m = Model(secret_bool=True)
1495
+ print(m.model_dump())
1496
+ #> {'secret_bool': Secret('**********')}
1497
+
1498
+ print(m.model_dump_json())
1499
+ #> {"secret_bool":"**********"}
1500
+
1501
+ print(m.secret_bool.get_secret_value())
1502
+ #> True
1503
+ ```
1504
+
1505
+ 2. Subclassing from parametrized `Secret`:
1506
+
1507
+ ```py
1508
+ from datetime import date
1509
+
1510
+ from pydantic import BaseModel, Secret
1511
+
1512
+ class SecretDate(Secret[date]):
1513
+ def _display(self) -> str:
1514
+ return '****/**/**'
1515
+
1516
+ class Model(BaseModel):
1517
+ secret_date: SecretDate
1518
+
1519
+ m = Model(secret_date=date(2022, 1, 1))
1520
+ print(m.model_dump())
1521
+ #> {'secret_date': SecretDate('****/**/**')}
1522
+
1523
+ print(m.model_dump_json())
1524
+ #> {"secret_date":"****/**/**"}
1525
+
1526
+ print(m.secret_date.get_secret_value())
1527
+ #> 2022-01-01
1528
+ ```
1529
+
1530
+ The value returned by the `_display` method will be used for `repr()` and `str()`.
1531
+ """
1532
+
1533
+ def _display (self ) -> str | bytes :
1534
+ return '**********' if self .get_secret_value () else ''
1535
+
1475
1536
@classmethod
1476
1537
def __get_pydantic_core_schema__ (cls , source : type [Any ], handler : GetCoreSchemaHandler ) -> core_schema .CoreSchema :
1477
- if issubclass (source , SecretStr ):
1478
- field_type = str
1479
- inner_schema = core_schema .str_schema ()
1538
+ inner_type = None
1539
+ # if origin_type is Secret, then cls is a GenericAlias, and we can extract the inner type directly
1540
+ origin_type = get_origin (source )
1541
+ if origin_type is not None :
1542
+ inner_type = get_args (source )[0 ]
1543
+ # otherwise, we need to get the inner type from the base class
1480
1544
else :
1481
- assert issubclass (source , SecretBytes )
1482
- field_type = bytes
1483
- inner_schema = core_schema .bytes_schema ()
1484
- error_kind = 'string_type' if field_type is str else 'bytes_type'
1545
+ bases = getattr (cls , '__orig_bases__' , getattr (cls , '__bases__' , []))
1546
+ for base in bases :
1547
+ if get_origin (base ) is Secret :
1548
+ inner_type = get_args (base )[0 ]
1549
+ if bases == [] or inner_type is None :
1550
+ raise TypeError (
1551
+ f"Can't get secret type from { cls .__name__ } . "
1552
+ 'Please use Secret[<type>], or subclass from Secret[<type>] instead.'
1553
+ )
1554
+
1555
+ inner_schema = handler .generate_schema (inner_type ) # type: ignore
1556
+
1557
+ def validate_secret_value (value , handler ) -> Secret [SecretType ]:
1558
+ if isinstance (value , Secret ):
1559
+ value = value .get_secret_value ()
1560
+ validated_inner = handler (value )
1561
+ return cls (validated_inner )
1562
+
1563
+ return core_schema .json_or_python_schema (
1564
+ python_schema = core_schema .no_info_wrap_validator_function (
1565
+ validate_secret_value ,
1566
+ inner_schema ,
1567
+ serialization = core_schema .plain_serializer_function_ser_schema (lambda x : x ),
1568
+ ),
1569
+ json_schema = core_schema .no_info_after_validator_function (
1570
+ lambda x : cls (x ), inner_schema , serialization = core_schema .to_string_ser_schema (when_used = 'json' )
1571
+ ),
1572
+ )
1573
+
1485
1574
1575
+ def _secret_display (value : SecretType ) -> str : # type: ignore
1576
+ return '**********' if value else ''
1577
+
1578
+
1579
+ class _SecretField (_SecretBase [SecretType ]):
1580
+ _inner_schema : ClassVar [CoreSchema ]
1581
+ _error_kind : ClassVar [str ]
1582
+
1583
+ @classmethod
1584
+ def __get_pydantic_core_schema__ (cls , source : type [Any ], handler : GetCoreSchemaHandler ) -> core_schema .CoreSchema :
1486
1585
def serialize (
1487
1586
value : _SecretField [SecretType ], info : core_schema .SerializationInfo
1488
1587
) -> str | _SecretField [SecretType ]:
@@ -1494,7 +1593,7 @@ def serialize(
1494
1593
return value
1495
1594
1496
1595
def get_json_schema (_core_schema : core_schema .CoreSchema , handler : GetJsonSchemaHandler ) -> JsonSchemaValue :
1497
- json_schema = handler (inner_schema )
1596
+ json_schema = handler (cls . _inner_schema )
1498
1597
_utils .update_not_none (
1499
1598
json_schema ,
1500
1599
type = 'string' ,
@@ -1505,31 +1604,33 @@ def get_json_schema(_core_schema: core_schema.CoreSchema, handler: GetJsonSchema
1505
1604
1506
1605
json_schema = core_schema .no_info_after_validator_function (
1507
1606
source , # construct the type
1508
- inner_schema ,
1607
+ cls . _inner_schema ,
1509
1608
)
1510
- s = core_schema .json_or_python_schema (
1511
- python_schema = core_schema .union_schema (
1512
- [
1513
- core_schema .is_instance_schema (source ),
1514
- json_schema ,
1515
- ],
1516
- strict = True ,
1517
- custom_error_type = error_kind ,
1518
- ),
1519
- json_schema = json_schema ,
1520
- serialization = core_schema .plain_serializer_function_ser_schema (
1521
- serialize ,
1522
- info_arg = True ,
1523
- return_schema = core_schema .str_schema (),
1524
- when_used = 'json' ,
1525
- ),
1526
- )
1527
- s .setdefault ('metadata' , {}).setdefault ('pydantic_js_functions' , []).append (get_json_schema )
1528
- return s
1529
1609
1610
+ def get_secret_schema (strict : bool ) -> CoreSchema :
1611
+ return core_schema .json_or_python_schema (
1612
+ python_schema = core_schema .union_schema (
1613
+ [
1614
+ core_schema .is_instance_schema (source ),
1615
+ json_schema ,
1616
+ ],
1617
+ custom_error_type = cls ._error_kind ,
1618
+ strict = strict ,
1619
+ ),
1620
+ json_schema = json_schema ,
1621
+ serialization = core_schema .plain_serializer_function_ser_schema (
1622
+ serialize ,
1623
+ info_arg = True ,
1624
+ return_schema = core_schema .str_schema (),
1625
+ when_used = 'json' ,
1626
+ ),
1627
+ )
1530
1628
1531
- def _secret_display (value : str | bytes ) -> str :
1532
- return '**********' if value else ''
1629
+ return core_schema .lax_or_strict_schema (
1630
+ lax_schema = get_secret_schema (strict = False ),
1631
+ strict_schema = get_secret_schema (strict = True ),
1632
+ metadata = {'pydantic_js_functions' : [get_json_schema ]},
1633
+ )
1533
1634
1534
1635
1535
1636
class SecretStr (_SecretField [str ]):
@@ -1556,8 +1657,14 @@ class User(BaseModel):
1556
1657
```
1557
1658
"""
1558
1659
1660
+ _inner_schema : ClassVar [CoreSchema ] = core_schema .str_schema ()
1661
+ _error_kind : ClassVar [str ] = 'string_type'
1662
+
1663
+ def __len__ (self ) -> int :
1664
+ return len (self ._secret_value )
1665
+
1559
1666
def _display (self ) -> str :
1560
- return _secret_display (self .get_secret_value () )
1667
+ return _secret_display (self ._secret_value )
1561
1668
1562
1669
1563
1670
class SecretBytes (_SecretField [bytes ]):
@@ -1583,8 +1690,14 @@ class User(BaseModel):
1583
1690
```
1584
1691
"""
1585
1692
1693
+ _inner_schema : ClassVar [CoreSchema ] = core_schema .bytes_schema ()
1694
+ _error_kind : ClassVar [str ] = 'bytes_type'
1695
+
1696
+ def __len__ (self ) -> int :
1697
+ return len (self ._secret_value )
1698
+
1586
1699
def _display (self ) -> bytes :
1587
- return _secret_display (self .get_secret_value () ).encode ()
1700
+ return _secret_display (self ._secret_value ).encode ()
1588
1701
1589
1702
1590
1703
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
0 commit comments