Skip to content

Commit

Permalink
Closes #7
Browse files Browse the repository at this point in the history
- Replaces climetlab with earthkit
  • Loading branch information
HCookie committed Sep 4, 2024
1 parent 6b32c35 commit 424877d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ Keep it human-readable, your future self will thank you!
## [Unreleased]

### Added
- earthkit-data replaces climetlab

### Changed

### Removed
- climetlab


## [0.1.10] Fix missing constants

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dynamic = [
dependencies = [
"anemoi-utils>=0.3",
"anytree",
"earthkit-data>=0.9",
"numpy",
"pyyaml",
"semantic-version",
Expand Down
27 changes: 17 additions & 10 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@


def forcing_and_constants(source, date, param):
import climetlab as cml
import earthkit

ds = cml.load_source(
"constants",
ds = earthkit.data.from_source(
"forcings",
source,
date=date,
param=param,
Expand Down Expand Up @@ -119,7 +119,7 @@ def run(

LOGGER.info("Loading input: %d fields (lagged=%d)", len(input_fields), len(self.lagged))

input_fields_numpy = input_fields.to_numpy(dtype=np.float32, reshape=False)
input_fields_numpy = input_fields.to_numpy(dtype=np.float32)
print(input_fields_numpy.shape)

input_fields_numpy = input_fields_numpy.reshape(
Expand Down Expand Up @@ -226,6 +226,13 @@ def run(
if start_datetime is None:
start_datetime = input_fields.order_by(valid_datetime="ascending")[-1].datetime()

if isinstance(start_datetime, dict):
start_datetime = start_datetime[
"base_time"
] # With earthkit, time is a dictionary with 'base_time' and 'valid_time'
else:
raise Exception()

constants = forcing_and_constants(
source=input_fields[:1],
param=self.checkpoint.computed_constants,
Expand Down Expand Up @@ -290,14 +297,14 @@ def run(

# Write dynamic fields
def get_most_recent_datetime(input_fields):
datetimes = [f.valid_datetime() for f in input_fields]
datetimes = [f.datetime()["valid_time"] for f in input_fields]
latest = datetimes[-1]
for d in datetimes:
assert d <= latest, (datetimes, d, latest)
return latest

most_recent_datetime = get_most_recent_datetime(input_fields)
reference_fields = [f for f in input_fields if f.valid_datetime() == most_recent_datetime]
reference_fields = [f for f in input_fields if f.datetime()["valid_time"] == most_recent_datetime]
precip_template = reference_fields[self.checkpoint.variable_to_index["2t"]]

accumulated_output = np.zeros(
Expand Down Expand Up @@ -335,8 +342,8 @@ def get_most_recent_datetime(input_fields):

for n, (m, param) in enumerate(zip(prognostic_data_from_retrieved_fields_mask, prognostic_params)):
template = reference_fields[m]
assert template.valid_datetime() == most_recent_datetime, (
template.valid_datetime(),
assert template.datetime()["valid_time"] == most_recent_datetime, (
template.datetime()["valid_time"],
most_recent_datetime,
)
output_callback(
Expand All @@ -350,8 +357,8 @@ def get_most_recent_datetime(input_fields):
if len(diagnostic_output_mask):
for n, param in enumerate(self.checkpoint.diagnostic_params):
accumulated_output[n] += np.maximum(0, diagnostic_fields_numpy[:, n])
assert precip_template.valid_datetime() == most_recent_datetime, (
precip_template.valid_datetime(),
assert precip_template.datetime()["valid_time"] == most_recent_datetime, (
precip_template.datetime()["valid_time"],
most_recent_datetime,
)
output_callback(
Expand Down

0 comments on commit 424877d

Please sign in to comment.