Skip to content

Commit 175bfb2

Browse files
Merge pull request python#38 from pitrou/accept_buffer_objects
Issue python#16: accept arbitrary buffer-compatible objects
2 parents 51a2be4 + 580c35e commit 175bfb2

File tree

2 files changed

+93
-93
lines changed

2 files changed

+93
-93
lines changed

lz4/block/_block.c

+56-73
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
109109
char *dest, *dest_start;
110110
compression_type comp;
111111
int output_size;
112+
Py_buffer source;
113+
int source_size;
112114

113115
#if IS_PY3
114116
static char *argnames[] = {
@@ -121,36 +123,14 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
121123
NULL
122124
};
123125
int return_bytearray = 0;
124-
PyObject * py_source;
125-
Py_ssize_t source_size;
126-
char * source;
127-
if (!PyArg_ParseTupleAndKeywords (args, kwargs, "O|spiip", argnames,
128-
&py_source,
129-
&mode, &store_size, &acceleration,
130-
&compression, &return_bytearray))
126+
127+
if (!PyArg_ParseTupleAndKeywords (args, kwargs, "y*|siiip", argnames,
128+
&source,
129+
&mode, &store_size, &acceleration, &compression,
130+
&return_bytearray))
131131
{
132132
return NULL;
133133
}
134-
if (PyByteArray_Check(py_source))
135-
{
136-
source = PyByteArray_AsString(py_source);
137-
if (source == NULL)
138-
{
139-
PyErr_SetString (PyExc_ValueError, "Failed to access source bytearray object");
140-
return NULL;
141-
}
142-
source_size = PyByteArray_GET_SIZE(py_source);
143-
}
144-
else
145-
{
146-
source = PyBytes_AsString(py_source);
147-
if (source == NULL)
148-
{
149-
PyErr_SetString (PyExc_ValueError, "Failed to access source object");
150-
return NULL;
151-
}
152-
source_size = PyBytes_GET_SIZE(py_source);
153-
}
154134
#else
155135
static char *argnames[] = {
156136
"source",
@@ -160,16 +140,22 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
160140
"compression",
161141
NULL
162142
};
163-
const char *source;
164-
int source_size;
165-
if (!PyArg_ParseTupleAndKeywords (args, kwargs, "s#|siii", argnames,
166-
&source, &source_size,
143+
if (!PyArg_ParseTupleAndKeywords (args, kwargs, "s*|siii", argnames,
144+
&source,
167145
&mode, &store_size, &acceleration, &compression))
168146
{
169147
return NULL;
170148
}
171149
#endif
172150

151+
source_size = (int) source.len;
152+
if (source.len != (Py_ssize_t) source_size)
153+
{
154+
PyBuffer_Release(&source);
155+
PyErr_Format(PyExc_OverflowError, "Input too large for C 'int'");
156+
return NULL;
157+
}
158+
173159
if (!strncmp (mode, "default", sizeof ("default")))
174160
{
175161
comp = DEFAULT;
@@ -184,6 +170,7 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
184170
}
185171
else
186172
{
173+
PyBuffer_Release(&source);
187174
PyErr_Format (PyExc_ValueError,
188175
"Invalid mode argument: %s. Must be one of: standard, fast, high_compression",
189176
mode);
@@ -207,6 +194,7 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
207194
py_dest = PyByteArray_FromStringAndSize (NULL, total_size);
208195
if (py_dest == NULL)
209196
{
197+
PyBuffer_Release(&source);
210198
return PyErr_NoMemory();
211199
}
212200
dest = PyByteArray_AS_STRING (py_dest);
@@ -216,6 +204,7 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
216204
py_dest = PyBytes_FromStringAndSize (NULL, total_size);
217205
if (py_dest == NULL)
218206
{
207+
PyBuffer_Release(&source);
219208
return PyErr_NoMemory();
220209
}
221210
dest = PyBytes_AS_STRING (py_dest);
@@ -224,6 +213,7 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
224213
py_dest = PyBytes_FromStringAndSize (NULL, total_size);
225214
if (py_dest == NULL)
226215
{
216+
PyBuffer_Release(&source);
227217
return PyErr_NoMemory();
228218
}
229219
dest = PyBytes_AS_STRING (py_dest);
@@ -244,23 +234,26 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
244234
switch (comp)
245235
{
246236
case DEFAULT:
247-
output_size = LZ4_compress_default (source, dest_start, source_size,
237+
output_size = LZ4_compress_default (source.buf, dest_start, source_size,
248238
dest_size);
249239
break;
250240
case FAST:
251-
output_size = LZ4_compress_fast (source, dest_start, source_size,
241+
output_size = LZ4_compress_fast (source.buf, dest_start, source_size,
252242
dest_size, acceleration);
253243
break;
254244
case HIGH_COMPRESSION:
255-
output_size = LZ4_compress_HC (source, dest_start, source_size,
245+
output_size = LZ4_compress_HC (source.buf, dest_start, source_size,
256246
dest_size, compression);
257247
break;
258248
}
259249

250+
Py_END_ALLOW_THREADS
251+
252+
PyBuffer_Release(&source);
253+
260254
if (output_size <= 0)
261255
{
262-
Py_BLOCK_THREADS
263-
PyErr_SetString (PyExc_ValueError, "Compression failed");
256+
PyErr_SetString (PyExc_ValueError, "Compression failed");
264257
Py_CLEAR (py_dest);
265258
return NULL;
266259
}
@@ -270,8 +263,6 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
270263
output_size += hdr_size;
271264
}
272265

273-
Py_END_ALLOW_THREADS
274-
275266
/* Resizes are expensive; tolerate some slop to avoid. */
276267
if (output_size < (dest_size / 4) * 3)
277268
{
@@ -299,83 +290,70 @@ compress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
299290
static PyObject *
300291
decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
301292
{
293+
Py_buffer source;
302294
const char * source_start;
295+
int source_size;
303296
PyObject *py_dest;
304297
char *dest;
305298
int output_size;
306299
size_t dest_size;
307300
int uncompressed_size = -1;
308301

309302
#if IS_PY3
310-
int return_bytearray = 0;
311303
static char *argnames[] = {
312304
"source",
313305
"uncompressed_size",
314306
"return_bytearray",
315307
NULL
316308
};
317-
PyObject * py_source;
318-
Py_ssize_t source_size;
319-
char * source;
320-
if (!PyArg_ParseTupleAndKeywords (args, kwargs, "O|ip", argnames,
321-
&py_source, &uncompressed_size,
309+
int return_bytearray = 0;
310+
if (!PyArg_ParseTupleAndKeywords (args, kwargs, "y*|ip", argnames,
311+
&source, &uncompressed_size,
322312
&return_bytearray))
323313
{
324314
return NULL;
325315
}
326-
if (PyByteArray_Check(py_source))
327-
{
328-
source = PyByteArray_AsString(py_source);
329-
if (source == NULL)
330-
{
331-
PyErr_SetString (PyExc_ValueError, "Failed to access source bytearray object");
332-
return NULL;
333-
}
334-
source_size = PyByteArray_Size(py_source);
335-
}
336-
else
337-
{
338-
source = PyBytes_AsString(py_source);
339-
if (source == NULL)
340-
{
341-
PyErr_SetString (PyExc_ValueError, "Failed to access source object");
342-
return NULL;
343-
}
344-
source_size = PyBytes_Size(py_source);
345-
}
346316
#else
347317
static char *argnames[] = {
348318
"source",
349319
"uncompressed_size",
350320
NULL
351321
};
352-
const char *source;
353-
int source_size = 0;
354-
if (!PyArg_ParseTupleAndKeywords (args, kwargs, "s#|i", argnames,
355-
&source, &source_size, &uncompressed_size))
322+
if (!PyArg_ParseTupleAndKeywords (args, kwargs, "s*|i", argnames,
323+
&source, &uncompressed_size))
356324
{
357325
return NULL;
358326
}
359327
#endif
328+
source_start = (const char *) source.buf;
329+
source_size = (int) source.len;
330+
if (source.len != (Py_ssize_t) source_size)
331+
{
332+
PyBuffer_Release(&source);
333+
PyErr_Format(PyExc_OverflowError, "Input too large for C 'int'");
334+
return NULL;
335+
}
336+
360337
if (uncompressed_size > 0)
361338
{
362339
dest_size = uncompressed_size;
363-
source_start = source;
364340
}
365341
else
366342
{
367343
if (source_size < hdr_size)
368344
{
345+
PyBuffer_Release(&source);
369346
PyErr_SetString (PyExc_ValueError, "Input source data size too small");
370347
return NULL;
371348
}
372-
dest_size = load_le32 (source);
373-
source_start = source + hdr_size;
349+
dest_size = load_le32 (source_start);
350+
source_start += hdr_size;
374351
source_size -= hdr_size;
375352
}
376353

377354
if (dest_size < 0 || dest_size > PY_SSIZE_T_MAX)
378355
{
356+
PyBuffer_Release(&source);
379357
PyErr_Format (PyExc_ValueError, "Invalid size in header: 0x%zu",
380358
dest_size);
381359
return NULL;
@@ -387,6 +365,7 @@ decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
387365
py_dest = PyByteArray_FromStringAndSize (NULL, dest_size);
388366
if (py_dest == NULL)
389367
{
368+
PyBuffer_Release(&source);
390369
return PyErr_NoMemory();
391370
}
392371
dest = PyByteArray_AS_STRING (py_dest);
@@ -396,6 +375,7 @@ decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
396375
py_dest = PyBytes_FromStringAndSize (NULL, dest_size);
397376
if (py_dest == NULL)
398377
{
378+
PyBuffer_Release(&source);
399379
return PyErr_NoMemory();
400380
}
401381
dest = PyBytes_AS_STRING (py_dest);
@@ -404,6 +384,7 @@ decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
404384
py_dest = PyBytes_FromStringAndSize (NULL, dest_size);
405385
if (py_dest == NULL)
406386
{
387+
PyBuffer_Release(&source);
407388
return PyErr_NoMemory();
408389
}
409390
dest = PyBytes_AS_STRING (py_dest);
@@ -416,6 +397,8 @@ decompress (PyObject * Py_UNUSED (self), PyObject * args, PyObject * kwargs)
416397

417398
Py_END_ALLOW_THREADS
418399

400+
PyBuffer_Release(&source);
401+
419402
if (output_size < 0)
420403
{
421404
PyErr_Format (PyExc_ValueError, "Corrupt input at byte %d", -output_size);
@@ -451,7 +434,7 @@ PyDoc_STRVAR(compress__doc,
451434
"Compress source, returning the compressed data as a string.\n" \
452435
"Raises an exception if any error occurs.\n\n" \
453436
"Args:\n" \
454-
" source (str, bytes or bytearray): Data to compress\n" \
437+
" source (str, bytes or buffer-compatible object): Data to compress\n" \
455438
" mode (str): If 'default' or unspecified use the default LZ4\n" \
456439
" compression mode. Set to 'fast' to use the fast compression\n" \
457440
" LZ4 mode at the expense of compression. Set to\n" \
@@ -476,7 +459,7 @@ PyDoc_STRVAR(decompress__doc,
476459
"Decompress source, returning the uncompressed data as a string.\n" \
477460
"Raises an exception if any error occurs.\n\n" \
478461
"Args:\n" \
479-
" source (str, bytes or bytearray): Data to decompress\n\n" \
462+
" source (str, bytes or buffer-compatible object): Data to decompress\n\n" \
480463
" uncompressed_size (int): If not specified or < 0, the uncompressed data\n" \
481464
" size is read from the start of the source block. If specified,\n" \
482465
" it is assumed that the full source data is compressed data.\n" \

tests/test_block.py

+37-20
Original file line numberDiff line numberDiff line change
@@ -214,26 +214,43 @@ def test_decompress_with_trailer(self):
214214
# Poor-man unittest.TestCase.skip for Python 2.6
215215
del TestLZ4BlockModern
216216

217-
class TestLZ4BlockPy3ByteArray(unittest.TestCase):
218-
def test_random_bytearray(self):
219-
DATA = bytearray(os.urandom(128 * 1024)) # Read 128kb
220-
compressed = lz4.block.compress(DATA, return_bytearray=True)
221-
self.assertEqual(type(compressed), bytearray)
222-
decompressed = lz4.block.decompress(compressed, return_bytearray=True)
223-
self.assertEqual(type(decompressed), bytearray)
224-
self.assertEqual(decompressed, DATA)
225-
def test_random_bytes(self):
226-
DATA = bytearray(os.urandom(128 * 1024)) # Read 128kb
227-
compressed = lz4.block.compress(DATA)
228-
self.assertEqual(type(compressed), bytes)
229-
decompressed = lz4.block.decompress(compressed)
230-
self.assertEqual(type(decompressed), bytes)
231-
self.assertEqual(decompressed, DATA)
232-
233-
if sys.version_info < (3, 3):
234-
# Poor-man unittest.TestCase.skip for Python < 3.3
235-
del TestLZ4BlockPy3ByteArray
217+
218+
class TestLZ4BlockBufferObjects(unittest.TestCase):
219+
220+
def test_bytearray(self):
221+
DATA = os.urandom(128 * 1024) # Read 128kb
222+
compressed = lz4.block.compress(DATA)
223+
self.assertEqual(lz4.block.compress(bytearray(DATA)), compressed)
224+
self.assertEqual(lz4.block.decompress(bytearray(compressed)), DATA)
225+
226+
def test_return_bytearray(self):
227+
if sys.version_info < (3,):
228+
return # skip
229+
DATA = os.urandom(128 * 1024) # Read 128kb
230+
compressed = lz4.block.compress(DATA)
231+
b = lz4.block.compress(DATA, return_bytearray=True)
232+
self.assertEqual(type(b), bytearray)
233+
self.assertEqual(bytes(b), compressed)
234+
b = lz4.block.decompress(compressed, return_bytearray=True)
235+
self.assertEqual(type(b), bytearray)
236+
self.assertEqual(bytes(b), DATA)
237+
238+
def test_memoryview(self):
239+
if sys.version_info < (2, 7):
240+
return # skip
241+
DATA = os.urandom(128 * 1024) # Read 128kb
242+
compressed = lz4.block.compress(DATA)
243+
self.assertEqual(lz4.block.compress(memoryview(DATA)), compressed)
244+
self.assertEqual(lz4.block.decompress(memoryview(compressed)), DATA)
245+
246+
def test_unicode(self):
247+
if sys.version_info < (3,):
248+
return # skip
249+
DATA = b'x'
250+
self.assertRaises(TypeError, lz4.block.compress, DATA.decode('latin1'))
251+
self.assertRaises(TypeError, lz4.block.decompress,
252+
lz4.block.compress(DATA).decode('latin1'))
253+
236254

237255
if __name__ == '__main__':
238256
unittest.main()
239-

0 commit comments

Comments
 (0)