diff --git a/.gitignore b/.gitignore index 22c1ad65..8e898fb3 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,10 @@ docs/_build/ # Jupyter Notebook .ipynb_checkpoints +# PyCharm +.idea/ +.coverage + # Distribution / packaging .Python env/ diff --git a/deepobs/pytorch/datasets/tolstoi.py b/deepobs/pytorch/datasets/tolstoi.py index 562b104a..ffcb23d7 100644 --- a/deepobs/pytorch/datasets/tolstoi.py +++ b/deepobs/pytorch/datasets/tolstoi.py @@ -7,8 +7,8 @@ import torch from torch.utils import data as dat -from .. import config from . import dataset +from ...config import get_data_dir class tolstoi(dataset.DataSet): @@ -43,10 +43,9 @@ def __init__(self, batch_size, seq_length=50, train_eval_size=653237): self._train_eval_size = train_eval_size super(tolstoi, self).__init__(batch_size) - def _make_dataloader(self, filepath): - # Load the array of character ids, determine the number of batches that - # can be produced, given batch size and sequence lengh - arr = np.load(filepath) + def _make_tolstoi_dataloader(self, arr): + # determine the number of batches that can be produced, given batch size + # and sequence lengh num_batches = int( np.floor((np.size(arr) - 1) / (self._batch_size * self._seq_length)) ) @@ -79,8 +78,8 @@ def _make_dataloader(self, filepath): return dataset def _make_train_dataloader(self): - filepath = os.path.join(config.get_data_dir(), "tolstoi", "train.npy") - return self._make_dataloader(filepath) + filepath = os.path.join(get_data_dir(), "tolstoi", "train.npy") + return self._make_tolstoi_dataloader(np.load(filepath)) def _make_train_eval_dataloader(self): indices = np.arange( @@ -90,5 +89,12 @@ def _make_train_eval_dataloader(self): return dat.TensorDataset(train_eval_set[0], train_eval_set[1]) def _make_test_dataloader(self): - filepath = os.path.join(config.get_data_dir(), "tolstoi", "test.npy") - return self._make_dataloader(filepath) + filepath = os.path.join(get_data_dir(), "tolstoi", "test.npy") + return self._make_tolstoi_dataloader(np.load(filepath)) + + def _make_train_and_valid_dataloader(self): + filepath = os.path.join(get_data_dir(), "tolstoi", "train.npy") + data = np.load(filepath) + valid_data = data[0: self._train_eval_size] + train_data = data[self._train_eval_size:] + return self._make_tolstoi_dataloader(valid_data), self._make_tolstoi_dataloader(train_data) diff --git a/deepobs/pytorch/testproblems/__init__.py b/deepobs/pytorch/testproblems/__init__.py index 3b8bd446..9cc5ac2a 100644 --- a/deepobs/pytorch/testproblems/__init__.py +++ b/deepobs/pytorch/testproblems/__init__.py @@ -21,3 +21,4 @@ from .svhn_3c3d import svhn_3c3d from .svhn_wrn164 import svhn_wrn164 from .testproblem import TestProblem +from .tolstoi_char_rnn import tolstoi_char_rnn diff --git a/deepobs/pytorch/testproblems/testproblem.py b/deepobs/pytorch/testproblems/testproblem.py index b661a50e..d6d689b2 100644 --- a/deepobs/pytorch/testproblems/testproblem.py +++ b/deepobs/pytorch/testproblems/testproblem.py @@ -143,7 +143,7 @@ def forward_func(): loss = self.loss_function(reduction=reduction)(outputs, labels) _, predicted = torch.max(outputs.data, 1) - total += labels.size(0) + total += labels.numel() correct += (predicted == labels).sum().item() accuracy = correct / total diff --git a/deepobs/pytorch/testproblems/testproblems_modules.py b/deepobs/pytorch/testproblems/testproblems_modules.py index 682284bd..877a0f8a 100644 --- a/deepobs/pytorch/testproblems/testproblems_modules.py +++ b/deepobs/pytorch/testproblems/testproblems_modules.py @@ -713,26 +713,35 @@ def __init__(self, seq_len, hidden_dim, vocab_size, num_layers): self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=hidden_dim ) + self.dropout = nn.Dropout(p=0.2) self.lstm = nn.LSTM( input_size=hidden_dim, hidden_size=hidden_dim, num_layers=num_layers, - dropout=0.2, + dropout=0.36, # tensorflow two dropouts with keep=0.8 each -> dropout=1-0.8*0.8=0.36 batch_first=True, ) + # deactivate redundant bias + self.lstm.bias_ih_l0.data = torch.zeros_like(self.lstm.bias_ih_l0, device=self.lstm.bias_ih_l0.device) + self.lstm.bias_ih_l1.data = torch.zeros_like(self.lstm.bias_ih_l1, device=self.lstm.bias_ih_l0.device) + self.lstm.bias_ih_l0.requires_grad = False + self.lstm.bias_ih_l1.requires_grad = False + self.dense = nn.Linear(in_features=hidden_dim, out_features=vocab_size) - # TODO init layers? def forward(self, x, state=None): """state is a tuple for hidden and cell state for initialisation of the lstm""" x = self.embedding(x) + x = self.dropout(x) # if no state is provided, default the state to zeros if state is None: x, new_state = self.lstm(x) else: x, new_state = self.lstm(x, state) - x = self.dense(x) - return x, new_state + x = self.dropout(x) + output = self.dense(x) + output = output.transpose(1, 2) + return output # , new_state class net_quadratic_deep(nn.Sequential): diff --git a/deepobs/pytorch/testproblems/tolstoi_char_rnn.py b/deepobs/pytorch/testproblems/tolstoi_char_rnn.py new file mode 100644 index 00000000..7375d293 --- /dev/null +++ b/deepobs/pytorch/testproblems/tolstoi_char_rnn.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +"""A vanilla RNN architecture for Tolstoi.""" +from torch import nn + +from deepobs.pytorch.testproblems.testproblem import WeightRegularizedTestproblem +from .testproblems_modules import net_char_rnn +from ..datasets.tolstoi import tolstoi + + +class tolstoi_char_rnn(WeightRegularizedTestproblem): + """DeepOBS test problem class for a two-layer LSTM for character-level language + modelling (Char RNN) on Tolstoi's War and Peace. + + Some network characteristics: + + - ``128`` hidden units per LSTM cell + - sequence length ``50`` + - cell state is automatically stored in variables between subsequent steps + - when the phase placeholder switches its value from one step to the next, + the cell state is set to its zero value (meaning that we set to zero state + after each round of evaluation, it is therefore important to set the + evaluation interval such that we evaluate after a full epoch.) + + Working training parameters are: + + - batch size ``50`` + - ``200`` epochs + - SGD with a learning rate of :math:`\\approx 0.1` works + + Args: + batch_size (int): Batch size to use. + l2_reg (float): L2-regularization factor. L2-Regularization (weight decay) + is used on the weights but not the biases. + Defaults to ``5e-4``. + + Attributes: + data: The dataset used by the test problem (datasets.DataSet instance). + loss_function: The loss function for this test problem. + net: The torch module (the neural network) that is trained. + """ + + def __init__(self, batch_size, l2_reg=0.0005): + """Create a new char_rnn test problem instance on Tolstoi. + + Args: + batch_size (int): Batch size to use. + l2_reg (float): L2-regularization factor. L2-Regularization (weight decay) + is used on the weights but not the biases. + Defaults to ``5e-4``. + """ + super(tolstoi_char_rnn, self).__init__(batch_size, l2_reg) + + def set_up(self): + """Set up the Char RNN test problem on Tolstoi.""" + self.data = tolstoi(self._batch_size) + self.loss_function = nn.CrossEntropyLoss + self.net = net_char_rnn(hidden_dim=128, num_layers=2, seq_len=50, vocab_size=83) + self.net.to(self._device) + self.regularization_groups = self.get_regularization_groups() diff --git a/tests/test_testproblems.py b/tests/test_testproblems.py index 10a8ad95..1fac964f 100644 --- a/tests/test_testproblems.py +++ b/tests/test_testproblems.py @@ -19,7 +19,7 @@ # Basic Settings of the Test BATCH_SIZE = 8 -NR_PT_TESTPROBLEMS = 20 +NR_PT_TESTPROBLEMS = 21 NR_TF_TESTPROBLEMS = 27 DEVICES = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"] FRAMEWORKS = ["pytorch", "tensorflow"] @@ -147,8 +147,13 @@ def _check_parameters(tproblem, framework): num_param = [] if framework == "pytorch": - for parameter in tproblem.net.parameters(): - num_param.append(parameter.numel()) + for name, parameter in tproblem.net.named_parameters(): + if parameter.requires_grad is False: + continue + elif "weight_hh_l" in name: # LSTM parameters counted separately in PyTorch + num_param[-1] += parameter.numel() + else: + num_param.append(parameter.numel()) elif framework == "tensorflow": num_param = [np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]