From 9bbd1c95015ef930a45e41ae0573e9f7dcfbe5aa Mon Sep 17 00:00:00 2001 From: Greg DeVosNouri Date: Sun, 22 Sep 2024 19:46:40 -0700 Subject: [PATCH] change name --- darts/models/__init__.py | 7 ++++--- darts/models/components/embed.py | 2 -- darts/models/forecasting/__init__.py | 2 +- .../{time_net_model.py => times_net_model.py} | 15 +++++++-------- .../forecasting/test_global_forecasting_models.py | 4 ++-- .../forecasting/test_historical_forecasts.py | 4 ++-- .../forecasting/test_probabilistic_models.py | 2 +- .../forecasting/test_torch_forecasting_model.py | 6 ++++-- 8 files changed, 21 insertions(+), 21 deletions(-) rename darts/models/forecasting/{time_net_model.py => times_net_model.py} (98%) diff --git a/darts/models/__init__.py b/darts/models/__init__.py index 4e32fb07fe..0836d9cb7d 100644 --- a/darts/models/__init__.py +++ b/darts/models/__init__.py @@ -33,6 +33,7 @@ from darts.models.forecasting.regression_model import RegressionModel from darts.models.forecasting.tbats_model import BATS, TBATS from darts.models.forecasting.theta import FourTheta, Theta +from darts.models.forecasting.times_net_model import TimesNetModel from darts.models.forecasting.varima import VARIMA try: @@ -50,7 +51,7 @@ from darts.models.forecasting.tcn_model import TCNModel from darts.models.forecasting.tft_model import TFTModel from darts.models.forecasting.tide_model import TiDEModel - from darts.models.forecasting.time_net_model import TimeNetModel + from darts.models.forecasting.times_net_model import TimesNetModel from darts.models.forecasting.transformer_model import TransformerModel from darts.models.forecasting.tsmixer_model import TSMixerModel except ModuleNotFoundError: @@ -72,7 +73,7 @@ TFTModel = NotImportedModule(module_name="(Py)Torch", warn=False) TiDEModel = NotImportedModule(module_name="(Py)Torch", warn=False) TransformerModel = NotImportedModule(module_name="(Py)Torch", warn=False) - TimeNetModel = NotImportedModule(module_name="(Py)Torch", warn=False) + TimesNetModel = NotImportedModule(module_name="(Py)Torch", warn=False) TSMixerModel = NotImportedModule(module_name="(Py)Torch", warn=False) try: @@ -153,7 +154,7 @@ "TFTModel", "TiDEModel", "TransformerModel", - "TimeNetModel", + "TimesNetModel", "TSMixerModel", "Prophet", "CatBoostModel", diff --git a/darts/models/components/embed.py b/darts/models/components/embed.py index 4c18d6690b..28859c35b6 100644 --- a/darts/models/components/embed.py +++ b/darts/models/components/embed.py @@ -2,8 +2,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils import weight_norm class PositionalEmbedding(nn.Module): diff --git a/darts/models/forecasting/__init__.py b/darts/models/forecasting/__init__.py index ddea38e60a..3c1cb01306 100644 --- a/darts/models/forecasting/__init__.py +++ b/darts/models/forecasting/__init__.py @@ -42,7 +42,7 @@ - :class:`~darts.models.forecasting.nhits.NHiTSModel` - :class:`~darts.models.forecasting.tcn_model.TCNModel` - :class:`~darts.models.forecasting.transformer_model.TransformerModel` - - :class:`~darts.models.forecasting.time_net_model.TimeNetModel` + - :class:`~darts.models.forecasting.time_net_model.TimesNetModel` - :class:`~darts.models.forecasting.tft_model.TFTModel` - :class:`~darts.models.forecasting.dlinear.DLinearModel` - :class:`~darts.models.forecasting.nlinear.NLinearModel` diff --git a/darts/models/forecasting/time_net_model.py b/darts/models/forecasting/times_net_model.py similarity index 98% rename from darts/models/forecasting/time_net_model.py rename to darts/models/forecasting/times_net_model.py index 9e837a684f..91bba8ebec 100644 --- a/darts/models/forecasting/time_net_model.py +++ b/darts/models/forecasting/times_net_model.py @@ -1,11 +1,10 @@ from typing import Tuple -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from darts.models.forecasting.embed import DataEmbedding +from darts.models.components.embed import DataEmbedding from darts.models.forecasting.pl_forecasting_module import ( PLPastCovariatesModule, io_processor, @@ -107,7 +106,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return res -class _TimeNetModule(PLPastCovariatesModule): +class _TimesNetModule(PLPastCovariatesModule): def __init__( self, input_dim: int, @@ -166,7 +165,7 @@ def forward(self, x_in: Tuple) -> torch.Tensor: return y -class TimeNetModel(PastCovariatesTorchModel): +class TimesNetModel(PastCovariatesTorchModel): def __init__( self, input_chunk_length: int, @@ -179,7 +178,7 @@ def __init__( **kwargs, ): """ - TimeNet model for time series forecasting. + TimesNet model for time series forecasting. Parameters: ----------- @@ -331,13 +330,13 @@ def encode_year(idx): Examples -------- >>> from darts.datasets import WeatherDataset - >>> from darts.models import TimeNetModel + >>> from darts.models import TimesNetModel >>> series = WeatherDataset().load() >>> # predicting atmospheric pressure >>> target = series['p (mbar)'][:100] >>> # optionally, use past observed rainfall (pretending to be unknown beyond index 100) >>> past_cov = series['rain (mm)'][:100] - >>> model = TimeNetModel( + >>> model = TimesNetModel( >>> input_chunk_length=6, >>> output_chunk_length=6, >>> n_epochs=20 @@ -368,7 +367,7 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module: output_dim = train_sample[-1].shape[1] nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters - return _TimeNetModule( + return _TimesNetModule( input_dim=input_dim, output_dim=output_dim, nr_params=nr_params, diff --git a/darts/tests/models/forecasting/test_global_forecasting_models.py b/darts/tests/models/forecasting/test_global_forecasting_models.py index d2d043e748..9caaca1718 100644 --- a/darts/tests/models/forecasting/test_global_forecasting_models.py +++ b/darts/tests/models/forecasting/test_global_forecasting_models.py @@ -33,9 +33,9 @@ TCNModel, TFTModel, TiDEModel, + TimesNetModel, TransformerModel, TSMixerModel, - TimeNetModel, ) from darts.models.forecasting.torch_forecasting_model import ( DualCovariatesTorchModel, @@ -161,7 +161,7 @@ 60.0, ), ( - TimeNetModel, + TimesNetModel, { "n_epochs": 10, "pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"], diff --git a/darts/tests/models/forecasting/test_historical_forecasts.py b/darts/tests/models/forecasting/test_historical_forecasts.py index a92c830c94..962009728e 100644 --- a/darts/tests/models/forecasting/test_historical_forecasts.py +++ b/darts/tests/models/forecasting/test_historical_forecasts.py @@ -36,9 +36,9 @@ TCNModel, TFTModel, TiDEModel, + TimesNetModel, TransformerModel, TSMixerModel, - TimeNetModel, ) from darts.utils.likelihood_models import GaussianLikelihood, QuantileRegression @@ -184,7 +184,7 @@ "PastCovariates", ), ( - TimeNetModel, + TimesNetModel, { "input_chunk_length": IN_LEN, "output_chunk_length": OUT_LEN, diff --git a/darts/tests/models/forecasting/test_probabilistic_models.py b/darts/tests/models/forecasting/test_probabilistic_models.py index 2186494037..080eff9e7b 100644 --- a/darts/tests/models/forecasting/test_probabilistic_models.py +++ b/darts/tests/models/forecasting/test_probabilistic_models.py @@ -34,9 +34,9 @@ TCNModel, TFTModel, TiDEModel, + TimeNetModel, TransformerModel, TSMixerModel, - TimeNetModel, ) from darts.models.forecasting.torch_forecasting_model import TorchForecastingModel from darts.utils.likelihood_models import ( diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 9ef83d6ba0..fea580dfb0 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -45,9 +45,10 @@ TCNModel, TFTModel, TiDEModel, + TimesNetModel, TransformerModel, TSMixerModel, - TimeNetModel, + iTransformerModel, ) from darts.models.components.layer_norm_variants import RINorm from darts.models.forecasting.global_baseline_models import _GlobalNaiveModel @@ -103,7 +104,8 @@ (TFTModel, {"add_relative_index": 2, **kwargs, **tft_light_kwargs}), (TiDEModel, kwargs), (TransformerModel, dict(kwargs, **trafo_light_kwargs)), - (TimeNetModel, kwargs), + (iTransformerModel, kwargs), + (TimesNetModel, kwargs), (TSMixerModel, kwargs), (GlobalNaiveSeasonal, kwargs), (GlobalNaiveAggregate, kwargs),