From bffdf0b8aa94d6966d6d4d9eee611259ea8dc853 Mon Sep 17 00:00:00 2001 From: Elvis Dohmatob Date: Tue, 13 Mar 2018 06:55:17 +0100 Subject: [PATCH] ENH: using sklearn.utils for create RNGs and generating minibatches --- spotlight/cross_validation.py | 7 ++----- spotlight/datasets/synthetic.py | 5 +++-- spotlight/factorization/implicit.py | 4 +++- spotlight/sampling.py | 6 ++---- spotlight/sequence/implicit.py | 4 +++- spotlight/torch_utils.py | 14 +++++++------- tests/test_torch_utils.py | 16 ++++++++++++++++ 7 files changed, 36 insertions(+), 20 deletions(-) create mode 100644 tests/test_torch_utils.py diff --git a/spotlight/cross_validation.py b/spotlight/cross_validation.py index 912ffbeb..ccb378ec 100644 --- a/spotlight/cross_validation.py +++ b/spotlight/cross_validation.py @@ -4,7 +4,7 @@ import numpy as np -from sklearn.utils import murmurhash3_32 +from sklearn.utils import murmurhash3_32, check_random_state from spotlight.interactions import Interactions @@ -36,10 +36,7 @@ def shuffle_interactions(interactions, interactions: :class:`spotlight.interactions.Interactions` The shuffled interactions. """ - - if random_state is None: - random_state = np.random.RandomState() - + random_state = check_random_state(random_state) shuffle_indices = np.arange(len(interactions.user_ids)) random_state.shuffle(shuffle_indices) diff --git a/spotlight/datasets/synthetic.py b/spotlight/datasets/synthetic.py index dec2d850..80b8361e 100644 --- a/spotlight/datasets/synthetic.py +++ b/spotlight/datasets/synthetic.py @@ -6,6 +6,8 @@ import numpy as np +from sklearn.utils import check_random_state + from spotlight.interactions import Interactions @@ -108,8 +110,7 @@ def generate_sequential(num_users=100, instance of the interactions class """ - if random_state is None: - random_state = np.random.RandomState() + random_state = check_random_state(random_state) transition_matrix = _build_transition_matrix( num_items - 1, diff --git a/spotlight/factorization/implicit.py b/spotlight/factorization/implicit.py index 59b57cf4..1f470956 100644 --- a/spotlight/factorization/implicit.py +++ b/spotlight/factorization/implicit.py @@ -4,6 +4,8 @@ import numpy as np +from sklearn.utils import check_random_state + import torch import torch.optim as optim @@ -104,7 +106,7 @@ def __init__(self, self._representation = representation self._sparse = sparse self._optimizer_func = optimizer_func - self._random_state = random_state or np.random.RandomState() + self._random_state = check_random_state(random_state) self._num_negative_samples = num_negative_samples self._num_users = None diff --git a/spotlight/sampling.py b/spotlight/sampling.py index a2a43cb2..c3a51d00 100644 --- a/spotlight/sampling.py +++ b/spotlight/sampling.py @@ -3,6 +3,7 @@ """ import numpy as np +from sklearn.utils import check_random_state def sample_items(num_items, shape, random_state=None): @@ -27,10 +28,7 @@ def sample_items(num_items, shape, random_state=None): items: np.array of shape [shape] Sampled item ids. """ - - if random_state is None: - random_state = np.random.RandomState() - + random_state = check_random_state(random_state) items = random_state.randint(0, num_items, shape, dtype=np.int64) return items diff --git a/spotlight/sequence/implicit.py b/spotlight/sequence/implicit.py index 509be5d3..31a32dcc 100644 --- a/spotlight/sequence/implicit.py +++ b/spotlight/sequence/implicit.py @@ -5,6 +5,8 @@ import numpy as np +from sklearn.utils import check_random_state + import torch import torch.optim as optim @@ -119,7 +121,7 @@ def __init__(self, self._use_cuda = use_cuda self._sparse = sparse self._optimizer_func = optimizer_func - self._random_state = random_state or np.random.RandomState() + self._random_state = check_random_state(random_state) self._num_negative_samples = num_negative_samples self._num_items = None diff --git a/spotlight/torch_utils.py b/spotlight/torch_utils.py index f425fa30..71ccd79e 100644 --- a/spotlight/torch_utils.py +++ b/spotlight/torch_utils.py @@ -1,5 +1,7 @@ import numpy as np +from sklearn.utils import check_random_state, gen_batches + import torch @@ -20,16 +22,15 @@ def cpu(tensor): def minibatch(*tensors, **kwargs): - batch_size = kwargs.get('batch_size', 128) if len(tensors) == 1: tensor = tensors[0] - for i in range(0, len(tensor), batch_size): - yield tensor[i:i + batch_size] + for minibatch_indices in gen_batches(len(tensor), batch_size): + yield tensor[minibatch_indices] else: - for i in range(0, len(tensors[0]), batch_size): - yield tuple(x[i:i + batch_size] for x in tensors) + for minibatch_indices in gen_batches(len(tensors[0]), batch_size): + yield tuple(x[minibatch_indices] for x in tensors) def shuffle(*arrays, **kwargs): @@ -40,8 +41,7 @@ def shuffle(*arrays, **kwargs): raise ValueError('All inputs to shuffle must have ' 'the same length.') - if random_state is None: - random_state = np.random.RandomState() + random_state = check_random_state(random_state) shuffle_indices = np.arange(len(arrays[0])) random_state.shuffle(shuffle_indices) diff --git a/tests/test_torch_utils.py b/tests/test_torch_utils.py new file mode 100644 index 00000000..625198fb --- /dev/null +++ b/tests/test_torch_utils.py @@ -0,0 +1,16 @@ +import numpy as np +import torch +from spotlight.torch_utils import minibatch + + +def test_minibatch(): + for n in [1, 12]: + data = torch.randn(n, 6) + s = 0 + ss = 0 + for x in minibatch(data, batch_size=3): + s += np.prod(x.size()) + ss += x.sum() + assert x.size(1) == data.size(1) + assert s == np.prod(data.size()) + assert ss == data.sum()