From ca99f5fe56bface09be758fd13bf589821b3eb1d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 25 Oct 2024 16:39:41 +0100 Subject: [PATCH] update --- docs/apis/level1.rst | 35 ++++++++++++++++ docs/apis/level2.rst | 7 ++++ docs/apis/level3.rst | 2 + docs/cli/run.rst | 2 + docs/conf.py | 4 +- src/anemoi/inference/runner.py | 73 ++++++++++++++++++++++++++++++++++ 6 files changed, 121 insertions(+), 2 deletions(-) diff --git a/docs/apis/level1.rst b/docs/apis/level1.rst index a82fc86..a420b27 100644 --- a/docs/apis/level1.rst +++ b/docs/apis/level1.rst @@ -14,3 +14,38 @@ illustrate this: .. literalinclude:: code/state.py :language: python + +The field names are the one that where provided when running the +:ref:`training `, which were the name given +to fields when creating the training :ref:`dataset +`. + +******** + States +******** + +A `state` is a Python :py:class:`dictionary` with the following keys: + +- ``date``: :py:class:`datetime.datetime` object that represent the + date at which the state is valid. +- ``latitudes``: a NumPy array with the list of latitudes that matches + the data values of fields +- ``longitudes``: a NumPy array with the corresponding list of + longitudes. It must have the same size as the latitudes array. +- ``fields``: a :py:class:`dictionary` that maps fields names with + their data. + +Each field is given as a NumPy array. If the model is +:py:attr:`multi-step +`, it will +needs to be initialised with fields from two or more dates, the values +must be two dimensions arrays, with the shape ``(number-of-dates, +number-of-grid-points)``, otherwise the values can be a one dimension +array. The first dimension is expected to represent each date in +ascending order, and the ``date`` entry of the state must be the last +one. + +As it iterates, the model will produce new states with the same format. +The ``date`` will represent the forecasted date, and the fields would +have the forecasted values as NumPy array. These arrays will be of one +dimensions (the number of grid points), even if the model is multi-step. diff --git a/docs/apis/level2.rst b/docs/apis/level2.rst index 40c3d04..d366665 100644 --- a/docs/apis/level2.rst +++ b/docs/apis/level2.rst @@ -3,3 +3,10 @@ ##################### Object oriented API ##################### + +- Runner +- Input +- Output +- Postprocessor +- Checkpoint +- Metadata diff --git a/docs/apis/level3.rst b/docs/apis/level3.rst index 435dc5f..34f1752 100644 --- a/docs/apis/level3.rst +++ b/docs/apis/level3.rst @@ -3,3 +3,5 @@ ################## Command line API ################## + +See the :ref:`run-cli` for more information. diff --git a/docs/cli/run.rst b/docs/cli/run.rst index 867df85..53ae64a 100644 --- a/docs/cli/run.rst +++ b/docs/cli/run.rst @@ -1,3 +1,5 @@ +.. _run_cli: + run ========== diff --git a/docs/conf.py b/docs/conf.py index 17c48b5..8d14e27 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,7 +30,7 @@ project = "Anemoi Inference" -author = "ECMWF" +author = "Anemoi contributors" year = datetime.datetime.now().year if year == 2024: @@ -38,7 +38,7 @@ else: years = "2024-%s" % (year,) -copyright = "%s, ECMWF" % (years,) +copyright = "%s, Anemoi contributors" % (years,) try: diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index bb1dc19..5f5cbe8 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -6,6 +6,7 @@ # nor does it submit to any jurisdiction. +import datetime import logging from functools import cached_property @@ -36,6 +37,7 @@ def __init__( device: str, precision: str = None, report_error=False, + allow_nans=None, # can be True of False verbose: bool = True, ): self.checkpoint = Checkpoint(checkpoint, verbose=verbose) @@ -43,6 +45,7 @@ def __init__( self.device = device self.precision = precision self.report_error = report_error + self.allow_nans = allow_nans # This could also be passed as an argument @@ -59,6 +62,8 @@ def __init__( def run(self, *, input_state, lead_time): + input_state = self.validate_input_state(input_state) + # timers = Timers() lead_time = to_timedelta(lead_time) @@ -268,3 +273,71 @@ def compute_forcings(self, *, latitudes, longitudes, dates, variables): assert len(ds) == len(variables) * len(dates), (len(ds), len(variables), dates) return ds + + def validate_input_state(self, input_state): + + if not isinstance(input_state, dict): + raise ValueError("Input state must be a dictionnary") + + EXPECT = dict(date=datetime.datetime, latitudes=np.ndarray, longitudes=np.ndarray, fields=dict) + + for key, klass in EXPECT.item(): + if key not in input_state: + raise ValueError(f"Input state must contain a `{key}` enytry") + + if not isinstance(input_state[key], klass): + raise ValueError( + f"Input state entry `{key}` is type {type(input_state[key])}, expected {klass} instead" + ) + + # Detach from the user's input so we can modify it + input_state = input_state.copy() + fields = input_state["fields"] = input_state["fields"].copy() + + for latlon in ("latitudes", "longitudes"): + if len(input_state[latlon]) != 1: + raise ValueError(f"Input state entry `{latlon}` must be 1D, shape is {input_state[latlon].shape}") + + nlat = len(input_state["latitudes"]) + nlon = len(input_state["longitudes"]) + if nlat != nlon: + raise ValueError(f"Size mismatch latitudes={nlat}, longitudes={nlon}") + + number_of_grid_points = nlat + + multi_step = len(self.checkpoint.lagged) + + expected_shape = (multi_step, number_of_grid_points) + + # Check field + with_nans = [] + + for name, field in list(fields.items()): + + # Allow for 1D fields if multi_step is 1 + if len(field.shape) == 1: + field = fields[name] = field.reshape(1, field.shape[0]) + + if field.shape != expected_shape: + raise ValueError(f"Field `name` has the wrong shape. Expected {expected_shape}, got {field.shape}") + + if not np.isinf(field).any(): + raise ValueError(f"Field `{name}` contains infinities") + + if np.isnan(field).any(): + with_nans.append(name) + + if with_nans: + msg = f"NaNs found in the following variables: {sorted(with_nans)}" + if self.allow_nans is None: + LOG.warning(msg) + self.allow_nans = True + + if not self.allow_nans: + raise ValueError(msg) + + # Needed for some output object, such as GribOutput, to compute `step` + + input_state["reference_date"] = input_state["date"] + + return input_state