Skip to content

Commit

Permalink
Less permissive one
Browse files Browse the repository at this point in the history
  • Loading branch information
CalebBell committed Oct 26, 2024
1 parent b3d233f commit f3e8fb8
Showing 1 changed file with 2 additions and 25 deletions.
27 changes: 2 additions & 25 deletions tests/test_numerics_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def check_inv(matrix, rtol=None):
expected[combined_relative_mask] = 0.0



# Check both directions
numpy_zeros = (expected == 0.0)
our_zeros = (result == 0.0)
Expand All @@ -109,30 +108,8 @@ def check_inv(matrix, rtol=None):
# Where numpy has zeros but we don't; no cases require it but it makes sense to do
result[mask_exact_zeros] = np.where(np.abs(result[mask_exact_zeros]) < thresh, 0.0, result[mask_exact_zeros])

# Where we have zeros but numpy doesn't - this is the one we discovered
expected[mask_exact_zeros] = np.where(np.abs(expected[mask_exact_zeros]) < thresh, 0.0, expected[mask_exact_zeros])

# numpy_values_at_our_zeros = expected[our_zeros]
# expected[our_zeros] = np.where(
# np.abs(numpy_values_at_our_zeros) < thresh,
# 0.0,
# numpy_values_at_our_zeros
# )

# We also need to check against the values we get in the inverse; it is helpful
# to zero out anything too close to "zero" relative to the values used in the matrix
# This is very necessary, and was needed when testing on different CPU architectures
inv_norm = np.max(np.sum(np.abs(result), axis=1))
if cond < 1e10:
zero_thresh = 100*thresh
elif cond < 1e14:
zero_thresh = 1000*thresh
else:
zero_thresh = 10000*thresh
trivial_relative_to_norm = np.where(np.abs(result)/inv_norm < zero_thresh)
result[trivial_relative_to_norm] = 0.0
trivial_relative_to_norm = np.where(np.abs(expected)/inv_norm < zero_thresh)
expected[trivial_relative_to_norm] = 0.0
# Where we have zeros but numpy doesn't - this is the one we discovered. Apply the check only to `numpy_zeros`
expected[numpy_zeros] = np.where(np.abs(expected[numpy_zeros]) < thresh, 0.0, expected[numpy_zeros])

if rtol is None:
rtol = get_rtol(matrix)
Expand Down

0 comments on commit f3e8fb8

Please sign in to comment.