Skip to content

Commit

Permalink
Adding unit tests for other values that could be specified to credibl…
Browse files Browse the repository at this point in the history
…e interval and also getting credible interval transitions
  • Loading branch information
dmnapolitano committed Jan 30, 2024
1 parent 8e594c4 commit 0bf5278
Showing 1 changed file with 98 additions and 1 deletion.
99 changes: 98 additions & 1 deletion tests/test_ei_transition_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from elexsolver.EITransitionSolver import EITransitionSolver

# high tolerance to match PyMC's unit tests
# high tolerance due to random sampling
# (which can produce different outcomes on different architectures, despite setting seeds)
RTOL = 1e-02
ATOL = 1e-02

Expand Down Expand Up @@ -140,3 +141,99 @@ def test_ei_credible_interval_percentages():
(current_lower, current_upper) = ei.get_credible_interval(99, transitions=False)
np.testing.assert_allclose(expected_lower, current_lower, rtol=RTOL, atol=ATOL)
np.testing.assert_allclose(expected_upper, current_upper, rtol=RTOL, atol=ATOL)


def test_ei_credible_interval_percentages_float_interval():
X = np.array(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12],
]
)

Y = np.array(
[
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11],
[12, 13],
]
)

expected_lower = np.array([[0.037212, 0.356174], [0.424652, 0.117605]])
expected_upper = np.array([[0.643826, 0.962788], [0.882395, 0.575348]])

ei = EITransitionSolver(random_seed=1024, n_samples=10, sampling_chains=1)
_ = ei.fit_predict(X, Y)
(current_lower, current_upper) = ei.get_credible_interval(0.99, transitions=False)
np.testing.assert_allclose(expected_lower, current_lower, rtol=RTOL, atol=ATOL)
np.testing.assert_allclose(expected_upper, current_upper, rtol=RTOL, atol=ATOL)


def test_ei_credible_interval_invalid():
X = np.array(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12],
]
)

Y = np.array(
[
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11],
[12, 13],
]
)

ei = EITransitionSolver(random_seed=1024, n_samples=10, sampling_chains=1)
_ = ei.fit_predict(X, Y)

with pytest.raises(ValueError):
ei.get_credible_interval(3467838976, transitions=False)


def test_ei_credible_interval_transitions():
X = np.array(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
[9, 10],
[11, 12],
]
)

Y = np.array(
[
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11],
[12, 13],
]
)

expected_lower = np.array([[0.017175, 0.164388], [0.228659, 0.063326]])
expected_upper = np.array([[0.29715, 0.444364], [0.475136, 0.309803]])

ei = EITransitionSolver(random_seed=1024, n_samples=10, sampling_chains=1)
_ = ei.fit_predict(X, Y)
(current_lower, current_upper) = ei.get_credible_interval(99, transitions=True)
np.testing.assert_allclose(expected_lower, current_lower, rtol=RTOL, atol=ATOL)
np.testing.assert_allclose(expected_upper, current_upper, rtol=RTOL, atol=ATOL)

0 comments on commit 0bf5278

Please sign in to comment.