Skip to content

Commit

Permalink
Handle invalid inputs for mutual information computation where mathem…
Browse files Browse the repository at this point in the history
…atically undefined values (e.g. LOG(0)) are returned.

PiperOrigin-RevId: 454202540
  • Loading branch information
tf-transform-team authored and tfx-copybara committed Jun 10, 2022
1 parent 8d60ea3 commit 74789ff
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensorflow_transform/info_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def calculate_partial_mutual_information(n_ij, x_i, y_j, n):
Returns:
Mutual information for the cell x=i, y=j.
"""
if n_ij == 0:
if n_ij == 0 or x_i == 0 or y_j == 0:
return 0
return n_ij * ((log2(n_ij) + log2(n)) -
(log2(x_i) + log2(y_j)))
Expand Down
21 changes: 21 additions & 0 deletions tensorflow_transform/info_theory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,27 @@ def test_calculate_partial_expected_mutual_information(
col_count=8,
total_count=16,
expected_mi=0),
dict(
testcase_name='invalid_input_zero_cell_count',
cell_count=4,
row_count=0,
col_count=8,
total_count=8,
expected_mi=0),
dict(
testcase_name='invalid_input_zero_row_count',
cell_count=4,
row_count=0,
col_count=8,
total_count=8,
expected_mi=0),
dict(
testcase_name='invalid_input_zero_col_count',
cell_count=4,
row_count=8,
col_count=0,
total_count=8,
expected_mi=0),
)
def test_mutual_information(self, cell_count, row_count, col_count,
total_count, expected_mi):
Expand Down

0 comments on commit 74789ff

Please sign in to comment.