15
15
import logging
16
16
import threading
17
17
import weakref
18
+ from enum import Enum
18
19
from functools import wraps
19
20
from typing import (
20
21
TYPE_CHECKING ,
21
22
Any ,
22
23
Callable ,
23
24
Collection ,
25
+ Dict ,
24
26
Generic ,
25
- Iterable ,
26
27
List ,
27
28
Optional ,
28
29
Type ,
@@ -190,7 +191,7 @@ def __init__(
190
191
root : "ListNode[_Node]" ,
191
192
key : KT ,
192
193
value : VT ,
193
- cache : "weakref.ReferenceType[LruCache]" ,
194
+ cache : "weakref.ReferenceType[LruCache[KT, VT] ]" ,
194
195
clock : Clock ,
195
196
callbacks : Collection [Callable [[], None ]] = (),
196
197
prune_unread_entries : bool = True ,
@@ -290,6 +291,12 @@ def move_to_front(self, clock: Clock, cache_list_root: ListNode) -> None:
290
291
self ._global_list_node .update_last_access (clock )
291
292
292
293
294
+ class _Sentinel (Enum ):
295
+ # defining a sentinel in this way allows mypy to correctly handle the
296
+ # type of a dictionary lookup.
297
+ sentinel = object ()
298
+
299
+
293
300
class LruCache (Generic [KT , VT ]):
294
301
"""
295
302
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
@@ -302,7 +309,7 @@ def __init__(
302
309
max_size : int ,
303
310
cache_name : Optional [str ] = None ,
304
311
cache_type : Type [Union [dict , TreeCache ]] = dict ,
305
- size_callback : Optional [Callable ] = None ,
312
+ size_callback : Optional [Callable [[ VT ], int ] ] = None ,
306
313
metrics_collection_callback : Optional [Callable [[], None ]] = None ,
307
314
apply_cache_factor_from_config : bool = True ,
308
315
clock : Optional [Clock ] = None ,
@@ -339,7 +346,7 @@ def __init__(
339
346
else :
340
347
real_clock = clock
341
348
342
- cache = cache_type ()
349
+ cache : Union [ Dict [ KT , _Node [ KT , VT ]], TreeCache ] = cache_type ()
343
350
self .cache = cache # Used for introspection.
344
351
self .apply_cache_factor_from_config = apply_cache_factor_from_config
345
352
@@ -374,7 +381,7 @@ def __init__(
374
381
# creating more each time we create a `_Node`.
375
382
weak_ref_to_self = weakref .ref (self )
376
383
377
- list_root = ListNode [_Node ].create_root_node ()
384
+ list_root = ListNode [_Node [ KT , VT ] ].create_root_node ()
378
385
379
386
lock = threading .Lock ()
380
387
@@ -422,7 +429,7 @@ def cache_len() -> int:
422
429
def add_node (
423
430
key : KT , value : VT , callbacks : Collection [Callable [[], None ]] = ()
424
431
) -> None :
425
- node = _Node (
432
+ node : _Node [ KT , VT ] = _Node (
426
433
list_root ,
427
434
key ,
428
435
value ,
@@ -439,10 +446,10 @@ def add_node(
439
446
if caches .TRACK_MEMORY_USAGE and metrics :
440
447
metrics .inc_memory_usage (node .memory )
441
448
442
- def move_node_to_front (node : _Node ) -> None :
449
+ def move_node_to_front (node : _Node [ KT , VT ] ) -> None :
443
450
node .move_to_front (real_clock , list_root )
444
451
445
- def delete_node (node : _Node ) -> int :
452
+ def delete_node (node : _Node [ KT , VT ] ) -> int :
446
453
node .drop_from_lists ()
447
454
448
455
deleted_len = 1
@@ -496,7 +503,7 @@ def cache_get(
496
503
497
504
@synchronized
498
505
def cache_set (
499
- key : KT , value : VT , callbacks : Iterable [Callable [[], None ]] = ()
506
+ key : KT , value : VT , callbacks : Collection [Callable [[], None ]] = ()
500
507
) -> None :
501
508
node = cache .get (key , None )
502
509
if node is not None :
@@ -590,8 +597,6 @@ def cache_clear() -> None:
590
597
def cache_contains (key : KT ) -> bool :
591
598
return key in cache
592
599
593
- self .sentinel = object ()
594
-
595
600
# make sure that we clear out any excess entries after we get resized.
596
601
self ._on_resize = evict
597
602
@@ -608,18 +613,18 @@ def cache_contains(key: KT) -> bool:
608
613
self .clear = cache_clear
609
614
610
615
def __getitem__ (self , key : KT ) -> VT :
611
- result = self .get (key , self .sentinel )
612
- if result is self .sentinel :
616
+ result = self .get (key , _Sentinel .sentinel )
617
+ if result is _Sentinel .sentinel :
613
618
raise KeyError ()
614
619
else :
615
- return cast ( VT , result )
620
+ return result
616
621
617
622
def __setitem__ (self , key : KT , value : VT ) -> None :
618
623
self .set (key , value )
619
624
620
625
def __delitem__ (self , key : KT , value : VT ) -> None :
621
- result = self .pop (key , self .sentinel )
622
- if result is self .sentinel :
626
+ result = self .pop (key , _Sentinel .sentinel )
627
+ if result is _Sentinel .sentinel :
623
628
raise KeyError ()
624
629
625
630
def __len__ (self ) -> int :
0 commit comments