Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] PyTorch: Tolstoi Char RNN #40

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ docs/_build/
# Jupyter Notebook
.ipynb_checkpoints

# PyCharm
.idea/
.coverage

# Distribution / packaging
.Python
env/
Expand Down
24 changes: 15 additions & 9 deletions deepobs/pytorch/datasets/tolstoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch
from torch.utils import data as dat

from .. import config
from . import dataset
from ...config import get_data_dir


class tolstoi(dataset.DataSet):
Expand Down Expand Up @@ -43,10 +43,9 @@ def __init__(self, batch_size, seq_length=50, train_eval_size=653237):
self._train_eval_size = train_eval_size
super(tolstoi, self).__init__(batch_size)

def _make_dataloader(self, filepath):
# Load the array of character ids, determine the number of batches that
# can be produced, given batch size and sequence lengh
arr = np.load(filepath)
def _make_tolstoi_dataloader(self, arr):
# determine the number of batches that can be produced, given batch size
# and sequence lengh
num_batches = int(
np.floor((np.size(arr) - 1) / (self._batch_size * self._seq_length))
)
Expand Down Expand Up @@ -79,8 +78,8 @@ def _make_dataloader(self, filepath):
return dataset

def _make_train_dataloader(self):
filepath = os.path.join(config.get_data_dir(), "tolstoi", "train.npy")
return self._make_dataloader(filepath)
filepath = os.path.join(get_data_dir(), "tolstoi", "train.npy")
return self._make_tolstoi_dataloader(np.load(filepath))

def _make_train_eval_dataloader(self):
indices = np.arange(
Expand All @@ -90,5 +89,12 @@ def _make_train_eval_dataloader(self):
return dat.TensorDataset(train_eval_set[0], train_eval_set[1])

def _make_test_dataloader(self):
filepath = os.path.join(config.get_data_dir(), "tolstoi", "test.npy")
return self._make_dataloader(filepath)
filepath = os.path.join(get_data_dir(), "tolstoi", "test.npy")
return self._make_tolstoi_dataloader(np.load(filepath))

def _make_train_and_valid_dataloader(self):
filepath = os.path.join(get_data_dir(), "tolstoi", "train.npy")
data = np.load(filepath)
valid_data = data[0: self._train_eval_size]
train_data = data[self._train_eval_size:]
return self._make_tolstoi_dataloader(valid_data), self._make_tolstoi_dataloader(train_data)
1 change: 1 addition & 0 deletions deepobs/pytorch/testproblems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .svhn_3c3d import svhn_3c3d
from .svhn_wrn164 import svhn_wrn164
from .testproblem import TestProblem
from .tolstoi_char_rnn import tolstoi_char_rnn
2 changes: 1 addition & 1 deletion deepobs/pytorch/testproblems/testproblem.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward_func():
loss = self.loss_function(reduction=reduction)(outputs, labels)

_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
total += labels.numel()
correct += (predicted == labels).sum().item()

accuracy = correct / total
Expand Down
17 changes: 13 additions & 4 deletions deepobs/pytorch/testproblems/testproblems_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,26 +713,35 @@ def __init__(self, seq_len, hidden_dim, vocab_size, num_layers):
self.embedding = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=hidden_dim
)
self.dropout = nn.Dropout(p=0.2)
self.lstm = nn.LSTM(
input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
dropout=0.2,
dropout=0.36, # tensorflow two dropouts with keep=0.8 each -> dropout=1-0.8*0.8=0.36
batch_first=True,
)
# deactivate redundant bias
self.lstm.bias_ih_l0.data = torch.zeros_like(self.lstm.bias_ih_l0, device=self.lstm.bias_ih_l0.device)
self.lstm.bias_ih_l1.data = torch.zeros_like(self.lstm.bias_ih_l1, device=self.lstm.bias_ih_l0.device)
self.lstm.bias_ih_l0.requires_grad = False
self.lstm.bias_ih_l1.requires_grad = False

self.dense = nn.Linear(in_features=hidden_dim, out_features=vocab_size)
# TODO init layers?

def forward(self, x, state=None):
"""state is a tuple for hidden and cell state for initialisation of the lstm"""
x = self.embedding(x)
x = self.dropout(x)
# if no state is provided, default the state to zeros
if state is None:
x, new_state = self.lstm(x)
else:
x, new_state = self.lstm(x, state)
x = self.dense(x)
return x, new_state
x = self.dropout(x)
output = self.dense(x)
output = output.transpose(1, 2)
return output # , new_state


class net_quadratic_deep(nn.Sequential):
Expand Down
59 changes: 59 additions & 0 deletions deepobs/pytorch/testproblems/tolstoi_char_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
"""A vanilla RNN architecture for Tolstoi."""
from torch import nn

from deepobs.pytorch.testproblems.testproblem import WeightRegularizedTestproblem
from .testproblems_modules import net_char_rnn
from ..datasets.tolstoi import tolstoi


class tolstoi_char_rnn(WeightRegularizedTestproblem):
"""DeepOBS test problem class for a two-layer LSTM for character-level language
modelling (Char RNN) on Tolstoi's War and Peace.

Some network characteristics:

- ``128`` hidden units per LSTM cell
- sequence length ``50``
- cell state is automatically stored in variables between subsequent steps
- when the phase placeholder switches its value from one step to the next,
the cell state is set to its zero value (meaning that we set to zero state
after each round of evaluation, it is therefore important to set the
evaluation interval such that we evaluate after a full epoch.)

Working training parameters are:

- batch size ``50``
- ``200`` epochs
- SGD with a learning rate of :math:`\\approx 0.1` works
Comment on lines +24 to +28
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from tensorflow and not verified.
In Tensorflow, the CrossEntropyLoss takes mean across time axis and sum across batch axis.
Such an option does not exist in PyTorch. The only options are "sum" or "mean" for both axes. Currently "mean" is chosen.
In this case, the learning rate should be a factor batch_size bigger, because gradients are a factor batch_size smaller.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you instead use "sum" and divide by seq_length (or whatever the variable name for the width of the time axis is)?
It would be great if running, e.g. SGD with lr=0.1 produced similar results in PyTorch and TensorFlow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly the idea we discussed in person. However, it turns out that this didn't work.
The division by seq_length must happen only after the CrossEntropyLoss. Therefore, it cannot be part of the model.
I see two possibilities:

  • introduce a custom CrossEntropyLoss, that divides after applying CrossEntropyLoss.
  • leave it as it is

In my opinion, both options are quite bad.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easiest would be to change the definition of the Loss in the TensorFlow version to something compatible with PyTorch...
Let me think about this and I will address and merge it once I find time for DeepOBS again.


Args:
batch_size (int): Batch size to use.
l2_reg (float): L2-regularization factor. L2-Regularization (weight decay)
is used on the weights but not the biases.
Defaults to ``5e-4``.

Attributes:
data: The dataset used by the test problem (datasets.DataSet instance).
loss_function: The loss function for this test problem.
net: The torch module (the neural network) that is trained.
"""

def __init__(self, batch_size, l2_reg=0.0005):
"""Create a new char_rnn test problem instance on Tolstoi.

Args:
batch_size (int): Batch size to use.
l2_reg (float): L2-regularization factor. L2-Regularization (weight decay)
is used on the weights but not the biases.
Defaults to ``5e-4``.
"""
super(tolstoi_char_rnn, self).__init__(batch_size, l2_reg)

def set_up(self):
"""Set up the Char RNN test problem on Tolstoi."""
self.data = tolstoi(self._batch_size)
self.loss_function = nn.CrossEntropyLoss
self.net = net_char_rnn(hidden_dim=128, num_layers=2, seq_len=50, vocab_size=83)
self.net.to(self._device)
self.regularization_groups = self.get_regularization_groups()
11 changes: 8 additions & 3 deletions tests/test_testproblems.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Basic Settings of the Test
BATCH_SIZE = 8
NR_PT_TESTPROBLEMS = 20
NR_PT_TESTPROBLEMS = 21
NR_TF_TESTPROBLEMS = 27
DEVICES = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"]
FRAMEWORKS = ["pytorch", "tensorflow"]
Expand Down Expand Up @@ -147,8 +147,13 @@ def _check_parameters(tproblem, framework):
num_param = []

if framework == "pytorch":
for parameter in tproblem.net.parameters():
num_param.append(parameter.numel())
for name, parameter in tproblem.net.named_parameters():
if parameter.requires_grad is False:
continue
elif "weight_hh_l" in name: # LSTM parameters counted separately in PyTorch
num_param[-1] += parameter.numel()
else:
num_param.append(parameter.numel())
elif framework == "tensorflow":
num_param = [np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]

Expand Down