Skip to content

Commit ca2851c

Browse files
authored
Merge pull request #81 from WebAssembly/fix-traps
Add some missing trap cases
2 parents e06f269 + 92173ef commit ca2851c

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

design/mvp/CanonicalABI.md

+28
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ class Opts:
207207

208208
def load(opts, ptr, t):
209209
assert(ptr == align_to(ptr, alignment(t)))
210+
assert(ptr + size(t) <= len(opts.memory))
210211
match despecialize(t):
211212
case Bool() : return convert_int_to_bool(load_int(opts, ptr, 1))
212213
case U8() : return load_int(opts, ptr, 1)
@@ -297,19 +298,23 @@ UTF16_TAG = 1 << 31
297298
def load_string_from_range(opts, ptr, tagged_code_units):
298299
match opts.string_encoding:
299300
case 'utf8':
301+
alignment = 1
300302
byte_length = tagged_code_units
301303
encoding = 'utf-8'
302304
case 'utf16':
305+
alignment = 2
303306
byte_length = 2 * tagged_code_units
304307
encoding = 'utf-16-le'
305308
case 'latin1+utf16':
309+
alignment = 2
306310
if bool(tagged_code_units & UTF16_TAG):
307311
byte_length = 2 * (tagged_code_units ^ UTF16_TAG)
308312
encoding = 'utf-16-le'
309313
else:
310314
byte_length = tagged_code_units
311315
encoding = 'latin-1'
312316

317+
trap_if(ptr != align_to(ptr, alignment))
313318
trap_if(ptr + byte_length > len(opts.memory))
314319
try:
315320
s = opts.memory[ptr : ptr+byte_length].decode(encoding)
@@ -403,6 +408,7 @@ The `store` function defines how to write a value `v` of a given value type
403408
```python
404409
def store(opts, v, t, ptr):
405410
assert(ptr == align_to(ptr, alignment(t)))
411+
assert(ptr + size(t) <= len(opts.memory))
406412
match despecialize(t):
407413
case Bool() : store_int(opts, int(bool(v)), ptr, 1)
408414
case U8() : store_int(opts, v, ptr, 1)
@@ -522,6 +528,8 @@ def store_string_copy(opts, src, src_code_units, dst_code_unit_size, dst_alignme
522528
dst_byte_length = dst_code_unit_size * src_code_units
523529
trap_if(dst_byte_length > MAX_STRING_BYTE_LENGTH)
524530
ptr = opts.realloc(0, 0, dst_alignment, dst_byte_length)
531+
trap_if(ptr != align_to(ptr, dst_alignment))
532+
trap_if(ptr + dst_byte_length > len(opts.memory))
525533
encoded = src.encode(dst_encoding)
526534
assert(dst_byte_length == len(encoded))
527535
opts.memory[ptr : ptr+len(encoded)] = encoded
@@ -546,15 +554,18 @@ def store_latin1_to_utf8(opts, src, src_code_units):
546554
def store_string_to_utf8(opts, src, src_code_units, worst_case_size):
547555
assert(src_code_units <= MAX_STRING_BYTE_LENGTH)
548556
ptr = opts.realloc(0, 0, 1, src_code_units)
557+
trap_if(ptr + src_code_units > len(opts.memory))
549558
encoded = src.encode('utf-8')
550559
assert(src_code_units <= len(encoded))
551560
opts.memory[ptr : ptr+src_code_units] = encoded[0 : src_code_units]
552561
if src_code_units < len(encoded):
553562
trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH)
554563
ptr = opts.realloc(ptr, src_code_units, 1, worst_case_size)
564+
trap_if(ptr + worst_case_size > len(opts.memory))
555565
opts.memory[ptr+src_code_units : ptr+len(encoded)] = encoded[src_code_units : ]
556566
if worst_case_size > len(encoded):
557567
ptr = opts.realloc(ptr, worst_case_size, 1, len(encoded))
568+
trap_if(ptr + len(encoded) > len(opts.memory))
558569
return (ptr, len(encoded))
559570
```
560571

@@ -567,10 +578,14 @@ def store_utf8_to_utf16(opts, src, src_code_units):
567578
worst_case_size = 2 * src_code_units
568579
trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH)
569580
ptr = opts.realloc(0, 0, 2, worst_case_size)
581+
trap_if(ptr != align_to(ptr, 2))
582+
trap_if(ptr + worst_case_size > len(opts.memory))
570583
encoded = src.encode('utf-16-le')
571584
opts.memory[ptr : ptr+len(encoded)] = encoded
572585
if len(encoded) < worst_case_size:
573586
ptr = opts.realloc(ptr, worst_case_size, 2, len(encoded))
587+
trap_if(ptr != align_to(ptr, 2))
588+
trap_if(ptr + len(encoded) > len(opts.memory))
574589
code_units = int(len(encoded) / 2)
575590
return (ptr, code_units)
576591
```
@@ -587,6 +602,8 @@ bytes):
587602
def store_string_to_latin1_or_utf16(opts, src, src_code_units):
588603
assert(src_code_units <= MAX_STRING_BYTE_LENGTH)
589604
ptr = opts.realloc(0, 0, 2, src_code_units)
605+
trap_if(ptr != align_to(ptr, 2))
606+
trap_if(ptr + src_code_units > len(opts.memory))
590607
dst_byte_length = 0
591608
for usv in src:
592609
if ord(usv) < (1 << 8):
@@ -596,17 +613,23 @@ def store_string_to_latin1_or_utf16(opts, src, src_code_units):
596613
worst_case_size = 2 * src_code_units
597614
trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH)
598615
ptr = opts.realloc(ptr, src_code_units, 2, worst_case_size)
616+
trap_if(ptr != align_to(ptr, 2))
617+
trap_if(ptr + worst_case_size > len(opts.memory))
599618
for j in range(dst_byte_length-1, -1, -1):
600619
opts.memory[ptr + 2*j] = opts.memory[ptr + j]
601620
opts.memory[ptr + 2*j + 1] = 0
602621
encoded = src.encode('utf-16-le')
603622
opts.memory[ptr+2*dst_byte_length : ptr+len(encoded)] = encoded[2*dst_byte_length : ]
604623
if worst_case_size > len(encoded):
605624
ptr = opts.realloc(ptr, worst_case_size, 2, len(encoded))
625+
trap_if(ptr != align_to(ptr, 2))
626+
trap_if(ptr + len(encoded) > len(opts.memory))
606627
tagged_code_units = int(len(encoded) / 2) | UTF16_TAG
607628
return (ptr, tagged_code_units)
608629
if dst_byte_length < src_code_units:
609630
ptr = opts.realloc(ptr, src_code_units, 2, dst_byte_length)
631+
trap_if(ptr != align_to(ptr, 2))
632+
trap_if(ptr + dst_byte_length > len(opts.memory))
610633
return (ptr, dst_byte_length)
611634
```
612635

@@ -625,6 +648,8 @@ def store_probably_utf16_to_latin1_or_utf16(opts, src, src_code_units):
625648
src_byte_length = 2 * src_code_units
626649
trap_if(src_byte_length > MAX_STRING_BYTE_LENGTH)
627650
ptr = opts.realloc(0, 0, 2, src_byte_length)
651+
trap_if(ptr != align_to(ptr, 2))
652+
trap_if(ptr + src_byte_length > len(opts.memory))
628653
encoded = src.encode('utf-16-le')
629654
opts.memory[ptr : ptr+len(encoded)] = encoded
630655
if any(ord(c) >= (1 << 8) for c in src):
@@ -634,6 +659,7 @@ def store_probably_utf16_to_latin1_or_utf16(opts, src, src_code_units):
634659
for i in range(latin1_size):
635660
opts.memory[ptr + i] = opts.memory[ptr + 2*i]
636661
ptr = opts.realloc(ptr, src_byte_length, 1, latin1_size)
662+
trap_if(ptr + latin1_size > len(opts.memory))
637663
return (ptr, latin1_size)
638664
```
639665

@@ -1046,6 +1072,7 @@ def lift(opts, max_flat, vi, ts):
10461072
ptr = vi.next('i32')
10471073
tuple_type = Tuple(ts)
10481074
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
1075+
trap_if(ptr + size(tuple_type) > len(opts.memory))
10491076
return list(load(opts, ptr, tuple_type).values())
10501077
else:
10511078
return [ lift_flat(opts, vi, t) for t in ts ]
@@ -1067,6 +1094,7 @@ def lower(opts, max_flat, vs, ts, out_param = None):
10671094
else:
10681095
ptr = out_param.next('i32')
10691096
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
1097+
trap_if(ptr + size(tuple_type) > len(opts.memory))
10701098
store(opts, tuple_value, tuple_type, ptr)
10711099
return [ Value('i32', ptr) ]
10721100
else:

design/mvp/canonical-abi/definitions.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ class Opts:
201201

202202
def load(opts, ptr, t):
203203
assert(ptr == align_to(ptr, alignment(t)))
204+
assert(ptr + size(t) <= len(opts.memory))
204205
match despecialize(t):
205206
case Bool() : return convert_int_to_bool(load_int(opts, ptr, 1))
206207
case U8() : return load_int(opts, ptr, 1)
@@ -223,7 +224,6 @@ def load(opts, ptr, t):
223224
#
224225

225226
def load_int(opts, ptr, nbytes, signed = False):
226-
trap_if(ptr + nbytes > len(opts.memory))
227227
return int.from_bytes(opts.memory[ptr : ptr+nbytes], 'little', signed=signed)
228228

229229
#
@@ -272,19 +272,23 @@ def load_string(opts, ptr):
272272
def load_string_from_range(opts, ptr, tagged_code_units):
273273
match opts.string_encoding:
274274
case 'utf8':
275+
alignment = 1
275276
byte_length = tagged_code_units
276277
encoding = 'utf-8'
277278
case 'utf16':
279+
alignment = 2
278280
byte_length = 2 * tagged_code_units
279281
encoding = 'utf-16-le'
280282
case 'latin1+utf16':
283+
alignment = 2
281284
if bool(tagged_code_units & UTF16_TAG):
282285
byte_length = 2 * (tagged_code_units ^ UTF16_TAG)
283286
encoding = 'utf-16-le'
284287
else:
285288
byte_length = tagged_code_units
286289
encoding = 'latin-1'
287290

291+
trap_if(ptr != align_to(ptr, alignment))
288292
trap_if(ptr + byte_length > len(opts.memory))
289293
try:
290294
s = opts.memory[ptr : ptr+byte_length].decode(encoding)
@@ -358,6 +362,7 @@ def unpack_flags_from_int(i, labels):
358362

359363
def store(opts, v, t, ptr):
360364
assert(ptr == align_to(ptr, alignment(t)))
365+
assert(ptr + size(t) <= len(opts.memory))
361366
match despecialize(t):
362367
case Bool() : store_int(opts, int(bool(v)), ptr, 1)
363368
case U8() : store_int(opts, v, ptr, 1)
@@ -380,7 +385,6 @@ def store(opts, v, t, ptr):
380385
#
381386

382387
def store_int(opts, v, ptr, nbytes, signed = False):
383-
trap_if(ptr + nbytes > len(opts.memory))
384388
opts.memory[ptr : ptr+nbytes] = int.to_bytes(v, nbytes, 'little', signed=signed)
385389

386390
#
@@ -447,6 +451,8 @@ def store_string_copy(opts, src, src_code_units, dst_code_unit_size, dst_alignme
447451
dst_byte_length = dst_code_unit_size * src_code_units
448452
trap_if(dst_byte_length > MAX_STRING_BYTE_LENGTH)
449453
ptr = opts.realloc(0, 0, dst_alignment, dst_byte_length)
454+
trap_if(ptr != align_to(ptr, dst_alignment))
455+
trap_if(ptr + dst_byte_length > len(opts.memory))
450456
encoded = src.encode(dst_encoding)
451457
assert(dst_byte_length == len(encoded))
452458
opts.memory[ptr : ptr+len(encoded)] = encoded
@@ -465,15 +471,18 @@ def store_latin1_to_utf8(opts, src, src_code_units):
465471
def store_string_to_utf8(opts, src, src_code_units, worst_case_size):
466472
assert(src_code_units <= MAX_STRING_BYTE_LENGTH)
467473
ptr = opts.realloc(0, 0, 1, src_code_units)
474+
trap_if(ptr + src_code_units > len(opts.memory))
468475
encoded = src.encode('utf-8')
469476
assert(src_code_units <= len(encoded))
470477
opts.memory[ptr : ptr+src_code_units] = encoded[0 : src_code_units]
471478
if src_code_units < len(encoded):
472479
trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH)
473480
ptr = opts.realloc(ptr, src_code_units, 1, worst_case_size)
481+
trap_if(ptr + worst_case_size > len(opts.memory))
474482
opts.memory[ptr+src_code_units : ptr+len(encoded)] = encoded[src_code_units : ]
475483
if worst_case_size > len(encoded):
476484
ptr = opts.realloc(ptr, worst_case_size, 1, len(encoded))
485+
trap_if(ptr + len(encoded) > len(opts.memory))
477486
return (ptr, len(encoded))
478487

479488
#
@@ -482,10 +491,14 @@ def store_utf8_to_utf16(opts, src, src_code_units):
482491
worst_case_size = 2 * src_code_units
483492
trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH)
484493
ptr = opts.realloc(0, 0, 2, worst_case_size)
494+
trap_if(ptr != align_to(ptr, 2))
495+
trap_if(ptr + worst_case_size > len(opts.memory))
485496
encoded = src.encode('utf-16-le')
486497
opts.memory[ptr : ptr+len(encoded)] = encoded
487498
if len(encoded) < worst_case_size:
488499
ptr = opts.realloc(ptr, worst_case_size, 2, len(encoded))
500+
trap_if(ptr != align_to(ptr, 2))
501+
trap_if(ptr + len(encoded) > len(opts.memory))
489502
code_units = int(len(encoded) / 2)
490503
return (ptr, code_units)
491504

@@ -494,6 +507,8 @@ def store_utf8_to_utf16(opts, src, src_code_units):
494507
def store_string_to_latin1_or_utf16(opts, src, src_code_units):
495508
assert(src_code_units <= MAX_STRING_BYTE_LENGTH)
496509
ptr = opts.realloc(0, 0, 2, src_code_units)
510+
trap_if(ptr != align_to(ptr, 2))
511+
trap_if(ptr + src_code_units > len(opts.memory))
497512
dst_byte_length = 0
498513
for usv in src:
499514
if ord(usv) < (1 << 8):
@@ -503,17 +518,23 @@ def store_string_to_latin1_or_utf16(opts, src, src_code_units):
503518
worst_case_size = 2 * src_code_units
504519
trap_if(worst_case_size > MAX_STRING_BYTE_LENGTH)
505520
ptr = opts.realloc(ptr, src_code_units, 2, worst_case_size)
521+
trap_if(ptr != align_to(ptr, 2))
522+
trap_if(ptr + worst_case_size > len(opts.memory))
506523
for j in range(dst_byte_length-1, -1, -1):
507524
opts.memory[ptr + 2*j] = opts.memory[ptr + j]
508525
opts.memory[ptr + 2*j + 1] = 0
509526
encoded = src.encode('utf-16-le')
510527
opts.memory[ptr+2*dst_byte_length : ptr+len(encoded)] = encoded[2*dst_byte_length : ]
511528
if worst_case_size > len(encoded):
512529
ptr = opts.realloc(ptr, worst_case_size, 2, len(encoded))
530+
trap_if(ptr != align_to(ptr, 2))
531+
trap_if(ptr + len(encoded) > len(opts.memory))
513532
tagged_code_units = int(len(encoded) / 2) | UTF16_TAG
514533
return (ptr, tagged_code_units)
515534
if dst_byte_length < src_code_units:
516535
ptr = opts.realloc(ptr, src_code_units, 2, dst_byte_length)
536+
trap_if(ptr != align_to(ptr, 2))
537+
trap_if(ptr + dst_byte_length > len(opts.memory))
517538
return (ptr, dst_byte_length)
518539

519540
#
@@ -522,6 +543,8 @@ def store_probably_utf16_to_latin1_or_utf16(opts, src, src_code_units):
522543
src_byte_length = 2 * src_code_units
523544
trap_if(src_byte_length > MAX_STRING_BYTE_LENGTH)
524545
ptr = opts.realloc(0, 0, 2, src_byte_length)
546+
trap_if(ptr != align_to(ptr, 2))
547+
trap_if(ptr + src_byte_length > len(opts.memory))
525548
encoded = src.encode('utf-16-le')
526549
opts.memory[ptr : ptr+len(encoded)] = encoded
527550
if any(ord(c) >= (1 << 8) for c in src):
@@ -531,6 +554,7 @@ def store_probably_utf16_to_latin1_or_utf16(opts, src, src_code_units):
531554
for i in range(latin1_size):
532555
opts.memory[ptr + i] = opts.memory[ptr + 2*i]
533556
ptr = opts.realloc(ptr, src_byte_length, 1, latin1_size)
557+
trap_if(ptr + latin1_size > len(opts.memory))
534558
return (ptr, latin1_size)
535559

536560
#
@@ -840,6 +864,7 @@ def lift(opts, max_flat, vi, ts):
840864
ptr = vi.next('i32')
841865
tuple_type = Tuple(ts)
842866
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
867+
trap_if(ptr + size(tuple_type) > len(opts.memory))
843868
return list(load(opts, ptr, tuple_type).values())
844869
else:
845870
return [ lift_flat(opts, vi, t) for t in ts ]
@@ -856,6 +881,7 @@ def lower(opts, max_flat, vs, ts, out_param = None):
856881
else:
857882
ptr = out_param.next('i32')
858883
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
884+
trap_if(ptr + size(tuple_type) > len(opts.memory))
859885
store(opts, tuple_value, tuple_type, ptr)
860886
return [ Value('i32', ptr) ]
861887
else:

0 commit comments

Comments
 (0)