diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8792aa2e..00ff8ab5 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -349,7 +349,6 @@ def _asarray( obj = np.asarray(obj).copy() #print(obj) return xp.array(obj, dtype=dtype, **copy_kwargs) - return obj return xp.asarray(obj, dtype=dtype, **kwargs) diff --git a/tests/test_common.py b/tests/test_common.py index 66076bfe..8b0a7ae7 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -42,6 +42,17 @@ def test_device(library): x2 = to_device(x, dev) assert device(x) == device(x2) +@pytest.mark.parametrize("target_library,func", is_functions.items()) +@pytest.mark.parametrize("source_library", is_functions.keys()) +def test_asarray(source_library, target_library, func): + src_lib = import_(source_library, wrapper=True) + tgt_lib = import_(target_library, wrapper=True) + is_tgt_type = globals()[func] + + a = src_lib.asarray([1, 2, 3]) + b = tgt_lib.asarray(a) + + assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"]) def test_to_device_host(library):