From b0ec8a603d1ca015d68d7494dacb48a2b2912017 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 11 Mar 2024 13:12:48 +0000 Subject: [PATCH 1/2] test case for asarray --- tests/test_common.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index 66076bfe..8b0a7ae7 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -42,6 +42,17 @@ def test_device(library): x2 = to_device(x, dev) assert device(x) == device(x2) +@pytest.mark.parametrize("target_library,func", is_functions.items()) +@pytest.mark.parametrize("source_library", is_functions.keys()) +def test_asarray(source_library, target_library, func): + src_lib = import_(source_library, wrapper=True) + tgt_lib = import_(target_library, wrapper=True) + is_tgt_type = globals()[func] + + a = src_lib.asarray([1, 2, 3]) + b = tgt_lib.asarray(a) + + assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_to_device_host(library): From 192da0a261fb35d73374c385f9f44034a529c7b3 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 11 Mar 2024 13:13:15 +0000 Subject: [PATCH 2/2] Partial fix --- array_api_compat/common/_aliases.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8792aa2e..00ff8ab5 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -349,7 +349,6 @@ def _asarray( obj = np.asarray(obj).copy() #print(obj) return xp.array(obj, dtype=dtype, **copy_kwargs) - return obj return xp.asarray(obj, dtype=dtype, **kwargs)