diff --git a/msgpack/_unpacker.pyx b/msgpack/_unpacker.pyx index 23f6478f..a689102f 100644 --- a/msgpack/_unpacker.pyx +++ b/msgpack/_unpacker.pyx @@ -8,12 +8,14 @@ from cpython.bytes cimport ( ) from cpython.buffer cimport ( Py_buffer, + PyObject_CheckBuffer, PyBuffer_Release, PyObject_GetBuffer, PyBUF_SIMPLE, ) from cpython.mem cimport PyMem_Malloc, PyMem_Free from cpython.object cimport PyCallable_Check +from cpython.exc cimport PyErr_WarnEx cdef extern from "Python.h": ctypedef struct PyObject @@ -129,12 +131,27 @@ def unpackb(object packed, object object_hook=None, object list_hook=None, cdef Py_ssize_t off = 0 cdef int ret - cdef char* buf + cdef Py_buffer view + cdef char* buf = NULL cdef Py_ssize_t buf_len cdef char* cenc = NULL cdef char* cerr = NULL - - PyObject_AsReadBuffer(packed, &buf, &buf_len) + cdef char buffer_supported = 0 + + if PyObject_CheckBuffer(packed): + buffer_supported = 1 + if PyObject_GetBuffer(packed, &view, PyBUF_SIMPLE) == 0: + if view.itemsize != 1: + PyBuffer_Release(&view) + raise ValueError("cannot unpack from multi-byte object") + buf_len = view.len + buf = view.buf + else: + PyObject_AsReadBuffer(packed, &buf, &buf_len) + PyErr_WarnEx(DeprecationWarning, + "Unpacking %s requires old buffer protocol, " + "which will be removed in msgpack 1.0." % type(packed), + 1) if encoding is not None: if isinstance(encoding, unicode): @@ -150,6 +167,8 @@ def unpackb(object packed, object object_hook=None, object list_hook=None, use_list, cenc, cerr, max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len) ret = unpack_construct(&ctx, buf, buf_len, &off) + if buffer_supported: + PyBuffer_Release(&view); if ret == 1: obj = unpack_data(&ctx) if off < buf_len: diff --git a/msgpack/fallback.py b/msgpack/fallback.py index abed3d9e..88bc84ab 100644 --- a/msgpack/fallback.py +++ b/msgpack/fallback.py @@ -3,6 +3,7 @@ import sys import array import struct +import warnings if sys.version_info[0] == 3: PY3 = True @@ -246,16 +247,25 @@ def __init__(self, file_like=None, read_size=0, use_list=True, raise TypeError("`ext_hook` is not callable") def feed(self, next_bytes): - if isinstance(next_bytes, array.array): - next_bytes = next_bytes.tostring() - elif isinstance(next_bytes, bytearray): - next_bytes = bytes(next_bytes) assert self._fb_feeding - if (self._fb_buf_n + len(next_bytes) - self._fb_sloppiness - > self._max_buffer_size): + try: + view = memoryview(next_bytes) + except TypeError: + # try to use legacy buffer protocol if 2.7, otherwise re-raise + if not PY3: + view = memoryview(buffer(next_bytes)) + warnings.warn("Unpacking %s requires old buffer protocol, " + "which will be removed in msgpack 1.0." % type(next_bytes), + DeprecationWarning) + else: + raise + if view.itemsize != 1: + raise ValueError("cannot unpack from multi-byte object") + L = len(view) + if self._fb_buf_n + L - self._fb_sloppiness > self._max_buffer_size: raise BufferFull - self._fb_buf_n += len(next_bytes) - self._fb_buffers.append(next_bytes) + self._fb_buf_n += L + self._fb_buffers.append(view) def _fb_sloppy_consume(self): """ Gets rid of some of the used parts of the buffer. """ @@ -322,9 +332,10 @@ def _fb_read(self, n, write_bytes=None): return buffs[self._fb_buf_i][self._fb_buf_o - n:self._fb_buf_o] # The remaining cases. - ret = b'' - while len(ret) != n: - sliced = n - len(ret) + ret = [] + n_read = 0 + while n_read != n: + sliced = n - n_read if self._fb_buf_i == len(buffs): if self._fb_feeding: break @@ -334,34 +345,39 @@ def _fb_read(self, n, write_bytes=None): tmp = self.file_like.read(to_read) if not tmp: break - buffs.append(tmp) - self._fb_buf_n += len(tmp) + tmpview = memoryview(tmp) + assert tmpview.itemsize == 1 + buffs.append(tmpview) + self._fb_buf_n += len(tmpview) continue - ret += buffs[self._fb_buf_i][self._fb_buf_o:self._fb_buf_o + sliced] + to_append = buffs[self._fb_buf_i][self._fb_buf_o:self._fb_buf_o + sliced] + n_read += len(to_append) + ret.append(to_append) self._fb_buf_o += sliced if self._fb_buf_o >= len(buffs[self._fb_buf_i]): self._fb_buf_o = 0 self._fb_buf_i += 1 + ret = b''.join([view.tobytes() for view in ret]) if len(ret) != n: self._fb_rollback() raise OutOfData if write_bytes is not None: write_bytes(ret) - return ret + return memoryview(ret) def _read_header(self, execute=EX_CONSTRUCT, write_bytes=None): typ = TYPE_IMMEDIATE n = 0 obj = None - c = self._fb_read(1, write_bytes) + c = self._fb_read(1, write_bytes).tobytes() b = ord(c) - if b & 0b10000000 == 0: + if b & 0b10000000 == 0: obj = b elif b & 0b11100000 == 0b11100000: obj = struct.unpack("b", c)[0] elif b & 0b11100000 == 0b10100000: n = b & 0b00011111 - obj = self._fb_read(n, write_bytes) + obj = self._fb_read(n, write_bytes).tobytes() typ = TYPE_RAW if n > self._max_str_len: raise UnpackValueError("%s exceeds max_str_len(%s)", n, self._max_str_len) @@ -386,37 +402,37 @@ def _read_header(self, execute=EX_CONSTRUCT, write_bytes=None): n = struct.unpack("B", self._fb_read(1, write_bytes))[0] if n > self._max_bin_len: raise UnpackValueError("%s exceeds max_bin_len(%s)" % (n, self._max_bin_len)) - obj = self._fb_read(n, write_bytes) + obj = self._fb_read(n, write_bytes).tobytes() elif b == 0xc5: typ = TYPE_BIN n = struct.unpack(">H", self._fb_read(2, write_bytes))[0] if n > self._max_bin_len: raise UnpackValueError("%s exceeds max_bin_len(%s)" % (n, self._max_bin_len)) - obj = self._fb_read(n, write_bytes) + obj = self._fb_read(n, write_bytes).tobytes() elif b == 0xc6: typ = TYPE_BIN n = struct.unpack(">I", self._fb_read(4, write_bytes))[0] if n > self._max_bin_len: raise UnpackValueError("%s exceeds max_bin_len(%s)" % (n, self._max_bin_len)) - obj = self._fb_read(n, write_bytes) + obj = self._fb_read(n, write_bytes).tobytes() elif b == 0xc7: # ext 8 typ = TYPE_EXT L, n = struct.unpack('Bb', self._fb_read(2, write_bytes)) if L > self._max_ext_len: raise UnpackValueError("%s exceeds max_ext_len(%s)" % (L, self._max_ext_len)) - obj = self._fb_read(L, write_bytes) + obj = self._fb_read(L, write_bytes).tobytes() elif b == 0xc8: # ext 16 typ = TYPE_EXT L, n = struct.unpack('>Hb', self._fb_read(3, write_bytes)) if L > self._max_ext_len: raise UnpackValueError("%s exceeds max_ext_len(%s)" % (L, self._max_ext_len)) - obj = self._fb_read(L, write_bytes) + obj = self._fb_read(L, write_bytes).tobytes() elif b == 0xc9: # ext 32 typ = TYPE_EXT L, n = struct.unpack('>Ib', self._fb_read(5, write_bytes)) if L > self._max_ext_len: raise UnpackValueError("%s exceeds max_ext_len(%s)" % (L, self._max_ext_len)) - obj = self._fb_read(L, write_bytes) + obj = self._fb_read(L, write_bytes).tobytes() elif b == 0xca: obj = struct.unpack(">f", self._fb_read(4, write_bytes))[0] elif b == 0xcb: @@ -467,19 +483,19 @@ def _read_header(self, execute=EX_CONSTRUCT, write_bytes=None): n = struct.unpack("B", self._fb_read(1, write_bytes))[0] if n > self._max_str_len: raise UnpackValueError("%s exceeds max_str_len(%s)", n, self._max_str_len) - obj = self._fb_read(n, write_bytes) + obj = self._fb_read(n, write_bytes).tobytes() elif b == 0xda: typ = TYPE_RAW n = struct.unpack(">H", self._fb_read(2, write_bytes))[0] if n > self._max_str_len: raise UnpackValueError("%s exceeds max_str_len(%s)", n, self._max_str_len) - obj = self._fb_read(n, write_bytes) + obj = self._fb_read(n, write_bytes).tobytes() elif b == 0xdb: typ = TYPE_RAW n = struct.unpack(">I", self._fb_read(4, write_bytes))[0] if n > self._max_str_len: raise UnpackValueError("%s exceeds max_str_len(%s)", n, self._max_str_len) - obj = self._fb_read(n, write_bytes) + obj = self._fb_read(n, write_bytes).tobytes() elif b == 0xdc: n = struct.unpack(">H", self._fb_read(2, write_bytes))[0] if n > self._max_array_len: diff --git a/test/test_buffer.py b/test/test_buffer.py index 5a71f904..87f359f9 100644 --- a/test/test_buffer.py +++ b/test/test_buffer.py @@ -18,3 +18,12 @@ def test_unpack_bytearray(): assert [b'foo', b'bar'] == obj expected_type = bytes assert all(type(s) == expected_type for s in obj) + + +def test_unpack_memoryview(): + buf = bytearray(packb(('foo', 'bar'))) + view = memoryview(buf) + obj = unpackb(view, use_list=1) + assert [b'foo', b'bar'] == obj + expected_type = bytes + assert all(type(s) == expected_type for s in obj)