Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use matvis insteaf of vis_cpu #286

Merged
merged 7 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ repos:
- flake8-bugbear
- flake8-comprehensions
- flake8-print
- repo: https://github.com/psf/black
rev: 23.11.0
hooks:

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.11.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: rst-backticks

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: rst-backticks

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
Expand Down
4 changes: 2 additions & 2 deletions hera_sim/tests/test_beams.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
stokes_matrix,
)
from hera_sim.defaults import defaults
from hera_sim.visibilities import ModelData, VisCPU, VisibilitySimulation
from hera_sim.visibilities import MatVis, ModelData, VisibilitySimulation

np.seterr(invalid="ignore")

Expand Down Expand Up @@ -118,7 +118,7 @@ def run_sim(
# calculate source fluxes for hera_sim
flux = (freqs[:, np.newaxis] / freqs[0]) ** spectral_index * flux

simulator = VisCPU(
simulator = MatVis(
use_gpu=use_gpu,
mpi_comm=DummyMPIComm() if use_mpi else None,
precision=2,
Expand Down
16 changes: 8 additions & 8 deletions hera_sim/tests/test_compare_pyuvsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from hera_sim import io
from hera_sim.beams import PolyBeam
from hera_sim.visibilities import ModelData, VisCPU, VisibilitySimulation
from hera_sim.visibilities import MatVis, ModelData, VisibilitySimulation

nfreq = 3
ntime = 20
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_beams(beam_type, polarized):
(100, "PolyBeam", True),
],
)
def test_compare_viscpu_with_pyuvsim(uvdata_allpols, nsource, beam_type, polarized):
def test_compare_matvis_with_pyuvsim(uvdata_allpols, nsource, beam_type, polarized):
"""Compare vis_cpu and pyuvsim simulated visibilities."""
piyanatk marked this conversation as resolved.
Show resolved Hide resolved
sky_model = get_sky_model(uvdata_allpols, nsource)

Expand All @@ -162,13 +162,13 @@ def test_compare_viscpu_with_pyuvsim(uvdata_allpols, nsource, beam_type, polariz
# (1) Run vis_cpu
steven-murray marked this conversation as resolved.
Show resolved Hide resolved
# ---------------------------------------------------------------------------
# Trim unwanted polarizations
uvdata_viscpu = copy.deepcopy(uvdata_allpols)
uvdata_matvis = copy.deepcopy(uvdata_allpols)

if not polarized:
uvdata_viscpu.select(polarizations=["ee"], inplace=True)
uvdata_matvis.select(polarizations=["ee"], inplace=True)

# Construct simulator object and run
simulator = VisCPU(
simulator = MatVis(
ref_time=Time("2018-08-31T04:02:30.11", format="isot", scale="utc"),
use_gpu=False,
)
Expand All @@ -182,13 +182,13 @@ def test_compare_viscpu_with_pyuvsim(uvdata_allpols, nsource, beam_type, polariz

sim = VisibilitySimulation(
data_model=ModelData(
uvdata=uvdata_viscpu, sky_model=sky_model, beams=vis_cpu_beams
uvdata=uvdata_matvis, sky_model=sky_model, beams=vis_cpu_beams
),
simulator=simulator,
)

sim.simulate()
uvd_viscpu = sim.uvdata
uvd_matvis = sim.uvdata

# ---------------------------------------------------------------------------
# (2) Run pyuvsim
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_compare_viscpu_with_pyuvsim(uvdata_allpols, nsource, beam_type, polariz
print("Baseline: ", i, j)
np.testing.assert_allclose(
uvd_uvsim.get_data((i, j, "xx")),
uvd_viscpu.get_data((i, j, "xx")),
uvd_matvis.get_data((i, j, "xx")),
atol=atol,
rtol=rtol,
)
38 changes: 19 additions & 19 deletions hera_sim/tests/test_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
from hera_sim.beams import PolyBeam
from hera_sim.defaults import defaults
from hera_sim.visibilities import (
MatVis,
ModelData,
UVSim,
VisCPU,
VisibilitySimulation,
load_simulator_from_yaml,
)
from vis_cpu import HAVE_GPU

SIMULATORS = (VisCPU, UVSim)
SIMULATORS = (MatVis, UVSim)

if HAVE_GPU:

class VisGPU(VisCPU):
"""Simple mock class to make testing VisCPU with use_gpu=True easier"""
class VisGPU(MatVis):
"""Simple mock class to make testing MatVis with use_gpu=True easier"""

def __init__(self, *args, **kwargs):
super().__init__(*args, use_gpu=True, ref_time="min", **kwargs)
Expand Down Expand Up @@ -111,7 +111,7 @@ def sky_modelJD(uvdataJD):
def test_JD(uvdata, uvdataJD, sky_model):
model_data = ModelData(sky_model=sky_model, uvdata=uvdata)

vis = VisCPU()
vis = MatVis()

sim1 = VisibilitySimulation(data_model=model_data, simulator=vis).simulate()

Expand All @@ -125,7 +125,7 @@ def test_JD(uvdata, uvdataJD, sky_model):

def test_vis_cpu_estimate_memory(uvdata, uvdataJD, sky_model):
model_data = ModelData(sky_model=sky_model, uvdata=uvdata)
vis = VisCPU()
vis = MatVis()
mem = vis.estimate_memory(model_data)
assert mem > 0

Expand Down Expand Up @@ -262,7 +262,7 @@ def test_shapes(uvdata, simulator):
@pytest.mark.parametrize("precision, cdtype", [(1, np.complex64), (2, complex)])
def test_dtypes(uvdata, precision, cdtype):
sky = create_uniform_sky(np.unique(uvdata.freq_array))
vis = VisCPU(precision=precision)
vis = MatVis(precision=precision)

# If data_array is empty, then we never create new vis, and the returned value
# is literally the data array, so we should expect to get complex128 regardless.
Expand Down Expand Up @@ -356,19 +356,19 @@ def test_single_source_autocorr_past_horizon(uvdata, simulator):
assert np.abs(np.mean(v)) == 0


def test_viscpu_coordinate_correction(uvdata2):
def test_matvis_coordinate_correction(uvdata2):
sim = VisibilitySimulation(
data_model=ModelData(
uvdata=uvdata2,
sky_model=zenith_sky_model(uvdata2),
),
simulator=VisCPU(
simulator=MatVis(
correct_source_positions=True, ref_time="2018-08-31T04:02:30.11"
),
)

# Apply correction
# viscpu.correct_point_source_pos(obstime="2018-08-31T04:02:30.11", frame="icrs")
# matvis.correct_point_source_pos(obstime="2018-08-31T04:02:30.11", frame="icrs")
v = sim.simulate().copy()
assert np.all(~np.isnan(v))

Expand All @@ -377,7 +377,7 @@ def test_viscpu_coordinate_correction(uvdata2):
uvdata=uvdata2,
sky_model=zenith_sky_model(uvdata2),
),
simulator=VisCPU(
simulator=MatVis(
correct_source_positions=True,
ref_time=apt.Time("2018-08-31T04:02:30.11", format="isot", scale="utc"),
),
Expand Down Expand Up @@ -526,7 +526,7 @@ def test_vis_cpu_pol(polarization_array, xfail):
)

beam = PolyBeam(polarized=False)
simulator = VisCPU()
simulator = MatVis()

if xfail:
with pytest.raises(KeyError):
Expand Down Expand Up @@ -567,7 +567,7 @@ def test_vis_cpu_stokespol(uvdata_linear, sky_model):
with pytest.raises(ValueError):
VisibilitySimulation(
data_model=ModelData(uvdata=uvdata_linear, sky_model=sky_model),
simulator=VisCPU(),
simulator=MatVis(),
)


Expand All @@ -585,10 +585,10 @@ def test_str_uvdata(uvdata, sky_model, tmp_path):
assert model_data.uvdata.Nants_data == uvdata.Nants_data


def test_ref_time_viscpu(uvdata2):
vc_mean = VisCPU(ref_time="mean")
vc_min = VisCPU(ref_time="min")
vc_max = VisCPU(ref_time="max")
def test_ref_time_matvis(uvdata2):
vc_mean = MatVis(ref_time="mean")
vc_min = MatVis(ref_time="min")
vc_max = MatVis(ref_time="max")

sky_model = half_sky_model(uvdata2)

Expand All @@ -615,10 +615,10 @@ def test_load_from_yaml(tmpdir):
example_dir = Path(__file__).parent.parent.parent / "config_examples"

simulator = load_simulator_from_yaml(example_dir / "simulator.yaml")
assert isinstance(simulator, VisCPU)
assert isinstance(simulator, MatVis)
assert simulator.ref_time == "mean"

sim2 = VisCPU.from_yaml(example_dir / "simulator.yaml")
sim2 = MatVis.from_yaml(example_dir / "simulator.yaml")

assert sim2.ref_time == simulator.ref_time
assert sim2.diffuse_ability == simulator.diffuse_ability
Expand Down
6 changes: 3 additions & 3 deletions hera_sim/tests/test_vis_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_vis_cli(tmp_path_factory):
parser,
[
str(cfg),
str(DATA_PATH / "viscpu.yaml"),
str(DATA_PATH / "matvis_cpu.yaml"),
"--compress",
str(outdir / "compression-cache.npy"),
"--normalize_beams",
Expand All @@ -85,12 +85,12 @@ def test_vis_cli_dry(tmp_path_factory):
parser,
[
str(cfg),
str(DATA_PATH / "viscpu.yaml"),
str(DATA_PATH / "matvis_cpu.yaml"),
"--compress",
str(outdir / "compression-cache.npy"),
"--dry",
"--object_name",
"viscpu",
"matvis",
],
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
simulator: VisCPU
simulator: MatVis
precision: 2
ref_time: mean
correct_source_positions: true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
simulator: VisCPU
simulator: MatVis
precision: 2
ref_time: mean
correct_source_positions: true
Expand Down
2 changes: 1 addition & 1 deletion hera_sim/visibilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


try:
from .vis_cpu import VisCPU
from .matvis import MatVis
except (ImportError, NameError): # pragma: no cover
pass

Expand Down
Loading
Loading