diff --git a/adafruit_ina3221.py b/adafruit_ina3221.py index 698ac36..d81f68b 100644 --- a/adafruit_ina3221.py +++ b/adafruit_ina3221.py @@ -23,16 +23,11 @@ * Adafruit's Bus Device library: https://github.com/adafruit/Adafruit_CircuitPython_BusDevice """ -import time - -from adafruit_bus_device.i2c_device import I2CDevice - try: from typing import Any, List - - from busio import I2C except ImportError: pass +from adafruit_bus_device.i2c_device import I2CDevice __version__ = "0.0.0+auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_INA3221.git" @@ -43,18 +38,10 @@ # Register Definitions CONFIGURATION = 0x00 -SHUNTVOLTAGE_CH1 = 0x01 -BUSVOLTAGE_CH1 = 0x02 -SHUNTVOLTAGE_CH2 = 0x03 -BUSVOLTAGE_CH2 = 0x04 -SHUNTVOLTAGE_CH3 = 0x05 -BUSVOLTAGE_CH3 = 0x06 -CRITICAL_ALERT_LIMIT_CH1 = 0x07 -WARNING_ALERT_LIMIT_CH1 = 0x08 -CRITICAL_ALERT_LIMIT_CH2 = 0x09 -WARNING_ALERT_LIMIT_CH2 = 0x0A -CRITICAL_ALERT_LIMIT_CH3 = 0x0B -WARNING_ALERT_LIMIT_CH3 = 0x0C +SHUNTVOLTAGE_REGS = [0x01, 0x03, 0x05] +BUSVOLTAGE_REGS = [0x02, 0x04, 0x06] +CRITICAL_ALERT_LIMIT_REGS = [0x07, 0x09, 0x0B] +WARNING_ALERT_LIMIT_REGS = [0x08, 0x0A, 0x0C] SHUNTVOLTAGE_SUM = 0x0D SHUNTVOLTAGE_SUM_LIMIT = 0x0E MASK_ENABLE = 0x0F @@ -75,6 +62,39 @@ CRITICAL_CH2 = 1 << 8 CRITICAL_CH1 = 1 << 9 +# Default value for Adafruit-breakout +DEFAULT_SHUNT_RESISTANCE = 0.05 + +# Precision (LSB) of bus-voltage and shunt-voltage +BUS_V_LSB = 0.008 # 8mV +SHUNT_V_LSB = 0.000040 # 40µV + + +def _mask(offset, len, read=True): + """return mask for reading or writing""" + if read: + return ((1 << len) - 1) << offset + else: + return ~(((1 << len) - 1) << offset) & 0xFFFF + + +def _to_signed(val: int, shift: int, bits: int): + """convert value to signed int and shift result""" + if val & (1 << (bits - 1)): + val -= 1 << (bits - 1) # remove sign + val = (1 << bits - 1) - 1 - val # bitwise not + return -(val >> shift) + return val >> shift + + +def _to_2comp(val: int, shift: int, bits: int): + """convert value to twos complement, shifting as necessary""" + if val > 0: + return val << shift + val = (-val) << shift + val = (1 << bits - 1) - val # bitwise not plus 1 + return val + (1 << (bits - 1)) + class AVG_MODE: """Enumeration for the averaging mode options in INA3221. @@ -152,66 +172,128 @@ class INA3221Channel: """Represents a single channel of the INA3221. Args: - parent (Any): The parent INA3221 instance managing the I2C communication. + device (Any): The device INA3221 instance managing the I2C communication. channel (int): The channel number (1, 2, or 3) for this instance. """ - def __init__(self, parent: Any, channel: int) -> None: - self._parent = parent + def __init__(self, device: Any, channel: int) -> None: + self._device = device self._channel = channel + self._shunt_resistance = DEFAULT_SHUNT_RESISTANCE + self._enabled = False + + def enable(self, flag: bool = True) -> None: + """Enable/disable this channel""" + # enable bits in the configuration-register: 14-12 + self._device._set_register_bits(CONFIGURATION, 14 - self._channel, 1, int(flag)) + self._enabled = flag + + @property + def enabled(self) -> bool: + """return buffered enable-state""" + return self._enabled @property def bus_voltage(self) -> float: """Bus voltage in volts.""" - return self._parent._bus_voltage(self._channel) + reg_addr = BUSVOLTAGE_REGS[self._channel] + raw_value = self._device._get_register_bits(reg_addr, 0, 16) + raw_value = _to_signed(raw_value, 3, 16) + return raw_value * BUS_V_LSB @property def shunt_voltage(self) -> float: """Shunt voltage in millivolts.""" - return self._parent._shunt_voltage(self._channel) + reg_addr = SHUNTVOLTAGE_REGS[self._channel] + raw_value = self._device._get_register_bits(reg_addr, 0, 16) + raw_value = _to_signed(raw_value, 3, 16) + return raw_value * SHUNT_V_LSB * 1000 @property def shunt_resistance(self) -> float: """Shunt resistance in ohms.""" - return self._parent._shunt_resistance[self._channel] + return self._shunt_resistance @shunt_resistance.setter def shunt_resistance(self, value: float) -> None: - self._parent._shunt_resistance[self._channel] = value + self._shunt_resistance = value @property - def current_amps(self) -> float: - """Returns the current in amperes. + def current(self) -> float: + """Returns the current in mA The current is calculated using the formula: I = Vshunt / Rshunt. If the shunt voltage is NaN (e.g., no valid measurement), it returns NaN. """ - shunt_voltage = self.shunt_voltage + shunt_voltage = self.shunt_voltage # this is in mV if shunt_voltage != shunt_voltage: # Check for NaN return float("nan") - return shunt_voltage / self.shunt_resistance + return shunt_voltage / self._shunt_resistance + + @property + def critical_alert_threshold(self) -> float: + """Critical-Alert threshold in amperes + + Returns: + float: The current critical alert threshold in amperes. + """ + reg_addr = CRITICAL_ALERT_LIMIT_REGS[self._channel] + threshold = self._device._get_register_bits(reg_addr, 3, 13) + return threshold * SHUNT_V_LSB / self._shunt_resistance + + @critical_alert_threshold.setter + def critical_alert_threshold(self, current: float) -> None: + threshold = int(current * self._shunt_resistance / SHUNT_V_LSB) + reg_addr = CRITICAL_ALERT_LIMIT_REGS[self._channel] + self._device._set_register_bits(reg_addr, 3, 13, threshold) + + @property + def warning_alert_threshold(self) -> float: + """Warning-Alert threshold in amperes + + Returns: + float: The current warning alert threshold in amperes. + """ + reg_addr = WARNING_ALERT_LIMIT_REGS[self._channel] + threshold = self._device._get_register_bits(reg_addr, 3, 13) + return threshold / self._shunt_resistance + + @warning_alert_threshold.setter + def warning_alert_threshold(self, current: float) -> None: + threshold = int(current * self._shunt_resistance) + reg_addr = WARNING_ALERT_LIMIT_REGS[self._channel] + self._device._set_register_bits(reg_addr, 3, 13, threshold) + + @property + def summation_channel(self) -> bool: + """Status of summation channel""" + return self._device._get_register_bits(MASK_ENABLE, 14 - self._channel, 1) + + @summation_channel.setter + def summation_channel(self, value: bool) -> None: + """set value of summation control""" + self._device._set_register_bits(MASK_ENABLE, 14 - self._channel, 1, int(value)) class INA3221: """Driver for the INA3221 device with three channels.""" - def __init__(self, i2c, address: int = DEFAULT_ADDRESS) -> None: + def __init__(self, i2c, address: int = DEFAULT_ADDRESS, enable: List = [0, 1, 2]) -> None: """Initializes the INA3221 class over I2C Args: i2c (I2C): The I2C bus to which the INA3221 is connected. address (int, optional): The I2C address of the INA3221. Defaults to DEFAULT_ADDRESS. + enable(List[int], optional): channels to initialize at start (default: all) """ self.i2c_dev = I2CDevice(i2c, address) - self._shunt_resistance: List[float] = [0.05, 0.05, 0.05] # Default shunt resistances self.reset() self.channels: List[INA3221Channel] = [INA3221Channel(self, i) for i in range(3)] for i in range(3): - self.enable_channel(i) + self.channels[i].enable(i in enable) self.mode: int = MODE.SHUNT_BUS_CONT self.shunt_voltage_conv_time: int = CONV_TIME.CONV_TIME_8MS self.bus_voltage_conv_time: int = CONV_TIME.CONV_TIME_8MS - # Set the default sampling rate (averaging mode) to 64 samples self.averaging_mode: int = AVG_MODE.AVG_64_SAMPLES def __getitem__(self, channel: int) -> INA3221Channel: @@ -233,29 +315,7 @@ def reset(self) -> None: Returns: None """ - config = self._read_register(CONFIGURATION, 2) - config = bytearray(config) - config[0] |= 0x80 # Set the reset bit - return self._write_register(CONFIGURATION, config) - - def enable_channel(self, channel: int) -> None: - """Enable a specific channel of the INA3221. - - Args: - channel (int): The channel number to enable (0, 1, or 2). - - Raises: - ValueError: If the channel number is invalid (must be 0, 1, or 2). - """ - if channel > 2: - raise ValueError("Invalid channel number. Must be 0, 1, or 2.") - - config = self._read_register(CONFIGURATION, 2) - config_value = (config[0] << 8) | config[1] - config_value |= 1 << (14 - channel) # Set the bit for the specific channel - high_byte = (config_value >> 8) & 0xFF - low_byte = config_value & 0xFF - self._write_register(CONFIGURATION, bytes([high_byte, low_byte])) + self._set_register_bits(CONFIGURATION, 15, 1, 1) @property def die_id(self) -> int: @@ -286,17 +346,13 @@ def mode(self) -> int: 4: Alternate power down mode, 5: Continuous shunt voltage measurement, 6: Continuous bus voltage measurement, 7: Continuous shunt and bus voltage measurements """ - config = self._read_register(CONFIGURATION, 2) - return config[1] & 0x07 + return self._get_register_bits(CONFIGURATION, 0, 3) @mode.setter def mode(self, value: int) -> None: if not 0 <= value <= 7: raise ValueError("Mode must be a 3-bit value (0-7).") - config = self._read_register(CONFIGURATION, 2) - config = bytearray(config) - config[1] = (config[1] & 0xF8) | value - self._write_register(CONFIGURATION, config) + self._set_register_bits(CONFIGURATION, 0, 3, value) @property def shunt_voltage_conv_time(self) -> int: @@ -307,17 +363,13 @@ def shunt_voltage_conv_time(self) -> int: 0: 140µs, 1: 204µs, 2: 332µs, 3: 588µs, 4: 1ms, 5: 2ms, 6: 4ms, 7: 8ms """ - config = self._read_register(CONFIGURATION, 2) - return (config[1] >> 4) & 0x07 + return self._get_register_bits(CONFIGURATION, 3, 3) @shunt_voltage_conv_time.setter def shunt_voltage_conv_time(self, conv_time: int) -> None: if conv_time < 0 or conv_time > 7: raise ValueError("Conversion time must be between 0 and 7") - config = self._read_register(CONFIGURATION, 2) - config = bytearray(config) - config[1] = (config[1] & 0x8F) | (conv_time << 4) - self._write_register(CONFIGURATION, config) + self._set_register_bits(CONFIGURATION, 3, 3, int(conv_time)) @property def bus_voltage_conv_time(self) -> int: @@ -328,19 +380,13 @@ def bus_voltage_conv_time(self) -> int: 0: 140µs, 1: 204µs, 2: 332µs, 3: 588µs, 4: 1ms, 5: 2ms, 6: 4ms, 7: 8ms """ - config = self._read_register(CONFIGURATION, 2) - return (config[0] >> 3) & 0x07 # Bits 12-14 are the bus voltage conversion time + return self._get_register_bits(CONFIGURATION, 6, 3) @bus_voltage_conv_time.setter def bus_voltage_conv_time(self, conv_time: int) -> None: if conv_time < 0 or conv_time > 7: raise ValueError("Conversion time must be between 0 and 7") - - config = self._read_register(CONFIGURATION, 2) - config = bytearray(config) - config[0] = config[0] & 0xC7 - config[0] = config[0] | (conv_time << 3) - self._write_register(CONFIGURATION, config) + self._set_register_bits(CONFIGURATION, 6, 3, int(conv_time)) @property def averaging_mode(self) -> int: @@ -352,65 +398,13 @@ def averaging_mode(self) -> int: 3: 64_SAMPLES, 4: 128_SAMPLES, 5: 256_SAMPLES, 6: 512_SAMPLES, 7: 1024_SAMPLES """ - config = self._read_register(CONFIGURATION, 2) - return (config[1] >> 1) & 0x07 + return self._get_register_bits(CONFIGURATION, 9, 3) @averaging_mode.setter def averaging_mode(self, mode: int) -> None: - config = self._read_register(CONFIGURATION, 2) - config = bytearray(config) - config[1] = (config[1] & 0xF1) | (mode << 1) - self._write_register(CONFIGURATION, config) - - @property - def critical_alert_threshold(self) -> float: - """Critical-Alert threshold in amperes - - Returns: - float: The current critical alert threshold in amperes. - """ - if self._channel > 2: - raise ValueError("Invalid channel number. Must be 0, 1, or 2.") - - reg_addr = CRITICAL_ALERT_LIMIT_CH1 + 2 * self._channel - result = self._parent._read_register(reg_addr, 2) - threshold = int.from_bytes(result, "big") - return (threshold >> 3) * 40e-6 / self.shunt_resistance - - @critical_alert_threshold.setter - def critical_alert_threshold(self, current: float) -> None: - if self._channel > 2: - raise ValueError("Invalid channel number. Must be 0, 1, or 2.") - - threshold = int(current * self.shunt_resistance / 40e-6 * 8) - reg_addr = CRITICAL_ALERT_LIMIT_CH1 + 2 * self._channel - threshold_bytes = threshold.to_bytes(2, "big") - self._parent._write_register(reg_addr, threshold_bytes) - - @property - def warning_alert_threshold(self) -> float: - """Warning-Alert threshold in amperes - - Returns: - float: The current warning alert threshold in amperes. - """ - if self._channel > 2: - raise ValueError("Invalid channel number. Must be 0, 1, or 2.") - - reg_addr = WARNING_ALERT_LIMIT_CH1 + self._channel - result = self._parent._read_register(reg_addr, 2) - threshold = int.from_bytes(result, "big") - return threshold / (self.shunt_resistance * 8) - - @warning_alert_threshold.setter - def warning_alert_threshold(self, current: float) -> None: - if self._channel > 2: - raise ValueError("Invalid channel number. Must be 0, 1, or 2.") - - threshold = int(current * self.shunt_resistance * 8) - reg_addr = WARNING_ALERT_LIMIT_CH1 + self._channel - threshold_bytes = threshold.to_bytes(2, "big") - self._parent._write_register(reg_addr, threshold_bytes) + if mode < 0 or mode > 7: + raise ValueError("Averaging mode must be between 0 and 7") + self._set_register_bits(CONFIGURATION, 9, 3, int(mode)) @property def flags(self) -> int: @@ -420,51 +414,7 @@ def flags(self) -> int: int: The current flag indicators from the Mask/Enable register, masked for relevant flag bits. """ - result = self._read_register(MASK_ENABLE, 2) - flags = int.from_bytes(result, "big") - - # Mask to keep only relevant flag bits - mask = ( - CONV_READY - | TIMECONT_ALERT - | POWER_VALID - | WARN_CH3 - | WARN_CH2 - | WARN_CH1 - | SUMMATION - | CRITICAL_CH3 - | CRITICAL_CH2 - | CRITICAL_CH1 - ) - - return flags & mask - - @property - def summation_channels(self) -> tuple: - """Status of summation channels (ch1, ch2, ch3) - - Returns: - tuple: A tuple of three boolean values indicating the status - of summation channels (ch1, ch2, ch3). - """ - result = self._read_register(MASK_ENABLE, 2) - mask_enable = int.from_bytes(result, "big") - ch1 = bool((mask_enable >> 14) & 0x01) - ch2 = bool((mask_enable >> 13) & 0x01) - ch3 = bool((mask_enable >> 12) & 0x01) - - return ch1, ch2, ch3 - - @summation_channels.setter - def summation_channels(self, channels: tuple) -> None: - if len(channels) != 3: - raise ValueError("Must pass a tuple of three boolean values (ch1, ch2, ch3)") - ch1, ch2, ch3 = channels - scc_value = (ch1 << 2) | (ch2 << 1) | (ch3 << 0) - result = self._read_register(MASK_ENABLE, 2) - mask_enable = int.from_bytes(result, "big") - mask_enable = (mask_enable & ~(0x07 << 12)) | (scc_value << 12) - self._write_register(MASK_ENABLE, mask_enable.to_bytes(2, "big")) + return self._read_register_bits(MASK_ENABLE, 0, 10) @property def power_valid_limits(self) -> tuple: @@ -472,61 +422,40 @@ def power_valid_limits(self) -> tuple: Returns: tuple: A tuple containing the lower and upper voltage limits - in volts as (vlimitlow, vlimithigh). + in volts as (lower_limit, upper_limit). """ - low_limit_result = self._read_register(POWERVALID_LOWERLIMIT, 2) - vlimitlow = int.from_bytes(low_limit_result, "big") * 8e-3 - high_limit_result = self._read_register(POWERVALID_UPPERLIMIT, 2) - vlimithigh = int.from_bytes(high_limit_result, "big") * 8e-3 - return vlimitlow, vlimithigh + raw_value = self._device._get_register_bits(POWERVALID_LOWERLIMIT, 0, 16) + lower_limit = _to_signed(raw_value, 3, 16) * 8e-3 + raw_value = self._device._get_register_bits(POWERVALID_UPPERLIMIT, 0, 16) + upper_limit = _to_signed(raw_value, 3, 16) * 8e-3 + return lower_limit, upper_limit @power_valid_limits.setter def power_valid_limits(self, limits: tuple) -> None: if len(limits) != 2: raise ValueError("Must provide both lower and upper voltage limits.") - vlimitlow, vlimithigh = limits - low_limit_value = int(vlimitlow * 1000) - high_limit_value = int(vlimithigh * 1000) - low_limit_bytes = low_limit_value.to_bytes(2, "big") - self._write_register(POWERVALID_LOWERLIMIT, low_limit_bytes) - high_limit_bytes = high_limit_value.to_bytes(2, "big") - self._write_register(POWERVALID_UPPERLIMIT, high_limit_bytes) - - def _to_signed(self, val, bits): - if val & (1 << (bits - 1)): - val -= 1 << bits - return val - - def _shunt_voltage(self, channel): - if channel > 2: - raise ValueError("Must be channel 0, 1 or 2") - reg_address = [SHUNTVOLTAGE_CH1, SHUNTVOLTAGE_CH2, SHUNTVOLTAGE_CH3][channel] - result = self._read_register(reg_address, 2) - raw_value = int.from_bytes(result, "big") - raw_value = self._to_signed(raw_value, 16) - - return (raw_value >> 3) * 40e-6 - - def _bus_voltage(self, channel): - if channel > 2: - raise ValueError("Must be channel 0, 1 or 2") - - reg_address = [BUSVOLTAGE_CH1, BUSVOLTAGE_CH2, BUSVOLTAGE_CH3][channel] - result = self._read_register(reg_address, 2) - raw_value = int.from_bytes(result, "big") - voltage = (raw_value >> 3) * 8e-3 - - return voltage - - def _current_amps(self, channel): - if channel >= 3: - raise ValueError("Must be channel 0, 1 or 2") - - shunt_voltage = self._shunt_voltage(channel) - if shunt_voltage != shunt_voltage: - raise ValueError("Must be channel 0, 1 or 2") - - return shunt_voltage / self._shunt_resistance[channel] + # convert to mV and twos-complement + lower_limit = _to_2comp(int(limits[0] * 1000), 3, 16) + upper_limit = _to_2comp(int(limits[1] * 1000), 3, 16) + self._device._set_register_bits(POWERVALID_LOWERLIMIT, 0, 16, lower_limit) + self._device._set_register_bits(POWERVALID_UPPERLIMIT, 0, 16, upper_limit) + + def _get_register_bits(self, reg, offset, len): + """return given bits from register""" + value = self._read_register(reg, 2) + value = (value[0] << 8) | value[1] # swap bytes + mask = _mask(offset, len, read=True) + return (value & mask) >> offset + + def _set_register_bits(self, reg, offset, len, value): + """set given bits of register""" + old = self._read_register(reg, 2) + old = (old[0] << 8) | old[1] # swap bytes + mask = _mask(offset, len, read=False) + new = (old & mask) | value << offset + high_byte = (new >> 8) & 0xFF + low_byte = new & 0xFF + self._write_register(reg, bytes([high_byte, low_byte])) def _write_register(self, reg, data): with self.i2c_dev: @@ -538,7 +467,7 @@ def _read_register(self, reg, length): with self.i2c_dev: self.i2c_dev.write(bytes([reg])) self.i2c_dev.readinto(result) - except OSError as e: - print(f"I2C error: {e}") + except OSError as ex: + print(f"I2C error: {ex}") return None return result diff --git a/examples/ina3221_simpletest.py b/examples/ina3221_simpletest.py index 23c31cd..5952507 100644 --- a/examples/ina3221_simpletest.py +++ b/examples/ina3221_simpletest.py @@ -8,17 +8,17 @@ from adafruit_ina3221 import INA3221 i2c = board.I2C() -ina = INA3221(i2c) +ina = INA3221(i2c, enable=[0, 1, 2]) while True: for i in range(3): bus_voltage = ina[i].bus_voltage shunt_voltage = ina[i].shunt_voltage - current = ina[i].current_amps * 1000 + current = ina[i].current print(f"Channel {i + 1}:") print(f" Bus Voltage: {bus_voltage:.6f} V") - print(f" Shunt Voltage: {shunt_voltage:.6f} V") + print(f" Shunt Voltage: {shunt_voltage:.6f} mV") print(f" Current: {current:.6f} mA") print("-" * 30)