Skip to content

Commit

Permalink
support for old weights
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed May 31, 2024
1 parent 77cbd11 commit 978b328
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 66 deletions.
2 changes: 2 additions & 0 deletions src/anemoi/inference/checkpoint/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def _computed_forcings(self):
]
)

print("FORCINGS", self._forcing_params())

constants = set(self._forcing_params()) - set(self.constants_from_input) - set(self.computed_constants)

if constants - known:
Expand Down
120 changes: 54 additions & 66 deletions src/anemoi/inference/checkpoint/metadata/version_0_0_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,54 @@ class Version_0_0_0(Metadata):
def __init__(self, metadata):
super().__init__(metadata)

FORCING_PARAMS = [
"z",
"lsm",
"sdor",
"slor",
"cos_latitude",
"cos_longitude",
"sin_latitude",
"sin_longitude",
"cos_julian_day",
"cos_local_time",
"sin_julian_day",
"sin_local_time",
"insolation",
]

indices = dict(
forcing=self._index_of(FORCING_PARAMS),
full=self._index_of(self.variables),
diagnostic=[],
prognostic=self._index_of(self.ordering),
)

config = dict(
data_indices=dict(
data=dict(
input=indices,
output=indices,
),
model=dict(
input=indices,
output=indices,
),
),
config=dict(
data=dict(timestep=6, frequency=6),
training=dict(
multistep_input=2,
precision="32",
),
),
)

self._metadata.update(config)

def _index_of(self, names):
return [self.variable_to_index[name] for name in names]

def dump(self, indent=0):
print("Version_0_0_0: Not implemented")

Expand All @@ -46,6 +94,8 @@ def dump(self, indent=0):
[50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000],
)

param_level_ml = tuple()

ordering = [
"q_50",
"q_100",
Expand Down Expand Up @@ -146,9 +196,8 @@ def dump(self, indent=0):
"sin_latitude",
"sin_longitude",
]
computed_constants_mask = []

computer_forcing = [
computed_forcing = [
"cos_julian_day",
"cos_local_time",
"sin_julian_day",
Expand All @@ -158,19 +207,11 @@ def dump(self, indent=0):

@property
def variables(self):
return self.ordering + self.computed_constants + self.forcing_params

@property
def num_input_features(self):
raise NotImplementedError()

@property
def data_to_model(self):
raise NotImplementedError()
return self.ordering + self.computed_constants + self.computed_forcing

@property
def model_to_data(self):
raise NotImplementedError()
def variables_with_nans(self):
return []

###########################################################################
@property
Expand All @@ -187,56 +228,3 @@ def select(self):
param_level=self.variables,
remapping={"param_level": "{param}_{levelist}"},
)

###########################################################################

@property
def constants_from_input(self):
raise NotImplementedError()

@property
def constants_from_input_mask(self):
raise NotImplementedError()

@property
def constant_data_from_input_mask(self):
raise NotImplementedError()

###########################################################################

@property
def prognostic_input_mask(self):
raise NotImplementedError()

@property
def prognostic_data_input_mask(self):
raise NotImplementedError()

@property
def prognostic_output_mask(self):
raise NotImplementedError()

@property
def diagnostic_output_mask(self):
raise NotImplementedError()

@property
def diagnostic_params(self):
raise NotImplementedError()

@property
def prognostic_params(self):
raise NotImplementedError()

###########################################################################
@property
def precision(self):
raise NotImplementedError()

@property
def multi_step(self):
raise NotImplementedError()

@property
def imputable_variables(self):
raise NotImplementedError()

0 comments on commit 978b328

Please sign in to comment.