Skip to content

Commit

Permalink
use earthkit-data
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 14, 2024
1 parent 9b28685 commit f55a164
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 136 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ cython_debug/
?.*
*.png
*.pny
_version.py
69 changes: 69 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
repos:

# Empty notebookds
- repo: local
hooks:
- id: clear-notebooks-output
name: clear-notebooks-output
files: tools/.*\.ipynb$
stages: [commit]
language: python
entry: jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace
additional_dependencies: [jupyter]


- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-yaml # Check YAML files for syntax errors only
args: [--unsafe, --allow-multiple-documents]
- id: debug-statements # Check for debugger imports and py37+ breakpoint()
- id: end-of-file-fixer # Ensure files end in a newline
- id: trailing-whitespace # Trailing whitespace checker
- id: no-commit-to-branch # Prevent committing to main / master
- id: check-added-large-files # Check for large files added to git
- id: check-merge-conflict # Check for files that contain merge conflict

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.1.1
hooks:
- id: black
args: [--line-length=120]

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args:
- -l 120
- --force-single-line-imports
- --profile black


- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
exclude: '(dev/.*|.*_)\.py$'
args:
- --line-length=120
- --fix
- --exit-non-zero-on-fix
- --preview

- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v0.9.1
hooks:
- id: sphinx-lint

# For now, we use it. But it does not support a lot of sphinx features
- repo: https://github.com/dzhu/rstfmt
rev: v0.0.14
hooks:
- id: rstfmt

- repo: https://github.com/b8raoult/pre-commit-docconvert
rev: "0.1.4"
hooks:
- id: docconvert
args: ["numpy"]
67 changes: 67 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/

[build-system]
requires = [
"setuptools>=60",
"setuptools-scm>=8",
]

[project]
name = "ai-models-graphcast"

description = "An ai-models plugin to run Deepmind's graphcast model"
keywords = [
"ai",
"tools",
]

license = { file = "LICENSE" }
authors = [
{ name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "[email protected]" },
]

requires-python = ">=3.10"

classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]

dynamic = [
"version",
]

# JAX requirements are in requirements.txt

dependencies = [
"ai-models>=0.4.0",
"dm-tree",
"dm-haiku==0.0.10",
]

optional-dependencies.dev = [
"pre-commit",
]
urls.Repository = "https://github.com/ecmwf-lab/ai-models-graphcast"
entry-points."ai_models.model".graphcast = "ai_models_graphcast.model:model"

[tool.setuptools_scm]
version_file = "src/ai_models_graphcast/_version.py"
68 changes: 0 additions & 68 deletions setup.py

This file was deleted.

File renamed without changes.
File renamed without changes.
27 changes: 9 additions & 18 deletions ai_models_graphcast/input.py → src/ai_models_graphcast/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
from collections import defaultdict

import climetlab as cml
import earthkit.data as ekd
import numpy as np
import xarray as xr

Expand Down Expand Up @@ -46,8 +46,8 @@ def forcing_variables_numpy(sample, forcing_variables, dates):
Returns:
torch.Tensor: Tensor with constants
"""
ds = cml.load_source(
"constants",
ds = ekd.from_source(
"forcings",
sample,
date=dates,
param=forcing_variables,
Expand All @@ -74,16 +74,13 @@ def create_training_xarray(
):
time_deltas = [
datetime.timedelta(hours=h)
for h in lagged
+ [hour for hour in range(hour_steps, lead_time + hour_steps, hour_steps)]
for h in lagged + [hour for hour in range(hour_steps, lead_time + hour_steps, hour_steps)]
]

all_datetimes = [start_date() + time_delta for time_delta in time_deltas]
all_datetimes = [start_date + time_delta for time_delta in time_deltas]

with timer("Creating forcing variables"):
forcing_numpy = forcing_variables_numpy(
fields_sfc, forcing_variables, all_datetimes
)
forcing_numpy = forcing_variables_numpy(fields_sfc, forcing_variables, all_datetimes)

with timer("Converting GRIB to xarray"):
# Create Input dataset
Expand Down Expand Up @@ -118,9 +115,7 @@ def create_training_xarray(
data_vars[CF_NAME_SFC[param]] = (["lat", "lon"], fields[0].to_numpy())
continue

data = np.stack(
[field.to_numpy(dtype=np.float32) for field in fields]
).reshape(
data = np.stack([field.to_numpy(dtype=np.float32) for field in fields]).reshape(
1,
len(given_datetimes),
len(lat),
Expand All @@ -141,9 +136,7 @@ def create_training_xarray(
data_vars[CF_NAME_SFC[param]] = (["batch", "time", "lat", "lon"], data)

for param, fields in pl.items():
data = np.stack(
[field.to_numpy(dtype=np.float32) for field in fields]
).reshape(
data = np.stack([field.to_numpy(dtype=np.float32) for field in fields]).reshape(
1,
len(given_datetimes),
len(levels),
Expand Down Expand Up @@ -188,9 +181,7 @@ def create_training_xarray(

with timer("Reindexing"):
# And we want the grid south to north
training_xarray = training_xarray.reindex(
lat=sorted(training_xarray.lat.values)
)
training_xarray = training_xarray.reindex(lat=sorted(training_xarray.lat.values))

if constants:
# Add geopotential_at_surface and land_sea_mask back in
Expand Down
45 changes: 13 additions & 32 deletions ai_models_graphcast/model.py → src/ai_models_graphcast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@


import dataclasses
import datetime
import functools
import gc
import logging
import os
from functools import cached_property

import xarray
from ai_models.model import Model
Expand All @@ -26,14 +24,12 @@
try:
import haiku as hk
import jax
from graphcast import (
autoregressive,
casting,
checkpoint,
data_utils,
graphcast,
normalization,
)
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
except ModuleNotFoundError as e:
msg = "You need to install Graphcast from git to use this model. See README.md for details."
LOG.error(msg)
Expand Down Expand Up @@ -88,9 +84,7 @@ def __init__(self, **kwargs):
self.lagged = [-6, 0]
self.params = None
self.ordering = self.param_sfc + [
f"{param}{level}"
for param in self.param_level_pl[0]
for level in self.param_level_pl[1]
f"{param}{level}" for param in self.param_level_pl[0] for level in self.param_level_pl[1]
]

# Jax doesn't seem to like passing configs as args through the jit. Passing it
Expand Down Expand Up @@ -119,17 +113,11 @@ def load_model(self):
def get_path(filename):
return os.path.join(self.assets, filename)

diffs_stddev_by_level = xarray.load_dataset(
get_path(self.download_files[1])
).compute()
diffs_stddev_by_level = xarray.load_dataset(get_path(self.download_files[1])).compute()

mean_by_level = xarray.load_dataset(
get_path(self.download_files[2])
).compute()
mean_by_level = xarray.load_dataset(get_path(self.download_files[2])).compute()

stddev_by_level = xarray.load_dataset(
get_path(self.download_files[3])
).compute()
stddev_by_level = xarray.load_dataset(get_path(self.download_files[3])).compute()

def construct_wrapped_graphcast(model_config, task_config):
"""Constructs and wraps the GraphCast Predictor."""
Expand Down Expand Up @@ -183,13 +171,7 @@ def run_forward(
LOG.info("Model license: %s", self.ckpt.license)

jax.jit(self._with_configs(run_forward.init))
self.model = self._drop_state(
self._with_params(jax.jit(self._with_configs(run_forward.apply)))
)

@cached_property
def start_date(self) -> "datetime":
return self.all_fields.order_by(valid_datetime="descending")[0].datetime
self.model = self._drop_state(self._with_params(jax.jit(self._with_configs(run_forward.apply))))

def run(self):
# We ignore 'tp' so that we make sure that step 0 is a field of zero values
Expand All @@ -205,7 +187,7 @@ def run(self):
fields_sfc=self.fields_sfc,
fields_pl=self.fields_pl,
lagged=self.lagged,
start_date=self.start_date,
start_date=self.start_datetime,
hour_steps=self.hour_steps,
lead_time=self.lead_time,
forcing_variables=self.forcing_variables,
Expand All @@ -226,8 +208,7 @@ def run(self):
) = data_utils.extract_inputs_targets_forcings(
training_xarray,
target_lead_times=[
f"{int(delta.days * 24 + delta.seconds/3600):d}h"
for delta in time_deltas[len(self.lagged) :]
f"{int(delta.days * 24 + delta.seconds/3600):d}h" for delta in time_deltas[len(self.lagged) :]
],
**dataclasses.asdict(self.task_config),
)
Expand Down
Loading

0 comments on commit f55a164

Please sign in to comment.