Skip to content

Commit c99b9ad

Browse files
committed
CABI: fix flatten_functype() to match canon_lift() and remove extraneous core arg from dtor
1 parent 42e78f3 commit c99b9ad

File tree

3 files changed

+51
-17
lines changed

3 files changed

+51
-17
lines changed

design/mvp/CanonicalABI.md

+16-5
Original file line numberDiff line numberDiff line change
@@ -1629,8 +1629,10 @@ def flatten_functype(opts, ft, context):
16291629
else:
16301630
match context:
16311631
case 'lift':
1632-
flat_params = []
1633-
flat_results = []
1632+
if opts.callback:
1633+
flat_results = ['i32']
1634+
else:
1635+
flat_results = []
16341636
case 'lower':
16351637
if len(flat_params) > 1:
16361638
flat_params = ['i32']
@@ -2077,16 +2079,21 @@ Based on this, `canon_lift` is defined:
20772079
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_block = default_on_block):
20782080
task = Task(opts, inst, ft, caller, on_return, on_block)
20792081
flat_args = await task.enter(on_start)
2082+
flat_ft = flatten_functype(opts, ft, 'lift')
2083+
assert(types_match_values(flat_ft.params, flat_args))
20802084
if opts.sync:
20812085
flat_results = await call_and_trap_on_throw(callee, task, flat_args)
2086+
assert(types_match_values(flat_ft.results, flat_results))
20822087
task.return_(flat_results)
20832088
if opts.post_return is not None:
20842089
[] = await call_and_trap_on_throw(opts.post_return, task, flat_results)
20852090
else:
20862091
if not opts.callback:
20872092
[] = await call_and_trap_on_throw(callee, task, flat_args)
2093+
assert(types_match_values(flat_ft.results, []))
20882094
else:
20892095
[packed_ctx] = await call_and_trap_on_throw(callee, task, flat_args)
2096+
assert(types_match_values(flat_ft.results, [packed_ctx]))
20902097
while packed_ctx != 0:
20912098
is_yield = bool(packed_ctx & 1)
20922099
ctx = packed_ctx & ~1
@@ -2144,6 +2151,8 @@ Given this, `canon_lower` is defined:
21442151
```python
21452152
async def canon_lower(opts, ft, callee, task, flat_args):
21462153
trap_if(not task.inst.may_leave)
2154+
flat_ft = flatten_functype(opts, ft, 'lower')
2155+
assert(types_match_values(flat_ft.params, flat_args))
21472156
subtask = Subtask(opts, ft, task, flat_args)
21482157
if opts.sync:
21492158
await task.call_sync(callee, task, subtask.on_start, subtask.on_return)
@@ -2162,6 +2171,7 @@ async def canon_lower(opts, ft, callee, task, flat_args):
21622171
flat_results = [i | (int(subtask.state) << 30)]
21632172
case Returned():
21642173
flat_results = [0]
2174+
assert(types_match_values(flat_ft.results, flat_results))
21652175
return flat_results
21662176
```
21672177
In the asynchronous case, if `do_call` blocks before `Subtask.finish`
@@ -2252,7 +2262,7 @@ async def canon_resource_drop(rt, sync, task, i):
22522262
callee_opts = CanonicalOptions(sync = rt.dtor_sync, callback = rt.dtor_callback)
22532263
ft = FuncType([U32Type()],[])
22542264
callee = partial(canon_lift, callee_opts, rt.impl, ft, rt.dtor)
2255-
flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep, 0])
2265+
flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep])
22562266
else:
22572267
task.trap_if_on_the_stack(rt.impl)
22582268
else:
@@ -2384,12 +2394,13 @@ Calling `$f` does a non-blocking check for whether an event is already
23842394
available, returning whether or not there was such an event as a boolean and,
23852395
if there was an event, storing the `i32` event+payload pair as an outparam.
23862396
```python
2387-
async def canon_task_poll(task, ptr):
2397+
async def canon_task_poll(opts, task, ptr):
23882398
trap_if(not task.inst.may_leave)
23892399
ret = await task.poll()
23902400
if ret is None:
23912401
return [0]
2392-
store(task, ret, TupleType([U32Type(), U32Type()]), ptr)
2402+
cx = CallContext(opts, task.inst, task)
2403+
store(cx, ret, TupleType([U32Type(), U32Type()]), ptr)
23932404
return [1]
23942405
```
23952406
Note that the `await` of `task.poll` indicates that `task.poll` can yield to

design/mvp/canonical-abi/definitions.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ class CoreFuncType(CoreExternType):
5252
def __eq__(self, other):
5353
return self.params == other.params and self.results == other.results
5454

55+
def types_match_values(ts, vs):
56+
if len(ts) != len(vs):
57+
return False
58+
return all(type_matches_value(t, v) for t,v in zip(ts, vs))
59+
60+
def type_matches_value(t, v):
61+
match t:
62+
case 'i32' | 'i64': return type(v) == int
63+
case 'f32' | 'f64': return type(v) == float
64+
assert(False)
65+
5566
@dataclass
5667
class CoreMemoryType(CoreExternType):
5768
initial: list[int]
@@ -1138,8 +1149,10 @@ def flatten_functype(opts, ft, context):
11381149
else:
11391150
match context:
11401151
case 'lift':
1141-
flat_params = []
1142-
flat_results = []
1152+
if opts.callback:
1153+
flat_results = ['i32']
1154+
else:
1155+
flat_results = []
11431156
case 'lower':
11441157
if len(flat_params) > 1:
11451158
flat_params = ['i32']
@@ -1421,16 +1434,21 @@ def lower_heap_values(cx, vs, ts, out_param):
14211434
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_return, on_block = default_on_block):
14221435
task = Task(opts, inst, ft, caller, on_return, on_block)
14231436
flat_args = await task.enter(on_start)
1437+
flat_ft = flatten_functype(opts, ft, 'lift')
1438+
assert(types_match_values(flat_ft.params, flat_args))
14241439
if opts.sync:
14251440
flat_results = await call_and_trap_on_throw(callee, task, flat_args)
1441+
assert(types_match_values(flat_ft.results, flat_results))
14261442
task.return_(flat_results)
14271443
if opts.post_return is not None:
14281444
[] = await call_and_trap_on_throw(opts.post_return, task, flat_results)
14291445
else:
14301446
if not opts.callback:
14311447
[] = await call_and_trap_on_throw(callee, task, flat_args)
1448+
assert(types_match_values(flat_ft.results, []))
14321449
else:
14331450
[packed_ctx] = await call_and_trap_on_throw(callee, task, flat_args)
1451+
assert(types_match_values(flat_ft.results, [packed_ctx]))
14341452
while packed_ctx != 0:
14351453
is_yield = bool(packed_ctx & 1)
14361454
ctx = packed_ctx & ~1
@@ -1452,6 +1470,8 @@ async def call_and_trap_on_throw(callee, task, args):
14521470

14531471
async def canon_lower(opts, ft, callee, task, flat_args):
14541472
trap_if(not task.inst.may_leave)
1473+
flat_ft = flatten_functype(opts, ft, 'lower')
1474+
assert(types_match_values(flat_ft.params, flat_args))
14551475
subtask = Subtask(opts, ft, task, flat_args)
14561476
if opts.sync:
14571477
await task.call_sync(callee, task, subtask.on_start, subtask.on_return)
@@ -1470,6 +1490,7 @@ async def do_call(on_block):
14701490
flat_results = [i | (int(subtask.state) << 30)]
14711491
case Returned():
14721492
flat_results = [0]
1493+
assert(types_match_values(flat_ft.results, flat_results))
14731494
return flat_results
14741495

14751496
### `canon resource.new`
@@ -1499,7 +1520,7 @@ async def canon_resource_drop(rt, sync, task, i):
14991520
callee_opts = CanonicalOptions(sync = rt.dtor_sync, callback = rt.dtor_callback)
15001521
ft = FuncType([U32Type()],[])
15011522
callee = partial(canon_lift, callee_opts, rt.impl, ft, rt.dtor)
1502-
flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep, 0])
1523+
flat_results = await canon_lower(caller_opts, ft, callee, task, [h.rep])
15031524
else:
15041525
task.trap_if_on_the_stack(rt.impl)
15051526
else:

design/mvp/canonical-abi/run_tests.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ async def dtor(task, args):
419419
nonlocal dtor_value
420420
assert(len(args) == 1)
421421
dtor_value = args[0]
422+
return []
422423

423424
rt = ResourceType(ComponentInstance(), dtor) # usable in imports and exports
424425
inst = ComponentInstance()
@@ -558,7 +559,7 @@ async def core_blocking_producer(task, args):
558559
async def consumer(task, args):
559560
[b] = args
560561
ptr = consumer_heap.realloc(0, 0, 1, 1)
561-
[ret] = await canon_lower(consumer_opts, eager_ft, eager_callee, task, [0, ptr])
562+
[ret] = await canon_lower(consumer_opts, eager_ft, eager_callee, task, [ptr])
562563
assert(ret == 0)
563564
u8 = consumer_heap.memory[ptr]
564565
assert(u8 == 43)
@@ -596,6 +597,7 @@ async def dtor(task, args):
596597
assert(len(args) == 1)
597598
await task.on_block(dtor_fut)
598599
dtor_value = args[0]
600+
return []
599601
rt = ResourceType(producer_inst, dtor)
600602

601603
[i] = await canon_resource_new(rt, task, 50)
@@ -652,10 +654,10 @@ async def core_producer_pre(fut, task, args):
652654
async def consumer(task, args):
653655
assert(len(args) == 0)
654656

655-
[ret] = await canon_lower(opts, producer_ft, producer1, task, [0, 0])
657+
[ret] = await canon_lower(opts, producer_ft, producer1, task, [])
656658
assert(ret == (1 | (CallState.STARTED << 30)))
657659

658-
[ret] = await canon_lower(opts, producer_ft, producer2, task, [0, 0])
660+
[ret] = await canon_lower(opts, producer_ft, producer2, task, [])
659661
assert(ret == (2 | (CallState.STARTED << 30)))
660662

661663
fut1.set_result(None)
@@ -730,10 +732,10 @@ async def producer2_core(task, args):
730732
async def consumer(task, args):
731733
assert(len(args) == 0)
732734

733-
[ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [0, 0])
735+
[ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [])
734736
assert(ret == (1 | (CallState.STARTED << 30)))
735737

736-
[ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [0, 0])
738+
[ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [])
737739
assert(ret == (2 | (CallState.STARTING << 30)))
738740

739741
assert(await task.poll() is None)
@@ -808,10 +810,10 @@ async def producer2_core(task, args):
808810
async def consumer(task, args):
809811
assert(len(args) == 0)
810812

811-
[ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [0, 0])
813+
[ret] = await canon_lower(consumer_opts, producer_ft, producer1, task, [])
812814
assert(ret == (1 | (CallState.RETURNED << 30)))
813815

814-
[ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [0, 0])
816+
[ret] = await canon_lower(consumer_opts, producer_ft, producer2, task, [])
815817
assert(ret == (2 | (CallState.STARTING << 30)))
816818

817819
assert(await task.poll() is None)
@@ -872,9 +874,9 @@ async def core_hostcall_pre(fut, task, args):
872874
lower_opts.sync = False
873875

874876
async def core_func(task, args):
875-
[ret] = await canon_lower(lower_opts, ft, hostcall1, task, [0,0])
877+
[ret] = await canon_lower(lower_opts, ft, hostcall1, task, [])
876878
assert(ret == (1 | (CallState.STARTED << 30)))
877-
[ret] = await canon_lower(lower_opts, ft, hostcall2, task, [0,0])
879+
[ret] = await canon_lower(lower_opts, ft, hostcall2, task, [])
878880
assert(ret == (2 | (CallState.STARTED << 30)))
879881

880882
fut1.set_result(None)

0 commit comments

Comments
 (0)