diff --git a/adafruit_minimqtt/adafruit_minimqtt.py b/adafruit_minimqtt/adafruit_minimqtt.py index 0a898998..6727e05b 100755 --- a/adafruit_minimqtt/adafruit_minimqtt.py +++ b/adafruit_minimqtt/adafruit_minimqtt.py @@ -175,6 +175,35 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): self.deinit() + def _sock_exact_recv(self, bufsize): + """Reads _exact_ number of bytes from the connected socket. Will only return + string with the exact number of bytes requested. + + The semantics of native socket receive is that it returns no more than the + specified number of bytes (i.e. max size). However, it makes no guarantees in + terms of the minimum size of the buffer, which could be 1 byte. This is a + wrapper for socket recv() to ensure that no less than the expected number of + bytes is returned or trigger a timeout exception. + + :param int bufsize: number of bytes to receive + """ + stamp = time.monotonic() + rc = self._sock.recv(bufsize) + to_read = bufsize - len(rc) + assert to_read >= 0 + read_timeout = self.keep_alive + while to_read > 0: + recv = self._sock.recv(to_read) + to_read -= len(recv) + rc += recv + if time.monotonic() - stamp > read_timeout: + raise MMQTTException( + "Unable to receive {} bytes within {} seconds.".format( + to_read, read_timeout + ) + ) + return rc + def deinit(self): """De-initializes the MQTT client and disconnects from the mqtt broker.""" self.disconnect() @@ -351,7 +380,7 @@ def connect(self, clean_session=True): while True: op = self._wait_for_msg() if op == 32: - rc = self._sock.recv(3) + rc = self._sock_exact_recv(3) assert rc[0] == 0x02 if rc[2] != 0x00: raise MMQTTException(CONNACK_ERRORS[rc[2]]) @@ -366,32 +395,38 @@ def disconnect(self): self.is_connected() if self.logger is not None: self.logger.debug("Sending DISCONNECT packet to broker") - self._sock.send(MQTT_DISCONNECT) + try: + self._sock.send(MQTT_DISCONNECT) + except RuntimeError as e: + if self.logger: + self.logger.warning("Unable to send DISCONNECT packet: {}".format(e)) if self.logger is not None: self.logger.debug("Closing socket") self._sock.close() self._is_connected = False - self._subscribed_topics = None + self._subscribed_topics = [] if self.on_disconnect is not None: self.on_disconnect(self, self.user_data, 0) def ping(self): """Pings the MQTT Broker to confirm if the broker is alive or if there is an active network connection. + Returns response codes of any messages received while waiting for PINGRESP. """ self.is_connected() - if self.logger is not None: + if self.logger: self.logger.debug("Sending PINGREQ") self._sock.send(MQTT_PINGREQ) - if self.logger is not None: - self.logger.debug("Checking PINGRESP") - while True: - op = self._wait_for_msg(0.5) - if op == 208: - ping_resp = self._sock.recv(2) - if ping_resp[0] != 0x00: - raise MMQTTException("PINGRESP not returned from broker.") - return + ping_timeout = self.keep_alive + stamp = time.monotonic() + rc, rcs = None, [] + while rc != MQTT_PINGRESP: + rc = self._wait_for_msg() + if rc: + rcs.append(rc) + if time.monotonic() - stamp > ping_timeout: + raise MMQTTException("PINGRESP not returned from broker.") + return rcs # pylint: disable=too-many-branches, too-many-statements def publish(self, topic, msg, retain=False, qos=0): @@ -486,9 +521,9 @@ def publish(self, topic, msg, retain=False, qos=0): while True: op = self._wait_for_msg() if op == 0x40: - sz = self._sock.recv(1) + sz = self._sock_exact_recv(1) assert sz == b"\x02" - rcv_pid = self._sock.recv(2) + rcv_pid = self._sock_exact_recv(2) rcv_pid = rcv_pid[0] << 0x08 | rcv_pid[1] if pid == rcv_pid: if self.on_publish is not None: @@ -571,7 +606,7 @@ def subscribe(self, topic, qos=0): while True: op = self._wait_for_msg() if op == 0x90: - rc = self._sock.recv(4) + rc = self._sock_exact_recv(4) assert rc[1] == packet[2] and rc[2] == packet[3] if rc[3] == 0x80: raise MMQTTException("SUBACK Failure!") @@ -634,7 +669,7 @@ def unsubscribe(self, topic): while True: op = self._wait_for_msg() if op == 176: - return_code = self._sock.recv(3) + return_code = self._sock_exact_recv(3) assert return_code[0] == 0x02 # [MQTT-3.32] assert ( @@ -671,6 +706,7 @@ def reconnect(self, resub_topics=True): def loop(self): """Non-blocking message loop. Use this method to check incoming subscription messages. + Returns response codes of any messages received. """ if self._timestamp == 0: self._timestamp = time.monotonic() @@ -682,10 +718,12 @@ def loop(self): "KeepAlive period elapsed - \ requesting a PINGRESP from the server..." ) - self.ping() + rcs = self.ping() self._timestamp = 0 + return rcs self._sock.settimeout(0.1) - return self._wait_for_msg() + rc = self._wait_for_msg() + return [rc] if rc else None def _wait_for_msg(self, timeout=30): """Reads and processes network events. @@ -694,24 +732,30 @@ def _wait_for_msg(self, timeout=30): res = self._sock.recv(1) self._sock.settimeout(timeout) if res in [None, b""]: + # If we get here, it means that there is nothing to be received return None - if res == MQTT_PINGRESP: - sz = self._sock.recv(1)[0] - assert sz == 0 - return None + if res[0] == MQTT_PINGRESP: + if self.logger: + self.logger.debug("Checking PINGRESP") + sz = self._sock_exact_recv(1)[0] + if sz != 0x00: + raise MMQTTException( + "Unexpected PINGRESP returned from broker: {}.".format(sz) + ) + return MQTT_PINGRESP if res[0] & 0xF0 != 0x30: return res[0] sz = self._recv_len() - topic_len = self._sock.recv(2) + topic_len = self._sock_exact_recv(2) topic_len = (topic_len[0] << 8) | topic_len[1] - topic = self._sock.recv(topic_len) + topic = self._sock_exact_recv(topic_len) topic = str(topic, "utf-8") sz -= topic_len + 2 if res[0] & 0x06: - pid = self._sock.recv(2) + pid = self._sock_exact_recv(2) pid = pid[0] << 0x08 | pid[1] sz -= 0x02 - msg = self._sock.recv(sz) + msg = self._sock_exact_recv(sz) self._handle_on_message(self, topic, str(msg, "utf-8")) if res[0] & 0x06 == 0x02: pkt = bytearray(b"\x40\x02\0\0") @@ -725,7 +769,7 @@ def _recv_len(self): n = 0 sh = 0 while True: - b = self._sock.recv(1)[0] + b = self._sock_exact_recv(1)[0] n |= (b & 0x7F) << sh if not b & 0x80: return n