Skip to content

Commit

Permalink
create input class
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Oct 23, 2024
1 parent 27a9fe9 commit b20206d
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 117 deletions.
8 changes: 8 additions & 0 deletions src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ def computed_time_dependent_forcings(self):
def accumulations(self):
return self._metadata.accumulations

def default_namer(self, *args, **kwargs):
"""
Return a callable that can be used to name fields.
In that case, return the namer that was used to create the
training dataset.
"""
return self._metadata.default_namer(*args, **kwargs)

###########################################################################

@cached_property
Expand Down
7 changes: 6 additions & 1 deletion src/anemoi/inference/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
from earthkit.data.utils.dates import to_datetime

from ..inputs.grib import GribInput
from ..precisions import PRECISIONS
from ..runners.cli import CLIRunner
from . import Command
Expand Down Expand Up @@ -76,8 +77,12 @@ def run(self, args):
args.date = to_datetime(args.date)

runner = CLIRunner(args.path, device=args.device, precision=args.precision)

input_fields = runner.retrieve_input_fields(args.date, args.use_grib_paramid)
input_state = runner.create_input_state(input_fields)

input = GribInput(runner.checkpoint)

input_state = input.create_input_state(input_fields)

_dump(input_state)

Expand Down
16 changes: 16 additions & 0 deletions src/anemoi/inference/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


class Input:
"""_summary_"""

def __init__(self, checkpoint, verbose=True):
self.checkpoint = checkpoint
self._verbose = verbose
125 changes: 125 additions & 0 deletions src/anemoi/inference/inputs/ekd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import logging
from collections import defaultdict

import numpy as np
from anemoi.utils.humanize import plural
from earthkit.data.indexing.fieldlist import FieldArray

from . import Input

LOG = logging.getLogger(__name__)


class EkdInput(Input):
"""
Handles earthkit-data FieldList as input
"""

def __init__(self, checkpoint, *, namer=None, verbose=True):
super().__init__(checkpoint, verbose)
self._namer = namer if namer is not None else checkpoint.default_namer()
assert callable(self._namer), type(self._namer)

def create_input_state(self, input_fields, date=None, dtype=np.float32, flatten=True):

input_state = dict()

if date is None:
date = input_fields.order_by(valid_datetime="ascending")[-1].datetime()["valid_time"]
LOG.info("start_datetime not provided, using %s as start_datetime", date.isoformat())

dates = [date + h for h in self.checkpoint.lagged]
date_to_index = {d.isoformat(): i for i, d in enumerate(dates)}

input_state["date"] = date
fields = input_state["fields"] = dict()

input_fields = self.filter_and_sort(input_fields, dates)

check = defaultdict(set)

first = True
for field in input_fields:

if first:
first = False
input_state["latitudes"], input_state["longitudes"] = field.grid_points()

name, valid_datetime = field.metadata("name"), field.metadata("valid_datetime")
if name not in fields:
fields[name] = np.full(
shape=(len(dates), self.checkpoint.number_of_grid_points),
fill_value=np.nan,
dtype=dtype,
)

date_idx = date_to_index[valid_datetime]

try:
fields[name][date_idx] = field.to_numpy(dtype=dtype, flatten=flatten)
except ValueError:
LOG.error("Error with field %s: expected shape=%s, got shape=%s", name, fields[name].shape, field.shape)
LOG.error("dates %s", dates)
LOG.error("number_of_grid_points %s", self.checkpoint.number_of_grid_points)
raise

if date_idx in check[name]:
LOG.error("Duplicate dates for %s: %s", name, date_idx)
LOG.error("Expected %s", list(date_to_index.keys()))
LOG.error("Got %s", list(check[name]))
raise ValueError(f"Duplicate dates for {name}")

check[name].add(date_idx)

for name, idx in check.items():
if len(idx) != len(dates):
LOG.error("Missing dates for %s: %s", name, idx)
LOG.error("Expected %s", list(date_to_index.keys()))
LOG.error("Got %s", list(idx))
raise ValueError(f"Missing dates for {name}")

# self.add_initial_forcings_to_input_state(input_state)

return input_state

def filter_and_sort(self, data, dates):
typed_variables = self.checkpoint.typed_variables
diagnostic_variables = self.checkpoint.diagnostic_variables

def _name(field, _, original_metadata):
return self._namer(field, original_metadata)

data = FieldArray([f.copy(name=_name) for f in data])

variable_from_input = [
v.name for v in typed_variables.values() if v.is_from_input and v.name not in diagnostic_variables
]

valid_datetime = [_.isoformat() for _ in dates]
LOG.info("Selecting fields %s %s", len(data), valid_datetime)

data = data.sel(name=variable_from_input, valid_datetime=valid_datetime).order_by("name", "valid_datetime")

expected = len(variable_from_input) * len(dates)

if len(data) != expected:
nvars = plural(len(variable_from_input), "variable")
ndates = plural(len(dates), "date")
nfields = plural(expected, "field")
msg = f"Expected ({nvars}) x ({ndates}) = {nfields}, got {len(data)}"
LOG.error("%s", msg)
# TODO: print a report
raise ValueError(msg)

assert len(data) == len(variable_from_input) * len(dates)

return data
18 changes: 18 additions & 0 deletions src/anemoi/inference/inputs/grib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import logging

from .ekd import EkdInput

LOG = logging.getLogger(__name__)


class GribInput(EkdInput):
pass
34 changes: 34 additions & 0 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


import logging
import warnings
from collections import defaultdict
from functools import cached_property

Expand Down Expand Up @@ -165,6 +166,39 @@ def accumulations(self):
"""Return the indices of the variables that are accumulations"""
return [v.name for v in self.typed_variables.values() if v.is_accumulation]

###########################################################################
# Default namer
###########################################################################

def default_namer(self, *args, **kwargs):
"""
Return a callable that can be used to name earthkit-data fields.
In that case, return the namer that was used to create the
training dataset.
"""

assert len(args) == 0, args
assert len(kwargs) == 0, kwargs

def namer(field, metadata):
warnings.warn("TEMPORARY CODE: Use the remapping in the metadata")
param, levelist, levtype = (
metadata.get("param"),
metadata.get("levelist"),
metadata.get("levtype"),
)

# Bug in eccodes that returns levelist for single level fields in GRIB2
if levtype in ("sfc", "o2d"):
levelist = None

if levelist is None:
return param

return f"{param}_{levelist}"

return namer

###########################################################################
# Data retrieval
###########################################################################
Expand Down
114 changes: 1 addition & 113 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,13 @@


import logging
import warnings
from collections import defaultdict
from functools import cached_property

import numpy as np
import torch
from anemoi.transform.grids.unstructured import UnstructuredGridFieldList
from anemoi.utils.dates import frequency_to_timedelta
from anemoi.utils.humanize import plural
from anemoi.utils.timer import Timer
from earthkit.data.indexing.fieldlist import FieldArray

from .checkpoint import Checkpoint
from .precisions import PRECISIONS
Expand Down Expand Up @@ -74,72 +70,11 @@ def __init__(self, checkpoint, *, accumulations=True, device: str, precision: st
self.postprocess = Accumulator(accumulations)

def run(self, *, input_state, lead_time):
lead_time = frequency_to_timedelta(lead_time)

input_tensor = self.prepare_input_tensor(input_state)
yield from self.postprocess(self.forecast(lead_time, input_tensor, input_state))

def create_input_state(self, input_fields, date=None, dtype=np.float32, flatten=True):

input_state = dict()

if date is None:
date = input_fields.order_by(valid_datetime="ascending")[-1].datetime()["valid_time"]
LOG.info("start_datetime not provided, using %s as start_datetime", date.isoformat())

dates = [date + h for h in self.checkpoint.lagged]
date_to_index = {d.isoformat(): i for i, d in enumerate(dates)}

input_state["date"] = date
fields = input_state["fields"] = dict()

input_fields = self.filter_and_sort(input_fields, dates)

check = defaultdict(set)

first = True
for field in input_fields:

if first:
first = False
input_state["latitudes"], input_state["longitudes"] = field.grid_points()

name, valid_datetime = field.metadata("name"), field.metadata("valid_datetime")
if name not in fields:
fields[name] = np.full(
shape=(len(dates), self.checkpoint.number_of_grid_points),
fill_value=np.nan,
dtype=dtype,
)

date_idx = date_to_index[valid_datetime]

try:
fields[name][date_idx] = field.to_numpy(dtype=dtype, flatten=flatten)
except ValueError:
LOG.error("Error with field %s: expected shape=%s, got shape=%s", name, fields[name].shape, field.shape)
LOG.error("dates %s", dates)
LOG.error("number_of_grid_points %s", self.checkpoint.number_of_grid_points)
raise

if date_idx in check[name]:
LOG.error("Duplicate dates for %s: %s", name, date_idx)
LOG.error("Expected %s", list(date_to_index.keys()))
LOG.error("Got %s", list(check[name]))
raise ValueError(f"Duplicate dates for {name}")

check[name].add(date_idx)

for name, idx in check.items():
if len(idx) != len(dates):
LOG.error("Missing dates for %s: %s", name, idx)
LOG.error("Expected %s", list(date_to_index.keys()))
LOG.error("Got %s", list(idx))
raise ValueError(f"Missing dates for {name}")

# self.add_initial_forcings_to_input_state(input_state)

return input_state

def add_initial_forcings_to_input_state(self, input_state):
latitudes = input_state["latitudes"]
longitudes = input_state["longitudes"]
Expand Down Expand Up @@ -282,53 +217,6 @@ def forecast(self, lead_time, input_tensor_numpy, input_state):
forcing = torch.from_numpy(forcing).to(self.device)
input_tensor_torch[:, -1, :, forcing_mask] = forcing

def filter_and_sort(self, data, dates):
typed_variables = self.checkpoint.typed_variables
diagnostic_variables = self.checkpoint.diagnostic_variables

def _name(field, key, original_metadata):
warnings.warn("TEMPORARY CODE: Use the remapping in the metadata")
param, levelist, levtype = (
original_metadata.get("param"),
original_metadata.get("levelist"),
original_metadata.get("levtype"),
)

# Bug in eccodes that returns levelist for single level fields in GRIB2
if levtype in ("sfc", "o2d"):
levelist = None

if levelist is None:
return param

return f"{param}_{levelist}"

data = FieldArray([f.copy(name=_name) for f in data])

variable_from_input = [
v.name for v in typed_variables.values() if v.is_from_input and v.name not in diagnostic_variables
]

valid_datetime = [_.isoformat() for _ in dates]
LOG.info("Selecting fields %s %s", len(data), valid_datetime)

data = data.sel(name=variable_from_input, valid_datetime=valid_datetime).order_by("name", "valid_datetime")

expected = len(variable_from_input) * len(dates)

if len(data) != expected:
nvars = plural(len(variable_from_input), "variable")
ndates = plural(len(dates), "date")
nfields = plural(expected, "field")
msg = f"Expected ({nvars}) x ({ndates}) = {nfields}, got {len(data)}"
LOG.error("%s", msg)
# TODO: print a report
raise ValueError(msg)

assert len(data) == len(variable_from_input) * len(dates)

return data

def compute_forcings(self, *, latitudes, longitudes, dates, variables):
import earthkit.data as ekd

Expand Down
Loading

0 comments on commit b20206d

Please sign in to comment.