Skip to content

implemented argmax and argmin methods using OpenVINO opset. #21071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ NumpyDtypeTest::test_absolute_bool
NumpyDtypeTest::test_add_
NumpyDtypeTest::test_all
NumpyDtypeTest::test_any
NumpyDtypeTest::test_argmax
NumpyDtypeTest::test_argmin
NumpyDtypeTest::test_argpartition
NumpyDtypeTest::test_array
NumpyDtypeTest::test_bitwise
Expand Down Expand Up @@ -81,8 +79,6 @@ NumpyDtypeTest::test_square_bool
HistogramTest
NumpyOneInputOpsCorrectnessTest::test_all
NumpyOneInputOpsCorrectnessTest::test_any
NumpyOneInputOpsCorrectnessTest::test_argmax
NumpyOneInputOpsCorrectnessTest::test_argmin
NumpyOneInputOpsCorrectnessTest::test_argpartition
NumpyOneInputOpsCorrectnessTest::test_array
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
Expand Down
49 changes: 47 additions & 2 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,56 @@ def arctanh(x):


def argmax(x, axis=None, keepdims=False):
raise NotImplementedError("`argmax` is not supported with openvino backend")
x = get_ov_output(x)
x_shape = x.get_partial_shape()
rank = x_shape.rank.get_length()
if rank == 0:
return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))
if axis is None:
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
x = ov_opset.reshape(x, flatten_shape, False).output(0)
axis = 0
else:
if axis < 0:
axis = rank + axis
sorted_indices = ov_opset.topk(
x,
k=1,
axis=axis,
mode="max",
sort="value",

).output(1)

if not keepdims:
sorted_indices = ov_opset.squeeze(sorted_indices, [axis]).output(0)
return OpenVINOKerasTensor(sorted_indices)


def argmin(x, axis=None, keepdims=False):
raise NotImplementedError("`argmin` is not supported with openvino backend")
x = get_ov_output(x)
x_shape = x.get_partial_shape()
rank = x_shape.rank.get_length()
if rank == 0:
return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))
if axis is None:
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
x = ov_opset.reshape(x, flatten_shape, False).output(0)
axis = 0
else:
if axis < 0:
axis = rank + axis
sorted_indices = ov_opset.topk(
x,
k=1,
axis=axis,
mode="min",
sort="value",
).output(1)

if not keepdims:
sorted_indices = ov_opset.squeeze(sorted_indices, [axis]).output(0)
return OpenVINOKerasTensor(sorted_indices)


def argsort(x, axis=-1):
Expand Down
Loading