Skip to content

gh-104909: Implement conditional stack effects for macros #105748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions Tools/cases_generator/generate_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,12 @@ def block(self, head: str):

def stack_adjust(
self,
diff: int,
input_effects: list[StackEffect],
output_effects: list[StackEffect],
):
# TODO: Get rid of 'diff' parameter
shrink, isym = list_effect_size(input_effects)
grow, osym = list_effect_size(output_effects)
diff += grow - shrink
diff = grow - shrink
if isym and isym != osym:
self.emit(f"STACK_SHRINK({isym});")
if diff < 0:
Expand Down Expand Up @@ -355,7 +353,6 @@ def write(self, out: Formatter) -> None:

# Write net stack growth/shrinkage
out.stack_adjust(
0,
[ieff for ieff in self.input_effects],
[oeff for oeff in self.output_effects],
)
Expand Down Expand Up @@ -848,9 +845,14 @@ def stack_analysis(

Ignore cache effects.

Return the list of variable names and the initial stack pointer.
Return the list of variables (as StackEffects) and the initial stack pointer.
"""
lowest = current = highest = 0
conditions: dict[int, str] = {} # Indexed by 'current'.
last_instr: Instruction | None = None
for thing in components:
if isinstance(thing, Instruction):
last_instr = thing
for thing in components:
match thing:
case Instruction() as instr:
Expand All @@ -863,9 +865,24 @@ def stack_analysis(
"which are not supported in macro instructions",
instr.inst, # TODO: Pass name+location of macro
)
if any(eff.cond for eff in instr.input_effects):
self.error(
f"Instruction {instr.name!r} has conditional input stack effect, "
"which are not supported in macro instructions",
instr.inst, # TODO: Pass name+location of macro
)
if any(eff.cond for eff in instr.output_effects) and instr is not last_instr:
self.error(
f"Instruction {instr.name!r} has conditional output stack effect, "
"but is not the last instruction in a macro",
instr.inst, # TODO: Pass name+location of macro
)
current -= len(instr.input_effects)
lowest = min(lowest, current)
current += len(instr.output_effects)
for eff in instr.output_effects:
if eff.cond:
conditions[current] = eff.cond
current += 1
highest = max(highest, current)
case parser.CacheEffect():
pass
Expand All @@ -874,9 +891,9 @@ def stack_analysis(
# At this point, 'current' is the net stack effect,
# and 'lowest' and 'highest' are the extremes.
# Note that 'lowest' may be negative.
# TODO: Reverse the numbering.
stack = [
StackEffect(f"_tmp_{i+1}", "") for i in reversed(range(highest - lowest))
StackEffect(f"_tmp_{i}", "", conditions.get(highest - i, ""))
for i in reversed(range(1, highest - lowest + 1))
]
return stack, -lowest

Expand Down Expand Up @@ -908,15 +925,17 @@ def effect_str(effects: list[StackEffect]) -> str:
low = 0
sp = 0
high = 0
pushed_symbolic: list[str] = []
for comp in parts:
for effect in comp.instr.input_effects:
assert not effect.cond, effect
assert not effect.size, effect
sp -= 1
low = min(low, sp)
for effect in comp.instr.output_effects:
assert not effect.cond, effect
assert not effect.size, effect
if effect.cond:
pushed_symbolic.append(maybe_parenthesize(f"{maybe_parenthesize(effect.cond)} ? 1 : 0"))
sp += 1
high = max(sp, high)
if high != max(0, sp):
Expand All @@ -926,7 +945,8 @@ def effect_str(effects: list[StackEffect]) -> str:
# calculations to use the micro ops.
self.error("Macro has virtual stack growth", thing)
popped = str(-low)
pushed = str(sp - low)
pushed_symbolic.append(str(sp - low - len(pushed_symbolic)))
pushed = " + ".join(pushed_symbolic)
case parser.Pseudo():
instr = self.pseudo_instrs[thing.name]
popped = pushed = None
Expand Down Expand Up @@ -1203,16 +1223,23 @@ def wrap_macro(self, mac: MacroInstruction):
with self.out.block(f"TARGET({mac.name})"):
if mac.predicted:
self.out.emit(f"PREDICTED({mac.name});")
for i, var in reversed(list(enumerate(mac.stack))):

# The input effects should have no conditionals.
# Only the output effects do (for now).
ieffects = [
StackEffect(eff.name, eff.type) if eff.cond else eff
for eff in mac.stack
]

for i, var in reversed(list(enumerate(ieffects))):
src = None
if i < mac.initial_sp:
src = StackEffect(f"stack_pointer[-{mac.initial_sp - i}]", "")
self.out.declare(var, src)

yield

# TODO: Use slices of mac.stack instead of numeric values
self.out.stack_adjust(mac.final_sp - mac.initial_sp, [], [])
self.out.stack_adjust(ieffects[:mac.initial_sp], mac.stack[:mac.final_sp])

for i, var in enumerate(reversed(mac.stack[: mac.final_sp]), 1):
dst = StackEffect(f"stack_pointer[-{i}]", "")
Expand Down
40 changes: 40 additions & 0 deletions Tools/cases_generator/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,43 @@ def test_cond_effect():
}
"""
run_cases_test(input, output)

def test_macro_cond_effect():
input = """
op(A, (left, middle, right --)) {
# Body of A
}
op(B, (-- deep, extra if (oparg), res)) {
# Body of B
}
macro(M) = A + B;
"""
output = """
TARGET(M) {
PyObject *_tmp_1 = stack_pointer[-1];
PyObject *_tmp_2 = stack_pointer[-2];
PyObject *_tmp_3 = stack_pointer[-3];
{
PyObject *right = _tmp_1;
PyObject *middle = _tmp_2;
PyObject *left = _tmp_3;
# Body of A
}
{
PyObject *deep;
PyObject *extra = NULL;
PyObject *res;
# Body of B
_tmp_3 = deep;
if (oparg) { _tmp_2 = extra; }
_tmp_1 = res;
}
STACK_SHRINK(1);
STACK_GROW((oparg ? 1 : 0));
stack_pointer[-1] = _tmp_1;
if (oparg) { stack_pointer[-2] = _tmp_2; }
stack_pointer[-3] = _tmp_3;
DISPATCH();
}
"""
run_cases_test(input, output)