|
15 | 15 |
|
16 | 16 | from abc import abstractmethod
|
17 | 17 | from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar, cast
|
| 18 | +from typing_extensions import Final |
18 | 19 |
|
19 | 20 | from mypy_extensions import mypyc_attr, trait
|
20 | 21 |
|
@@ -417,3 +418,158 @@ def visit_type_alias_type(self, t: TypeAliasType) -> T:
|
417 | 418 | def query_types(self, types: Iterable[Type]) -> T:
|
418 | 419 | """Perform a query for a list of types using the strategy to combine the results."""
|
419 | 420 | return self.strategy([t.accept(self) for t in types])
|
| 421 | + |
| 422 | + |
| 423 | +# Return True if at least one type component returns True |
| 424 | +ANY_STRATEGY: Final = 0 |
| 425 | +# Return True if no type component returns False |
| 426 | +ALL_STRATEGY: Final = 1 |
| 427 | + |
| 428 | + |
| 429 | +class BoolTypeQuery(SyntheticTypeVisitor[bool]): |
| 430 | + """Visitor for performing recursive queries of types with a bool result. |
| 431 | +
|
| 432 | + Use TypeQuery if you need non-bool results. |
| 433 | +
|
| 434 | + 'strategy' is used to combine results for a series of types. It must |
| 435 | + be ANY_STRATEGY or ALL_STRATEGY. |
| 436 | +
|
| 437 | + Note: This visitor keeps an internal state (tracks type aliases to avoid |
| 438 | + recursion), so it should *never* be re-used for querying different types |
| 439 | + unless you call reset() first. |
| 440 | + """ |
| 441 | + |
| 442 | + def __init__(self, strategy: int) -> None: |
| 443 | + self.strategy = strategy |
| 444 | + if strategy == ANY_STRATEGY: |
| 445 | + self.default = False |
| 446 | + else: |
| 447 | + assert strategy == ALL_STRATEGY |
| 448 | + self.default = True |
| 449 | + # Keep track of the type aliases already visited. This is needed to avoid |
| 450 | + # infinite recursion on types like A = Union[int, List[A]]. An empty set is |
| 451 | + # represented as None as a micro-optimization. |
| 452 | + self.seen_aliases: set[TypeAliasType] | None = None |
| 453 | + # By default, we eagerly expand type aliases, and query also types in the |
| 454 | + # alias target. In most cases this is a desired behavior, but we may want |
| 455 | + # to skip targets in some cases (e.g. when collecting type variables). |
| 456 | + self.skip_alias_target = False |
| 457 | + |
| 458 | + def reset(self) -> None: |
| 459 | + """Clear mutable state (but preserve strategy). |
| 460 | +
|
| 461 | + This *must* be called if you want to reuse the visitor. |
| 462 | + """ |
| 463 | + self.seen_aliases = None |
| 464 | + |
| 465 | + def visit_unbound_type(self, t: UnboundType) -> bool: |
| 466 | + return self.query_types(t.args) |
| 467 | + |
| 468 | + def visit_type_list(self, t: TypeList) -> bool: |
| 469 | + return self.query_types(t.items) |
| 470 | + |
| 471 | + def visit_callable_argument(self, t: CallableArgument) -> bool: |
| 472 | + return t.typ.accept(self) |
| 473 | + |
| 474 | + def visit_any(self, t: AnyType) -> bool: |
| 475 | + return self.default |
| 476 | + |
| 477 | + def visit_uninhabited_type(self, t: UninhabitedType) -> bool: |
| 478 | + return self.default |
| 479 | + |
| 480 | + def visit_none_type(self, t: NoneType) -> bool: |
| 481 | + return self.default |
| 482 | + |
| 483 | + def visit_erased_type(self, t: ErasedType) -> bool: |
| 484 | + return self.default |
| 485 | + |
| 486 | + def visit_deleted_type(self, t: DeletedType) -> bool: |
| 487 | + return self.default |
| 488 | + |
| 489 | + def visit_type_var(self, t: TypeVarType) -> bool: |
| 490 | + return self.query_types([t.upper_bound] + t.values) |
| 491 | + |
| 492 | + def visit_param_spec(self, t: ParamSpecType) -> bool: |
| 493 | + return self.default |
| 494 | + |
| 495 | + def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool: |
| 496 | + return self.default |
| 497 | + |
| 498 | + def visit_unpack_type(self, t: UnpackType) -> bool: |
| 499 | + return self.query_types([t.type]) |
| 500 | + |
| 501 | + def visit_parameters(self, t: Parameters) -> bool: |
| 502 | + return self.query_types(t.arg_types) |
| 503 | + |
| 504 | + def visit_partial_type(self, t: PartialType) -> bool: |
| 505 | + return self.default |
| 506 | + |
| 507 | + def visit_instance(self, t: Instance) -> bool: |
| 508 | + return self.query_types(t.args) |
| 509 | + |
| 510 | + def visit_callable_type(self, t: CallableType) -> bool: |
| 511 | + # FIX generics |
| 512 | + # Avoid allocating any objects here as an optimization. |
| 513 | + args = self.query_types(t.arg_types) |
| 514 | + ret = t.ret_type.accept(self) |
| 515 | + if self.strategy == ANY_STRATEGY: |
| 516 | + return args or ret |
| 517 | + else: |
| 518 | + return args and ret |
| 519 | + |
| 520 | + def visit_tuple_type(self, t: TupleType) -> bool: |
| 521 | + return self.query_types(t.items) |
| 522 | + |
| 523 | + def visit_typeddict_type(self, t: TypedDictType) -> bool: |
| 524 | + return self.query_types(list(t.items.values())) |
| 525 | + |
| 526 | + def visit_raw_expression_type(self, t: RawExpressionType) -> bool: |
| 527 | + return self.default |
| 528 | + |
| 529 | + def visit_literal_type(self, t: LiteralType) -> bool: |
| 530 | + return self.default |
| 531 | + |
| 532 | + def visit_star_type(self, t: StarType) -> bool: |
| 533 | + return t.type.accept(self) |
| 534 | + |
| 535 | + def visit_union_type(self, t: UnionType) -> bool: |
| 536 | + return self.query_types(t.items) |
| 537 | + |
| 538 | + def visit_overloaded(self, t: Overloaded) -> bool: |
| 539 | + return self.query_types(t.items) # type: ignore[arg-type] |
| 540 | + |
| 541 | + def visit_type_type(self, t: TypeType) -> bool: |
| 542 | + return t.item.accept(self) |
| 543 | + |
| 544 | + def visit_ellipsis_type(self, t: EllipsisType) -> bool: |
| 545 | + return self.default |
| 546 | + |
| 547 | + def visit_placeholder_type(self, t: PlaceholderType) -> bool: |
| 548 | + return self.query_types(t.args) |
| 549 | + |
| 550 | + def visit_type_alias_type(self, t: TypeAliasType) -> bool: |
| 551 | + # Skip type aliases already visited types to avoid infinite recursion. |
| 552 | + # TODO: Ideally we should fire subvisitors here (or use caching) if we care |
| 553 | + # about duplicates. |
| 554 | + if self.seen_aliases is None: |
| 555 | + self.seen_aliases = set() |
| 556 | + elif t in self.seen_aliases: |
| 557 | + return self.default |
| 558 | + self.seen_aliases.add(t) |
| 559 | + if self.skip_alias_target: |
| 560 | + return self.query_types(t.args) |
| 561 | + return get_proper_type(t).accept(self) |
| 562 | + |
| 563 | + def query_types(self, types: list[Type] | tuple[Type, ...]) -> bool: |
| 564 | + """Perform a query for a sequence of types using the strategy to combine the results.""" |
| 565 | + # Special-case for lists and tuples to allow mypyc to produce better code. |
| 566 | + if isinstance(types, list): |
| 567 | + if self.strategy == ANY_STRATEGY: |
| 568 | + return any(t.accept(self) for t in types) |
| 569 | + else: |
| 570 | + return all(t.accept(self) for t in types) |
| 571 | + else: |
| 572 | + if self.strategy == ANY_STRATEGY: |
| 573 | + return any(t.accept(self) for t in types) |
| 574 | + else: |
| 575 | + return all(t.accept(self) for t in types) |
0 commit comments