Skip to content

Commit

Permalink
change name
Browse files Browse the repository at this point in the history
  • Loading branch information
gdevos010 committed Sep 23, 2024
1 parent 3cf2e54 commit 9bbd1c9
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 21 deletions.
7 changes: 4 additions & 3 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from darts.models.forecasting.regression_model import RegressionModel
from darts.models.forecasting.tbats_model import BATS, TBATS
from darts.models.forecasting.theta import FourTheta, Theta
from darts.models.forecasting.times_net_model import TimesNetModel
from darts.models.forecasting.varima import VARIMA

try:
Expand All @@ -50,7 +51,7 @@
from darts.models.forecasting.tcn_model import TCNModel
from darts.models.forecasting.tft_model import TFTModel
from darts.models.forecasting.tide_model import TiDEModel
from darts.models.forecasting.time_net_model import TimeNetModel
from darts.models.forecasting.times_net_model import TimesNetModel
from darts.models.forecasting.transformer_model import TransformerModel
from darts.models.forecasting.tsmixer_model import TSMixerModel
except ModuleNotFoundError:
Expand All @@ -72,7 +73,7 @@
TFTModel = NotImportedModule(module_name="(Py)Torch", warn=False)
TiDEModel = NotImportedModule(module_name="(Py)Torch", warn=False)
TransformerModel = NotImportedModule(module_name="(Py)Torch", warn=False)
TimeNetModel = NotImportedModule(module_name="(Py)Torch", warn=False)
TimesNetModel = NotImportedModule(module_name="(Py)Torch", warn=False)
TSMixerModel = NotImportedModule(module_name="(Py)Torch", warn=False)

try:
Expand Down Expand Up @@ -153,7 +154,7 @@
"TFTModel",
"TiDEModel",
"TransformerModel",
"TimeNetModel",
"TimesNetModel",
"TSMixerModel",
"Prophet",
"CatBoostModel",
Expand Down
2 changes: 0 additions & 2 deletions darts/models/components/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm


class PositionalEmbedding(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
- :class:`~darts.models.forecasting.nhits.NHiTSModel`
- :class:`~darts.models.forecasting.tcn_model.TCNModel`
- :class:`~darts.models.forecasting.transformer_model.TransformerModel`
- :class:`~darts.models.forecasting.time_net_model.TimeNetModel`
- :class:`~darts.models.forecasting.time_net_model.TimesNetModel`
- :class:`~darts.models.forecasting.tft_model.TFTModel`
- :class:`~darts.models.forecasting.dlinear.DLinearModel`
- :class:`~darts.models.forecasting.nlinear.NLinearModel`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from darts.models.forecasting.embed import DataEmbedding
from darts.models.components.embed import DataEmbedding
from darts.models.forecasting.pl_forecasting_module import (
PLPastCovariatesModule,
io_processor,
Expand Down Expand Up @@ -107,7 +106,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return res


class _TimeNetModule(PLPastCovariatesModule):
class _TimesNetModule(PLPastCovariatesModule):
def __init__(
self,
input_dim: int,
Expand Down Expand Up @@ -166,7 +165,7 @@ def forward(self, x_in: Tuple) -> torch.Tensor:
return y


class TimeNetModel(PastCovariatesTorchModel):
class TimesNetModel(PastCovariatesTorchModel):
def __init__(
self,
input_chunk_length: int,
Expand All @@ -179,7 +178,7 @@ def __init__(
**kwargs,
):
"""
TimeNet model for time series forecasting.
TimesNet model for time series forecasting.
Parameters:
-----------
Expand Down Expand Up @@ -331,13 +330,13 @@ def encode_year(idx):
Examples
--------
>>> from darts.datasets import WeatherDataset
>>> from darts.models import TimeNetModel
>>> from darts.models import TimesNetModel
>>> series = WeatherDataset().load()
>>> # predicting atmospheric pressure
>>> target = series['p (mbar)'][:100]
>>> # optionally, use past observed rainfall (pretending to be unknown beyond index 100)
>>> past_cov = series['rain (mm)'][:100]
>>> model = TimeNetModel(
>>> model = TimesNetModel(
>>> input_chunk_length=6,
>>> output_chunk_length=6,
>>> n_epochs=20
Expand Down Expand Up @@ -368,7 +367,7 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
output_dim = train_sample[-1].shape[1]
nr_params = 1 if self.likelihood is None else self.likelihood.num_parameters

return _TimeNetModule(
return _TimesNetModule(
input_dim=input_dim,
output_dim=output_dim,
nr_params=nr_params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
TCNModel,
TFTModel,
TiDEModel,
TimesNetModel,
TransformerModel,
TSMixerModel,
TimeNetModel,
)
from darts.models.forecasting.torch_forecasting_model import (
DualCovariatesTorchModel,
Expand Down Expand Up @@ -161,7 +161,7 @@
60.0,
),
(
TimeNetModel,
TimesNetModel,
{
"n_epochs": 10,
"pl_trainer_kwargs": tfm_kwargs["pl_trainer_kwargs"],
Expand Down
4 changes: 2 additions & 2 deletions darts/tests/models/forecasting/test_historical_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
TCNModel,
TFTModel,
TiDEModel,
TimesNetModel,
TransformerModel,
TSMixerModel,
TimeNetModel,
)
from darts.utils.likelihood_models import GaussianLikelihood, QuantileRegression

Expand Down Expand Up @@ -184,7 +184,7 @@
"PastCovariates",
),
(
TimeNetModel,
TimesNetModel,
{
"input_chunk_length": IN_LEN,
"output_chunk_length": OUT_LEN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
TCNModel,
TFTModel,
TiDEModel,
TimeNetModel,
TransformerModel,
TSMixerModel,
TimeNetModel,
)
from darts.models.forecasting.torch_forecasting_model import TorchForecastingModel
from darts.utils.likelihood_models import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
TCNModel,
TFTModel,
TiDEModel,
TimesNetModel,
TransformerModel,
TSMixerModel,
TimeNetModel,
iTransformerModel,
)
from darts.models.components.layer_norm_variants import RINorm
from darts.models.forecasting.global_baseline_models import _GlobalNaiveModel
Expand Down Expand Up @@ -103,7 +104,8 @@
(TFTModel, {"add_relative_index": 2, **kwargs, **tft_light_kwargs}),
(TiDEModel, kwargs),
(TransformerModel, dict(kwargs, **trafo_light_kwargs)),
(TimeNetModel, kwargs),
(iTransformerModel, kwargs),
(TimesNetModel, kwargs),
(TSMixerModel, kwargs),
(GlobalNaiveSeasonal, kwargs),
(GlobalNaiveAggregate, kwargs),
Expand Down

0 comments on commit 9bbd1c9

Please sign in to comment.