diff --git a/crates/threads-xform/src/lib.rs b/crates/threads-xform/src/lib.rs index 7a7cb898f73..5f8361c21b0 100644 --- a/crates/threads-xform/src/lib.rs +++ b/crates/threads-xform/src/lib.rs @@ -155,6 +155,8 @@ impl Config { size: self.thread_stack_size, }; + let _ = module.exports.add("__stack_alloc", stack.alloc); + inject_start(module, &tls, &stack, thread_counter_addr, memory)?; // we expose a `__wbindgen_thread_destroy()` helper function that deallocates stack space. @@ -163,6 +165,16 @@ impl Config { // After calling this function in a given agent, the instance should be considered // "destroyed" and any further invocations into it will trigger UB. This function // should not be called from an agent that cannot block (e.g. the main document thread). + // + // You can also call it from a "leader" agent, passing appropriate values, if said leader + // is in charge of cleaning up after a "follower" agent. In that case: + // - The "appropriate values" are the values of the `__tls_base` and `__stack_alloc` globals + // from the follower thread, after initialization. + // - The leader does _not_ need to block. + // - Similar restrictions apply: the follower thread should be considered unusable afterwards, + // the leader should not call this function with the same set of parameteres twice. + // - Moreover, concurrent calls can lead to UB: the follower could be in the middle of a + // call while the leader is destroying its stack! You should make sure that this cannot happen. inject_destroy(module, &tls, &stack, memory)?; Ok(()) @@ -365,37 +377,63 @@ fn inject_destroy( ) -> Result<(), Error> { let free = find_function(module, "__wbindgen_free")?; - let mut builder = walrus::FunctionBuilder::new(&mut module.types, &[], &[]); + let mut builder = + walrus::FunctionBuilder::new(&mut module.types, &[ValType::I32, ValType::I32], &[]); builder.name("__wbindgen_thread_destroy".into()); let mut body = builder.func_body(); + // if no explicit parameters are passed (i.e. their value is 0) then we assume + // we're being called from the agent that must be destroyed and rely on its globals + let tls_base = module.locals.add(ValType::I32); + let stack_alloc = module.locals.add(ValType::I32); + // Ideally, at this point, we would destroy the values stored in TLS. // We can't really do that without help from the standard library. // See https://github.com/rustwasm/wasm-bindgen/pull/2769#issuecomment-1015775467. - // free the TLS space - body.global_get(tls.base) - .i32_const(tls.size as i32) - .call(free); - - // set tls.base = i32::MIN to trigger invalid memory - body.i32_const(i32::MIN).global_set(tls.base); - - // free the stack callin `__wbindgen_free(stack.alloc, stack.size)` - with_temp_stack(&mut body, memory, stack, |body| { - body.global_get(stack.alloc) - .i32_const(stack.size as i32) - .call(free); - }); + body.local_get(tls_base).if_else( + None, + |body| { + body.local_get(tls_base) + .i32_const(tls.size as i32) + .call(free); + }, + |body| { + body.global_get(tls.base) + .i32_const(tls.size as i32) + .call(free); + + // set tls.base = i32::MIN to trigger invalid memory + body.i32_const(i32::MIN).global_set(tls.base); + }, + ); + + // free the stack calling `__wbindgen_free(stack.alloc, stack.size)` + body.local_get(stack_alloc).if_else( + None, + |body| { + // we're destroying somebody else's stack, so we can use our own + body.local_get(stack_alloc) + .i32_const(stack.size as i32) + .call(free); + }, + |mut body| { + with_temp_stack(&mut body, memory, stack, |body| { + body.global_get(stack.alloc) + .i32_const(stack.size as i32) + .call(free); + }); - // set stack.alloc = 0 to trigger invalid memory - body.i32_const(0).global_set(stack.alloc); + // set stack.alloc = 0 to trigger invalid memory + body.i32_const(0).global_set(stack.alloc); + }, + ); - let free_id = builder.finish(Vec::new(), &mut module.funcs); + let destroy_id = builder.finish(Vec::new(), &mut module.funcs); - module.exports.add("__wbindgen_thread_destroy", free_id); + module.exports.add("__wbindgen_thread_destroy", destroy_id); Ok(()) } diff --git a/crates/threads-xform/tests/basic.wat b/crates/threads-xform/tests/basic.wat index 2c0b03aeee9..5f22078d8fc 100644 --- a/crates/threads-xform/tests/basic.wat +++ b/crates/threads-xform/tests/basic.wat @@ -76,41 +76,56 @@ global.set 1 global.get 1 call $__wasm_init_tls) - (func $__wbindgen_thread_destroy (type 0) - global.get 1 - i32.const 128 - call $__wbindgen_free - i32.const -2147483648 - global.set 1 - i32.const 393216 - global.set 2 - loop ;; label = @1 - i32.const 327684 - i32.const 0 - i32.const 1 - i32.atomic.rmw.cmpxchg - if ;; label = @2 + (func $__wbindgen_thread_destroy (type 3) (param i32 i32) + (local i32 i32) + local.get 0 + if ;; label = @1 + local.get 0 + i32.const 128 + call $__wbindgen_free + else + global.get 1 + i32.const 128 + call $__wbindgen_free + i32.const -2147483648 + global.set 1 + end + local.get 1 + if ;; label = @1 + local.get 1 + i32.const 1048576 + call $__wbindgen_free + else + i32.const 393216 + global.set 2 + loop ;; label = @2 i32.const 327684 + i32.const 0 i32.const 1 - i64.const -1 - memory.atomic.wait32 - drop - br 1 (;@1;) - else + i32.atomic.rmw.cmpxchg + if ;; label = @3 + i32.const 327684 + i32.const 1 + i64.const -1 + memory.atomic.wait32 + drop + br 1 (;@2;) + else + end end - end - global.get 3 - i32.const 1048576 - call $__wbindgen_free - i32.const 327684 - i32.const 0 - i32.atomic.store - i32.const 327684 - i32.const 1 - memory.atomic.notify - drop - i32.const 0 - global.set 3) + global.get 3 + i32.const 1048576 + call $__wbindgen_free + i32.const 327684 + i32.const 0 + i32.atomic.store + i32.const 327684 + i32.const 1 + memory.atomic.notify + drop + i32.const 0 + global.set 3 + end) (func $__wasm_init_tls (type 1) (param i32) i32.const 232323 drop) @@ -120,14 +135,15 @@ (func $__wbindgen_malloc (type 2) (param i32) (result i32) i32.const 999999) (func $__wbindgen_free (type 3) (param i32 i32)) - (global (;0;) i32 (i32.const 393216)) - (global (;1;) (mut i32) (i32.const 0)) - (global (;2;) (mut i32) (i32.const 65536)) - (global (;3;) (mut i32) (i32.const 0)) + (global (;0;) i32 i32.const 393216) + (global (;1;) (mut i32) i32.const 0) + (global (;2;) (mut i32) i32.const 65536) + (global (;3;) (mut i32) i32.const 0) (export "__wbindgen_malloc" (func $__wbindgen_malloc)) (export "__wbindgen_free" (func $__wbindgen_free)) (export "__heap_base" (global 0)) (export "__tls_base" (global 1)) + (export "__stack_alloc" (global 3)) (export "__wbindgen_thread_destroy" (func $__wbindgen_thread_destroy)) (start 0)) ;) diff --git a/crates/threads-xform/tests/unaligned.wat b/crates/threads-xform/tests/unaligned.wat index 241f7ce05d4..816868946e4 100644 --- a/crates/threads-xform/tests/unaligned.wat +++ b/crates/threads-xform/tests/unaligned.wat @@ -76,41 +76,56 @@ global.set 1 global.get 1 call $__wasm_init_tls) - (func $__wbindgen_thread_destroy (type 0) - global.get 1 - i32.const 128 - call $__wbindgen_free - i32.const -2147483648 - global.set 1 - i32.const 393216 - global.set 2 - loop ;; label = @1 - i32.const 327688 - i32.const 0 - i32.const 1 - i32.atomic.rmw.cmpxchg - if ;; label = @2 + (func $__wbindgen_thread_destroy (type 3) (param i32 i32) + (local i32 i32) + local.get 0 + if ;; label = @1 + local.get 0 + i32.const 128 + call $__wbindgen_free + else + global.get 1 + i32.const 128 + call $__wbindgen_free + i32.const -2147483648 + global.set 1 + end + local.get 1 + if ;; label = @1 + local.get 1 + i32.const 1048576 + call $__wbindgen_free + else + i32.const 393216 + global.set 2 + loop ;; label = @2 i32.const 327688 + i32.const 0 i32.const 1 - i64.const -1 - memory.atomic.wait32 - drop - br 1 (;@1;) - else + i32.atomic.rmw.cmpxchg + if ;; label = @3 + i32.const 327688 + i32.const 1 + i64.const -1 + memory.atomic.wait32 + drop + br 1 (;@2;) + else + end end - end - global.get 3 - i32.const 1048576 - call $__wbindgen_free - i32.const 327688 - i32.const 0 - i32.atomic.store - i32.const 327688 - i32.const 1 - memory.atomic.notify - drop - i32.const 0 - global.set 3) + global.get 3 + i32.const 1048576 + call $__wbindgen_free + i32.const 327688 + i32.const 0 + i32.atomic.store + i32.const 327688 + i32.const 1 + memory.atomic.notify + drop + i32.const 0 + global.set 3 + end) (func $__wasm_init_tls (type 1) (param i32) i32.const 232323 drop) @@ -120,14 +135,15 @@ (func $__wbindgen_malloc (type 2) (param i32) (result i32) i32.const 999999) (func $__wbindgen_free (type 3) (param i32 i32)) - (global (;0;) i32 (i32.const 393219)) - (global (;1;) (mut i32) (i32.const 0)) - (global (;2;) (mut i32) (i32.const 65536)) - (global (;3;) (mut i32) (i32.const 0)) + (global (;0;) i32 i32.const 393219) + (global (;1;) (mut i32) i32.const 0) + (global (;2;) (mut i32) i32.const 65536) + (global (;3;) (mut i32) i32.const 0) (export "__wbindgen_malloc" (func $__wbindgen_malloc)) (export "__wbindgen_free" (func $__wbindgen_free)) (export "__heap_base" (global 0)) (export "__tls_base" (global 1)) + (export "__stack_alloc" (global 3)) (export "__wbindgen_thread_destroy" (func $__wbindgen_thread_destroy)) (start 0)) ;)