Skip to content

Commit d783a2c

Browse files
committed
Fix device kwarg-only argument being used as positional for calls to default_dtypes throughout tests
1 parent aa24801 commit d783a2c

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

Diff for: dpctl/tests/test_tensor_array_api_inspection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_array_api_inspection_default_dtypes():
9696

9797
info = dpt.__array_namespace_info__()
9898
default_dts_nodev = info.default_dtypes()
99-
default_dts_dev = info.default_dtypes(dev)
99+
default_dts_dev = info.default_dtypes(device=dev)
100100

101101
assert (
102102
int_dt == default_dts_nodev["integral"] == default_dts_dev["integral"]

Diff for: dpctl/tests/test_usm_ndarray_indexing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def test_mixed_index_getitem():
491491
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
492492
i1b = dpt.ones(10, dtype="?")
493493
info = x.__array_namespace__().__array_namespace_info__()
494-
ind_dt = info.default_dtypes(x.device)["indexing"]
494+
ind_dt = info.default_dtypes(device=x.device)["indexing"]
495495
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
496496
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
497497
y = x[i0, i1b, i2]
@@ -503,7 +503,7 @@ def test_mixed_index_setitem():
503503
x = dpt.reshape(dpt.arange(1000, dtype="i4"), (10, 10, 10))
504504
i1b = dpt.ones(10, dtype="?")
505505
info = x.__array_namespace__().__array_namespace_info__()
506-
ind_dt = info.default_dtypes(x.device)["indexing"]
506+
ind_dt = info.default_dtypes(device=x.device)["indexing"]
507507
i0 = dpt.asarray([0, 2, 3], dtype=ind_dt)[:, dpt.newaxis]
508508
i2 = dpt.asarray([3, 4, 7], dtype=ind_dt)[:, dpt.newaxis]
509509
v_shape = (3, int(dpt.sum(i1b, dtype="i8")))

Diff for: dpctl/tests/test_usm_ndarray_searchsorted.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def _check(hay_stack, needles, needles_np):
1212
assert hay_stack.ndim == 1
1313

1414
info_ = dpt.__array_namespace_info__()
15-
default_dts_dev = info_.default_dtypes(hay_stack.device)
15+
default_dts_dev = info_.default_dtypes(device=hay_stack.device)
1616
index_dt = default_dts_dev["indexing"]
1717

1818
p_left = dpt.searchsorted(hay_stack, needles, side="left")

0 commit comments

Comments
 (0)