Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample weights #120

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# spotlight results files
*_results.txt

*~
*#*

Expand All @@ -18,3 +21,4 @@

# IDE
tags
cscope.out
122 changes: 116 additions & 6 deletions spotlight/factorization/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def fit(self, interactions, verbose=False):
verbose: bool
Output additional information about current epoch and loss.
"""
# Call weighted fit method if sample weights are specified
if interactions.weights is not None:
return self._fit_weighted(self, interactions, verbose=False)

user_ids = interactions.user_ids.astype(np.int64)
item_ids = interactions.item_ids.astype(np.int64)
Expand All @@ -220,11 +223,19 @@ def fit(self, interactions, verbose=False):

epoch_loss = 0.0

for (minibatch_num,
(batch_user,
batch_item)) in enumerate(minibatch(user_ids_tensor,
item_ids_tensor,
batch_size=self._batch_size)):
for (
minibatch_num,
(
batch_user,
batch_item,
)
) in enumerate(
minibatch(
user_ids_tensor,
item_ids_tensor,
batch_size=self._batch_size
)
):

positive_prediction = self._net(batch_user, batch_item)

Expand All @@ -236,7 +247,106 @@ def fit(self, interactions, verbose=False):

self._optimizer.zero_grad()

loss = self._loss_func(positive_prediction, negative_prediction)
loss = self._loss_func(
positive_prediction,
negative_prediction,
)
epoch_loss += loss.item()

loss.backward()
self._optimizer.step()

epoch_loss /= minibatch_num + 1

if verbose:
print('Epoch {}: loss {}'.format(epoch_num, epoch_loss))

if np.isnan(epoch_loss) or epoch_loss == 0.0:
raise ValueError('Degenerate epoch loss: {}'
.format(epoch_loss))

def _fit_weighted(self, interactions, verbose=False):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not convinced about having a separate _fit_weighted function, even if it is internal. I think it introduces too much code duplication that will have to be kept in sync.

What about modifying minibatch to look roughly like this:

def minibatch(*tensors, **kwargs):

    batch_size = kwargs.get('batch_size', 128)

    if len(tensors) == 1:
        tensor = tensors[0]
        for i in range(0, len(tensor), batch_size):
            yield tensor[i:i + batch_size]
    else:
        for i in range(0, len(tensors[0]), batch_size):
            yield tuple(x[i:i + batch_size] if x is not None else None for x in tensors)

This way, it emits tensor slices if an argument is a tensor (as before), but also emits None in the tuple if an argument is None.

"""
Fit the model using sample weights.

When called repeatedly, model fitting will resume from
the point at which training stopped in the previous fit
call.

Parameters
----------

interactions: :class:`spotlight.interactions.Interactions`
The input dataset.

verbose: bool
Output additional information about current epoch and loss.
"""
if interactions.weights is None:
raise ValueError('''Sample weights must be specified in the
interactions object. If you don't have sample weights,
use the fit method instead''')

user_ids = interactions.user_ids.astype(np.int64)
item_ids = interactions.item_ids.astype(np.int64)
sample_weights = interactions.weights.astype(np.float32)

if not self._initialized:
self._initialize(interactions)

self._check_input(user_ids, item_ids)

for epoch_num in range(self._n_iter):

users, items, sample_weights = shuffle(
user_ids,
item_ids,
sample_weights,
random_state=self._random_state
)

user_ids_tensor = gpu(torch.from_numpy(users),
self._use_cuda)
item_ids_tensor = gpu(torch.from_numpy(items),
self._use_cuda)
sample_weights_tensor = gpu(
torch.from_numpy(sample_weights),
self._use_cuda
)

epoch_loss = 0.0

for (
minibatch_num,
(
batch_user,
batch_item,
batch_sample_weights
)
) in enumerate(
minibatch(
user_ids_tensor,
item_ids_tensor,
sample_weights_tensor,
batch_size=self._batch_size
)
):

positive_prediction = self._net(batch_user, batch_item)

if self._loss == 'adaptive_hinge':
negative_prediction = self._get_multiple_negative_predictions(
batch_user, n=self._num_negative_samples)
else:
negative_prediction = self._get_negative_prediction(batch_user)

self._optimizer.zero_grad()

loss = self._loss_func(
positive_prediction,
negative_prediction,
sample_weights=batch_sample_weights
)
epoch_loss += loss.item()

loss.backward()
Expand Down
Loading