Skip to content

FAI-889: Allow non-string categorical feature domains #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions src/trustyai/model/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
from typing import Optional, Tuple, List, Union

from jpype import _jclass

from org.kie.trustyai.explainability.model.domain import (
FeatureDomain,
NumericalFeatureDomain,
CategoricalFeatureDomain,
CategoricalNumericalFeatureDomain,
ObjectFeatureDomain,
EmptyFeatureDomain,
)


def feature_domain(
values: Optional[Union[Tuple, List[str]]]
) -> Optional[FeatureDomain]:
def feature_domain(values: Optional[Union[Tuple, List]]) -> Optional[FeatureDomain]:
r"""Create a Java :class:`FeatureDomain`. This represents the valid range of values for a
particular feature, which is useful when constraining a counterfactual explanation to ensure it
only recovers valid inputs. For example, if we had a feature that described a person's age, we
Expand All @@ -22,13 +23,18 @@ def feature_domain(

Parameters
----------
values : Optional[Union[Tuple, List[str]]]
values : Optional[Union[Tuple, List]]
The valid values of the feature. If `values` takes the form of:

* **A tuple of floats or integers:** The feature domain will be a continuous range from
``values[0]`` to ``values[1]``.
* **A list of strings:** The feature domain will be categorical, where `values` contains
all possible valid feature values.
* **A list of floats or integers:**: The feature domain will be a *numeric* categorical,
where `values` contains all possible valid feature values.
* **A list of strings:** The feature domain will be a *string* categorical, where `values`
contains all possible valid feature values.
* **A list of objects:** The feature domain will be an *object* categorical, where `values`
contains all possible valid feature values. These may present an issue if the objects
are not natively Java serializable.

Otherwise, the feature domain will be taken as `Empty`, which will mean it will be held
fixed during the counterfactual explanation.
Expand All @@ -43,12 +49,30 @@ def feature_domain(
if not values:
domain = EmptyFeatureDomain.create()
else:
if isinstance(values[0], (float, int)):
domain = NumericalFeatureDomain.create(values[0], values[1])
elif isinstance(values[0], str):
domain = CategoricalFeatureDomain.create(
_jclass.JClass("java.util.Arrays").asList(values)
if isinstance(values, tuple):
assert isinstance(values[0], (float, int)) and isinstance(
values[1], (float, int)
)
assert len(values) == 2, (
"Tuples passed as domain values must only contain"
" two values that define the (minimum, maximum) of the domain"
)
domain = NumericalFeatureDomain.create(values[0], values[1])

elif isinstance(values, list):
print(values[0], isinstance(values[0], str))
java_array = _jclass.JClass("java.util.Arrays").asList(values)
if isinstance(values[0], bool) and isinstance(values[1], bool):
domain = ObjectFeatureDomain.create(java_array)
elif isinstance(values[0], (float, int)) and isinstance(
values[1], (float, int)
):
domain = CategoricalNumericalFeatureDomain.create(java_array)
elif isinstance(values[0], str):
domain = CategoricalFeatureDomain.create(java_array)
else:
domain = ObjectFeatureDomain.create(java_array)

else:
domain = EmptyFeatureDomain.create()
return domain
34 changes: 32 additions & 2 deletions tests/general/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@ def test_numeric_domain_tuple():
assert jdomain.getUpperBound() == 1000.0


def test_categorical_numeric_domain_list():
"""Test create numeric domain from list"""
domain = [0, 1000]
jdomain = feature_domain(domain)
assert jdomain.getCategories().size() == 2
assert jdomain.getCategories().containsAll(domain)

domain = [0.0, 1000.0]
jdomain = feature_domain(domain)
assert jdomain.getCategories().size() == 2
assert jdomain.getCategories().containsAll(domain)


def test_categorical_object_domain_list():
"""Test create object domain from list"""
domain = [True, False]
jdomain = feature_domain(domain)
assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain"
assert jdomain.getCategories().size() == 2
assert jdomain.getCategories().containsAll(domain)


def test_categorical_object_domain_list_2():
"""Test create object domain from list"""
domain = [b"test", b"test2"]
jdomain = feature_domain(domain)
assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain"
assert jdomain.getCategories().size() == 2
assert jdomain.getCategories().containsAll(domain)


def test_empty_domain():
"""Test empty domain"""
domain = feature_domain(None)
Expand All @@ -45,7 +76,7 @@ def test_empty_domain():

def test_categorical_domain_tuple():
"""Test create categorical domain from tuple and list"""
domain = ("foo", "bar", "baz")
domain = ["foo", "bar", "baz"]
jdomain = feature_domain(domain)
assert jdomain.getCategories().size() == 3
assert jdomain.getCategories().containsAll(list(domain))
Expand All @@ -55,7 +86,6 @@ def test_categorical_domain_tuple():
assert jdomain.getCategories().size() == 3
assert jdomain.getCategories().containsAll(domain)


def test_feature_function():
"""Test helper method to create features"""
f1 = feature(name="f-1", value=1.0, dtype="number")
Expand Down