diff --git a/tests/test_message.py b/tests/test_message.py index 7ba87fa5..220b0cc1 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,3 +1,5 @@ +from uamqp.message import MessageProperties, MessageHeader, Message, constants, errors, c_uamqp +import pickle import pytest from uamqp.message import MessageProperties, Message, SequenceBody, DataBody, ValueBody @@ -33,6 +35,117 @@ def test_message_properties(): properties.user_id = 'werid/0\0\1\t\n' assert properties.user_id == b'werid/0\0\1\t\n' +def send_complete_callback(result, error): + # helper for test below not in test, b/c results in: + # AttributeError: Can't pickle local object + print(result) + print(error) + + +def test_message_pickle(): + properties = MessageProperties() + properties.message_id = '2' + properties.user_id = '1' + properties.to = 'dkfj' + properties.subject = 'dsljv' + properties.reply_to = "kdjfk" + properties.correlation_id = 'ienag' + properties.content_type = 'b' + properties.content_encoding = '39ru' + properties.absolute_expiry_time = 24 + properties.creation_time = 10 + properties.group_id = '3irow' + properties.group_sequence = 39 + properties.reply_to_group_id = '39rud' + + header = MessageHeader() + header.delivery_count = 3 + header.time_to_live = 5 + header.first_acquirer = 'dkfj' + header.durable = True + header.priority = 4 + + data_message = Message(body=[b'testmessage1', b'testmessage2']) + pickled = pickle.loads(pickle.dumps(data_message)) + body = list(pickled.get_data()) + assert len(body) == 2 + assert body == [b'testmessage1', b'testmessage2'] + + sequence_message = Message( + body=[[1234.56, b'testmessage2', True], [-1234.56, {b'key': b'value'}, False]], + body_type=MessageBodyType.Sequence + ) + pickled = pickle.loads(pickle.dumps(sequence_message)) + body = list(pickled.get_data()) + assert len(body) == 2 + assert body == [[1234.56, b'testmessage2', True], [-1234.56, {b'key': b'value'}, False]] + + value_message = Message( + body={b'key': [1, b'str', False]}, + body_type=MessageBodyType.Value + ) + pickled = pickle.loads(pickle.dumps(value_message)) + body = pickled.get_data() + assert body == {b'key': [1, b'str', False]} + + error = errors.MessageModified(False, False, {b'key': b'value'}) + pickled_error = pickle.loads(pickle.dumps(error)) + assert pickled_error._annotations == {b'key': b'value'} # pylint: disable=protected-access + + message = Message(body="test", properties=properties, header=header) + message.on_send_complete = send_complete_callback + message.footer = {'a':2} + message.state = constants.MessageState.ReceivedSettled + + pickled = pickle.loads(pickle.dumps(message)) + assert list(message.get_data()) == [b"test"] + assert message.footer == pickled.footer + assert message.state == pickled.state + assert message.application_properties == pickled.application_properties + assert message.annotations == pickled.annotations + assert message.delivery_annotations == pickled.delivery_annotations + assert message.settled == pickled.settled + assert message.properties.message_id == pickled.properties.message_id + assert message.properties.user_id == pickled.properties.user_id + assert message.properties.to == pickled.properties.to + assert message.properties.subject == pickled.properties.subject + assert message.properties.reply_to == pickled.properties.reply_to + assert message.properties.correlation_id == pickled.properties.correlation_id + assert message.properties.content_type == pickled.properties.content_type + assert message.properties.content_encoding == pickled.properties.content_encoding + assert message.properties.absolute_expiry_time == pickled.properties.absolute_expiry_time + assert message.properties.creation_time == pickled.properties.creation_time + assert message.properties.group_id == pickled.properties.group_id + assert message.properties.group_sequence == pickled.properties.group_sequence + assert message.properties.reply_to_group_id == pickled.properties.reply_to_group_id + assert message.header.delivery_count == pickled.header.delivery_count + assert message.header.time_to_live == pickled.header.time_to_live + assert message.header.first_acquirer == pickled.header.first_acquirer + assert message.header.durable == pickled.header.durable + assert message.header.priority == pickled.header.priority + + # send with message param + settler = errors.MessageAlreadySettled + internal_message = c_uamqp.create_message() + internal_message.add_body_data(b"hi") + message_w_message_param = Message( + message=internal_message, + settler=settler, + delivery_no=1 + ) + pickled = pickle.loads(pickle.dumps(message_w_message_param)) + message_data = str(message_w_message_param.get_data()) + pickled_data = str(pickled.get_data()) + + assert message_data == pickled_data + assert message_w_message_param.footer == pickled.footer + assert message_w_message_param.state == pickled.state + assert message_w_message_param.application_properties == pickled.application_properties + assert message_w_message_param.annotations == pickled.annotations + assert message_w_message_param.delivery_annotations == pickled.delivery_annotations + assert message_w_message_param.settled == pickled.settled + assert pickled.delivery_no == 1 + assert type(pickled._settler()) == type(settler()) # pylint: disable=protected-access def test_message_auto_body_type(): single_data = b'!@#$%^&*()_+1234567890' diff --git a/uamqp/constants.py b/uamqp/constants.py index ce7b6af8..3c7b2417 100644 --- a/uamqp/constants.py +++ b/uamqp/constants.py @@ -188,3 +188,10 @@ class MessageBodyType(Enum): Data = c_uamqp.MessageBodyType.DataType Value = c_uamqp.MessageBodyType.ValueType Sequence = c_uamqp.MessageBodyType.SequenceType + + +BODY_TYPE_C_PYTHON_MAP = { + c_uamqp.MessageBodyType.DataType.value: MessageBodyType.Data, + c_uamqp.MessageBodyType.SequenceType.value: MessageBodyType.Sequence, + c_uamqp.MessageBodyType.ValueType.value: MessageBodyType.Value +} diff --git a/uamqp/errors.py b/uamqp/errors.py index daf87092..0415236f 100644 --- a/uamqp/errors.py +++ b/uamqp/errors.py @@ -259,6 +259,9 @@ def __init__(self): response = "Invalid operation: this message is already settled." super(MessageAlreadySettled, self).__init__(response) + def __reduce__(self): + return (self.__class__, ()) + class MessageAccepted(MessageResponse): pass @@ -267,6 +270,8 @@ class MessageAccepted(MessageResponse): class MessageRejected(MessageResponse): def __init__(self, condition=None, description=None, encoding='UTF-8', info=None): + self._encoding = encoding + self._info = info if condition: self.error_condition = condition.encode(encoding) if isinstance(condition, six.text_type) else condition else: @@ -282,6 +287,9 @@ def __init__(self, condition=None, description=None, encoding='UTF-8', info=None self.error_info = utils.data_factory(info, encoding=encoding) if info else None super(MessageRejected, self).__init__() + def __reduce__(self): + return (self.__class__, (self.error_condition, self.error_description, self._encoding, self._info)) + class MessageReleased(MessageResponse): pass @@ -292,11 +300,16 @@ class MessageModified(MessageResponse): def __init__(self, failed, undeliverable, annotations=None, encoding='UTF-8'): self.failed = failed self.undeliverable = undeliverable + self._encoding = encoding + self._annotations = annotations if annotations and not isinstance(annotations, dict): raise TypeError("Disposition annotations must be a dictionary.") self.annotations = utils.data_factory(annotations, encoding=encoding) if annotations else None super(MessageModified, self).__init__() + def __reduce__(self): + return (self.__class__, (self.failed, self.undeliverable, self._annotations, self._encoding)) + class ErrorResponse(object): diff --git a/uamqp/message.py b/uamqp/message.py index 9aee5b3d..1906ff93 100644 --- a/uamqp/message.py +++ b/uamqp/message.py @@ -1,8 +1,8 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # pylint: disable=too-many-lines @@ -76,20 +76,23 @@ class Message(object): :type delivery_annotations: dict """ - def __init__(self, - body=None, - properties=None, - application_properties=None, - annotations=None, - header=None, - msg_format=None, - message=None, - settler=None, - delivery_no=None, - encoding='UTF-8', - body_type=None, - footer=None, - delivery_annotations=None): + + def __init__( + self, + body=None, + properties=None, + application_properties=None, + annotations=None, + header=None, + msg_format=None, + message=None, + settler=None, + delivery_no=None, + encoding='UTF-8', + body_type=None, + footer=None, + delivery_annotations=None + ): self.state = constants.MessageState.WaitingToBeSent self.idle_time = 0 self.retries = 0 @@ -133,6 +136,31 @@ def __init__(self, self._header = header self._footer = footer + def __getstate__(self): + state = self.__dict__.copy() + state["state"] = self.state.value + state["_message"] = None + state["_body_type"] = self._body.type.value if self._body else None + if isinstance(self._body, (DataBody, SequenceBody)): + state["_body"] = list(self._body.data) + elif isinstance(self._body, ValueBody): + state["_body"] = self._body.data + + return state + + def __setstate__(self, state): + state["state"] = constants.MessageState(state.get("state")) + self.__dict__.update(state) + + body = state.get("_body") + body_type = constants.BODY_TYPE_C_PYTHON_MAP.get(state.get("_body_type")) + self._message = c_uamqp.create_message() + if body: + if not body_type: + self._auto_set_body(body) + else: + self._set_body_by_body_type(body, body_type) + @property def properties(self): if self._need_further_parse: @@ -168,7 +196,8 @@ def footer(self, value): if value and not isinstance(value, dict): raise TypeError("Footer must be a dictionary") footer_props = c_uamqp.create_footer( - utils.data_factory(value, encoding=self._encoding)) + utils.data_factory(value, encoding=self._encoding) + ) self._message.footer = footer_props self._footer = value @@ -227,8 +256,12 @@ def _parse_message_properties(self): if self._need_further_parse: _props = self._message.properties if _props: - _logger.debug("Parsing received message properties %r.", self.delivery_no) - self._properties = MessageProperties(properties=_props, encoding=self._encoding) + _logger.debug( + "Parsing received message properties %r.", self.delivery_no + ) + self._properties = MessageProperties( + properties=_props, encoding=self._encoding + ) _header = self._message.header if _header: _logger.debug("Parsing received message header %r.", self.delivery_no) @@ -239,15 +272,23 @@ def _parse_message_properties(self): self._footer = _footer.map _app_props = self._message.application_properties if _app_props: - _logger.debug("Parsing received message application properties %r.", self.delivery_no) + _logger.debug( + "Parsing received message application properties %r.", + self.delivery_no, + ) self._application_properties = _app_props.map _ann = self._message.message_annotations if _ann: - _logger.debug("Parsing received message annotations %r.", self.delivery_no) + _logger.debug( + "Parsing received message annotations %r.", self.delivery_no + ) self._annotations = _ann.map _delivery_ann = self._message.delivery_annotations if _delivery_ann: - _logger.debug("Parsing received message delivery annotations %r.", self.delivery_no) + _logger.debug( + "Parsing received message delivery annotations %r.", + self.delivery_no, + ) self._delivery_annotations = _delivery_ann.map self._need_further_parse = False @@ -333,19 +374,23 @@ def _populate_message_attributes(self, c_message): if self.application_properties: if not isinstance(self.application_properties, dict): raise TypeError("Application properties must be a dictionary.") - amqp_props = utils.data_factory(self.application_properties, encoding=self._encoding) + amqp_props = utils.data_factory( + self.application_properties, encoding=self._encoding + ) c_message.application_properties = amqp_props if self.annotations: if not isinstance(self.annotations, dict): raise TypeError("Message annotations must be a dictionary.") ann_props = c_uamqp.create_message_annotations( - utils.data_factory(self.annotations, encoding=self._encoding)) + utils.data_factory(self.annotations, encoding=self._encoding) + ) c_message.message_annotations = ann_props if self.delivery_annotations: if not isinstance(self.delivery_annotations, dict): raise TypeError("Delivery annotations must be a dictionary.") delivery_ann_props = c_uamqp.create_delivery_annotations( - utils.data_factory(self.delivery_annotations, encoding=self._encoding)) + utils.data_factory(self.delivery_annotations, encoding=self._encoding) + ) c_message.delivery_annotations = delivery_ann_props if self.header: c_message.header = self.header.get_header_obj() @@ -353,7 +398,8 @@ def _populate_message_attributes(self, c_message): if not isinstance(self.footer, dict): raise TypeError("Footer must be a dictionary.") footer = c_uamqp.create_footer( - utils.data_factory(self.footer, encoding=self._encoding)) + utils.data_factory(self.footer, encoding=self._encoding) + ) c_message.footer = footer @property @@ -474,7 +520,8 @@ def reject(self, condition=None, description=None, info=None): condition=condition, description=description, info=info, - encoding=self._encoding) + encoding=self._encoding, + ) self._settler(self._response) self.state = constants.MessageState.ReceivedSettled return True @@ -518,10 +565,8 @@ def modify(self, failed, deliverable, annotations=None): """ if self._can_settle_message(): self._response = errors.MessageModified( - failed, - deliverable, - annotations=annotations, - encoding=self._encoding) + failed, deliverable, annotations=annotations, encoding=self._encoding + ) self._settler(self._response) self.state = constants.MessageState.ReceivedSettled return True @@ -580,14 +625,16 @@ class BatchMessage(Message): max_message_length = constants.MAX_MESSAGE_LENGTH_BYTES size_offset = 0 - def __init__(self, - data=None, - properties=None, - application_properties=None, - annotations=None, - header=None, - multi_messages=False, - encoding='UTF-8'): + def __init__( + self, + data=None, + properties=None, + application_properties=None, + annotations=None, + header=None, + multi_messages=False, + encoding="UTF-8", + ): # pylint: disable=super-init-not-called self._multi_messages = multi_messages self._body_gen = data @@ -605,12 +652,14 @@ def _create_batch_message(self): :rtype: ~uamqp.message.Message """ - return Message(body=[], - properties=self.properties, - annotations=self.annotations, - msg_format=self.batch_format, - header=self.header, - encoding=self._encoding) + return Message( + body=[], + properties=self.properties, + annotations=self.annotations, + msg_format=self.batch_format, + header=self.header, + encoding=self._encoding, + ) def _multi_message_generator(self): """Generate multiple ~uamqp.message.Message objects from a single data @@ -627,7 +676,9 @@ def _multi_message_generator(self): message_size = new_message.get_message_encoded_size() + self.size_offset body_size = 0 if unappended_message_bytes: - new_message._body.append(unappended_message_bytes) # pylint: disable=protected-access + new_message._body.append( # pylint: disable=protected-access + unappended_message_bytes + ) body_size += len(unappended_message_bytes) try: for data in self._body_gen: @@ -640,13 +691,18 @@ def _multi_message_generator(self): internal_uamqp_message = data try: # uamqp Message - if not internal_uamqp_message.application_properties and self.application_properties: - internal_uamqp_message.application_properties = self.application_properties + if ( + not internal_uamqp_message.application_properties + and self.application_properties + ): + internal_uamqp_message.application_properties = ( + self.application_properties + ) message_bytes = internal_uamqp_message.encode_message() except AttributeError: # raw data wrap_message = Message( body=internal_uamqp_message, - application_properties=self.application_properties + application_properties=self.application_properties, ) message_bytes = wrap_message.encode_message() body_size += len(message_bytes) @@ -655,7 +711,9 @@ def _multi_message_generator(self): unappended_message_bytes = message_bytes yield new_message raise StopIteration() - new_message._body.append(message_bytes) # pylint: disable=protected-access + new_message._body.append( # pylint: disable=protected-access + message_bytes + ) except StopIteration: _logger.debug("Sent partial message.") continue @@ -689,11 +747,19 @@ def gather(self): internal_uamqp_message = data try: # uamqp Message - if not internal_uamqp_message.application_properties and self.application_properties: - internal_uamqp_message.application_properties = self.application_properties + if ( + not internal_uamqp_message.application_properties + and self.application_properties + ): + internal_uamqp_message.application_properties = ( + self.application_properties + ) message_bytes = internal_uamqp_message.encode_message() except AttributeError: # raw data - wrap_message = Message(body=internal_uamqp_message, application_properties=self.application_properties) + wrap_message = Message( + body=internal_uamqp_message, + application_properties=self.application_properties, + ) message_bytes = wrap_message.encode_message() body_size += len(message_bytes) if (body_size + message_size) > self.max_message_length: @@ -742,22 +808,24 @@ class MessageProperties(object): :vartype reply_to_group_id: """ - def __init__(self, - message_id=None, - user_id=None, - to=None, - subject=None, - reply_to=None, - correlation_id=None, - content_type=None, - content_encoding=None, - absolute_expiry_time=None, - creation_time=None, - group_id=None, - group_sequence=None, - reply_to_group_id=None, - properties=None, - encoding='UTF-8'): + def __init__( + self, + message_id=None, + user_id=None, + to=None, + subject=None, + reply_to=None, + correlation_id=None, + content_type=None, + content_encoding=None, + absolute_expiry_time=None, + creation_time=None, + group_id=None, + group_sequence=None, + reply_to_group_id=None, + properties=None, + encoding="UTF-8", + ): self._encoding = encoding if properties: self._message_id = properties.message_id @@ -789,21 +857,33 @@ def __init__(self, self.reply_to_group_id = reply_to_group_id def __str__(self): - return str({ - 'message_id': self.message_id, - 'user_id': self.user_id, - 'to': self.to, - 'subject': self.subject, - 'reply_to': self.reply_to, - 'correlation_id': self.correlation_id, - 'content_type': self.content_type, - 'content_encoding': self.content_encoding, - 'absolute_expiry_time': self.absolute_expiry_time, - 'creation_time': self.creation_time, - 'group_id': self.group_id, - 'group_sequence': self.group_sequence, - 'reply_to_group_id': self.reply_to_group_id - }) + return str( + { + "message_id": self.message_id, + "user_id": self.user_id, + "to": self.to, + "subject": self.subject, + "reply_to": self.reply_to, + "correlation_id": self.correlation_id, + "content_type": self.content_type, + "content_encoding": self.content_encoding, + "absolute_expiry_time": self.absolute_expiry_time, + "creation_time": self.creation_time, + "group_id": self.group_id, + "group_sequence": self.group_sequence, + "reply_to_group_id": self.reply_to_group_id, + } + ) + + def __getstate__(self): + state = self._get_properties_dict() + state["_encoding"] = self._encoding + return state + + def __setstate__(self, state): + self._encoding = state.pop("_encoding") + for key, val in state.items(): + self.__setattr__(key, val) @property def message_id(self): @@ -836,7 +916,9 @@ def user_id(self, value): # user_id is type of binary according to the spec. # convert byte string into bytearray then wrap the data into c_uamqp.BinaryValue. if value is not None: - self._user_id = utils.data_factory(bytearray(value), encoding=self._encoding) + self._user_id = utils.data_factory( + bytearray(value), encoding=self._encoding + ) else: self._user_id = None @@ -974,25 +1056,42 @@ def _set_attr(self, attr, properties): if attr_value is not None: setattr(properties, attr, attr_value) + def _get_properties_dict(self): + return { + "message_id": self.message_id, + "user_id": self.user_id, + "to": self.to, + "subject": self.subject, + "reply_to": self.reply_to, + "correlation_id": self.correlation_id, + "content_type": self.content_type, + "content_encoding": self.content_encoding, + "absolute_expiry_time": self.absolute_expiry_time, + "creation_time": self.creation_time, + "group_id": self.group_id, + "group_sequence": self.group_sequence, + "reply_to_group_id": self.reply_to_group_id, + } + def get_properties_obj(self): """Get the underlying C reference from this object. :rtype: uamqp.c_uamqp.cProperties """ properties = c_uamqp.cProperties() - self._set_attr('message_id', properties) - self._set_attr('user_id', properties) - self._set_attr('to', properties) - self._set_attr('subject', properties) - self._set_attr('reply_to', properties) - self._set_attr('correlation_id', properties) - self._set_attr('content_type', properties) - self._set_attr('content_encoding', properties) - self._set_attr('absolute_expiry_time', properties) - self._set_attr('creation_time', properties) - self._set_attr('group_id', properties) - self._set_attr('group_sequence', properties) - self._set_attr('reply_to_group_id', properties) + self._set_attr("message_id", properties) + self._set_attr("user_id", properties) + self._set_attr("to", properties) + self._set_attr("subject", properties) + self._set_attr("reply_to", properties) + self._set_attr("correlation_id", properties) + self._set_attr("content_type", properties) + self._set_attr("content_encoding", properties) + self._set_attr("absolute_expiry_time", properties) + self._set_attr("creation_time", properties) + self._set_attr("group_id", properties) + self._set_attr("group_sequence", properties) + self._set_attr("reply_to_group_id", properties) return properties @@ -1001,7 +1100,7 @@ class MessageBody(object): not be used directly. """ - def __init__(self, c_message, encoding='UTF-8'): + def __init__(self, c_message, encoding="UTF-8"): self._message = c_message self._encoding = encoding @@ -1235,13 +1334,15 @@ def __init__(self, header=None): self.priority = header.priority def __str__(self): - return str({ - 'delivery_count': self.delivery_count, - 'time_to_live': self.time_to_live, - 'first_acquirer': self.first_acquirer, - 'durable': self.durable, - 'priority': self.priority - }) + return str( + { + "delivery_count": self.delivery_count, + "time_to_live": self.time_to_live, + "first_acquirer": self.first_acquirer, + "durable": self.durable, + "priority": self.priority, + } + ) def get_header_obj(self): """Get the underlying C reference from this object.