From 13e9dfb309fe5cb250dc730b06b2df3af9f009bd Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Mon, 28 Nov 2022 12:46:39 +0000 Subject: [PATCH 1/2] expanded feature domain flexibility --- src/trustyai/model/domain.py | 46 +++++++++++++++++++++++-------- tests/general/test_conversions.py | 34 +++++++++++++++++++++-- 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/src/trustyai/model/domain.py b/src/trustyai/model/domain.py index 17484e9..e154214 100644 --- a/src/trustyai/model/domain.py +++ b/src/trustyai/model/domain.py @@ -3,17 +3,18 @@ from typing import Optional, Tuple, List, Union from jpype import _jclass + from org.kie.trustyai.explainability.model.domain import ( FeatureDomain, NumericalFeatureDomain, CategoricalFeatureDomain, + CategoricalNumericalFeatureDomain, + ObjectFeatureDomain, EmptyFeatureDomain, ) -def feature_domain( - values: Optional[Union[Tuple, List[str]]] -) -> Optional[FeatureDomain]: +def feature_domain(values: Optional[Union[Tuple, List]]) -> Optional[FeatureDomain]: r"""Create a Java :class:`FeatureDomain`. This represents the valid range of values for a particular feature, which is useful when constraining a counterfactual explanation to ensure it only recovers valid inputs. For example, if we had a feature that described a person's age, we @@ -22,13 +23,18 @@ def feature_domain( Parameters ---------- - values : Optional[Union[Tuple, List[str]]] + values : Optional[Union[Tuple, List]] The valid values of the feature. If `values` takes the form of: * **A tuple of floats or integers:** The feature domain will be a continuous range from ``values[0]`` to ``values[1]``. - * **A list of strings:** The feature domain will be categorical, where `values` contains - all possible valid feature values. + * **A list of floats or integers:**: The feature domain will be a *numeric* categorical, + where `values` contains all possible valid feature values. + * **A list of strings:** The feature domain will be a *string* categorical, where `values` + contains all possible valid feature values. + * **A list of objects:** The feature domain will be an *object* categorical, where `values` + contains all possible valid feature values. These may present an issue if the objects + are not natively Java serializable. Otherwise, the feature domain will be taken as `Empty`, which will mean it will be held fixed during the counterfactual explanation. @@ -43,12 +49,30 @@ def feature_domain( if not values: domain = EmptyFeatureDomain.create() else: - if isinstance(values[0], (float, int)): - domain = NumericalFeatureDomain.create(values[0], values[1]) - elif isinstance(values[0], str): - domain = CategoricalFeatureDomain.create( - _jclass.JClass("java.util.Arrays").asList(values) + if isinstance(values, tuple): + assert isinstance(values[0], (float, int)) and isinstance( + values[1], (float, int) + ) + assert len(values) == 2, ( + "Tuples passed as domain values must only contain" + " two values that define the (minimum, maximum) of the domain" ) + domain = NumericalFeatureDomain.create(values[0], values[1]) + + elif isinstance(values, list): + print(values[0], isinstance(values[0], str)) + java_array = _jclass.JClass("java.util.Arrays").asList(values) + if isinstance(values[0], bool) and isinstance(values[1], bool): + domain = ObjectFeatureDomain.create(java_array) + elif isinstance(values[0], (float, int)) and isinstance( + values[1], (float, int) + ): + domain = CategoricalNumericalFeatureDomain.create(java_array) + elif isinstance(values[0], str): + domain = CategoricalFeatureDomain.create(java_array) + else: + domain = ObjectFeatureDomain.create(java_array) + else: domain = EmptyFeatureDomain.create() return domain diff --git a/tests/general/test_conversions.py b/tests/general/test_conversions.py index 0d191a1..23d9821 100644 --- a/tests/general/test_conversions.py +++ b/tests/general/test_conversions.py @@ -37,6 +37,37 @@ def test_numeric_domain_tuple(): assert jdomain.getUpperBound() == 1000.0 +def test_categorical_numeric_domain_list(): + """Test create numeric domain from list""" + domain = [0, 1000] + jdomain = feature_domain(domain) + assert jdomain.getCategories().size() == 2 + assert jdomain.getCategories().containsAll(domain) + + domain = [0.0, 1000.0] + jdomain = feature_domain(domain) + assert jdomain.getCategories().size() == 2 + assert jdomain.getCategories().containsAll(domain) + + +def test_categorical_object_domain_list(): + """Test create object domain from list""" + domain = [True, False] + jdomain = feature_domain(domain) + assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain" + assert jdomain.getCategories().size() == 2 + assert jdomain.getCategories().containsAll(domain) + + +def test_categorical_object_domain_list_2(): + """Test create object domain from list""" + domain = [b"test", b"test2"] + jdomain = feature_domain(domain) + assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain" + assert jdomain.getCategories().size() == 2 + assert jdomain.getCategories().containsAll(domain) + + def test_empty_domain(): """Test empty domain""" domain = feature_domain(None) @@ -45,7 +76,7 @@ def test_empty_domain(): def test_categorical_domain_tuple(): """Test create categorical domain from tuple and list""" - domain = ("foo", "bar", "baz") + domain = ["foo", "bar", "baz"] jdomain = feature_domain(domain) assert jdomain.getCategories().size() == 3 assert jdomain.getCategories().containsAll(list(domain)) @@ -55,7 +86,6 @@ def test_categorical_domain_tuple(): assert jdomain.getCategories().size() == 3 assert jdomain.getCategories().containsAll(domain) - def test_feature_function(): """Test helper method to create features""" f1 = feature(name="f-1", value=1.0, dtype="number") From 7fbb3fdb17697401171adb81345a4844e31e7e3b Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Tue, 29 Nov 2022 09:37:24 +0000 Subject: [PATCH 2/2] removed vestigial debugging print --- src/trustyai/model/domain.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/trustyai/model/domain.py b/src/trustyai/model/domain.py index e154214..4413a30 100644 --- a/src/trustyai/model/domain.py +++ b/src/trustyai/model/domain.py @@ -60,7 +60,6 @@ def feature_domain(values: Optional[Union[Tuple, List]]) -> Optional[FeatureDoma domain = NumericalFeatureDomain.create(values[0], values[1]) elif isinstance(values, list): - print(values[0], isinstance(values[0], str)) java_array = _jclass.JClass("java.util.Arrays").asList(values) if isinstance(values[0], bool) and isinstance(values[1], bool): domain = ObjectFeatureDomain.create(java_array)