Skip to content

Commit

Permalink
Replace climetlab with earthkit-data (#8)
Browse files Browse the repository at this point in the history
* Closes #7
- Replaces climetlab with earthkit

* Fix: Remove testing exception

* Fix: Address review comments
- ekd instead of earthkit.data
- Use 'valid_time' instead of 'base_time'

* Fix: Address valid_time retrieval comment
  • Loading branch information
HCookie authored Sep 6, 2024
1 parent a334e11 commit 24e27c8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ Keep it human-readable, your future self will thank you!
## [Unreleased]

### Added
- earthkit-data replaces climetlab

### Changed

### Removed
- climetlab


## [0.1.10] Fix missing constants

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ dynamic = [
dependencies = [
"anemoi-utils>=0.3",
"anytree",
"earthkit-data>=0.9",
"numpy",
"pyyaml",
"semantic-version",
Expand Down
23 changes: 12 additions & 11 deletions src/anemoi/inference/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@


def forcing_and_constants(source, date, param):
import climetlab as cml
import earthkit.data as ekd

ds = cml.load_source(
"constants",
ds = ekd.from_source(
"forcings",
source,
date=date,
param=param,
Expand Down Expand Up @@ -119,7 +119,7 @@ def run(

LOGGER.info("Loading input: %d fields (lagged=%d)", len(input_fields), len(self.lagged))

input_fields_numpy = input_fields.to_numpy(dtype=np.float32, reshape=False)
input_fields_numpy = input_fields.to_numpy(dtype=np.float32)
print(input_fields_numpy.shape)

input_fields_numpy = input_fields_numpy.reshape(
Expand Down Expand Up @@ -224,7 +224,8 @@ def run(
]

if start_datetime is None:
start_datetime = input_fields.order_by(valid_datetime="ascending")[-1].datetime()
start_datetime_str = input_fields.order_by(valid_datetime="ascending")[-1].metadata("valid_datetime")
start_datetime = datetime.datetime.fromisoformat(start_datetime_str)

constants = forcing_and_constants(
source=input_fields[:1],
Expand Down Expand Up @@ -290,14 +291,14 @@ def run(

# Write dynamic fields
def get_most_recent_datetime(input_fields):
datetimes = [f.valid_datetime() for f in input_fields]
datetimes = [f.datetime()["valid_time"] for f in input_fields]
latest = datetimes[-1]
for d in datetimes:
assert d <= latest, (datetimes, d, latest)
return latest

most_recent_datetime = get_most_recent_datetime(input_fields)
reference_fields = [f for f in input_fields if f.valid_datetime() == most_recent_datetime]
reference_fields = [f for f in input_fields if f.datetime()["valid_time"] == most_recent_datetime]
precip_template = reference_fields[self.checkpoint.variable_to_index["2t"]]

accumulated_output = np.zeros(
Expand Down Expand Up @@ -335,8 +336,8 @@ def get_most_recent_datetime(input_fields):

for n, (m, param) in enumerate(zip(prognostic_data_from_retrieved_fields_mask, prognostic_params)):
template = reference_fields[m]
assert template.valid_datetime() == most_recent_datetime, (
template.valid_datetime(),
assert template.datetime()["valid_time"] == most_recent_datetime, (
template.datetime()["valid_time"],
most_recent_datetime,
)
output_callback(
Expand All @@ -350,8 +351,8 @@ def get_most_recent_datetime(input_fields):
if len(diagnostic_output_mask):
for n, param in enumerate(self.checkpoint.diagnostic_params):
accumulated_output[n] += np.maximum(0, diagnostic_fields_numpy[:, n])
assert precip_template.valid_datetime() == most_recent_datetime, (
precip_template.valid_datetime(),
assert precip_template.datetime()["valid_time"] == most_recent_datetime, (
precip_template.datetime()["valid_time"],
most_recent_datetime,
)
output_callback(
Expand Down

0 comments on commit 24e27c8

Please sign in to comment.