diff --git a/can/interfaces/slcan.py b/can/interfaces/slcan.py index ac90a2600..035969d8b 100644 --- a/can/interfaces/slcan.py +++ b/can/interfaces/slcan.py @@ -9,7 +9,12 @@ import logging from can import BusABC, Message -from ..exceptions import CanInterfaceNotImplementedError, CanOperationError +from ..exceptions import ( + CanInterfaceNotImplementedError, + CanInitializationError, + CanOperationError, + error_check, +) from can import typechecking @@ -62,8 +67,6 @@ def __init__( **kwargs: Any, ) -> None: """ - :raise ValueError: if both *bitrate* and *btr* are set - :param str channel: port of underlying serial or usb device (e.g. ``/dev/ttyUSB0``, ``COM8``, ...) Must not be empty. Can also end with ``@115200`` (or similarly) to specify the baudrate. @@ -79,30 +82,37 @@ def __init__( Time to wait in seconds after opening serial connection :param rtscts: turn hardware handshake (RTS/CTS) on and off + + :raise ValueError: if both ``bitrate`` and ``btr`` are set or the channel is invalid + :raise CanInterfaceNotImplementedError: if the serial module is missing + :raise CanInitializationError: if the underlying serial connection could not be established """ if serial is None: raise CanInterfaceNotImplementedError("The serial module is not installed") if not channel: # if None or empty - raise TypeError("Must specify a serial port.") + raise ValueError("Must specify a serial port.") if "@" in channel: (channel, baudrate) = channel.split("@") ttyBaudrate = int(baudrate) - self.serialPortOrig = serial.serial_for_url( - channel, baudrate=ttyBaudrate, rtscts=rtscts - ) + + with error_check(exception_type=CanInitializationError): + self.serialPortOrig = serial.serial_for_url( + channel, baudrate=ttyBaudrate, rtscts=rtscts + ) self._buffer = bytearray() time.sleep(sleep_after_open) - if bitrate is not None and btr is not None: - raise ValueError("Bitrate and btr mutually exclusive.") - if bitrate is not None: - self.set_bitrate(bitrate) - if btr is not None: - self.set_bitrate_reg(btr) - self.open() + with error_check(exception_type=CanInitializationError): + if bitrate is not None and btr is not None: + raise ValueError("Bitrate and btr mutually exclusive.") + if bitrate is not None: + self.set_bitrate(bitrate) + if btr is not None: + self.set_bitrate_reg(btr) + self.open() super().__init__( channel, ttyBaudrate=115200, bitrate=None, rtscts=False, **kwargs @@ -110,17 +120,19 @@ def __init__( def set_bitrate(self, bitrate: int) -> None: """ - :raise ValueError: if both *bitrate* is not among the possible values - :param bitrate: Bitrate in bit/s + + :raise ValueError: if ``bitrate`` is not among the possible values """ - self.close() if bitrate in self._BITRATES: - self._write(self._BITRATES[bitrate]) + bitrate_code = self._BITRATES[bitrate] else: bitrates = ", ".join(str(k) for k in self._BITRATES.keys()) raise ValueError(f"Invalid bitrate, choose one of {bitrates}.") + + self.close() + self._write(bitrate_code) self.open() def set_bitrate_reg(self, btr: str) -> None: @@ -133,33 +145,38 @@ def set_bitrate_reg(self, btr: str) -> None: self.open() def _write(self, string: str) -> None: - self.serialPortOrig.write(string.encode() + self.LINE_TERMINATOR) - self.serialPortOrig.flush() + with error_check("Could not write to serial device"): + self.serialPortOrig.write(string.encode() + self.LINE_TERMINATOR) + self.serialPortOrig.flush() def _read(self, timeout: Optional[float]) -> Optional[str]: - # first read what is already in receive buffer - while self.serialPortOrig.in_waiting: - self._buffer += self.serialPortOrig.read() - # if we still don't have a complete message, do a blocking read - start = time.time() - time_left = timeout - while not (ord(self._OK) in self._buffer or ord(self._ERROR) in self._buffer): - self.serialPortOrig.timeout = time_left - byte = self.serialPortOrig.read() - if byte: - self._buffer += byte - # if timeout is None, try indefinitely - if timeout is None: - continue - # try next one only if there still is time, and with - # reduced timeout - else: - time_left = timeout - (time.time() - start) - if time_left > 0: + with error_check("Could not read from serial device"): + # first read what is already in receive buffer + while self.serialPortOrig.in_waiting: + self._buffer += self.serialPortOrig.read() + # if we still don't have a complete message, do a blocking read + start = time.time() + time_left = timeout + while not ( + ord(self._OK) in self._buffer or ord(self._ERROR) in self._buffer + ): + self.serialPortOrig.timeout = time_left + byte = self.serialPortOrig.read() + if byte: + self._buffer += byte + # if timeout is None, try indefinitely + if timeout is None: continue + # try next one only if there still is time, and with + # reduced timeout else: - return None + time_left = timeout - (time.time() - start) + if time_left > 0: + continue + else: + return None + # return first message for i in range(len(self._buffer)): if self._buffer[i] == ord(self._OK) or self._buffer[i] == ord(self._ERROR): @@ -170,8 +187,9 @@ def _read(self, timeout: Optional[float]) -> Optional[str]: def flush(self) -> None: del self._buffer[:] - while self.serialPortOrig.in_waiting: - self.serialPortOrig.read() + with error_check("Could not flush"): + while self.serialPortOrig.in_waiting: + self.serialPortOrig.read() def open(self) -> None: self._write("O") @@ -247,7 +265,8 @@ def send(self, msg: Message, timeout: Optional[float] = None) -> None: def shutdown(self) -> None: self.close() - self.serialPortOrig.close() + with error_check("Could not close serial socket"): + self.serialPortOrig.close() def fileno(self) -> int: try: diff --git a/test/test_rotating_loggers.py b/test/test_rotating_loggers.py index dede154ab..d900f4f23 100644 --- a/test/test_rotating_loggers.py +++ b/test/test_rotating_loggers.py @@ -6,11 +6,8 @@ import os from pathlib import Path -import tempfile from unittest.mock import Mock -import pytest - import can from .data.example_data import generate_message @@ -90,7 +87,7 @@ def test_rotate_without_rotator(self, tmp_path): assert os.path.exists(source) is False assert os.path.exists(dest) is False - logger_instance._get_new_writer(source) + logger_instance._writer = logger_instance._get_new_writer(source) logger_instance.stop() assert os.path.exists(source) is True @@ -113,7 +110,7 @@ def test_rotate_with_rotator(self, tmp_path): assert os.path.exists(source) is False assert os.path.exists(dest) is False - logger_instance._get_new_writer(source) + logger_instance._writer = logger_instance._get_new_writer(source) logger_instance.stop() assert os.path.exists(source) is True