Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Oct 25, 2024
1 parent 54a9b12 commit ca99f5f
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 2 deletions.
35 changes: 35 additions & 0 deletions docs/apis/level1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,38 @@ illustrate this:

.. literalinclude:: code/state.py
:language: python

The field names are the one that where provided when running the
:ref:`training <anemoi-training:index-page>`, which were the name given
to fields when creating the training :ref:`dataset
<anemoi-datasets:index-page>`.

********
States
********

A `state` is a Python :py:class:`dictionary` with the following keys:

- ``date``: :py:class:`datetime.datetime` object that represent the
date at which the state is valid.
- ``latitudes``: a NumPy array with the list of latitudes that matches
the data values of fields
- ``longitudes``: a NumPy array with the corresponding list of
longitudes. It must have the same size as the latitudes array.
- ``fields``: a :py:class:`dictionary` that maps fields names with
their data.

Each field is given as a NumPy array. If the model is
:py:attr:`multi-step
<anemoi.models.interface.AnemoiModelInterface.multi_step>`, it will
needs to be initialised with fields from two or more dates, the values
must be two dimensions arrays, with the shape ``(number-of-dates,
number-of-grid-points)``, otherwise the values can be a one dimension
array. The first dimension is expected to represent each date in
ascending order, and the ``date`` entry of the state must be the last
one.

As it iterates, the model will produce new states with the same format.
The ``date`` will represent the forecasted date, and the fields would
have the forecasted values as NumPy array. These arrays will be of one
dimensions (the number of grid points), even if the model is multi-step.
7 changes: 7 additions & 0 deletions docs/apis/level2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,10 @@
#####################
Object oriented API
#####################

- Runner
- Input
- Output
- Postprocessor
- Checkpoint
- Metadata
2 changes: 2 additions & 0 deletions docs/apis/level3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
##################
Command line API
##################

See the :ref:`run-cli` for more information.
2 changes: 2 additions & 0 deletions docs/cli/run.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _run_cli:

run
==========

Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@

project = "Anemoi Inference"

author = "ECMWF"
author = "Anemoi contributors"

year = datetime.datetime.now().year
if year == 2024:
years = "2024"
else:
years = "2024-%s" % (year,)

copyright = "%s, ECMWF" % (years,)
copyright = "%s, Anemoi contributors" % (years,)


try:
Expand Down
73 changes: 73 additions & 0 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# nor does it submit to any jurisdiction.


import datetime
import logging
from functools import cached_property

Expand Down Expand Up @@ -36,13 +37,15 @@ def __init__(
device: str,
precision: str = None,
report_error=False,
allow_nans=None, # can be True of False
verbose: bool = True,
):
self.checkpoint = Checkpoint(checkpoint, verbose=verbose)
self._verbose = verbose
self.device = device
self.precision = precision
self.report_error = report_error
self.allow_nans = allow_nans

# This could also be passed as an argument

Expand All @@ -59,6 +62,8 @@ def __init__(

def run(self, *, input_state, lead_time):

input_state = self.validate_input_state(input_state)

# timers = Timers()

lead_time = to_timedelta(lead_time)
Expand Down Expand Up @@ -268,3 +273,71 @@ def compute_forcings(self, *, latitudes, longitudes, dates, variables):
assert len(ds) == len(variables) * len(dates), (len(ds), len(variables), dates)

return ds

def validate_input_state(self, input_state):

if not isinstance(input_state, dict):
raise ValueError("Input state must be a dictionnary")

EXPECT = dict(date=datetime.datetime, latitudes=np.ndarray, longitudes=np.ndarray, fields=dict)

for key, klass in EXPECT.item():
if key not in input_state:
raise ValueError(f"Input state must contain a `{key}` enytry")

if not isinstance(input_state[key], klass):
raise ValueError(
f"Input state entry `{key}` is type {type(input_state[key])}, expected {klass} instead"
)

# Detach from the user's input so we can modify it
input_state = input_state.copy()
fields = input_state["fields"] = input_state["fields"].copy()

for latlon in ("latitudes", "longitudes"):
if len(input_state[latlon]) != 1:
raise ValueError(f"Input state entry `{latlon}` must be 1D, shape is {input_state[latlon].shape}")

nlat = len(input_state["latitudes"])
nlon = len(input_state["longitudes"])
if nlat != nlon:
raise ValueError(f"Size mismatch latitudes={nlat}, longitudes={nlon}")

number_of_grid_points = nlat

multi_step = len(self.checkpoint.lagged)

expected_shape = (multi_step, number_of_grid_points)

# Check field
with_nans = []

for name, field in list(fields.items()):

# Allow for 1D fields if multi_step is 1
if len(field.shape) == 1:
field = fields[name] = field.reshape(1, field.shape[0])

if field.shape != expected_shape:
raise ValueError(f"Field `name` has the wrong shape. Expected {expected_shape}, got {field.shape}")

if not np.isinf(field).any():
raise ValueError(f"Field `{name}` contains infinities")

if np.isnan(field).any():
with_nans.append(name)

if with_nans:
msg = f"NaNs found in the following variables: {sorted(with_nans)}"
if self.allow_nans is None:
LOG.warning(msg)
self.allow_nans = True

if not self.allow_nans:
raise ValueError(msg)

# Needed for some output object, such as GribOutput, to compute `step`

input_state["reference_date"] = input_state["date"]

return input_state

0 comments on commit ca99f5f

Please sign in to comment.