Expand Up @@ -4,6 +4,7 @@ from collections import defaultdict from typing import Iterable, Sequence from typing_extensions import TypeAlias as _TypeAlias from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op from mypy.expandtype import expand_type Expand All @@ -27,6 +28,10 @@ ) from mypy.typestate import type_state Bounds: _TypeAlias = "dict[TypeVarId, set[Type]]" Graph: _TypeAlias = "set[tuple[TypeVarId, TypeVarId]]" Solutions: _TypeAlias = "dict[TypeVarId, Type | None]" def solve_constraints( original_vars: Sequence[TypeVarLikeType], Expand All @@ -36,20 +41,22 @@ def solve_constraints( ) -> tuple[list[Type | None], list[TypeVarLikeType]]: """Solve type constraints. Return the best type(s) for type variables; each type can be None if the value of the variable could not be solved. Return the best type(s) for type variables; each type can be None if the value of the variable could not be solved. If a variable has no constraints, if strict=True then arbitrarily pick NoneType as the value of the type variable. If strict=False, pick AnyType. pick UninhabitedType as the value of the type variable. If strict=False, pick AnyType. If allow_polymorphic=True, then use the full algorithm that can potentially return free type variables in solutions (these require special care when applying). Otherwise, use a simplified algorithm that just solves each type variable individually if possible. """ vars = [tv.id for tv in original_vars] if not vars: return [], [] originals = {tv.id: tv for tv in original_vars} extra_vars: list[TypeVarId] = [] # Get additional variables from generic actuals. # Get additionaltype variables from generic actuals. for c in constraints: extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) originals.update({v.id: v for v in c.extra_tvars if v.id not in originals}) Expand All @@ -66,7 +73,7 @@ def solve_constraints( if allow_polymorphic: if constraints: solutions, free_vars =solve_non_linear ( solutions, free_vars =solve_with_dependent ( vars + extra_vars, constraints, vars, originals ) else: Expand All @@ -80,7 +87,7 @@ def solve_constraints( continue lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] uppers = [c.target for c in cs if c.op == SUBTYPE_OF] solution = solve_one(lowers, uppers, [] ) solution = solve_one(lowers, uppers) # Do not leak type variables in non-polymorphic solutions. if solution is None or not get_vars( Expand All @@ -104,20 +111,20 @@ def solve_constraints( return res, [originals[tv] for tv in free_vars] defsolve_non_linear ( defsolve_with_dependent ( vars: list[TypeVarId], constraints: list[Constraint], original_vars: list[TypeVarId], originals: dict[TypeVarId, TypeVarLikeType], ) -> tuple[dict[TypeVarId, Type | None] , list[TypeVarId]]: """Solve set of constraints that mayinclude non-linear ones , like T <: List[S]. ) -> tuple[Solutions , list[TypeVarId]]: """Solve set of constraints that maydepend on each other , like T <: List[S]. The whole algorithm consists of five steps: * Propagate via linear constraintsto get all possible constraintsfor each variable * Propagate via linear constraintsand use secondary constraintsto get transitive closure * Find dependencies between type variables, group them in SCCs, and sort topologically * Check all SCC are intrinsically linear, we can't solve (express) T <: List[T] * Checkthat all SCC are intrinsically linear, we can't solve (express) T <: List[T] * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) * Solve constraints iteratively starting from leafs, updatingtargets after each step. * Solve constraints iteratively starting from leafs, updatingbounds after each step. """ graph, lowers, uppers = transitive_closure(vars, constraints) Expand All @@ -129,14 +136,12 @@ def solve_non_linear( free_vars = [] for scc in raw_batches[0]: # If all constrain targets in this SCC are type variables within the # same SCC then the only meaningful solution we can express, is that # each variable is equal to a new free variable. For example if we # have T <: S, S <: U, we deduce: T = S = U = <free>. # If there are no bounds on this SCC, then the only meaningful solution we can # express, is that each variable is equal to a new free variable. For example, # if we have T <: S, S <: U, we deduce: T = S = U = <free>. if all(not lowers[tv] and not uppers[tv] for tv in scc): # For convenience with current type application machinery, we randomly # choose one of the existing type variables in SCC and designate it as free # instead of defining a new type variable as a common solution. # For convenience with current type application machinery, we use a stable # choice that prefers the original type variables (not polymorphic ones) in SCC. # TODO: be careful about upper bounds (or values) when introducing free vars. free_vars.append(sorted(scc, key=lambda x: (x not in original_vars, x.raw_id))[0]) Expand All @@ -159,32 +164,29 @@ def solve_non_linear( solutions: dict[TypeVarId, Type | None] = {} for flat_batch in batches: res = solve_iteratively(flat_batch, graph, lowers, uppers, free_vars ) res = solve_iteratively(flat_batch, graph, lowers, uppers) solutions.update(res) return solutions, free_vars def solve_iteratively( batch: list[TypeVarId], graph: set[tuple[TypeVarId, TypeVarId]], lowers: dict[TypeVarId, set[Type]], uppers: dict[TypeVarId, set[Type]], free_vars: list[TypeVarId], ) -> dict[TypeVarId, Type | None]: """Solve constraints sequentially, updating constraint targets after each step. We solve for type variables that appear in `batch`. If a constraint target is not constant (i.e. constraint looks like T :> F[S, ...]), we substitute solutions found so far in the target F[S, ...]. This way we can gradually solve for all variables in the batch taking one solvable variable at a time (i.e. such a variable that has at least one constant bound). Importantly, variables in free_vars are considered constants, so for example if we have just one initial constraint T <: List[S], we will have two SCCs {T} and {S}, then we first designate S as free, and therefore T = List[S] is a valid solution for T. batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds ) -> Solutions: """Solve transitive closure sequentially, updating upper/lower bounds after each step. Transitive closure is represented as a linear graph plus lower/upper bounds for each type variable, see transitive_closure() docstring for details. We solve for type variables that appear in `batch`. If a bound is not constant (i.e. it looks like T :> F[S, ...]), we substitute solutions found so far in the target F[S, ...] after solving the batch. Importantly, after solving each variable in a batch, we move it from linear graph to upper/lower bounds, this way we can guarantee consistency of solutions (see comment below for an example when this is important). """ solutions = {} s_batch = set(batch) not_allowed_vars = [v for v in batch if v not in free_vars] while s_batch: for tv in sorted(s_batch, key=lambda x: x.raw_id): if lowers[tv] or uppers[tv]: Expand All @@ -194,7 +196,7 @@ def solve_iteratively( break # Solve each solvable type variable separately. s_batch.remove(solvable_tv) result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars ) result = solve_one(lowers[solvable_tv], uppers[solvable_tv]) solutions[solvable_tv] = result if result is None: # TODO: support backtracking lower/upper bound choices and order within SCCs. Expand Down Expand Up @@ -227,9 +229,7 @@ def solve_iteratively( return solutions def solve_one( lowers: Iterable[Type], uppers: Iterable[Type], not_allowed_vars: list[TypeVarId] ) -> Type | None: def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" bottom: Type | None = None top: Type | None = None Expand All @@ -239,10 +239,6 @@ def solve_one( # bounds based on constraints. Note that we assume that the constraint # targets do not have constraint references. for target in lowers: # There may be multiple steps needed to solve all vars within a # (linear) SCC. We ignore targets pointing to not yet solved vars. if get_vars(target, not_allowed_vars): continue if bottom is None: bottom = target else: Expand All @@ -254,9 +250,6 @@ def solve_one( bottom = join_types(bottom, target) for target in uppers: # Same as above. if get_vars(target, not_allowed_vars): continue if top is None: top = target else: Expand Down Expand Up @@ -291,6 +284,7 @@ def normalize_constraints( This includes two things currently: * Complement T :> S by S <: T * Remove strict duplicates * Remove constrains for unrelated variables """ res = constraints.copy() for c in constraints: Expand All @@ -301,23 +295,29 @@ def normalize_constraints( def transitive_closure( tvars: list[TypeVarId], constraints: list[Constraint] ) -> tuple[ set[tuple[TypeVarId, TypeVarId]], dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]] ]: ) -> tuple[Graph, Bounds, Bounds]: """Find transitive closure for given constraints on type variables. Transitive closure gives maximal set of lower/upper bounds for each type variable, such that we cannot deduce any further bounds by chaining other existing bounds. The transitive closure is represented by: * A set of lower and upper bounds for each type variable, where only constant and non-linear terms are included in the bounds. * A graph of linear constraints between type variables (represented as a set of pairs) Such separation simplifies reasoning, and allows an efficient and simple incremental transitive closure algorithm that we use here. For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive closure is given by: * {} <: T <: {S, U, int} * {T} <: S <: {U, int} * {T, S} <: U <: {int} * {} <: T <: {int} * {} <: S <: {int} * {} <: U <: {int} * {T <: S, S <: U, T <: U} """ uppers:dict[TypeVarId, set[Type]] = defaultdict(set) lowers:dict[TypeVarId, set[Type]] = defaultdict(set) graph:set[tuple[TypeVarId, TypeVarId]] = {(tv, tv) for tv in tvars} uppers:Bounds = defaultdict(set) lowers:Bounds = defaultdict(set) graph:Graph = {(tv, tv) for tv in tvars} remaining = set(constraints) while remaining: Expand All @@ -329,10 +329,9 @@ def transitive_closure( lower, upper = c.target.id, c.type_var if (lower, upper) in graph: continue extras = {graph | = { (l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph } graph |= extras for u in tvars: if (upper, u) in graph: lowers[u] |= lowers[lower] Expand Down Expand Up @@ -364,10 +363,7 @@ def transitive_closure( def compute_dependencies( tvars: list[TypeVarId], graph: set[tuple[TypeVarId, TypeVarId]], lowers: dict[TypeVarId, set[Type]], uppers: dict[TypeVarId, set[Type]], tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds ) -> dict[TypeVarId, list[TypeVarId]]: """Compute dependencies between type variables induced by constraints. Expand All @@ -383,17 +379,14 @@ def compute_dependencies( deps |= get_vars(ut, tvars) for other in tvars: if other == tv: # TODO: is there a value in either skipping or adding trivial deps? continue if (tv, other) in graph or (other, tv) in graph: deps.add(other) res[tv] = list(deps) return res def check_linear( scc: set[TypeVarId], lowers: dict[TypeVarId, set[Type]], uppers: dict[TypeVarId, set[Type]] ) -> bool: def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool: """Check there are only linear constraints between type variables in SCC. Linear are constraints like T <: S (while T <: F[S] are non-linear). Expand Down