Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:ecmwf/anemoi-inference into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 30, 2024
2 parents 9ee6b9a + 8ac1cfa commit 468a9d3
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# anemoi-utils
# anemoi-inference

**DISCLAIMER**
This project is **BETA** and will be **Experimental** for the foreseeable future.
Expand Down
8 changes: 7 additions & 1 deletion src/anemoi/inference/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import zipfile
from functools import cached_property

from anemoi.utils.checkpoints import has_metadata
from anemoi.utils.checkpoints import load_metadata

from .metadata import Metadata
Expand All @@ -30,7 +31,12 @@ def __repr__(self):

def __getattr__(self, name):
if self._metadata is None:
self._metadata = Metadata.from_metadata(load_metadata(self.path))
try:
self._metadata = Metadata.from_metadata(load_metadata(self.path))
except ValueError:
if has_metadata(self.path):
raise
self._metadata = Metadata.from_metadata(None)

return getattr(self._metadata, name)

Expand Down
5 changes: 5 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@


def from_versions(checkpoint_version, dataset_version):
from .version_0_0_0 import Version_0_0_0
from .version_0_1_0 import Version_0_1_0
from .version_0_2_0 import Version_0_2_0

VERSIONS = {
("0.0.0", "0.0.0"): Version_0_0_0,
("1.0.0", "0.1.0"): Version_0_1_0,
("1.0.0", "0.2.0"): Version_0_2_0,
}
Expand Down Expand Up @@ -56,6 +58,9 @@ def to_dict(self):

@classmethod
def from_metadata(cls, metadata):
if metadata is None:
metadata = dict(version="0.0.0", dataset=dict(version="0.0.0"))

if isinstance(metadata["dataset"], list):
from .patch import list_to_dict

Expand Down
14 changes: 14 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def drop_fill(metadata):
return metadata


def select_fill(metadata):
metadata["variables"] = [x for x in metadata["forward"]["variables"] if x in metadata["select"]]
return metadata


def rename_fill(metadata, select):

rename = metadata["rename"]
Expand Down Expand Up @@ -49,6 +54,15 @@ def patch(a, b):
}
)

if "select" in a:
return select_fill(
{
"action": "select",
"select": a["select"],
"forward": zarr_fill({"action": "zarr", "attrs": b}),
}
)

if "rename" in a:
return rename_fill(
{
Expand Down
242 changes: 242 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/version_0_0_0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging

from . import Metadata

LOG = logging.getLogger(__name__)


class Version_0_0_0(Metadata):
"""
Reference for very old checkpoints
Will not work and need to be updated
"""

def __init__(self, metadata):
super().__init__(metadata)

def dump(self, indent=0):
print("Version_0_0_0: Not implemented")

# Input
area = [90, 0, -90, 360]
grid = [0.25, 0.25]
param_sfc = [
"z",
"sp",
"msl",
"lsm",
"sst",
"sdor",
"slor",
"10u",
"10v",
"2t",
"2d",
]
param_level_pl = (
["q", "t", "u", "v", "w", "z"],
[50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000],
)

ordering = [
"q_50",
"q_100",
"q_150",
"q_200",
"q_250",
"q_300",
"q_400",
"q_500",
"q_600",
"q_700",
"q_850",
"q_925",
"q_1000",
"t_50",
"t_100",
"t_150",
"t_200",
"t_250",
"t_300",
"t_400",
"t_500",
"t_600",
"t_700",
"t_850",
"t_925",
"t_1000",
"u_50",
"u_100",
"u_150",
"u_200",
"u_250",
"u_300",
"u_400",
"u_500",
"u_600",
"u_700",
"u_850",
"u_925",
"u_1000",
"v_50",
"v_100",
"v_150",
"v_200",
"v_250",
"v_300",
"v_400",
"v_500",
"v_600",
"v_700",
"v_850",
"v_925",
"v_1000",
"w_50",
"w_100",
"w_150",
"w_200",
"w_250",
"w_300",
"w_400",
"w_500",
"w_600",
"w_700",
"w_850",
"w_925",
"w_1000",
"z_50",
"z_100",
"z_150",
"z_200",
"z_250",
"z_300",
"z_400",
"z_500",
"z_600",
"z_700",
"z_850",
"z_925",
"z_1000",
"sp",
"msl",
"sst",
"10u",
"10v",
"2t",
"2d",
"z",
"lsm",
"sdor",
"slor",
]

param_format = {"param_level": "{param}{levelist}"}

computed_constants = [
"cos_latitude",
"cos_longitude",
"sin_latitude",
"sin_longitude",
]
computed_constants_mask = []

computer_forcing = [
"cos_julian_day",
"cos_local_time",
"sin_julian_day",
"sin_local_time",
"insolation",
]

@property
def variables(self):
return self.ordering + self.computed_constants + self.forcing_params

@property
def num_input_features(self):
raise NotImplementedError()

@property
def data_to_model(self):
raise NotImplementedError()

@property
def model_to_data(self):
raise NotImplementedError()

###########################################################################
@property
def order_by(self):
return dict(
valid_datetime="ascending",
param_level=self.ordering,
remapping={"param_level": "{param}_{levelist}"},
)

@property
def select(self):
return dict(
param_level=self.variables,
remapping={"param_level": "{param}_{levelist}"},
)

###########################################################################

@property
def constants_from_input(self):
raise NotImplementedError()

@property
def constants_from_input_mask(self):
raise NotImplementedError()

@property
def constant_data_from_input_mask(self):
raise NotImplementedError()

###########################################################################

@property
def prognostic_input_mask(self):
raise NotImplementedError()

@property
def prognostic_data_input_mask(self):
raise NotImplementedError()

@property
def prognostic_output_mask(self):
raise NotImplementedError()

@property
def diagnostic_output_mask(self):
raise NotImplementedError()

@property
def diagnostic_params(self):
raise NotImplementedError()

@property
def prognostic_params(self):
raise NotImplementedError()

###########################################################################
@property
def precision(self):
raise NotImplementedError()

@property
def multi_step(self):
raise NotImplementedError()

@property
def imputable_variables(self):
raise NotImplementedError()
3 changes: 3 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/version_0_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,6 @@ def patch_metadata(self):
pl.remove([param, level])
if [param, level] in ml:
ml.remove([param, level])

def dump(self, indent=0):
print("Version_0_1_0: Not implemented")
19 changes: 9 additions & 10 deletions src/anemoi/inference/checkpoint/metadata/version_0_2_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def param_level_ml(self):

class ZarrRequest(DataRequest):
def __init__(self, metadata):
super().__init__(metadata)
self.attributes = metadata["attrs"]
self.request = self.attributes["data_request"]

Expand Down Expand Up @@ -143,10 +144,13 @@ class StatisticsRequest(Forward):

class RenameRequest(Forward):

# Drop variables
# No need to rename anything as self.metadata["variables"] is already up to date

@property
def variables_with_nans(self):
raise NotImplementedError()
return sorted(self.forward.variables_with_nans)
rename = self.metadata["rename"]
return sorted([rename.get(x, x) for x in self.forward.variables_with_nans])


class MultiRequest(Forward):
Expand Down Expand Up @@ -254,17 +258,12 @@ def variables_with_nans(self):

class DropRequest(SelectRequest):

@property
def variables(self):
raise NotImplementedError()
# Drop variables
# No need to drop anything as self.metadata["variables"] is already up to date

@property
def variables_with_nans(self):
result = set()
for dataset in self.metadata["datasets"]:
result.extend(dataset.variables_with_nans)

return sorted(result)
return [x for x in self.forward.variables_with_nans if x in self.variables]


def data_request(specific):
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/inference/commands/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def run(self, args):
c.dump()
return

c.dump()

print("area:", c.area)
print("computed_constants_mask:", c.computed_constants_mask)
print("computed_constants:", c.computed_constants)
Expand Down

0 comments on commit 468a9d3

Please sign in to comment.