From 74789ff3818f01e58df88bf67306a0e89bf11457 Mon Sep 17 00:00:00 2001 From: tf-transform-team Date: Fri, 10 Jun 2022 11:10:23 -0700 Subject: [PATCH] Handle invalid inputs for mutual information computation where mathematically undefined values (e.g. LOG(0)) are returned. PiperOrigin-RevId: 454202540 --- tensorflow_transform/info_theory.py | 2 +- tensorflow_transform/info_theory_test.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tensorflow_transform/info_theory.py b/tensorflow_transform/info_theory.py index 751963d7..f536d105 100644 --- a/tensorflow_transform/info_theory.py +++ b/tensorflow_transform/info_theory.py @@ -84,7 +84,7 @@ def calculate_partial_mutual_information(n_ij, x_i, y_j, n): Returns: Mutual information for the cell x=i, y=j. """ - if n_ij == 0: + if n_ij == 0 or x_i == 0 or y_j == 0: return 0 return n_ij * ((log2(n_ij) + log2(n)) - (log2(x_i) + log2(y_j))) diff --git a/tensorflow_transform/info_theory_test.py b/tensorflow_transform/info_theory_test.py index 05bf4138..b8136b4f 100644 --- a/tensorflow_transform/info_theory_test.py +++ b/tensorflow_transform/info_theory_test.py @@ -137,6 +137,27 @@ def test_calculate_partial_expected_mutual_information( col_count=8, total_count=16, expected_mi=0), + dict( + testcase_name='invalid_input_zero_cell_count', + cell_count=4, + row_count=0, + col_count=8, + total_count=8, + expected_mi=0), + dict( + testcase_name='invalid_input_zero_row_count', + cell_count=4, + row_count=0, + col_count=8, + total_count=8, + expected_mi=0), + dict( + testcase_name='invalid_input_zero_col_count', + cell_count=4, + row_count=8, + col_count=0, + total_count=8, + expected_mi=0), ) def test_mutual_information(self, cell_count, row_count, col_count, total_count, expected_mi):