Skip to content

Commit

Permalink
Merge pull request #286 from HERA-Team/switch-to-matvis
Browse files Browse the repository at this point in the history
refactor: use matvis insteaf of vis_cpu
  • Loading branch information
steven-murray authored Dec 11, 2023
2 parents cbe4efe + 9c66df9 commit 45fc8cc
Show file tree
Hide file tree
Showing 12 changed files with 567 additions and 567 deletions.
16 changes: 9 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ repos:
- flake8-bugbear
- flake8-comprehensions
- flake8-print
- repo: https://github.com/psf/black
rev: 23.11.0
hooks:

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.11.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: rst-backticks

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: rst-backticks

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
Expand Down
2 changes: 1 addition & 1 deletion config_examples/simulator.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
simulator: VisCPU
simulator: MatVis
precision: 2
ref_time: mean
correct_source_positions: true
4 changes: 2 additions & 2 deletions hera_sim/tests/test_beams.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
stokes_matrix,
)
from hera_sim.defaults import defaults
from hera_sim.visibilities import ModelData, VisCPU, VisibilitySimulation
from hera_sim.visibilities import MatVis, ModelData, VisibilitySimulation

np.seterr(invalid="ignore")

Expand Down Expand Up @@ -118,7 +118,7 @@ def run_sim(
# calculate source fluxes for hera_sim
flux = (freqs[:, np.newaxis] / freqs[0]) ** spectral_index * flux

simulator = VisCPU(
simulator = MatVis(
use_gpu=use_gpu,
mpi_comm=DummyMPIComm() if use_mpi else None,
precision=2,
Expand Down
20 changes: 10 additions & 10 deletions hera_sim/tests/test_compare_pyuvsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from hera_sim import io
from hera_sim.beams import PolyBeam
from hera_sim.visibilities import ModelData, VisCPU, VisibilitySimulation
from hera_sim.visibilities import MatVis, ModelData, VisibilitySimulation

nfreq = 3
ntime = 20
Expand Down Expand Up @@ -150,25 +150,25 @@ def get_beams(beam_type, polarized):
(100, "PolyBeam", True),
],
)
def test_compare_viscpu_with_pyuvsim(uvdata_allpols, nsource, beam_type, polarized):
"""Compare vis_cpu and pyuvsim simulated visibilities."""
def test_compare_matvis_with_pyuvsim(uvdata_allpols, nsource, beam_type, polarized):
"""Compare matvis and pyuvsim simulated visibilities."""
sky_model = get_sky_model(uvdata_allpols, nsource)

# Beam models
beams = get_beams(beam_type=beam_type, polarized=polarized)
beam_dict = {str(i): 0 for i in range(nants)}

# ---------------------------------------------------------------------------
# (1) Run vis_cpu
# (1) Run matvis
# ---------------------------------------------------------------------------
# Trim unwanted polarizations
uvdata_viscpu = copy.deepcopy(uvdata_allpols)
uvdata_matvis = copy.deepcopy(uvdata_allpols)

if not polarized:
uvdata_viscpu.select(polarizations=["ee"], inplace=True)
uvdata_matvis.select(polarizations=["ee"], inplace=True)

# Construct simulator object and run
simulator = VisCPU(
simulator = MatVis(
ref_time=Time("2018-08-31T04:02:30.11", format="isot", scale="utc"),
use_gpu=False,
)
Expand All @@ -182,13 +182,13 @@ def test_compare_viscpu_with_pyuvsim(uvdata_allpols, nsource, beam_type, polariz

sim = VisibilitySimulation(
data_model=ModelData(
uvdata=uvdata_viscpu, sky_model=sky_model, beams=vis_cpu_beams
uvdata=uvdata_matvis, sky_model=sky_model, beams=vis_cpu_beams
),
simulator=simulator,
)

sim.simulate()
uvd_viscpu = sim.uvdata
uvd_matvis = sim.uvdata

# ---------------------------------------------------------------------------
# (2) Run pyuvsim
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_compare_viscpu_with_pyuvsim(uvdata_allpols, nsource, beam_type, polariz
print("Baseline: ", i, j)
np.testing.assert_allclose(
uvd_uvsim.get_data((i, j, "xx")),
uvd_viscpu.get_data((i, j, "xx")),
uvd_matvis.get_data((i, j, "xx")),
atol=atol,
rtol=rtol,
)
40 changes: 20 additions & 20 deletions hera_sim/tests/test_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from astropy import units
from astropy.coordinates.angles import Latitude, Longitude
from astropy.units import rad, sday
from matvis import HAVE_GPU
from pathlib import Path
from pyradiosky import SkyModel
from pyuvsim.analyticbeam import AnalyticBeam
Expand All @@ -16,20 +17,19 @@
from hera_sim.beams import PolyBeam
from hera_sim.defaults import defaults
from hera_sim.visibilities import (
MatVis,
ModelData,
UVSim,
VisCPU,
VisibilitySimulation,
load_simulator_from_yaml,
)
from vis_cpu import HAVE_GPU

SIMULATORS = (VisCPU, UVSim)
SIMULATORS = (MatVis, UVSim)

if HAVE_GPU:

class VisGPU(VisCPU):
"""Simple mock class to make testing VisCPU with use_gpu=True easier"""
class VisGPU(MatVis):
"""Simple mock class to make testing MatVis with use_gpu=True easier"""

def __init__(self, *args, **kwargs):
super().__init__(*args, use_gpu=True, ref_time="min", **kwargs)
Expand Down Expand Up @@ -111,7 +111,7 @@ def sky_modelJD(uvdataJD):
def test_JD(uvdata, uvdataJD, sky_model):
model_data = ModelData(sky_model=sky_model, uvdata=uvdata)

vis = VisCPU()
vis = MatVis()

sim1 = VisibilitySimulation(data_model=model_data, simulator=vis).simulate()

Expand All @@ -125,7 +125,7 @@ def test_JD(uvdata, uvdataJD, sky_model):

def test_vis_cpu_estimate_memory(uvdata, uvdataJD, sky_model):
model_data = ModelData(sky_model=sky_model, uvdata=uvdata)
vis = VisCPU()
vis = MatVis()
mem = vis.estimate_memory(model_data)
assert mem > 0

Expand Down Expand Up @@ -262,7 +262,7 @@ def test_shapes(uvdata, simulator):
@pytest.mark.parametrize("precision, cdtype", [(1, np.complex64), (2, complex)])
def test_dtypes(uvdata, precision, cdtype):
sky = create_uniform_sky(np.unique(uvdata.freq_array))
vis = VisCPU(precision=precision)
vis = MatVis(precision=precision)

# If data_array is empty, then we never create new vis, and the returned value
# is literally the data array, so we should expect to get complex128 regardless.
Expand Down Expand Up @@ -356,19 +356,19 @@ def test_single_source_autocorr_past_horizon(uvdata, simulator):
assert np.abs(np.mean(v)) == 0


def test_viscpu_coordinate_correction(uvdata2):
def test_matvis_coordinate_correction(uvdata2):
sim = VisibilitySimulation(
data_model=ModelData(
uvdata=uvdata2,
sky_model=zenith_sky_model(uvdata2),
),
simulator=VisCPU(
simulator=MatVis(
correct_source_positions=True, ref_time="2018-08-31T04:02:30.11"
),
)

# Apply correction
# viscpu.correct_point_source_pos(obstime="2018-08-31T04:02:30.11", frame="icrs")
# matvis.correct_point_source_pos(obstime="2018-08-31T04:02:30.11", frame="icrs")
v = sim.simulate().copy()
assert np.all(~np.isnan(v))

Expand All @@ -377,7 +377,7 @@ def test_viscpu_coordinate_correction(uvdata2):
uvdata=uvdata2,
sky_model=zenith_sky_model(uvdata2),
),
simulator=VisCPU(
simulator=MatVis(
correct_source_positions=True,
ref_time=apt.Time("2018-08-31T04:02:30.11", format="isot", scale="utc"),
),
Expand Down Expand Up @@ -526,7 +526,7 @@ def test_vis_cpu_pol(polarization_array, xfail):
)

beam = PolyBeam(polarized=False)
simulator = VisCPU()
simulator = MatVis()

if xfail:
with pytest.raises(KeyError):
Expand Down Expand Up @@ -567,7 +567,7 @@ def test_vis_cpu_stokespol(uvdata_linear, sky_model):
with pytest.raises(ValueError):
VisibilitySimulation(
data_model=ModelData(uvdata=uvdata_linear, sky_model=sky_model),
simulator=VisCPU(),
simulator=MatVis(),
)


Expand All @@ -585,10 +585,10 @@ def test_str_uvdata(uvdata, sky_model, tmp_path):
assert model_data.uvdata.Nants_data == uvdata.Nants_data


def test_ref_time_viscpu(uvdata2):
vc_mean = VisCPU(ref_time="mean")
vc_min = VisCPU(ref_time="min")
vc_max = VisCPU(ref_time="max")
def test_ref_time_matvis(uvdata2):
vc_mean = MatVis(ref_time="mean")
vc_min = MatVis(ref_time="min")
vc_max = MatVis(ref_time="max")

sky_model = half_sky_model(uvdata2)

Expand All @@ -615,10 +615,10 @@ def test_load_from_yaml(tmpdir):
example_dir = Path(__file__).parent.parent.parent / "config_examples"

simulator = load_simulator_from_yaml(example_dir / "simulator.yaml")
assert isinstance(simulator, VisCPU)
assert isinstance(simulator, MatVis)
assert simulator.ref_time == "mean"

sim2 = VisCPU.from_yaml(example_dir / "simulator.yaml")
sim2 = MatVis.from_yaml(example_dir / "simulator.yaml")

assert sim2.ref_time == simulator.ref_time
assert sim2.diffuse_ability == simulator.diffuse_ability
Expand Down
6 changes: 3 additions & 3 deletions hera_sim/tests/test_vis_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_vis_cli(tmp_path_factory):
parser,
[
str(cfg),
str(DATA_PATH / "viscpu.yaml"),
str(DATA_PATH / "matvis_cpu.yaml"),
"--compress",
str(outdir / "compression-cache.npy"),
"--normalize_beams",
Expand All @@ -85,12 +85,12 @@ def test_vis_cli_dry(tmp_path_factory):
parser,
[
str(cfg),
str(DATA_PATH / "viscpu.yaml"),
str(DATA_PATH / "matvis_cpu.yaml"),
"--compress",
str(outdir / "compression-cache.npy"),
"--dry",
"--object_name",
"viscpu",
"matvis",
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
simulator: VisCPU
simulator: MatVis
precision: 2
ref_time: mean
correct_source_positions: true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
simulator: VisCPU
simulator: MatVis
precision: 2
ref_time: mean
correct_source_positions: true
Expand Down
2 changes: 1 addition & 1 deletion hera_sim/visibilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


try:
from .vis_cpu import VisCPU
from .matvis import MatVis
except (ImportError, NameError): # pragma: no cover
pass

Expand Down
Loading

0 comments on commit 45fc8cc

Please sign in to comment.