Skip to content

Commit

Permalink
add pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Oct 26, 2024
1 parent 2a23a4d commit 5900bed
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 27 deletions.
3 changes: 3 additions & 0 deletions docs/apis/level3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ checkpoint that was trained with one the ICON grid:

This is still work in progress, and content of the YAML configuration
files will change and the examples above may not work in the future.


.. autopydantic_model:: target.usage_model.ExampleSettings
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinxarg.ext",
"sphinxcontrib.autodoc_pydantic",
]

# Add any paths that contain templates here, relative to this directory.
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies = [
"numpy",
"omegaconf",
"packaging",
"pydantic",
"pyyaml",
"semantic-version",
"torch",
Expand All @@ -59,6 +60,7 @@ optional-dependencies.all = [ "anemoi-inference[plugin]" ]
optional-dependencies.dev = [ "anemoi-datasets[all,docs,plugin,tests]" ]

optional-dependencies.docs = [
"autodoc-pydantic",
"nbsphinx",
"pandoc",
"rstfmt",
Expand Down
30 changes: 10 additions & 20 deletions src/anemoi/inference/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#
from __future__ import annotations

import json
import logging
import os

from omegaconf import OmegaConf

from ..config import Configuration
from ..inputs.dataset import DatasetInput
from ..inputs.gribfile import GribFileInput
from ..inputs.icon import IconInput
Expand All @@ -23,24 +26,6 @@

LOG = logging.getLogger(__name__)

DEFAULTS = OmegaConf.create(
{
"checkpoint": "???",
"date": None,
"device": "cuda",
"lead_time": "10d",
"precision": None,
"allow_nans": False,
"icon_grid": None,
"input": None,
"output": None,
"write_initial_state": True,
"use_grib_paramid": False,
"dataset": None,
"env": {},
}
)


class RunCmd(Command):
"""Inspect the contents of a checkpoint file."""
Expand All @@ -54,11 +39,16 @@ def add_arguments(self, command_parser):
def run(self, args):

config = OmegaConf.merge(
DEFAULTS,
OmegaConf.create(Configuration().dict()), # Load default configuration
OmegaConf.load(args.config),
OmegaConf.from_dotlist(args.overrides),
)
LOG.info("Configuration:\n\n%s", OmegaConf.to_yaml(config))

# Validate the configuration

config = Configuration(**config)

LOG.info("Configuration:\n\n%s", json.dumps(config.dict(), indent=4))

for key, value in config.env.items():
os.environ[key] = str(value)
Expand Down
11 changes: 4 additions & 7 deletions src/anemoi/inference/inputs/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ def _(r):
class MarsInput(GribInput):
"""Get input fields from MARS"""

def __init__(self, checkpoint, *, use_grib_paramid=False, **kwargs):
super().__init__(checkpoint)
def __init__(self, runner, *, use_grib_paramid=False, **kwargs):
super().__init__(runner)
self.use_grib_paramid = use_grib_paramid
self.kwargs = kwargs
assert use_grib_paramid

def create_input_state(self, *, date):
if date is None:
Expand All @@ -81,10 +82,6 @@ def _retrieve(self, date):

dates = [date + h for h in self.checkpoint.lagged]

requests = self.checkpoint.mars_requests(
dates=dates,
expver="0001",
use_grib_paramid=self.use_grib_paramid,
)
requests = self.checkpoint.mars_requests(dates=dates, expver="0001", use_grib_paramid=self.use_grib_paramid)

return retrieve(requests, self.checkpoint.grid, self.checkpoint.area, **self.kwargs)
2 changes: 2 additions & 0 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def area(self):
def mars_requests(self, *, use_grib_paramid=False, variables=all):
"""Return a list of MARS requests for the variables in the dataset"""

assert use_grib_paramid

from anemoi.utils.grib import shortname_to_paramid

for variable, metadata in self.variables_metadata.items():
Expand Down

0 comments on commit 5900bed

Please sign in to comment.