@@ -201,6 +201,7 @@ class Opts:
201
201
202
202
def load (opts , ptr , t ):
203
203
assert (ptr == align_to (ptr , alignment (t )))
204
+ assert (ptr + size (t ) <= len (opts .memory ))
204
205
match despecialize (t ):
205
206
case Bool () : return convert_int_to_bool (load_int (opts , ptr , 1 ))
206
207
case U8 () : return load_int (opts , ptr , 1 )
@@ -223,7 +224,6 @@ def load(opts, ptr, t):
223
224
#
224
225
225
226
def load_int (opts , ptr , nbytes , signed = False ):
226
- trap_if (ptr + nbytes > len (opts .memory ))
227
227
return int .from_bytes (opts .memory [ptr : ptr + nbytes ], 'little' , signed = signed )
228
228
229
229
#
@@ -272,19 +272,23 @@ def load_string(opts, ptr):
272
272
def load_string_from_range (opts , ptr , tagged_code_units ):
273
273
match opts .string_encoding :
274
274
case 'utf8' :
275
+ alignment = 1
275
276
byte_length = tagged_code_units
276
277
encoding = 'utf-8'
277
278
case 'utf16' :
279
+ alignment = 2
278
280
byte_length = 2 * tagged_code_units
279
281
encoding = 'utf-16-le'
280
282
case 'latin1+utf16' :
283
+ alignment = 2
281
284
if bool (tagged_code_units & UTF16_TAG ):
282
285
byte_length = 2 * (tagged_code_units ^ UTF16_TAG )
283
286
encoding = 'utf-16-le'
284
287
else :
285
288
byte_length = tagged_code_units
286
289
encoding = 'latin-1'
287
290
291
+ trap_if (ptr != align_to (ptr , alignment ))
288
292
trap_if (ptr + byte_length > len (opts .memory ))
289
293
try :
290
294
s = opts .memory [ptr : ptr + byte_length ].decode (encoding )
@@ -358,6 +362,7 @@ def unpack_flags_from_int(i, labels):
358
362
359
363
def store (opts , v , t , ptr ):
360
364
assert (ptr == align_to (ptr , alignment (t )))
365
+ assert (ptr + size (t ) <= len (opts .memory ))
361
366
match despecialize (t ):
362
367
case Bool () : store_int (opts , int (bool (v )), ptr , 1 )
363
368
case U8 () : store_int (opts , v , ptr , 1 )
@@ -380,7 +385,6 @@ def store(opts, v, t, ptr):
380
385
#
381
386
382
387
def store_int (opts , v , ptr , nbytes , signed = False ):
383
- trap_if (ptr + nbytes > len (opts .memory ))
384
388
opts .memory [ptr : ptr + nbytes ] = int .to_bytes (v , nbytes , 'little' , signed = signed )
385
389
386
390
#
@@ -447,6 +451,8 @@ def store_string_copy(opts, src, src_code_units, dst_code_unit_size, dst_alignme
447
451
dst_byte_length = dst_code_unit_size * src_code_units
448
452
trap_if (dst_byte_length > MAX_STRING_BYTE_LENGTH )
449
453
ptr = opts .realloc (0 , 0 , dst_alignment , dst_byte_length )
454
+ trap_if (ptr != align_to (ptr , dst_alignment ))
455
+ trap_if (ptr + dst_byte_length > len (opts .memory ))
450
456
encoded = src .encode (dst_encoding )
451
457
assert (dst_byte_length == len (encoded ))
452
458
opts .memory [ptr : ptr + len (encoded )] = encoded
@@ -465,15 +471,18 @@ def store_latin1_to_utf8(opts, src, src_code_units):
465
471
def store_string_to_utf8 (opts , src , src_code_units , worst_case_size ):
466
472
assert (src_code_units <= MAX_STRING_BYTE_LENGTH )
467
473
ptr = opts .realloc (0 , 0 , 1 , src_code_units )
474
+ trap_if (ptr + src_code_units > len (opts .memory ))
468
475
encoded = src .encode ('utf-8' )
469
476
assert (src_code_units <= len (encoded ))
470
477
opts .memory [ptr : ptr + src_code_units ] = encoded [0 : src_code_units ]
471
478
if src_code_units < len (encoded ):
472
479
trap_if (worst_case_size > MAX_STRING_BYTE_LENGTH )
473
480
ptr = opts .realloc (ptr , src_code_units , 1 , worst_case_size )
481
+ trap_if (ptr + worst_case_size > len (opts .memory ))
474
482
opts .memory [ptr + src_code_units : ptr + len (encoded )] = encoded [src_code_units : ]
475
483
if worst_case_size > len (encoded ):
476
484
ptr = opts .realloc (ptr , worst_case_size , 1 , len (encoded ))
485
+ trap_if (ptr + len (encoded ) > len (opts .memory ))
477
486
return (ptr , len (encoded ))
478
487
479
488
#
@@ -482,10 +491,14 @@ def store_utf8_to_utf16(opts, src, src_code_units):
482
491
worst_case_size = 2 * src_code_units
483
492
trap_if (worst_case_size > MAX_STRING_BYTE_LENGTH )
484
493
ptr = opts .realloc (0 , 0 , 2 , worst_case_size )
494
+ trap_if (ptr != align_to (ptr , 2 ))
495
+ trap_if (ptr + worst_case_size > len (opts .memory ))
485
496
encoded = src .encode ('utf-16-le' )
486
497
opts .memory [ptr : ptr + len (encoded )] = encoded
487
498
if len (encoded ) < worst_case_size :
488
499
ptr = opts .realloc (ptr , worst_case_size , 2 , len (encoded ))
500
+ trap_if (ptr != align_to (ptr , 2 ))
501
+ trap_if (ptr + len (encoded ) > len (opts .memory ))
489
502
code_units = int (len (encoded ) / 2 )
490
503
return (ptr , code_units )
491
504
@@ -494,6 +507,8 @@ def store_utf8_to_utf16(opts, src, src_code_units):
494
507
def store_string_to_latin1_or_utf16 (opts , src , src_code_units ):
495
508
assert (src_code_units <= MAX_STRING_BYTE_LENGTH )
496
509
ptr = opts .realloc (0 , 0 , 2 , src_code_units )
510
+ trap_if (ptr != align_to (ptr , 2 ))
511
+ trap_if (ptr + src_code_units > len (opts .memory ))
497
512
dst_byte_length = 0
498
513
for usv in src :
499
514
if ord (usv ) < (1 << 8 ):
@@ -503,17 +518,23 @@ def store_string_to_latin1_or_utf16(opts, src, src_code_units):
503
518
worst_case_size = 2 * src_code_units
504
519
trap_if (worst_case_size > MAX_STRING_BYTE_LENGTH )
505
520
ptr = opts .realloc (ptr , src_code_units , 2 , worst_case_size )
521
+ trap_if (ptr != align_to (ptr , 2 ))
522
+ trap_if (ptr + worst_case_size > len (opts .memory ))
506
523
for j in range (dst_byte_length - 1 , - 1 , - 1 ):
507
524
opts .memory [ptr + 2 * j ] = opts .memory [ptr + j ]
508
525
opts .memory [ptr + 2 * j + 1 ] = 0
509
526
encoded = src .encode ('utf-16-le' )
510
527
opts .memory [ptr + 2 * dst_byte_length : ptr + len (encoded )] = encoded [2 * dst_byte_length : ]
511
528
if worst_case_size > len (encoded ):
512
529
ptr = opts .realloc (ptr , worst_case_size , 2 , len (encoded ))
530
+ trap_if (ptr != align_to (ptr , 2 ))
531
+ trap_if (ptr + len (encoded ) > len (opts .memory ))
513
532
tagged_code_units = int (len (encoded ) / 2 ) | UTF16_TAG
514
533
return (ptr , tagged_code_units )
515
534
if dst_byte_length < src_code_units :
516
535
ptr = opts .realloc (ptr , src_code_units , 2 , dst_byte_length )
536
+ trap_if (ptr != align_to (ptr , 2 ))
537
+ trap_if (ptr + dst_byte_length > len (opts .memory ))
517
538
return (ptr , dst_byte_length )
518
539
519
540
#
@@ -522,6 +543,8 @@ def store_probably_utf16_to_latin1_or_utf16(opts, src, src_code_units):
522
543
src_byte_length = 2 * src_code_units
523
544
trap_if (src_byte_length > MAX_STRING_BYTE_LENGTH )
524
545
ptr = opts .realloc (0 , 0 , 2 , src_byte_length )
546
+ trap_if (ptr != align_to (ptr , 2 ))
547
+ trap_if (ptr + src_byte_length > len (opts .memory ))
525
548
encoded = src .encode ('utf-16-le' )
526
549
opts .memory [ptr : ptr + len (encoded )] = encoded
527
550
if any (ord (c ) >= (1 << 8 ) for c in src ):
@@ -531,6 +554,7 @@ def store_probably_utf16_to_latin1_or_utf16(opts, src, src_code_units):
531
554
for i in range (latin1_size ):
532
555
opts .memory [ptr + i ] = opts .memory [ptr + 2 * i ]
533
556
ptr = opts .realloc (ptr , src_byte_length , 1 , latin1_size )
557
+ trap_if (ptr + latin1_size > len (opts .memory ))
534
558
return (ptr , latin1_size )
535
559
536
560
#
@@ -840,6 +864,7 @@ def lift(opts, max_flat, vi, ts):
840
864
ptr = vi .next ('i32' )
841
865
tuple_type = Tuple (ts )
842
866
trap_if (ptr != align_to (ptr , alignment (tuple_type )))
867
+ trap_if (ptr + size (tuple_type ) > len (opts .memory ))
843
868
return list (load (opts , ptr , tuple_type ).values ())
844
869
else :
845
870
return [ lift_flat (opts , vi , t ) for t in ts ]
@@ -856,6 +881,7 @@ def lower(opts, max_flat, vs, ts, out_param = None):
856
881
else :
857
882
ptr = out_param .next ('i32' )
858
883
trap_if (ptr != align_to (ptr , alignment (tuple_type )))
884
+ trap_if (ptr + size (tuple_type ) > len (opts .memory ))
859
885
store (opts , tuple_value , tuple_type , ptr )
860
886
return [ Value ('i32' , ptr ) ]
861
887
else :
0 commit comments