Skip to content

Commit

Permalink
Merge branch 'develop' into feature/validate_datetime
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult authored Sep 8, 2024
2 parents b73cb1e + 24e27c8 commit 02bea59
Show file tree
Hide file tree
Showing 14 changed files with 235 additions and 41 deletions.
3 changes: 3 additions & 0 deletions .github/ci-config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dependency_branch: develop
parallelism_factor: 8
self_build: false # Only for python packages
15 changes: 15 additions & 0 deletions .github/workflows/changelog-pr-update.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: Check Changelog Update on PR
on:
pull_request:
types: [assigned, opened, synchronize, reopened, labeled, unlabeled]
branches:
- main
- develop
jobs:
Check-Changelog:
name: Check Changelog Action
runs-on: ubuntu-20.04
steps:
- uses: tarides/changelog-check-action@v2
with:
changelog: CHANGELOG.md
43 changes: 43 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: ci

on:
# Trigger the workflow on push to main or develop, except tag creation
push:
branches:
- 'main'
- 'develop'
tags-ignore:
- '**'
paths:
- "src/**"
- "tests/**"

# Trigger the workflow on pull request
pull_request: ~

# Trigger the workflow manually
workflow_dispatch: ~

# Trigger after public PR approved for CI
pull_request_target:
types: [labeled]

jobs:
# Run CI including downstream packages on self-hosted runners
downstream-ci:
name: downstream-ci
if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }}
uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci.yml@main
with:
anemoi-inference: ecmwf/anemoi-inference@${{ github.event.pull_request.head.sha || github.sha }}
codecov_upload: true
secrets: inherit

# Build downstream packages on HPC
downstream-ci-hpc:
name: downstream-ci-hpc
if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }}
uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci.yml@main
with:
anemoi-inference: ecmwf/anemoi-inference@${{ github.event.pull_request.head.sha || github.sha }}
secrets: inherit
10 changes: 10 additions & 0 deletions .github/workflows/label-public-pr.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Manage labels of pull requests that originate from forks
name: label-public-pr

on:
pull_request_target:
types: [opened, synchronize]

jobs:
label:
uses: ecmwf-actions/reusable-workflows/.github/workflows/label-pr.yml@v2
27 changes: 3 additions & 24 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v2
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -50,27 +50,6 @@ jobs:
run: pytest

deploy:

if: ${{ github.event_name == 'release' }}
runs-on: ubuntu-latest
needs: [checks, quality]

steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.x

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build wheel twine
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m build
twine upload dist/*
uses: ecmwf-actions/reusable-workflows/.github/workflows/cd-pypi.yml@v2
secrets: inherit
22 changes: 22 additions & 0 deletions .github/workflows/readthedocs-pr-update.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Read the Docs PR Preview
on:
pull_request_target:
types:
- opened
- synchronize
- reopened
# Execute this action only on PRs that touch
# documentation files.
paths:
- "docs/**"

permissions:
pull-requests: write

jobs:
documentation-links:
runs-on: ubuntu-latest
steps:
- uses: readthedocs/actions/preview@v1
with:
project-slug: "anemoi-inference"
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
- id: check-added-large-files # Check for large files added to git
- id: check-merge-conflict # Check for files that contain merge conflict
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
rev: 24.8.0
hooks:
- id: black
args: [--line-length=120]
Expand All @@ -34,7 +34,7 @@ repos:
- --force-single-line-imports
- --profile black
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.6
rev: v0.6.3
hooks:
- id: ruff
# Next line if for documenation cod snippets
Expand Down Expand Up @@ -65,6 +65,6 @@ repos:
- id: optional-dependencies-all
args: ["--inplace", "--exclude-keys=dev,docs,tests", "--group=dev=all,docs,tests"]
- repo: https://github.com/tox-dev/pyproject-fmt
rev: "2.1.3"
rev: "2.2.1"
hooks:
- id: pyproject-fmt
90 changes: 90 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

Please add your functional changes to the appropriate section in the PR.
Keep it human-readable, your future self will thank you!

## [Unreleased]

### Added
- earthkit-data replaces climetlab

### Changed

### Removed
- climetlab


## [0.1.10] Fix missing constants

### Added
- (GH) Added downstream-ci, reathedocs update check and label public pr workflows

### Changed
- Fix missing constant_fields property to query constants in the checkpoint

## [0.1.9] Patch, Move output finalise to ai-models

### Removed
- output finalise in the plugin

## [0.1.8] Patch, Support for output finalise in the plugin

### Added
- Support for output finalise in the plugin

## [0.1.7] Patch, graph utility

### Added
- graph utility

### Changed
- updated dependencies

## [0.1.6] Patch, update dependencies

### Changed
- updated dependencies

## [0.1.5] Patch, inspect cli tool

### Added
- tests
- inspect cli tool

## [0.1.4] Patch, autocast option

### Added
- add autocast option

## [0.1.3] Patch, support ai-models

### Added
- ai-models and AIModelPlugin

## [0.1.2] Patch

### Added
- dependency group all

## [0.1.0] Initial Release

### Added
Initial Implementation of anemoi-inference

## Git Diffs:
[0.1.10]: https://github.com/ecmwf/anemoi-inference/compare/0.1.9...0.1.10
[0.1.9]: https://github.com/ecmwf/anemoi-inference/compare/0.1.8...0.1.9
[0.1.8]: https://github.com/ecmwf/anemoi-inference/compare/0.1.7...0.1.8
[0.1.7]: https://github.com/ecmwf/anemoi-inference/compare/0.1.6...0.1.7
[0.1.6]: https://github.com/ecmwf/anemoi-inference/compare/0.1.5...0.1.6
[0.1.5]: https://github.com/ecmwf/anemoi-inference/compare/0.1.4...0.1.5
[0.1.4]: https://github.com/ecmwf/anemoi-inference/compare/0.1.3...0.1.4
[0.1.3]: https://github.com/ecmwf/anemoi-inference/compare/0.1.2...0.1.3
[0.1.2]: https://github.com/ecmwf/anemoi-inference/compare/0.1.1...0.1.2
[0.1.1]: https://github.com/ecmwf/anemoi-inference/compare/0.1.0...0.1.1
[0.1.0]: https://github.com/ecmwf/anemoi-inference/releases/tag/0.1.0
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dynamic = [
dependencies = [
"anemoi-utils>=0.3",
"anytree",
"earthkit-data>=0.9",
"numpy",
"pyyaml",
"semantic-version",
Expand Down
9 changes: 9 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,12 @@ def report_loading_error(self):
LOG.error("Training provenance:\n%s", json.dumps(provenance_training, indent=2))

###########################################################################

@property
def predict_step_shape(self):
return (
1, # Batch size
self.multi_step, # Lagged time steps
self.number_of_grid_points, # Grid points
self.num_input_features, # Fields
)
10 changes: 10 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/version_0_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,13 @@ def graph(self, graph):
dataset["attrs"] = dataset.copy()

return ZarrRequest(dataset).graph(graph)

@property
def number_of_grid_points(self):
from .version_0_2_0 import ZarrRequest

dataset = self._dataset.copy()
if "attrs" not in dataset:
dataset["attrs"] = dataset.copy()

return ZarrRequest(dataset).number_of_grid_points
9 changes: 9 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/version_0_2_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def dump(self, indent=0):
def graph_kids(self):
return []

@property
def number_of_grid_points(self):
if "shape" in self.attributes:
return self.attributes["shape"][-1]
return {
"o96": 40_320,
"n320": 542_080,
}[self.attributes["resolution"].lower()]


class Forward(DataRequest):
@cached_property
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/inference/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def param_level_pl(self):
def param_level_ml(self):
return self.runner.checkpoint.param_level_ml

@property
def constant_fields(self):
return self.runner.checkpoint.constants_from_input

@property
def grid(self):
return self.runner.checkpoint.grid
Expand Down
27 changes: 13 additions & 14 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@


def forcing_and_constants(source, date, param):
import climetlab as cml
import earthkit.data as ekd

ds = cml.load_source(
"constants",
ds = ekd.from_source(
"forcings",
source,
date=date,
param=param,
Expand Down Expand Up @@ -119,8 +119,9 @@ def run(

LOGGER.info("Loading input: %d fields (lagged=%d)", len(input_fields), len(self.lagged))


if start_datetime is None:
start_datetime = input_fields.order_by(valid_datetime="ascending")[-1].datetime()
start_datetime = input_fields.order_by(valid_datetime="ascending")[-1].metadata("valid_datetime")

num_fields_per_date = len(input_fields) // len(self.lagged) # assumed

Expand All @@ -142,6 +143,7 @@ def run(
)

input_fields_numpy = input_fields.to_numpy(dtype=np.float32, reshape=False)

print(input_fields_numpy.shape)

input_fields_numpy = input_fields_numpy.reshape(
Expand Down Expand Up @@ -285,10 +287,7 @@ def run(

with Timer(f"Loading {self.checkpoint}"):
try:
model = torch.load(
self.checkpoint.path,
map_location=device,
).to(device)
model = torch.load(self.checkpoint.path, map_location=device, weights_only=False).to(device)
except Exception:
self.checkpoint.report_loading_error()
raise
Expand All @@ -312,14 +311,14 @@ def run(

# Write dynamic fields
def get_most_recent_datetime(input_fields):
datetimes = [f.valid_datetime() for f in input_fields]
datetimes = [f.datetime()["valid_time"] for f in input_fields]
latest = datetimes[-1]
for d in datetimes:
assert d <= latest, (datetimes, d, latest)
return latest

most_recent_datetime = get_most_recent_datetime(input_fields)
reference_fields = [f for f in input_fields if f.valid_datetime() == most_recent_datetime]
reference_fields = [f for f in input_fields if f.datetime()["valid_time"] == most_recent_datetime]
precip_template = reference_fields[self.checkpoint.variable_to_index["2t"]]

accumulated_output = np.zeros(
Expand Down Expand Up @@ -357,8 +356,8 @@ def get_most_recent_datetime(input_fields):

for n, (m, param) in enumerate(zip(prognostic_data_from_retrieved_fields_mask, prognostic_params)):
template = reference_fields[m]
assert template.valid_datetime() == most_recent_datetime, (
template.valid_datetime(),
assert template.datetime()["valid_time"] == most_recent_datetime, (
template.datetime()["valid_time"],
most_recent_datetime,
)
output_callback(
Expand All @@ -372,8 +371,8 @@ def get_most_recent_datetime(input_fields):
if len(diagnostic_output_mask):
for n, param in enumerate(self.checkpoint.diagnostic_params):
accumulated_output[n] += np.maximum(0, diagnostic_fields_numpy[:, n])
assert precip_template.valid_datetime() == most_recent_datetime, (
precip_template.valid_datetime(),
assert precip_template.datetime()["valid_time"] == most_recent_datetime, (
precip_template.datetime()["valid_time"],
most_recent_datetime,
)
output_callback(
Expand Down

0 comments on commit 02bea59

Please sign in to comment.