Skip to content

Commit 88b4f28

Browse files
Merge pull request python#10 from darkk/master
Fix lz4.frame.decompress crash caused by incorrect realloc() usage
2 parents d0ef857 + 216303e commit 88b4f28

File tree

5 files changed

+111
-67
lines changed

5 files changed

+111
-67
lines changed

lz4/block/_block.c

+6
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
283283
PyErr_Format (PyExc_ValueError, "Corrupt input at byte %d", -output_size);
284284
Py_CLEAR (py_dest);
285285
}
286+
else if ((size_t)output_size != dest_size)
287+
{
288+
// IMHO, it's better to fail explicitly than to allow fishy data to pass through.
289+
PyErr_Format (PyExc_ValueError, "Decompressor writes %d bytes, but %zu are claimed", output_size, dest_size);
290+
Py_CLEAR (py_dest);
291+
}
286292

287293
return py_dest;
288294
}

lz4/frame/_frame.c

+34-59
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ typedef unsigned int uint32_t;
6262
#define inline
6363
#endif
6464

65+
static const char * capsule_name = "_frame.LZ4F_cctx";
66+
static void destruct_compression_context (PyObject * py_context);
6567
struct compression_context
6668
{
6769
LZ4F_compressionContext_t compression_context;
@@ -114,57 +116,26 @@ create_compression_context (PyObject * Py_UNUSED (self),
114116
return NULL;
115117
}
116118

117-
return PyCapsule_New (context, NULL, NULL);
119+
return PyCapsule_New (context, capsule_name, destruct_compression_context);
118120
}
119121

120-
/****************************
121-
* free_compression_context *
122-
****************************/
123-
PyDoc_STRVAR(free_compression_context__doc,
124-
"free_compression_context(context)\n\n" \
125-
"Releases the resources held by a compression context previously\n" \
126-
"created with create_compression_context.\n\n" \
127-
"Args:\n" \
128-
" context (cCtx): Compression context.\n"
129-
);
130-
131-
static PyObject *
132-
free_compression_context (PyObject * Py_UNUSED (self), PyObject * args,
133-
PyObject * keywds)
122+
static void
123+
destruct_compression_context (PyObject * py_context)
134124
{
135-
PyObject *py_context = NULL;
136-
static char *kwlist[] = { "context", NULL };
137-
struct compression_context *context;
138-
LZ4F_errorCode_t result;
139-
140-
if (!PyArg_ParseTupleAndKeywords (args, keywds, "O", kwlist, &py_context))
141-
{
142-
return NULL;
143-
}
144-
145-
context =
146-
(struct compression_context *) PyCapsule_GetPointer (py_context, NULL);
147-
if (!context)
148-
{
149-
PyErr_Format (PyExc_ValueError, "No compression context supplied");
150-
return NULL;
151-
}
125+
struct compression_context *context =
126+
#ifndef PyCapsule_Type
127+
PyCapsule_GetPointer (py_context, capsule_name);
128+
// That's always true as there is no PyCapsule_SetPointer calls.
129+
#else
130+
py_context; // compatibility with 2.6 via capsulethunk
131+
#endif
152132

153133
Py_BEGIN_ALLOW_THREADS
154-
result =
155-
LZ4F_freeCompressionContext (context->compression_context);
134+
LZ4F_freeCompressionContext (context->compression_context);
135+
// That's always LZ4F_OK_NoError as free() is `void free()` and it's just a wrapper.
156136
Py_END_ALLOW_THREADS
157137

158-
if (LZ4F_isError (result))
159-
{
160-
PyErr_Format (PyExc_RuntimeError,
161-
"LZ4F_freeCompressionContext failed with code: %s",
162-
LZ4F_getErrorName (result));
163-
return NULL;
164-
}
165138
PyMem_Free (context);
166-
167-
Py_RETURN_NONE;
168139
}
169140

170141
/******************
@@ -185,9 +156,10 @@ free_compression_context (PyObject * Py_UNUSED (self), PyObject * args,
185156
" - BLOCKMODE_LINKED or 1: linked mode\n\n" \
186157
" The default is BLOCKMODE_INDEPENDENT.\n" \
187158
" compression_level (int): Specifies the level of compression used.\n" \
188-
" Values between 0-16 are valid, with 0 (default) being the\n" \
189-
" lowest compression, and 16 the highest. Values above 16 will\n" \
190-
" be treated as 16. Values betwee 3-6 are recommended.\n" \
159+
" Values between 0-16 are valid, with 0 (default) being the\n" \
160+
" lowest compression (0-2 are the same value), and 16 the highest.\n" \
161+
" Values above 16 will be treated as 16.\n" \
162+
" Values between 4-9 are recommended.\n" \
191163
" The following module constants are provided as a convenience:\n\n" \
192164
" - COMPRESSIONLEVEL_MIN: Minimum compression (0, the default)\n" \
193165
" - COMPRESSIONLEVEL_MINHC: Minimum high-compression mode (3)\n" \
@@ -381,7 +353,7 @@ compress_begin (PyObject * Py_UNUSED (self), PyObject * args,
381353
preferences.frameInfo.contentSize = source_size;
382354

383355
context =
384-
(struct compression_context *) PyCapsule_GetPointer (py_context, NULL);
356+
(struct compression_context *) PyCapsule_GetPointer (py_context, capsule_name);
385357

386358
if (!context || !context->compression_context)
387359
{
@@ -448,7 +420,7 @@ compress_update (PyObject * Py_UNUSED (self), PyObject * args,
448420
}
449421

450422
context =
451-
(struct compression_context *) PyCapsule_GetPointer (py_context, NULL);
423+
(struct compression_context *) PyCapsule_GetPointer (py_context, capsule_name);
452424
if (!context || !context->compression_context)
453425
{
454426
PyErr_Format (PyExc_ValueError, "No compression context supplied");
@@ -542,7 +514,7 @@ compress_end (PyObject * Py_UNUSED (self), PyObject * args, PyObject * keywds)
542514
}
543515

544516
context =
545-
(struct compression_context *) PyCapsule_GetPointer (py_context, NULL);
517+
(struct compression_context *) PyCapsule_GetPointer (py_context, capsule_name);
546518
if (!context || !context->compression_context)
547519
{
548520
PyErr_SetString (PyExc_ValueError, "No compression context supplied");
@@ -805,6 +777,7 @@ decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * keywds)
805777
}
806778

807779
destination_written += destination_write;
780+
source_cursor += source_read;
808781

809782
if (result == 0)
810783
{
@@ -813,11 +786,10 @@ decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * keywds)
813786

814787
if (destination_written == destination_size)
815788
{
816-
/* Destination_buffer is full, so need to expand it. We'll expand
817-
it by the approximate size needed from the return value - see
818-
LZ4 docs. */
819-
destination_size += result;
820-
if (!PyMem_Realloc(destination_buffer, destination_size))
789+
/* Destination_buffer is full, so need to expand it. */
790+
destination_size *= 2;
791+
char * nextgen = PyMem_Realloc(destination_buffer, destination_size);
792+
if (!nextgen)
821793
{
822794
LZ4F_freeDecompressionContext (context);
823795
Py_BLOCK_THREADS
@@ -826,14 +798,14 @@ decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * keywds)
826798
PyMem_Free (destination_buffer);
827799
return NULL;
828800
}
801+
destination_buffer = nextgen;
829802
}
830803
/* Data still remaining to be decompressed, so increment the source and
831804
destination cursor locations, and reset source_read and
832805
destination_write ready for the next iteration. Important to
833806
re-initialize destination_cursor here (as opposed to simply
834807
incrementing it) so we're pointing to the realloc'd memory location. */
835808
destination_cursor = destination_buffer + destination_written;
836-
source_cursor += source_read;
837809
destination_write = destination_size - destination_written;
838810
source_read = source_end - source_cursor;
839811
}
@@ -850,6 +822,13 @@ decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * keywds)
850822
LZ4F_getErrorName (result));
851823
return NULL;
852824
}
825+
else if (source_cursor != source_end)
826+
{
827+
PyMem_Free (destination_buffer);
828+
PyErr_Format (PyExc_ValueError,
829+
"Extra data: %zd trailing bytes", source_end - source_cursor);
830+
return NULL;
831+
}
853832

854833
py_dest = PyBytes_FromStringAndSize (destination_buffer, destination_written);
855834

@@ -869,10 +848,6 @@ static PyMethodDef module_methods[] =
869848
"create_compression_context", (PyCFunction) create_compression_context,
870849
METH_VARARGS | METH_KEYWORDS, create_compression_context__doc
871850
},
872-
{
873-
"free_compression_context", (PyCFunction) free_compression_context,
874-
METH_VARARGS | METH_KEYWORDS, free_compression_context__doc
875-
},
876851
{
877852
"compress", (PyCFunction) compress,
878853
METH_VARARGS | METH_KEYWORDS, compress__doc

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def version_scheme(version):
1111
version = guess_next_dev_version(version)
1212
return version.lstrip("v")
1313

14-
LZ4_VERSION = "r131"
14+
LZ4_VERSION = "1.7.4.2"
1515

1616
def library_is_installed(libname):
1717
# Check to see if we have a library called'libname' installed on the

tests/test_block.py

+33
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,39 @@ def roundtriphc(x):
170170
assert data == out
171171
pool.close()
172172

173+
def test_block_format(self):
174+
data = lz4.compress(b'A' * 64)
175+
self.assertEqual(data[:4], b'\x40\0\0\0')
176+
self.assertEqual(lz4.decompress(data[4:], uncompressed_size=64), b'A' * 64)
177+
178+
class TestLZ4BlockModern(unittest.TestCase):
179+
def test_decompress_ui32_overflow(self):
180+
data = lz4.compress(b'A' * 64)
181+
with self.assertRaisesRegexp(OverflowError, r'^signed integer is greater than maximum$'):
182+
lz4.decompress(data[4:], uncompressed_size=((1<<32) + 64))
183+
184+
def test_decompress_without_leak(self):
185+
# Verify that hand-crafted packet does not leak uninitialized(?) memory.
186+
data = lz4.compress(b'A' * 64)
187+
with self.assertRaisesRegexp(ValueError, r'^Decompressor writes 64 bytes, but 79 are claimed$'):
188+
lz4.decompress(b'\x4f' + data[1:])
189+
with self.assertRaisesRegexp(ValueError, r'^Decompressor writes 64 bytes, but 79 are claimed$'):
190+
lz4.decompress(data[4:], uncompressed_size=79)
191+
192+
def test_decompress_with_trailer(self):
193+
data = b'A' * 64
194+
comp = lz4.compress(data)
195+
with self.assertRaisesRegexp(ValueError, r'^Corrupt input at byte'):
196+
self.assertEqual(data, lz4.block.decompress(comp + b'A'))
197+
with self.assertRaisesRegexp(ValueError, r'^Corrupt input at byte'):
198+
self.assertEqual(data, lz4.block.decompress(comp + comp))
199+
with self.assertRaisesRegexp(ValueError, r'^Corrupt input at byte'):
200+
self.assertEqual(data, lz4.block.decompress(comp + comp[4:]))
201+
202+
if sys.version_info < (2, 7):
203+
# Poor-man unittest.TestCase.skip for Python 2.6
204+
del TestLZ4BlockModern
205+
173206
if __name__ == '__main__':
174207
unittest.main()
175208

tests/test_frame.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import lz4.frame as lz4frame
22
import unittest
33
import os
4+
import sys
45
from multiprocessing.pool import ThreadPool
56

67
class TestLZ4Frame(unittest.TestCase):
78
def test_create_and_free_compression_context(self):
89
context = lz4frame.create_compression_context()
910
self.assertNotEqual(context, None)
10-
lz4frame.free_compression_context(context)
1111

1212
def test_compress(self):
1313
input_data = b"2099023098234882923049823094823094898239230982349081231290381209380981203981209381238901283098908123109238098123"
@@ -25,7 +25,28 @@ def test_compress_begin_update_end(self):
2525
compressed += lz4frame.compress_update(context, input_data[:chunk_size])
2626
compressed += lz4frame.compress_update(context, input_data[chunk_size:])
2727
compressed += lz4frame.compress_end(context)
28-
lz4frame.free_compression_context(context)
28+
decompressed = lz4frame.decompress(compressed)
29+
self.assertEqual(input_data, decompressed)
30+
31+
def test_compress_huge_with_size(self):
32+
context = lz4frame.create_compression_context()
33+
input_data = b"2099023098234882923049823094823094898239230982349081231290381209380981203981209381238901283098908123109238098123" * 4096
34+
chunk_size = int((len(input_data)/2)+1)
35+
compressed = lz4frame.compress_begin(context, source_size=len(input_data))
36+
compressed += lz4frame.compress_update(context, input_data[:chunk_size])
37+
compressed += lz4frame.compress_update(context, input_data[chunk_size:])
38+
compressed += lz4frame.compress_end(context)
39+
decompressed = lz4frame.decompress(compressed)
40+
self.assertEqual(input_data, decompressed)
41+
42+
def test_compress_huge_without_size(self):
43+
context = lz4frame.create_compression_context()
44+
input_data = b"2099023098234882923049823094823094898239230982349081231290381209380981203981209381238901283098908123109238098123" * 4096
45+
chunk_size = int((len(input_data)/2)+1)
46+
compressed = lz4frame.compress_begin(context)
47+
compressed += lz4frame.compress_update(context, input_data[:chunk_size])
48+
compressed += lz4frame.compress_update(context, input_data[chunk_size:])
49+
compressed += lz4frame.compress_end(context)
2950
decompressed = lz4frame.decompress(compressed)
3051
self.assertEqual(input_data, decompressed)
3152

@@ -86,7 +107,6 @@ def test_compress_begin_update_end_no_auto_flush(self):
86107
compressed += lz4frame.compress_update(context, input_data[:chunk_size])
87108
compressed += lz4frame.compress_update(context, input_data[chunk_size:])
88109
compressed += lz4frame.compress_end(context)
89-
lz4frame.free_compression_context(context)
90110
decompressed = lz4frame.decompress(compressed)
91111
self.assertEqual(input_data, decompressed)
92112

@@ -105,7 +125,6 @@ def test_compress_begin_update_end_no_auto_flush_2(self):
105125
end = start + chunk_size
106126

107127
compressed += lz4frame.compress_end(context)
108-
lz4frame.free_compression_context(context)
109128
decompressed = lz4frame.decompress(compressed)
110129
self.assertEqual(input_data, decompressed)
111130

@@ -130,7 +149,6 @@ def test_compress_begin_update_end_not_defaults(self):
130149
end = start + chunk_size
131150

132151
compressed += lz4frame.compress_end(context)
133-
lz4frame.free_compression_context(context)
134152
decompressed = lz4frame.decompress(compressed)
135153
self.assertEqual(input_data, decompressed)
136154

@@ -155,7 +173,6 @@ def test_compress_begin_update_end_no_auto_flush_not_defaults(self):
155173
end = start + chunk_size
156174

157175
compressed += lz4frame.compress_end(context)
158-
lz4frame.free_compression_context(context)
159176
decompressed = lz4frame.decompress(compressed)
160177
self.assertEqual(input_data, decompressed)
161178

@@ -216,7 +233,6 @@ def roundtrip(x):
216233
end = start + chunk_size
217234

218235
compressed += lz4frame.compress_end(context)
219-
lz4frame.free_compression_context(context)
220236
decompressed = lz4frame.decompress(compressed)
221237
return decompressed
222238

@@ -225,5 +241,19 @@ def roundtrip(x):
225241
pool.close()
226242
self.assertEqual(data, out)
227243

244+
class TestLZ4FrameModern(unittest.TestCase):
245+
def test_decompress_trailer(self):
246+
input_data = b"2099023098234882923049823094823094898239230982349081231290381209380981203981209381238901283098908123109238098123"
247+
compressed = lz4frame.compress(input_data)
248+
with self.assertRaisesRegexp(ValueError, r'^Extra data: 64 trailing bytes'):
249+
lz4frame.decompress(compressed + b'A'*64)
250+
# This API does not support frame concatenation!
251+
with self.assertRaisesRegexp(ValueError, r'^Extra data: \d+ trailing bytes'):
252+
lz4frame.decompress(compressed + compressed)
253+
254+
if sys.version_info < (2, 7):
255+
# Poor-man unittest.TestCase.skip for Python 2.6
256+
del TestLZ4FrameModern
257+
228258
if __name__ == '__main__':
229259
unittest.main()

0 commit comments

Comments
 (0)