Skip to content

Commit e3c73c6

Browse files
committed
more refactoring and corrections
1 parent 63052d3 commit e3c73c6

File tree

1 file changed

+108
-134
lines changed

1 file changed

+108
-134
lines changed

adafruit_ina3221.py

Lines changed: 108 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
# Register Definitions
4040
CONFIGURATION = 0x00
4141
SHUNTVOLTAGE_REGS = [0x01, 0x03, 0x05]
42-
BUSTVOLTAGE_REGS = [0x02, 0x04, 0x06]
42+
BUSVOLTAGE_REGS = [0x02, 0x04, 0x06]
4343
CRITICAL_ALERT_LIMIT_REGS = [0x07, 0x09, 0x0B]
4444
WARNING_ALERT_LIMIT_REGS = [0x08, 0x0A, 0x0C]
4545
SHUNTVOLTAGE_SUM = 0x0D
@@ -62,6 +62,35 @@
6262
CRITICAL_CH2 = 1 << 8
6363
CRITICAL_CH1 = 1 << 9
6464

65+
# Default value for Adafruit-breakout
66+
DEFAULT_SHUNT_RESISTANCE = 0.05
67+
68+
# Precision (LSB) of bus-voltage and shunt-voltage
69+
BUS_V_LSB = 0.008 # 8mV
70+
SHUNT_V_LSB = 0.000040 # 40µV
71+
72+
def _mask(offset,len,read=True):
73+
""" return mask for reading or writing """
74+
if read:
75+
return ((1<<len)-1)<<offset
76+
else:
77+
return ~(((1<<len)-1)<<offset) & 0xFFFF
78+
79+
def _to_signed(val: int, shift: int, bits: int):
80+
""" convert value to signed int and shift result """
81+
if val & (1 << (bits - 1)):
82+
val -= 1 << (bits-1) # remove sign
83+
val = (1 << bits-1) - 1 - val # bitwise not
84+
return - (val >> shift)
85+
return val >> shift
86+
87+
def _to_2comp(val: int, shift: int, bits: int):
88+
""" convert value to twos complement, shifting as necessary """
89+
if val > 0:
90+
return val << shift
91+
val = (-val) << shift
92+
val = (1 << bits-1) - val # bitwise not plus 1
93+
return val + (1 << (bits-1))
6594

6695
class AVG_MODE:
6796
"""Enumeration for the averaging mode options in INA3221.
@@ -146,18 +175,15 @@ class INA3221Channel:
146175
def __init__(self, device: Any, channel: int) -> None:
147176
self._device = device
148177
self._channel = channel
149-
self._shunt_resistance = 0.5
178+
self._shunt_resistance = DEFAULT_SHUNT_RESISTANCE
150179
self._enabled = False
151180

152-
def enable(self) -> None:
153-
"""Enable this channel"""
154-
config = self._device._read_register(CONFIGURATION, 2)
155-
config_value = (config[0] << 8) | config[1]
156-
config_value |= 1 << (14 - self._channel) # Set the bit for the specific channel
157-
high_byte = (config_value >> 8) & 0xFF
158-
low_byte = config_value & 0xFF
159-
self._device._write_register(CONFIGURATION, bytes([high_byte, low_byte]))
160-
self._enabled = True
181+
def enable(self, flag: bool = True) -> None:
182+
"""Enable/disable this channel"""
183+
# enable bits in the configuration-register: 14-12
184+
self._device._set_register_bits(CONFIGURATION,
185+
14-self._channel,1,int(flag))
186+
self._enabled = flag
161187

162188
@property
163189
def enabled(self) -> bool:
@@ -167,22 +193,18 @@ def enabled(self) -> bool:
167193
@property
168194
def bus_voltage(self) -> float:
169195
"""Bus voltage in volts."""
170-
reg_address = BUSVOLTAGE_REGS[self._channel]
171-
result = self._device._read_register(reg_address, 2)
172-
raw_value = int.from_bytes(result, "big")
173-
voltage = (raw_value >> 3) * 8e-3
174-
return voltage
196+
reg_addr = BUSVOLTAGE_REGS[self._channel]
197+
raw_value = self._device._get_register_bits(reg_addr,0,16)
198+
raw_value = _to_signed(raw_value,3,16)
199+
return raw_value * BUS_V_LSB
175200

176201
@property
177202
def shunt_voltage(self) -> float:
178203
"""Shunt voltage in millivolts."""
179-
reg_address = SHUNTVOLTAGE_REGS[self._channel]
180-
result = self._device._read_register(reg_address, 2)
181-
raw_value = int.from_bytes(result, "big")
182-
raw_value = (
183-
raw_value - 0x10000 if raw_value & 0x8000 else raw_value
184-
) # convert to signed int16
185-
return (raw_value >> 3) * 40e-6
204+
reg_addr = SHUNTVOLTAGE_REGS[self._channel]
205+
raw_value = self._device._get_register_bits(reg_addr,0,16)
206+
raw_value = _to_signed(raw_value,3,16)
207+
return raw_value * SHUNT_V_LSB * 1000
186208

187209
@property
188210
def shunt_resistance(self) -> float:
@@ -194,13 +216,13 @@ def shunt_resistance(self, value: float) -> None:
194216
self._shunt_resistance = value
195217

196218
@property
197-
def current_amps(self) -> float:
198-
"""Returns the current in amperes.
219+
def current(self) -> float:
220+
"""Returns the current in mA
199221
200222
The current is calculated using the formula: I = Vshunt / Rshunt.
201223
If the shunt voltage is NaN (e.g., no valid measurement), it returns NaN.
202224
"""
203-
shunt_voltage = self.shunt_voltage
225+
shunt_voltage = self.shunt_voltage # this is in mV
204226
if shunt_voltage != shunt_voltage: # Check for NaN
205227
return float("nan")
206228
return shunt_voltage / self._shunt_resistance
@@ -213,16 +235,14 @@ def critical_alert_threshold(self) -> float:
213235
float: The current critical alert threshold in amperes.
214236
"""
215237
reg_addr = CRITICAL_ALERT_LIMIT_REGS[self._channel]
216-
result = self._device._read_register(reg_addr, 2)
217-
threshold = int.from_bytes(result, "big")
218-
return (threshold >> 3) * 40e-6 / self._shunt_resistance
238+
threshold = self._device._get_register_bits(reg_addr,3,13)
239+
return threshold * SHUNT_V_LSB / self._shunt_resistance
219240

220241
@critical_alert_threshold.setter
221242
def critical_alert_threshold(self, current: float) -> None:
222-
threshold = int(current * self._shunt_resistance / 40e-6 * 8)
243+
threshold = int(current * self._shunt_resistance / SHUNT_V_LSB)
223244
reg_addr = CRITICAL_ALERT_LIMIT_REGS[self._channel]
224-
threshold_bytes = threshold.to_bytes(2, "big")
225-
self._device._write_register(reg_addr, threshold_bytes)
245+
self._device._set_register_bits(reg_addr,3,13,threshold)
226246

227247
@property
228248
def warning_alert_threshold(self) -> float:
@@ -232,17 +252,24 @@ def warning_alert_threshold(self) -> float:
232252
float: The current warning alert threshold in amperes.
233253
"""
234254
reg_addr = WARNING_ALERT_LIMIT_REGS[self._channel]
235-
result = self._device._read_register(reg_addr, 2)
236-
threshold = int.from_bytes(result, "big")
237-
return threshold / (self._shunt_resistance * 8)
255+
threshold = self._device._get_register_bits(reg_addr,3,13)
256+
return threshold / self._shunt_resistance
238257

239258
@warning_alert_threshold.setter
240259
def warning_alert_threshold(self, current: float) -> None:
241-
threshold = int(current * self._shunt_resistance * 8)
260+
threshold = int(current * self._shunt_resistance)
242261
reg_addr = WARNING_ALERT_LIMIT_REGS[self._channel]
243-
threshold_bytes = threshold.to_bytes(2, "big")
244-
self._device._write_register(reg_addr, threshold_bytes)
262+
self._device._set_register_bits(reg_addr,3,13,threshold)
245263

264+
@property
265+
def summation_channel(self) -> bool:
266+
"""Status of summation channel """
267+
return self._device._get_register_bits(MASK_ENABLE,14-self._channel,1)
268+
269+
@summation_channel.setter
270+
def summation_channel(self, value: bool) -> None:
271+
""" set value of summation control """
272+
self._device._set_register_bits(MASK_ENABLE,14-self._channel,1,int(value))
246273

247274
class INA3221:
248275
"""Driver for the INA3221 device with three channels."""
@@ -252,17 +279,17 @@ def __init__(self, i2c, address: int = DEFAULT_ADDRESS, enable: List = [0, 1, 2]
252279
Args:
253280
i2c (I2C): The I2C bus to which the INA3221 is connected.
254281
address (int, optional): The I2C address of the INA3221. Defaults to DEFAULT_ADDRESS.
282+
enable(List[int], optional): channels to initialize at start (default: all)
255283
"""
256284
self.i2c_dev = I2CDevice(i2c, address)
257285
self.reset()
258286

259287
self.channels: List[INA3221Channel] = [INA3221Channel(self, i) for i in range(3)]
260-
for i in enable:
261-
self.channels[i].enable()
288+
for i in range(3):
289+
self.channels[i].enable(i in enable)
262290
self.mode: int = MODE.SHUNT_BUS_CONT
263291
self.shunt_voltage_conv_time: int = CONV_TIME.CONV_TIME_8MS
264292
self.bus_voltage_conv_time: int = CONV_TIME.CONV_TIME_8MS
265-
# Set the default sampling rate (averaging mode) to 64 samples
266293
self.averaging_mode: int = AVG_MODE.AVG_64_SAMPLES
267294

268295
def __getitem__(self, channel: int) -> INA3221Channel:
@@ -284,10 +311,7 @@ def reset(self) -> None:
284311
Returns:
285312
None
286313
"""
287-
config = self._read_register(CONFIGURATION, 2)
288-
config = bytearray(config)
289-
config[0] |= 0x80 # Set the reset bit
290-
return self._write_register(CONFIGURATION, config)
314+
self._set_register_bits(CONFIGURATION,15,1,1)
291315

292316
@property
293317
def die_id(self) -> int:
@@ -318,17 +342,13 @@ def mode(self) -> int:
318342
4: Alternate power down mode, 5: Continuous shunt voltage measurement,
319343
6: Continuous bus voltage measurement, 7: Continuous shunt and bus voltage measurements
320344
"""
321-
config = self._read_register(CONFIGURATION, 2)
322-
return config[1] & 0x07
345+
return self._get_register_bits(CONFIGURATION,0,3)
323346

324347
@mode.setter
325348
def mode(self, value: int) -> None:
326349
if not 0 <= value <= 7:
327350
raise ValueError("Mode must be a 3-bit value (0-7).")
328-
config = self._read_register(CONFIGURATION, 2)
329-
config = bytearray(config)
330-
config[1] = (config[1] & 0xF8) | value
331-
self._write_register(CONFIGURATION, config)
351+
self._set_register_bits(CONFIGURATION,0,3,value)
332352

333353
@property
334354
def shunt_voltage_conv_time(self) -> int:
@@ -339,17 +359,13 @@ def shunt_voltage_conv_time(self) -> int:
339359
0: 140µs, 1: 204µs, 2: 332µs, 3: 588µs,
340360
4: 1ms, 5: 2ms, 6: 4ms, 7: 8ms
341361
"""
342-
config = self._read_register(CONFIGURATION, 2)
343-
return (config[1] >> 4) & 0x07
362+
return self._get_register_bits(CONFIGURATION,3,3)
344363

345364
@shunt_voltage_conv_time.setter
346365
def shunt_voltage_conv_time(self, conv_time: int) -> None:
347366
if conv_time < 0 or conv_time > 7:
348367
raise ValueError("Conversion time must be between 0 and 7")
349-
config = self._read_register(CONFIGURATION, 2)
350-
config = bytearray(config)
351-
config[1] = (config[1] & 0x8F) | (conv_time << 4)
352-
self._write_register(CONFIGURATION, config)
368+
self._set_register_bits(CONFIGURATION,3,3,int(conv_time))
353369

354370
@property
355371
def bus_voltage_conv_time(self) -> int:
@@ -360,19 +376,13 @@ def bus_voltage_conv_time(self) -> int:
360376
0: 140µs, 1: 204µs, 2: 332µs, 3: 588µs,
361377
4: 1ms, 5: 2ms, 6: 4ms, 7: 8ms
362378
"""
363-
config = self._read_register(CONFIGURATION, 2)
364-
return (config[0] >> 3) & 0x07 # Bits 12-14 are the bus voltage conversion time
379+
return self._get_register_bits(CONFIGURATION,6,3)
365380

366381
@bus_voltage_conv_time.setter
367382
def bus_voltage_conv_time(self, conv_time: int) -> None:
368383
if conv_time < 0 or conv_time > 7:
369384
raise ValueError("Conversion time must be between 0 and 7")
370-
371-
config = self._read_register(CONFIGURATION, 2)
372-
config = bytearray(config)
373-
config[0] = config[0] & 0xC7
374-
config[0] = config[0] | (conv_time << 3)
375-
self._write_register(CONFIGURATION, config)
385+
self._set_register_bits(CONFIGURATION,6,3,int(conv_time))
376386

377387
@property
378388
def averaging_mode(self) -> int:
@@ -384,15 +394,13 @@ def averaging_mode(self) -> int:
384394
3: 64_SAMPLES, 4: 128_SAMPLES, 5: 256_SAMPLES,
385395
6: 512_SAMPLES, 7: 1024_SAMPLES
386396
"""
387-
config = self._read_register(CONFIGURATION, 2)
388-
return (config[1] >> 1) & 0x07
397+
return self._get_register_bits(CONFIGURATION,9,3)
389398

390399
@averaging_mode.setter
391400
def averaging_mode(self, mode: int) -> None:
392-
config = self._read_register(CONFIGURATION, 2)
393-
config = bytearray(config)
394-
config[1] = (config[1] & 0xF1) | (mode << 1)
395-
self._write_register(CONFIGURATION, config)
401+
if mode < 0 or mode > 7:
402+
raise ValueError("Averaging mode must be between 0 and 7")
403+
self._set_register_bits(CONFIGURATION,9,3,int(mode))
396404

397405
@property
398406
def flags(self) -> int:
@@ -402,82 +410,48 @@ def flags(self) -> int:
402410
int: The current flag indicators from the Mask/Enable register,
403411
masked for relevant flag bits.
404412
"""
405-
result = self._read_register(MASK_ENABLE, 2)
406-
flags = int.from_bytes(result, "big")
407-
408-
# Mask to keep only relevant flag bits
409-
mask = (
410-
CONV_READY
411-
| TIMECONT_ALERT
412-
| POWER_VALID
413-
| WARN_CH3
414-
| WARN_CH2
415-
| WARN_CH1
416-
| SUMMATION
417-
| CRITICAL_CH3
418-
| CRITICAL_CH2
419-
| CRITICAL_CH1
420-
)
421-
422-
return flags & mask
423-
424-
@property
425-
def summation_channels(self) -> tuple:
426-
"""Status of summation channels (ch1, ch2, ch3)
427-
428-
Returns:
429-
tuple: A tuple of three boolean values indicating the status
430-
of summation channels (ch1, ch2, ch3).
431-
"""
432-
result = self._read_register(MASK_ENABLE, 2)
433-
mask_enable = int.from_bytes(result, "big")
434-
ch1 = bool((mask_enable >> 14) & 0x01)
435-
ch2 = bool((mask_enable >> 13) & 0x01)
436-
ch3 = bool((mask_enable >> 12) & 0x01)
437-
438-
return ch1, ch2, ch3
439-
440-
@summation_channels.setter
441-
def summation_channels(self, channels: tuple) -> None:
442-
if len(channels) != 3:
443-
raise ValueError("Must pass a tuple of three boolean values (ch1, ch2, ch3)")
444-
ch1, ch2, ch3 = channels
445-
scc_value = (ch1 << 2) | (ch2 << 1) | (ch3 << 0)
446-
result = self._read_register(MASK_ENABLE, 2)
447-
mask_enable = int.from_bytes(result, "big")
448-
mask_enable = (mask_enable & ~(0x07 << 12)) | (scc_value << 12)
449-
self._write_register(MASK_ENABLE, mask_enable.to_bytes(2, "big"))
413+
return self._read_register_bits(MASK_ENABLE,0,10)
450414

451415
@property
452416
def power_valid_limits(self) -> tuple:
453417
"""Power-Valid upper and lower voltage limits in volts.
454418
455419
Returns:
456420
tuple: A tuple containing the lower and upper voltage limits
457-
in volts as (vlimitlow, vlimithigh).
421+
in volts as (lower_limit, upper_limit).
458422
"""
459-
low_limit_result = self._read_register(POWERVALID_LOWERLIMIT, 2)
460-
vlimitlow = int.from_bytes(low_limit_result, "big") * 8e-3
461-
high_limit_result = self._read_register(POWERVALID_UPPERLIMIT, 2)
462-
vlimithigh = int.from_bytes(high_limit_result, "big") * 8e-3
463-
return vlimitlow, vlimithigh
423+
raw_value = self._device._get_register_bits(POWERVALID_LOWERLIMIT,0,16)
424+
lower_limit = _to_signed(raw_value,3,16) * 8e-3
425+
raw_value = self._device._get_register_bits(POWERVALID_UPPERLIMIT,0,16)
426+
upper_limit = _to_signed(raw_value,3,16) * 8e-3
427+
return lower_limit, upper_limit
464428

465429
@power_valid_limits.setter
466430
def power_valid_limits(self, limits: tuple) -> None:
467431
if len(limits) != 2:
468432
raise ValueError("Must provide both lower and upper voltage limits.")
469-
vlimitlow, vlimithigh = limits
470-
low_limit_value = int(vlimitlow * 1000)
471-
high_limit_value = int(vlimithigh * 1000)
472-
low_limit_bytes = low_limit_value.to_bytes(2, "big")
473-
self._write_register(POWERVALID_LOWERLIMIT, low_limit_bytes)
474-
high_limit_bytes = high_limit_value.to_bytes(2, "big")
475-
self._write_register(POWERVALID_UPPERLIMIT, high_limit_bytes)
476-
477-
def _to_signed(self, val, bits):
478-
if val & (1 << (bits - 1)):
479-
val -= 1 << bits
480-
return val
433+
# convert to mV and twos-complement
434+
lower_limit = _to_2comp(int(limits[0] * 1000),3,16)
435+
upper_limit = _to_2comp(int(limits[1] * 1000),3,16)
436+
self._device._set_register_bits(POWERVALID_LOWERLIMIT,0,16,lower_limit)
437+
self._device._set_register_bits(POWERVALID_UPPERLIMIT,0,16,upper_limit)
438+
439+
def _get_register_bits(self,reg,offset,len):
440+
""" return given bits from register """
441+
value = self._read_register(reg, 2)
442+
value = (value[0] << 8) | value[1] # swap bytes
443+
mask = _mask(offset,len,read=True)
444+
return (value & mask) >> offset
445+
446+
def _set_register_bits(self,reg,offset,len,value):
447+
""" set given bits of register """
448+
old = self._read_register(reg, 2)
449+
old = (old[0] << 8) | old[1] # swap bytes
450+
mask = _mask(offset,len,read=False)
451+
new = (old & mask) | value << offset
452+
high_byte = (new >> 8) & 0xFF
453+
low_byte = new & 0xFF
454+
self._write_register(reg, bytes([high_byte, low_byte]))
481455

482456
def _write_register(self, reg, data):
483457
with self.i2c_dev:

0 commit comments

Comments
 (0)