Skip to content

Commit

Permalink
ENH: using sklearn.utils for create RNGs and generating minibatches
Browse files Browse the repository at this point in the history
  • Loading branch information
dohmatob committed Mar 13, 2018
1 parent 2285a04 commit bffdf0b
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 20 deletions.
7 changes: 2 additions & 5 deletions spotlight/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from sklearn.utils import murmurhash3_32
from sklearn.utils import murmurhash3_32, check_random_state

from spotlight.interactions import Interactions

Expand Down Expand Up @@ -36,10 +36,7 @@ def shuffle_interactions(interactions,
interactions: :class:`spotlight.interactions.Interactions`
The shuffled interactions.
"""

if random_state is None:
random_state = np.random.RandomState()

random_state = check_random_state(random_state)
shuffle_indices = np.arange(len(interactions.user_ids))
random_state.shuffle(shuffle_indices)

Expand Down
5 changes: 3 additions & 2 deletions spotlight/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np

from sklearn.utils import check_random_state

from spotlight.interactions import Interactions


Expand Down Expand Up @@ -108,8 +110,7 @@ def generate_sequential(num_users=100,
instance of the interactions class
"""

if random_state is None:
random_state = np.random.RandomState()
random_state = check_random_state(random_state)

transition_matrix = _build_transition_matrix(
num_items - 1,
Expand Down
4 changes: 3 additions & 1 deletion spotlight/factorization/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np

from sklearn.utils import check_random_state

import torch

import torch.optim as optim
Expand Down Expand Up @@ -104,7 +106,7 @@ def __init__(self,
self._representation = representation
self._sparse = sparse
self._optimizer_func = optimizer_func
self._random_state = random_state or np.random.RandomState()
self._random_state = check_random_state(random_state)
self._num_negative_samples = num_negative_samples

self._num_users = None
Expand Down
6 changes: 2 additions & 4 deletions spotlight/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import numpy as np
from sklearn.utils import check_random_state


def sample_items(num_items, shape, random_state=None):
Expand All @@ -27,10 +28,7 @@ def sample_items(num_items, shape, random_state=None):
items: np.array of shape [shape]
Sampled item ids.
"""

if random_state is None:
random_state = np.random.RandomState()

random_state = check_random_state(random_state)
items = random_state.randint(0, num_items, shape, dtype=np.int64)

return items
4 changes: 3 additions & 1 deletion spotlight/sequence/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import numpy as np

from sklearn.utils import check_random_state

import torch

import torch.optim as optim
Expand Down Expand Up @@ -119,7 +121,7 @@ def __init__(self,
self._use_cuda = use_cuda
self._sparse = sparse
self._optimizer_func = optimizer_func
self._random_state = random_state or np.random.RandomState()
self._random_state = check_random_state(random_state)
self._num_negative_samples = num_negative_samples

self._num_items = None
Expand Down
14 changes: 7 additions & 7 deletions spotlight/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from sklearn.utils import check_random_state, gen_batches

import torch


Expand All @@ -20,16 +22,15 @@ def cpu(tensor):


def minibatch(*tensors, **kwargs):

batch_size = kwargs.get('batch_size', 128)

if len(tensors) == 1:
tensor = tensors[0]
for i in range(0, len(tensor), batch_size):
yield tensor[i:i + batch_size]
for minibatch_indices in gen_batches(len(tensor), batch_size):
yield tensor[minibatch_indices]
else:
for i in range(0, len(tensors[0]), batch_size):
yield tuple(x[i:i + batch_size] for x in tensors)
for minibatch_indices in gen_batches(len(tensors[0]), batch_size):
yield tuple(x[minibatch_indices] for x in tensors)


def shuffle(*arrays, **kwargs):
Expand All @@ -40,8 +41,7 @@ def shuffle(*arrays, **kwargs):
raise ValueError('All inputs to shuffle must have '
'the same length.')

if random_state is None:
random_state = np.random.RandomState()
random_state = check_random_state(random_state)

shuffle_indices = np.arange(len(arrays[0]))
random_state.shuffle(shuffle_indices)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np
import torch
from spotlight.torch_utils import minibatch


def test_minibatch():
for n in [1, 12]:
data = torch.randn(n, 6)
s = 0
ss = 0
for x in minibatch(data, batch_size=3):
s += np.prod(x.size())
ss += x.sum()
assert x.size(1) == data.size(1)
assert s == np.prod(data.size())
assert ss == data.sum()

0 comments on commit bffdf0b

Please sign in to comment.