Skip to content

Commit d33c801

Browse files
Optimize array_equivalent for NDFrame.equals (#35328)
1 parent c94f602 commit d33c801

File tree

3 files changed

+113
-46
lines changed

3 files changed

+113
-46
lines changed

Diff for: pandas/core/dtypes/missing.py

+64-32
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,9 @@ def _isna_compat(arr, fill_value=np.nan) -> bool:
355355
return True
356356

357357

358-
def array_equivalent(left, right, strict_nan: bool = False) -> bool:
358+
def array_equivalent(
359+
left, right, strict_nan: bool = False, dtype_equal: bool = False
360+
) -> bool:
359361
"""
360362
True if two arrays, left and right, have equal non-NaN elements, and NaNs
361363
in corresponding locations. False otherwise. It is assumed that left and
@@ -368,6 +370,12 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool:
368370
left, right : ndarrays
369371
strict_nan : bool, default False
370372
If True, consider NaN and None to be different.
373+
dtype_equal : bool, default False
374+
Whether `left` and `right` are known to have the same dtype
375+
according to `is_dtype_equal`. Some methods like `BlockManager.equals`.
376+
require that the dtypes match. Setting this to ``True`` can improve
377+
performance, but will give different results for arrays that are
378+
equal but different dtypes.
371379
372380
Returns
373381
-------
@@ -391,43 +399,28 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool:
391399
if left.shape != right.shape:
392400
return False
393401

402+
if dtype_equal:
403+
# fastpath when we require that the dtypes match (Block.equals)
404+
if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype):
405+
return _array_equivalent_float(left, right)
406+
elif is_datetimelike_v_numeric(left.dtype, right.dtype):
407+
return False
408+
elif needs_i8_conversion(left.dtype):
409+
return _array_equivalent_datetimelike(left, right)
410+
elif is_string_dtype(left.dtype):
411+
# TODO: fastpath for pandas' StringDtype
412+
return _array_equivalent_object(left, right, strict_nan)
413+
else:
414+
return np.array_equal(left, right)
415+
416+
# Slow path when we allow comparing different dtypes.
394417
# Object arrays can contain None, NaN and NaT.
395418
# string dtypes must be come to this path for NumPy 1.7.1 compat
396419
if is_string_dtype(left.dtype) or is_string_dtype(right.dtype):
397-
398-
if not strict_nan:
399-
# isna considers NaN and None to be equivalent.
400-
return lib.array_equivalent_object(
401-
ensure_object(left.ravel()), ensure_object(right.ravel())
402-
)
403-
404-
for left_value, right_value in zip(left, right):
405-
if left_value is NaT and right_value is not NaT:
406-
return False
407-
408-
elif left_value is libmissing.NA and right_value is not libmissing.NA:
409-
return False
410-
411-
elif isinstance(left_value, float) and np.isnan(left_value):
412-
if not isinstance(right_value, float) or not np.isnan(right_value):
413-
return False
414-
else:
415-
try:
416-
if np.any(np.asarray(left_value != right_value)):
417-
return False
418-
except TypeError as err:
419-
if "Cannot compare tz-naive" in str(err):
420-
# tzawareness compat failure, see GH#28507
421-
return False
422-
elif "boolean value of NA is ambiguous" in str(err):
423-
return False
424-
raise
425-
return True
420+
return _array_equivalent_object(left, right, strict_nan)
426421

427422
# NaNs can occur in float and complex arrays.
428423
if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype):
429-
430-
# empty
431424
if not (np.prod(left.shape) and np.prod(right.shape)):
432425
return True
433426
return ((left == right) | (isna(left) & isna(right))).all()
@@ -452,6 +445,45 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool:
452445
return np.array_equal(left, right)
453446

454447

448+
def _array_equivalent_float(left, right):
449+
return ((left == right) | (np.isnan(left) & np.isnan(right))).all()
450+
451+
452+
def _array_equivalent_datetimelike(left, right):
453+
return np.array_equal(left.view("i8"), right.view("i8"))
454+
455+
456+
def _array_equivalent_object(left, right, strict_nan):
457+
if not strict_nan:
458+
# isna considers NaN and None to be equivalent.
459+
return lib.array_equivalent_object(
460+
ensure_object(left.ravel()), ensure_object(right.ravel())
461+
)
462+
463+
for left_value, right_value in zip(left, right):
464+
if left_value is NaT and right_value is not NaT:
465+
return False
466+
467+
elif left_value is libmissing.NA and right_value is not libmissing.NA:
468+
return False
469+
470+
elif isinstance(left_value, float) and np.isnan(left_value):
471+
if not isinstance(right_value, float) or not np.isnan(right_value):
472+
return False
473+
else:
474+
try:
475+
if np.any(np.asarray(left_value != right_value)):
476+
return False
477+
except TypeError as err:
478+
if "Cannot compare tz-naive" in str(err):
479+
# tzawareness compat failure, see GH#28507
480+
return False
481+
elif "boolean value of NA is ambiguous" in str(err):
482+
return False
483+
raise
484+
return True
485+
486+
455487
def _infer_fill_value(val):
456488
"""
457489
infer the fill value for the nan/NaT from the provided

Diff for: pandas/core/internals/managers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,7 @@ def equals(self, other: "BlockManager") -> bool:
14361436
return array_equivalent(left, right)
14371437

14381438
for i in range(len(self.items)):
1439-
# Check column-wise, return False if any column doesnt match
1439+
# Check column-wise, return False if any column doesn't match
14401440
left = self.iget_values(i)
14411441
right = other.iget_values(i)
14421442
if not is_dtype_equal(left.dtype, right.dtype):
@@ -1445,7 +1445,7 @@ def equals(self, other: "BlockManager") -> bool:
14451445
if not left.equals(right):
14461446
return False
14471447
else:
1448-
if not array_equivalent(left, right):
1448+
if not array_equivalent(left, right, dtype_equal=True):
14491449
return False
14501450
return True
14511451

Diff for: pandas/tests/dtypes/test_missing.py

+47-12
Original file line numberDiff line numberDiff line change
@@ -300,50 +300,80 @@ def test_period(self):
300300
tm.assert_series_equal(notna(s), ~exp)
301301

302302

303-
def test_array_equivalent():
304-
assert array_equivalent(np.array([np.nan, np.nan]), np.array([np.nan, np.nan]))
303+
@pytest.mark.parametrize("dtype_equal", [True, False])
304+
def test_array_equivalent(dtype_equal):
305305
assert array_equivalent(
306-
np.array([np.nan, 1, np.nan]), np.array([np.nan, 1, np.nan])
306+
np.array([np.nan, np.nan]), np.array([np.nan, np.nan]), dtype_equal=dtype_equal
307+
)
308+
assert array_equivalent(
309+
np.array([np.nan, 1, np.nan]),
310+
np.array([np.nan, 1, np.nan]),
311+
dtype_equal=dtype_equal,
307312
)
308313
assert array_equivalent(
309314
np.array([np.nan, None], dtype="object"),
310315
np.array([np.nan, None], dtype="object"),
316+
dtype_equal=dtype_equal,
311317
)
312318
# Check the handling of nested arrays in array_equivalent_object
313319
assert array_equivalent(
314320
np.array([np.array([np.nan, None], dtype="object"), None], dtype="object"),
315321
np.array([np.array([np.nan, None], dtype="object"), None], dtype="object"),
322+
dtype_equal=dtype_equal,
316323
)
317324
assert array_equivalent(
318325
np.array([np.nan, 1 + 1j], dtype="complex"),
319326
np.array([np.nan, 1 + 1j], dtype="complex"),
327+
dtype_equal=dtype_equal,
320328
)
321329
assert not array_equivalent(
322330
np.array([np.nan, 1 + 1j], dtype="complex"),
323331
np.array([np.nan, 1 + 2j], dtype="complex"),
332+
dtype_equal=dtype_equal,
333+
)
334+
assert not array_equivalent(
335+
np.array([np.nan, 1, np.nan]),
336+
np.array([np.nan, 2, np.nan]),
337+
dtype_equal=dtype_equal,
338+
)
339+
assert not array_equivalent(
340+
np.array(["a", "b", "c", "d"]), np.array(["e", "e"]), dtype_equal=dtype_equal
341+
)
342+
assert array_equivalent(
343+
Float64Index([0, np.nan]), Float64Index([0, np.nan]), dtype_equal=dtype_equal
324344
)
325345
assert not array_equivalent(
326-
np.array([np.nan, 1, np.nan]), np.array([np.nan, 2, np.nan])
346+
Float64Index([0, np.nan]), Float64Index([1, np.nan]), dtype_equal=dtype_equal
347+
)
348+
assert array_equivalent(
349+
DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan]), dtype_equal=dtype_equal
350+
)
351+
assert not array_equivalent(
352+
DatetimeIndex([0, np.nan]), DatetimeIndex([1, np.nan]), dtype_equal=dtype_equal
353+
)
354+
assert array_equivalent(
355+
TimedeltaIndex([0, np.nan]),
356+
TimedeltaIndex([0, np.nan]),
357+
dtype_equal=dtype_equal,
327358
)
328-
assert not array_equivalent(np.array(["a", "b", "c", "d"]), np.array(["e", "e"]))
329-
assert array_equivalent(Float64Index([0, np.nan]), Float64Index([0, np.nan]))
330-
assert not array_equivalent(Float64Index([0, np.nan]), Float64Index([1, np.nan]))
331-
assert array_equivalent(DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan]))
332-
assert not array_equivalent(DatetimeIndex([0, np.nan]), DatetimeIndex([1, np.nan]))
333-
assert array_equivalent(TimedeltaIndex([0, np.nan]), TimedeltaIndex([0, np.nan]))
334359
assert not array_equivalent(
335-
TimedeltaIndex([0, np.nan]), TimedeltaIndex([1, np.nan])
360+
TimedeltaIndex([0, np.nan]),
361+
TimedeltaIndex([1, np.nan]),
362+
dtype_equal=dtype_equal,
336363
)
337364
assert array_equivalent(
338365
DatetimeIndex([0, np.nan], tz="US/Eastern"),
339366
DatetimeIndex([0, np.nan], tz="US/Eastern"),
367+
dtype_equal=dtype_equal,
340368
)
341369
assert not array_equivalent(
342370
DatetimeIndex([0, np.nan], tz="US/Eastern"),
343371
DatetimeIndex([1, np.nan], tz="US/Eastern"),
372+
dtype_equal=dtype_equal,
344373
)
374+
# The rest are not dtype_equal
345375
assert not array_equivalent(
346-
DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan], tz="US/Eastern")
376+
DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan], tz="US/Eastern"),
347377
)
348378
assert not array_equivalent(
349379
DatetimeIndex([0, np.nan], tz="CET"),
@@ -353,6 +383,11 @@ def test_array_equivalent():
353383
assert not array_equivalent(DatetimeIndex([0, np.nan]), TimedeltaIndex([0, np.nan]))
354384

355385

386+
def test_array_equivalent_different_dtype_but_equal():
387+
# Unclear if this is exposed anywhere in the public-facing API
388+
assert array_equivalent(np.array([1, 2]), np.array([1.0, 2.0]))
389+
390+
356391
@pytest.mark.parametrize(
357392
"lvalue, rvalue",
358393
[

0 commit comments

Comments
 (0)