Skip to content

Commit

Permalink
Ruff and reinstate moving_object tests on ephemerides
Browse files Browse the repository at this point in the history
  • Loading branch information
rhiannonlynne committed Oct 9, 2024
1 parent 11d5024 commit 771209a
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 47 deletions.
2 changes: 1 addition & 1 deletion rubin_sim/moving_objects/cheby_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def read_coefficients(self, cheby_fits_file):
if not os.path.isfile(cheby_fits_file):
raise IOError("Could not find cheby_fits_file at %s" % (cheby_fits_file))
# Read the coefficients file.
coeffs = pd.read_table(cheby_fits_file, sep="\s+")
coeffs = pd.read_table(cheby_fits_file, sep=r"\s+")
# The header line provides information on the number of
# coefficients for each parameter.
datacols = coeffs.columns.values
Expand Down
4 changes: 2 additions & 2 deletions rubin_sim/moving_objects/orbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def read_orbits(self, orbit_file, delim=None, skiprows=None):
"COMPCODE",
)
# First use names_com, and then change if required.
orbits = pd.read_csv(orbit_file, sep="\s+", header=None, names=names_com)
orbits = pd.read_csv(orbit_file, sep=r"\s+", header=None, names=names_com)

if orbits["FORMAT"][0] == "KEP":
orbits.columns = names_kep
Expand All @@ -397,7 +397,7 @@ def read_orbits(self, orbit_file, delim=None, skiprows=None):

else:
if delim is None:
orbits = pd.read_csv(orbit_file, sep="\s+", skiprows=skiprows, names=names)
orbits = pd.read_csv(orbit_file, sep=r"\s+", skiprows=skiprows, names=names)
else:
orbits = pd.read_csv(orbit_file, sep=delim, skiprows=skiprows, names=names)

Expand Down
36 changes: 24 additions & 12 deletions tests/moving_objects/test_chebyfits.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,26 @@ def test_precompute_multipliers(self):
self.assertTrue(key in self.cheb.multipliers)

def test_set_segment_length(self):
# Expect MBAs with standard ngran and tolerance to have length ~2.0 days.
# Expect MBAs with standard ngran and tolerance to have
# length ~2.0 days.
self.cheb.calc_segment_length()
self.assertAlmostEqual(self.cheb.length, 2.0)
# Test that we can set it to other values which fit into the 30 day window.
# Test that we can set it to other values which fit into
# the 30 day window.
self.cheb.calc_segment_length(length=1.5)
self.assertEqual(self.cheb.length, 1.5)
# Test that we if we try to set it to a value which does not fit into the 30 day window,
# Test that we if we try to set it to a value which does not
# fit into the 30 day window,
# that the actual value used is different - and smaller.
self.cheb.calc_segment_length(length=1.9)
self.assertTrue(self.cheb.length < 1.9)
# Test that we get a warning about the residuals if we try to set the length to be too long.
# Test that we get a warning about the residuals if we try
# to set the length to be too long.
with warnings.catch_warnings(record=True) as w:
self.cheb.calc_segment_length(length=5.0)
self.assertTrue(len(w), 1)
# Now check granularity works for other orbit types (which would have other standard lengths).
# Now check granularity works for other orbit types
# (which would have other standard lengths).
# Check for multiple orbit types.
for orbit_file in [
"test_orbitsMBA.s3m",
Expand All @@ -73,7 +78,8 @@ def test_set_segment_length(self):
pos_resid, ratio = cheb._test_residuals(cheb.length)
self.assertTrue(pos_resid < sky_tolerance)
self.assertEqual((cheb.length * 100) % 1, 0)
# print('final', orbit_file, sky_tolerance, pos_resid, cheb.length, ratio)
# print('final', orbit_file, sky_tolerance, pos_resid,
# cheb.length, ratio)
# And check for challenging 'impactors'.
for orbit_file in ["test_orbitsImpactors.s3m"]:
self.orbits.read_orbits(os.path.join(self.testdir, orbit_file))
Expand All @@ -85,7 +91,8 @@ def test_set_segment_length(self):
cheb.calc_segment_length()
pos_resid, ratio = cheb._test_residuals(cheb.length)
self.assertTrue(pos_resid < sky_tolerance)
# print('final', orbit_file, sky_tolerance, pos_resid, cheb.length, ratio)
# print('final', orbit_file, sky_tolerance, pos_resid,
# cheb.length, ratio)

@unittest.skip("Skipping because it has a strange platform-dependent failure")
def test_segments(self):
Expand All @@ -108,7 +115,8 @@ def test_segments(self):
for k in coeff_keys:
self.assertTrue(k in self.cheb.coeffs.keys())
# And in this case, we had a 30 day timespan with 1 day segments
# (one day segments should be more than enough to meet 2.5mas tolerance, so not subdivided)
# (one day segments should be more than enough to meet
# 2.5mas tolerance, so not subdivided)
self.assertEqual(len(self.cheb.coeffs["t_start"]), 30 * len(self.orbits))
# And we used 14 coefficients for ra and dec.
self.assertEqual(len(self.cheb.coeffs["ra"][0]), 14)
Expand Down Expand Up @@ -144,7 +152,8 @@ def test_run_through(self):
t_start = self.orbits.orbits.epoch.iloc[0]
interval = 30
cheb = ChebyFits(self.orbits, t_start, interval, ngran=64, sky_tolerance=2.5, n_decimal=10)
# Set granularity. Use an value that will be too long, to trigger recursion below.
# Set granularity. Use an value that will be too long,
# to trigger recursion below.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
cheb.calc_segment_length(length=10.0)
Expand All @@ -156,17 +165,20 @@ def test_run_through(self):
resid_name = os.path.join(self.scratch_dir, "resid2.txt")
failed_name = os.path.join(self.scratch_dir, "failed2.txt")
cheb.write(coeff_name, resid_name, failed_name)
# Test that the segments for each individual object fit together start/end.
# Test that the segments for each individual object fit
# together start/end.
for k in cheb.coeffs:
cheb.coeffs[k] = np.array(cheb.coeffs[k])
for obj_id in np.unique(cheb.coeffs["obj_id"]):
condition = cheb.coeffs["obj_id"] == obj_id
te_prev = t_start
for ts, te in zip(cheb.coeffs["t_start"][condition], cheb.coeffs["t_end"][condition]):
# Test that the start of the current interval = the end of the previous interval.
# Test that the start of the current interval =
# the end of the previous interval.
self.assertEqual(te_prev, ts)
te_prev = te
# Test that the end of the last interval is equal to the end of the total interval
# Test that the end of the last interval is equal to the end
# of the total interval
self.assertEqual(te, t_start + interval)


Expand Down
11 changes: 7 additions & 4 deletions tests/moving_objects/test_chebyshevutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from rubin_sim.moving_objects import chebeval, chebfit, make_cheb_matrix, make_cheb_matrix_only_x
from rubin_sim.moving_objects import chebeval, chebfit, make_cheb_matrix


class TestChebgrid(unittest.TestCase):
Expand All @@ -22,7 +22,8 @@ def test_eval(self):
yy_w_vel, vv = chebeval(np.linspace(-1, 1, 17), p)
yy_wout_vel, vv = chebeval(np.linspace(-1, 1, 17), p, do_velocity=False)
self.assertTrue(np.allclose(yy_wout_vel, yy_w_vel))
# Test that we get a nan for a value outside the range of the 'interval', if mask=True
# Test that we get a nan for a value outside the range of the
# 'interval', if mask=True
yy_w_vel, vv = chebeval(np.linspace(-2, 1, 17), p, mask=True)
self.assertTrue(
np.isnan(yy_w_vel[0]),
Expand All @@ -42,7 +43,8 @@ def test_ends_locked(self):
self.assertAlmostEqual(vv[-1], dy[-1], places=13)

def test_accuracy(self):
"""If n_poly is greater than number of values being fit, then fit should be exact."""
"""If n_poly is greater than number of values being fit,
then fit should be exact."""
x = np.linspace(0, np.pi, 9)
y = np.sin(x)
dy = np.cos(x)
Expand All @@ -53,7 +55,8 @@ def test_accuracy(self):
self.assertLess(np.sum(resid), 1e-13)

def test_accuracy_prefit_c1c2(self):
"""If n_poly is greater than number of values being fit, then fit should be exact."""
"""If n_poly is greater than number of values being fit,
then fit should be exact."""
NPOINTS = 8
NPOLY = 16
x = np.linspace(0, np.pi, NPOINTS + 1)
Expand Down
32 changes: 20 additions & 12 deletions tests/moving_objects/test_chebyvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ def test_set_coeff(self):
self.assertTrue(k in cheby_values.coeffs)
self.assertTrue(isinstance(cheby_values.coeffs[k], np.ndarray))
self.assertEqual(len(np.unique(cheby_values.coeffs["obj_id"])), len(self.orbits))
# This will only be true for carefully selected length/orbit type, where subdivision did not occur.
# This will only be true for carefully selected length/orbit type,
# where subdivision did not occur.
# For the test MBAs, a len=1day will work.
# For the test NEOs, a len=0.25 day will work (with 2.5mas skyTol).
# self.assertEqual(len(cheby_values.coeffs['tStart']),
# (self.interval / self.set_length) * len(self.orbits))
# (self.interval / self.set_length) * len(self.orbits))
self.assertEqual(len(cheby_values.coeffs["ra"][0]), self.n_coeffs)
self.assertTrue("meanRA" in cheby_values.coeffs)
self.assertTrue("meanDec" in cheby_values.coeffs)
Expand All @@ -90,9 +91,11 @@ def test_read_coeffs(self):
# Can't test strings with np.test.assert_almost_equal.
np.testing.assert_equal(cheby_values.coeffs[k], cheby_values2.coeffs[k])
else:
# All of these will only be accurate to 2 less decimal places than they are
# print out with in chebyFits. Since vmag, delta and elongation only use 7
# decimal places, this means we can test to 5 decimal places for those.
# All of these will only be accurate to 2 less decimal places
# than they are print out with in chebyFits. Since vmag,
# delta and elongation only use 7
# decimal places, this means we can test to 5 decimal
# places for those.
np.testing.assert_allclose(cheby_values.coeffs[k], cheby_values2.coeffs[k], rtol=0, atol=1e-5)

def test_get_ephemerides(self):
Expand Down Expand Up @@ -142,12 +145,13 @@ def test_get_ephemerides(self):
)
self.assertTrue(
np.isnan(ephemerides["ra"][0]),
msg="Expected Nan for out of range ephemeris, got %.2e" % (ephemerides["ra"][0]),
msg=f"Expected Nan for out of range ephemeris, got {ephemerides['ra'][0]}",
)


class TestJPLValues(unittest.TestCase):
# Test the interpolation-generated RA/Dec values against JPL generated RA/Dec values.
# Test the interpolation-generated RA/Dec values against JPL
# generated RA/Dec values.
# The resulting errors should be similar to the errors reported
# from testEphemerides when testing against JPL values.
def setUp(self):
Expand All @@ -156,14 +160,15 @@ def setUp(self):
self.jpl_dir = os.path.join(get_data_dir(), "tests", "jpl_testdata")
self.orbits.read_orbits(os.path.join(self.jpl_dir, "S0_n747.des"), skiprows=1)
# Read JPL ephems.
self.jpl = pd.read_table(os.path.join(self.jpl_dir, "807_n747.txt"), delim_whitespace=True)
self.jpl = pd.read_table(os.path.join(self.jpl_dir, "807_n747.txt"), sep=r"\s+")
self.jpl["obj_id"] = self.jpl["objId"]
# Add times in TAI and UTC, because.
t = Time(self.jpl["epoch_mjd"], format="mjd", scale="utc")
self.jpl["mjdTAI"] = t.tai.mjd
self.jpl["mjdUTC"] = t.utc.mjd
self.jpl = self.jpl.to_records(index=False)
# Generate interpolation coefficients for the time period in the JPL catalog.
# Generate interpolation coefficients for the time period
# in the JPL catalog.
self.scratch_dir = tempfile.mkdtemp(dir=ROOT, prefix="TestJPLValues-")
self.coeff_file = os.path.join(self.scratch_dir, "test_coeffs")
self.resid_file = os.path.join(self.scratch_dir, "test_resids")
Expand Down Expand Up @@ -202,7 +207,8 @@ def tearDown(self):
shutil.rmtree(self.scratch_dir)

def test_ra_dec(self):
# We won't compare Vmag, because this also needs information on trailing losses.
# We won't compare Vmag, because this also needs
# information on trailing losses.
times = np.unique(self.jpl["mjdTAI"])
delta_ra = np.zeros(len(times), float)
delta_dec = np.zeros(len(times), float)
Expand All @@ -219,8 +225,10 @@ def test_ra_dec(self):
delta_ra[i] = d_ra.max()
delta_dec[i] = d_dec.max()
# Should be (given OOrb direct prediction):
# Much of the time we're closer than 1mas, but there are a few which hit higher values.
# This is consistent with the errors/values reported by oorb directly in testEphemerides.
# Much of the time we're closer than 1mas, but there are a
# few which hit higher values.
# This is consistent with the errors/values reported by oorb
# directly in testEphemerides.

# # XXX--units?
print("max JPL errors", delta_ra.max(), delta_dec.max())
Expand Down
36 changes: 21 additions & 15 deletions tests/moving_objects/test_ephemerides.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def test_convert_to_oorb_array(self):

def test_convert_from_oorb_array(self):
# Check that we can convert orbital elements TO oorb format and back
# without losing info (except ObjId -- we will lose that unless we use updateOrbits.)
# without losing info
# (except ObjId -- we will lose that unless we use updateOrbits.)
self.ephems._convert_to_oorb_elem(self.orbits.orbits, self.orbits.orb_format)
new_orbits = Orbits()
new_orbits.set_orbits(self.orbits.orbits)
Expand Down Expand Up @@ -116,8 +117,8 @@ def test_ephemeris(self):
time_scale="UTC",
by_object=False,
)
# Temp removing this as it is giving an intermittent fail. Not sure why
# np.testing.assert_equal(ephs_all, ephs)
for key in ephs_all.dtype.names:
np.testing.assert_almost_equal(ephs_all[key], ephs[key])
# Reset ephems to use KEP Orbits, and calculate new ephemerides.
self.ephems.set_orbits(self.orbits_kep)
oorb_ephs = self.ephems._generate_oorb_ephs_basic(eph_times, obscode=807, eph_mode="N")
Expand All @@ -135,26 +136,29 @@ def test_ephemeris(self):
time_scale="UTC",
by_object=False,
)
# Also seems to be an intermitent fail
# np.testing.assert_equal(ephsAllKEP, ephsKEP)
# Check that ephemerides calculated from the different (COM/KEP) orbits are almost equal.
# for column in ephs.dtype.names:
# np.testing.assert_allclose(ephs[column], ephsKEP[column], rtol=0, atol=1e-7)
# Check that the wrapped method using KEP elements and the wrapped method using COM elements match.
# for column in ephsAll.dtype.names:
# np.testing.assert_allclose(ephsAllKEP[column], ephsAll[column], rtol=0, atol=1e-7)
for key in ephs_all_kep.dtype.names:
np.testing.assert_almost_equal(ephs_all_kep[key], ephs_kep[key])
# Check that ephemerides calculated from the different (COM/KEP)
# orbits are almost equal
for column in ephs.dtype.names:
np.testing.assert_allclose(ephs[column], ephs_kep[column], rtol=1e-5, atol=1e-4)
# Check that the wrapped method using KEP elements and the wrapped
# method using COM elements match.
for column in ephs_all.dtype.names:
np.testing.assert_allclose(ephs_all_kep[column], ephs_all[column], rtol=1e-5, atol=1e-4)


class TestJPLValues(unittest.TestCase):
"""Test the oorb generated RA/Dec values against JPL generated RA/Dec values."""
"""Test the oorb generated RA/Dec values against
JPL generated RA/Dec values."""

def setUp(self):
# Read orbits.
self.orbits = Orbits()
self.jpl_dir = os.path.join(get_data_dir(), "tests", "jpl_testdata")
self.orbits.read_orbits(os.path.join(self.jpl_dir, "S0_n747.des"), skiprows=1)
# Read JPL ephems.
self.jpl = pd.read_csv(os.path.join(self.jpl_dir, "807_n747.txt"), sep="\s+")
self.jpl = pd.read_csv(os.path.join(self.jpl_dir, "807_n747.txt"), sep=r"\s+")
# Temp key fix
self.jpl["obj_id"] = self.jpl["objId"]
# Add times in TAI and UTC, because.
Expand All @@ -167,7 +171,8 @@ def tear_down(self):
del self.jpl

def test_ra_dec(self):
# We won't compare Vmag, because this also needs information on trailing losses.
# We won't compare Vmag, because this also needs information
# on trailing losses.
times = self.jpl["mjdUTC"].unique()
delta_ra = np.zeros(len(times), float)
delta_dec = np.zeros(len(times), float)
Expand All @@ -193,7 +198,8 @@ def test_ra_dec(self):
# Convert to mas
delta_ra *= 3600.0 * 1000.0
delta_dec *= 3600.0 * 1000.0
# Much of the time we're closer than 1mas, but there are a few which hit higher values.
# Much of the time we're closer than 1mas,
# but there are a few which hit higher values.
print("max JPL errors", np.max(delta_ra), np.max(delta_dec))
print("std JPL errors", np.std(delta_ra), np.std(delta_dec))
self.assertLess(np.max(delta_ra), 25)
Expand Down
2 changes: 1 addition & 1 deletion tests/moving_objects/test_orbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_set_orbits(self):
self.assertEqual(len(new_orbits), 1)
self.assertEqual(new_orbits.orb_format, "COM")
assert_frame_equal(new_orbits.orbits, suborbits)
# Test that we can set the orbits using a numpy array with many objects.
# Test that we can set the orbits using a numpy array of many objects.
numpyorbits = orbits.orbits.to_records(index=False)
new_orbits = Orbits()
new_orbits.set_orbits(numpyorbits)
Expand Down

0 comments on commit 771209a

Please sign in to comment.