Skip to content

Commit 74789ff

Browse files
tf-transform-teamtfx-copybara
tf-transform-team
authored andcommitted
Handle invalid inputs for mutual information computation where mathematically undefined values (e.g. LOG(0)) are returned.
PiperOrigin-RevId: 454202540
1 parent 8d60ea3 commit 74789ff

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

tensorflow_transform/info_theory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def calculate_partial_mutual_information(n_ij, x_i, y_j, n):
8484
Returns:
8585
Mutual information for the cell x=i, y=j.
8686
"""
87-
if n_ij == 0:
87+
if n_ij == 0 or x_i == 0 or y_j == 0:
8888
return 0
8989
return n_ij * ((log2(n_ij) + log2(n)) -
9090
(log2(x_i) + log2(y_j)))

tensorflow_transform/info_theory_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,27 @@ def test_calculate_partial_expected_mutual_information(
137137
col_count=8,
138138
total_count=16,
139139
expected_mi=0),
140+
dict(
141+
testcase_name='invalid_input_zero_cell_count',
142+
cell_count=4,
143+
row_count=0,
144+
col_count=8,
145+
total_count=8,
146+
expected_mi=0),
147+
dict(
148+
testcase_name='invalid_input_zero_row_count',
149+
cell_count=4,
150+
row_count=0,
151+
col_count=8,
152+
total_count=8,
153+
expected_mi=0),
154+
dict(
155+
testcase_name='invalid_input_zero_col_count',
156+
cell_count=4,
157+
row_count=8,
158+
col_count=0,
159+
total_count=8,
160+
expected_mi=0),
140161
)
141162
def test_mutual_information(self, cell_count, row_count, col_count,
142163
total_count, expected_mi):

0 commit comments

Comments
 (0)