From 8064b6e68bcf3b3a12b17ab7e4cb481a99cbf131 Mon Sep 17 00:00:00 2001 From: Diane Napolitano Date: Tue, 23 Jan 2024 10:08:46 -0500 Subject: [PATCH] Adding two more EI solver unit tests --- tests/test_ei_transition_solver.py | 121 +++++++++++++++-------------- 1 file changed, 61 insertions(+), 60 deletions(-) diff --git a/tests/test_ei_transition_solver.py b/tests/test_ei_transition_solver.py index 646fde0a..6af0098c 100644 --- a/tests/test_ei_transition_solver.py +++ b/tests/test_ei_transition_solver.py @@ -37,66 +37,67 @@ def test_ei_fit_predict(): np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) -# def test_matrix_fit_predict_with_weights(): -# X = np.array( -# [ -# [1, 2], -# [3, 4], -# [5, 6], -# [7, 8], -# [9, 10], -# [11, 12], -# ] -# ) - -# Y = np.array( -# [ -# [2, 3], -# [4, 5], -# [6, 7], -# [8, 9], -# [10, 11], -# [12, 13], -# ] -# ) - -# weights = np.array([500, 250, 125, 62.5, 31.25, 15.625]) - -# expected = np.array([[0.737329, 0.262671], [0.230589, 0.769411]]) - -# tms = TransitionMatrixSolver() -# current = tms.fit_predict(X, Y, weights=weights) -# np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) - - -# def test_matrix_fit_predict_pivoted(): -# X = np.array( -# [ -# [1, 2], -# [3, 4], -# [5, 6], -# [7, 8], -# [9, 10], -# [11, 12], -# ] -# ).T - -# Y = np.array( -# [ -# [2, 3], -# [4, 5], -# [6, 7], -# [8, 9], -# [10, 11], -# [12, 13], -# ] -# ).T - -# expected = np.array([[0.760428, 0.239572], [0.216642, 0.783358]]) - -# tms = TransitionMatrixSolver() -# current = tms.fit_predict(X, Y) -# np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) +def test_ei_fit_predict_with_weights(): + # NOTE: currently, supplying weights to the EI solver does nothing. + X = np.array( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + ] + ) + + Y = np.array( + [ + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11], + [12, 13], + ] + ) + + weights = np.array([500, 250, 125, 62.5, 31.25, 15.625]) + + expected = np.array([[0.883539, 0.116461], [0.09511, 0.90489]]) + + ei = EITransitionSolver(random_seed=1024) + current = ei.fit_predict(X, Y, weights=weights) + np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) + + +def test_ei_fit_predict_pivoted(): + X = np.array( + [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + ] + ).T + + Y = np.array( + [ + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11], + [12, 13], + ] + ).T + + expected = np.array([[0.883539, 0.116461], [0.09511, 0.90489]]) + + ei = EITransitionSolver(random_seed=1024) + current = ei.fit_predict(X, Y) + np.testing.assert_allclose(expected, current, rtol=RTOL, atol=ATOL) def test_ei_get_prediction_interval():