Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonWittner committed Dec 22, 2023
1 parent 139a97f commit 30aa303
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 13 deletions.
2 changes: 1 addition & 1 deletion neuralprophet/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,5 +623,5 @@ def _create_dataset(model, df, predict_mode, prediction_frequency=None):
config_regressors=model.config_regressors,
config_missing=model.config_missing,
prediction_frequency=prediction_frequency,
config_train=model.config_train
config_train=model.config_train,
)
40 changes: 30 additions & 10 deletions neuralprophet/time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,15 @@ def __init__(self, df, name, **kwargs):
]
self.kwargs = kwargs

learning_rate = kwargs['config_train'].learning_rate
if kwargs['predict_mode'] or (learning_rate is None) or kwargs['config_lagged_regressors'] or kwargs['config_country_holidays'] or kwargs['config_events'] or kwargs['prediction_frequency']:
learning_rate = kwargs["config_train"].learning_rate
if (
kwargs["predict_mode"]
or (learning_rate is None)
or kwargs["config_lagged_regressors"]
or kwargs["config_country_holidays"]
or kwargs["config_events"]
or kwargs["prediction_frequency"]
):
inputs, targets = tabularize_univariate_datetime(df, **kwargs)
self.init_after_tabularized(inputs, targets)
self.filter_samples_after_init(kwargs["prediction_frequency"])
Expand Down Expand Up @@ -102,15 +109,22 @@ def __getitem__(self, index):
Targets to be predicted of same length as each of the model inputs, dims: (num_samples, n_forecasts)
"""
# TODO: Drop config_train from self!
learning_rate = self.kwargs['config_train'].learning_rate
if self.kwargs['predict_mode'] or (learning_rate is None) or self.kwargs['config_lagged_regressors'] or self.kwargs['config_country_holidays'] or self.kwargs['config_events'] or self.kwargs['prediction_frequency']:
learning_rate = self.kwargs["config_train"].learning_rate
if (
self.kwargs["predict_mode"]
or (learning_rate is None)
or self.kwargs["config_lagged_regressors"]
or self.kwargs["config_country_holidays"]
or self.kwargs["config_events"]
or self.kwargs["prediction_frequency"]
):
sample = self.samples[index]
targets = self.targets[index]
meta = self.meta
return sample, targets, meta
else:
start_idx = index
end_idx = start_idx + self.kwargs.get('n_lags') + self.kwargs.get('n_forecasts')
end_idx = start_idx + self.kwargs.get("n_lags") + self.kwargs.get("n_forecasts")

df_slice = self.df.iloc[start_idx:end_idx]

Expand Down Expand Up @@ -139,7 +153,6 @@ def drop_nan_init(self, drop_missing):
number of steps to predict
"""


def drop_nan_after_init(self, df, predict_steps, drop_missing):
"""Checks if inputs/targets contain any NaN values and drops them, if user opts to.
Parameters
Expand Down Expand Up @@ -361,13 +374,20 @@ def tabularize_univariate_datetime(
Targets to be predicted of same length as each of the model inputs, dims: (num_samples, n_forecasts)
"""
max_lags = get_max_num_lags(config_lagged_regressors, n_lags)
#n_samples = len(df) - max_lags + 1 - n_forecasts
#TODO
# n_samples = len(df) - max_lags + 1 - n_forecasts
# TODO
learning_rate = config_train.learning_rate
if predict_mode or (learning_rate is None) or config_lagged_regressors or config_country_holidays or config_events or prediction_frequency:
if (
predict_mode
or (learning_rate is None)
or config_lagged_regressors
or config_country_holidays
or config_events
or prediction_frequency
):
n_samples = len(df) - max_lags + 1 - n_forecasts
else:
n_samples=1
n_samples = 1

# data is stored in OrderedDict
inputs = OrderedDict({})
Expand Down
3 changes: 2 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,7 @@ def test_unused_future_regressors():
m.add_lagged_regressor("cost")
m.fit(df, freq="D")


def test_on_the_fly_sampling():
start_date = "2022-10-16 00:00:00"
end_date = "2022-12-30 00:00:00"
Expand All @@ -1715,5 +1716,5 @@ def test_on_the_fly_sampling():
df.loc[3, "y"] = np.nan

m = NeuralProphet(epochs=1, learning_rate=0.01)
m.fit(df, freq='H')
m.fit(df, freq="H")
metrics = m.predict(df)
13 changes: 12 additions & 1 deletion tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,18 @@ def test_too_many_NaN():
df["ID"] = "__df__"
# Check if ValueError is thrown, if NaN values remain after auto-imputing
with pytest.raises(ValueError):
time_dataset.TimeDataset(df, "name", predict_mode=False, config_missing=config_missing, config_lagged_regressors=None, config_country_holidays=None, config_events=None, config_train=config_train, predict_steps=1, prediction_frequency=None)
time_dataset.TimeDataset(
df,
"name",
predict_mode=False,
config_missing=config_missing,
config_lagged_regressors=None,
config_country_holidays=None,
config_events=None,
config_train=config_train,
predict_steps=1,
prediction_frequency=None,
)


def test_future_df_with_nan():
Expand Down

0 comments on commit 30aa303

Please sign in to comment.