Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Oct 25, 2024
1 parent ca99f5f commit cc0c653
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 12 deletions.
6 changes: 4 additions & 2 deletions src/anemoi/inference/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def add_arguments(self, command_parser):
"--icon-grid", help="NetCDF containing the ICON grid (e.g. icon_grid_0026_R03B07_G.nc)."
)

command_parser.add_argument("--allow-nans", help="Allow NaNs in the output.", action="store_true")

command_parser.add_argument("path", help="Path to the checkpoint.")

def run(self, args):
Expand All @@ -56,7 +58,7 @@ def run(self, args):

args.lead_time = as_timedelta(args.lead_time)

runner = CLIRunner(args.path, device=args.device, precision=args.precision)
runner = CLIRunner(args.path, device=args.device, precision=args.precision, allow_nans=args.allow_nans)

if args.icon_grid is not None:
if args.input is None:
Expand All @@ -70,7 +72,7 @@ def run(self, args):
input = MarsInput(runner.checkpoint, use_grib_paramid=args.use_grib_paramid)

if args.output is not None:
output = GribFileOutput(args.output, runner.checkpoint)
output = GribFileOutput(args.output, runner.checkpoint, allow_nans=args.allow_nans)
else:
output = PrinterOutput(runner.checkpoint)

Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/inference/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ def load_forcings(self, state, date):

fields = retrieve(requests=requests, grid=self.grid, area=self.area)

warnings.warn("TEMPORARY CODE: Fields need to be sorted by name")
warnings.warn("🚧 TEMPORARY CODE 🚧: Fields need to be sorted by name")

return fields.to_numpy(dtype=np.float32, flatten=True)
2 changes: 1 addition & 1 deletion src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def default_namer(self, *args, **kwargs):
assert len(kwargs) == 0, kwargs

def namer(field, metadata):
warnings.warn("TEMPORARY CODE: Use the remapping in the metadata")
warnings.warn("🚧 TEMPORARY CODE 🚧: Use the remapping in the metadata")
param, levelist, levtype = (
metadata.get("param"),
metadata.get("levelist"),
Expand Down
6 changes: 4 additions & 2 deletions src/anemoi/inference/outputs/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ class GribOutput(Output):
Handles grib
"""

def __init__(self, checkpoint, *, verbose=True):
def __init__(self, checkpoint, *, allow_nans=False, verbose=True):
super().__init__(checkpoint, verbose=verbose)
self._first = True
self.typed_variables = self.checkpoint.typed_variables
self.allow_nans = allow_nans

def write_initial_state(self, state):
state.setdefault("reference_date", state["date"])
warnings.warn("🚧 TEMPORARY CODE 🚧: write_initial_state not yet implemented")

def write_state(self, state):
state.setdefault("reference_date", state["date"])
Expand All @@ -48,7 +50,7 @@ def write_state(self, state):
for name, value in state["fields"].items():
variable = self.typed_variables[name]
if variable.is_accumulation:
warnings.warn("TEMPORARY CODE: accumaulations are not supported yet")
warnings.warn("🚧 TEMPORARY CODE 🚧: accumaulations are not supported yet")
continue

keys = {}
Expand Down
13 changes: 10 additions & 3 deletions src/anemoi/inference/outputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging

import earthkit.data as ekd
import numpy as np

from .grib import GribOutput

Expand All @@ -21,10 +22,16 @@ class GribFileOutput(GribOutput):
Handles grib files
"""

def __init__(self, path, checkpoint, *, verbose=True, **kwargs):
super().__init__(checkpoint, verbose=verbose)
def __init__(self, path, checkpoint, *, allow_nans=False, verbose=True, **kwargs):
super().__init__(checkpoint, allow_nans=allow_nans, verbose=verbose)
self.path = path
self.output = ekd.new_grib_output(self.path, split_output=True, **kwargs)

def write_message(self, message, *args, **kwargs):
self.output.write(message, *args, **kwargs)
try:
self.output.write(message, *args, check_nans=self.allow_nans, **kwargs)
except Exception as e:
LOG.error("Error writing message to %s: %s", self.path, e)
if np.isnan(message.data).any():
LOG.error("Message contains NaNs (%s)", kwargs)
raise e
6 changes: 3 additions & 3 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def validate_input_state(self, input_state):

EXPECT = dict(date=datetime.datetime, latitudes=np.ndarray, longitudes=np.ndarray, fields=dict)

for key, klass in EXPECT.item():
for key, klass in EXPECT.items():
if key not in input_state:
raise ValueError(f"Input state must contain a `{key}` enytry")

Expand All @@ -295,7 +295,7 @@ def validate_input_state(self, input_state):
fields = input_state["fields"] = input_state["fields"].copy()

for latlon in ("latitudes", "longitudes"):
if len(input_state[latlon]) != 1:
if len(input_state[latlon].shape) != 1:
raise ValueError(f"Input state entry `{latlon}` must be 1D, shape is {input_state[latlon].shape}")

nlat = len(input_state["latitudes"])
Expand All @@ -321,7 +321,7 @@ def validate_input_state(self, input_state):
if field.shape != expected_shape:
raise ValueError(f"Field `name` has the wrong shape. Expected {expected_shape}, got {field.shape}")

if not np.isinf(field).any():
if np.isinf(field).any():
raise ValueError(f"Field `{name}` contains infinities")

if np.isnan(field).any():
Expand Down

0 comments on commit cc0c653

Please sign in to comment.