Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample weights #120

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# spotlight results files
*_results.txt

*~
*#*

Expand All @@ -18,3 +21,4 @@

# IDE
tags
cscope.out
47 changes: 36 additions & 11 deletions spotlight/factorization/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _check_input(self, user_ids, item_ids, allow_items_none=False):

def fit(self, interactions, verbose=False):
"""
Fit the model.
Fit the model using sample weights.

When called repeatedly, model fitting will resume from
the point at which training stopped in the previous fit
Expand All @@ -198,9 +198,11 @@ def fit(self, interactions, verbose=False):
verbose: bool
Output additional information about current epoch and loss.
"""

user_ids = interactions.user_ids.astype(np.int64)
item_ids = interactions.item_ids.astype(np.int64)
sample_weights = None
if interactions.weights is not None:
sample_weights = interactions.weights.astype(np.float32)

if not self._initialized:
self._initialize(interactions)
Expand All @@ -209,22 +211,41 @@ def fit(self, interactions, verbose=False):

for epoch_num in range(self._n_iter):

users, items = shuffle(user_ids,
item_ids,
random_state=self._random_state)
users, items, sample_weights = shuffle(
user_ids,
item_ids,
sample_weights,
random_state=self._random_state
)

user_ids_tensor = gpu(torch.from_numpy(users),
self._use_cuda)
item_ids_tensor = gpu(torch.from_numpy(items),
self._use_cuda)
sample_weights_tensor = None
if sample_weights is not None:
sample_weights_tensor = gpu(
torch.from_numpy(sample_weights),
self._use_cuda
)

epoch_loss = 0.0

for (minibatch_num,
(batch_user,
batch_item)) in enumerate(minibatch(user_ids_tensor,
item_ids_tensor,
batch_size=self._batch_size)):
for (
minibatch_num,
(
batch_user,
batch_item,
batch_sample_weights
)
) in enumerate(
minibatch(
user_ids_tensor,
item_ids_tensor,
sample_weights_tensor,
batch_size=self._batch_size
)
):

positive_prediction = self._net(batch_user, batch_item)

Expand All @@ -236,7 +257,11 @@ def fit(self, interactions, verbose=False):

self._optimizer.zero_grad()

loss = self._loss_func(positive_prediction, negative_prediction)
loss = self._loss_func(
positive_prediction,
negative_prediction,
sample_weights=batch_sample_weights
)
epoch_loss += loss.item()

loss.backward()
Expand Down
140 changes: 107 additions & 33 deletions spotlight/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,44 @@
from spotlight.torch_utils import assert_no_grad


def pointwise_loss(positive_predictions, negative_predictions, mask=None):
def _weighted_loss(loss, sample_weights=None, mask=None):
"""Sample weight and mask handler for loss functions.
If both sample_weights and mask are specified, sample_weights will override
as one may zero-out, as well as scale, certain entries via the weights.

Parameters
----------

loss: tensor
Tensor with element-wise losses from one of the loss functions in this
file.
sample_weights: tensor, optional
Tensor containing weights to scale the loss by.
mask: tensor, optional
A binary tensor used to zero the loss from some entries
of the loss tensor.

Returns
-------

loss, float
The mean value of the loss function.
"""
if sample_weights is not None:
loss = loss * sample_weights
return loss.sum() / sample_weights.sum()

if mask is not None:
mask = mask.float()
loss = loss * mask
return loss.sum() / mask.sum()

return loss.mean()


def pointwise_loss(
positive_predictions, negative_predictions,
sample_weights=None, mask=None):
"""
Logistic loss function.

Expand All @@ -26,6 +63,8 @@ def pointwise_loss(positive_predictions, negative_predictions, mask=None):
Tensor containing predictions for known positive items.
negative_predictions: tensor
Tensor containing predictions for sampled negative items.
sample_weights: tensor, optional
Tensor containing weights to scale the loss by.
mask: tensor, optional
A binary tensor used to zero the loss from some entries
of the loss tensor.
Expand All @@ -42,15 +81,12 @@ def pointwise_loss(positive_predictions, negative_predictions, mask=None):

loss = (positives_loss + negatives_loss)

if mask is not None:
mask = mask.float()
loss = loss * mask
return loss.sum() / mask.sum()

return loss.mean()
return _weighted_loss(loss, sample_weights, mask)


def bpr_loss(positive_predictions, negative_predictions, mask=None):
def bpr_loss(
positive_predictions, negative_predictions,
sample_weights=None, mask=None):
"""
Bayesian Personalised Ranking [1]_ pairwise loss function.

Expand All @@ -61,6 +97,8 @@ def bpr_loss(positive_predictions, negative_predictions, mask=None):
Tensor containing predictions for known positive items.
negative_predictions: tensor
Tensor containing predictions for sampled negative items.
sample_weights: tensor, optional
Tensor containing weights to scale the loss by.
mask: tensor, optional
A binary tensor used to zero the loss from some entries
of the loss tensor.
Expand All @@ -82,15 +120,12 @@ def bpr_loss(positive_predictions, negative_predictions, mask=None):
loss = (1.0 - F.sigmoid(positive_predictions -
negative_predictions))

if mask is not None:
mask = mask.float()
loss = loss * mask
return loss.sum() / mask.sum()

return loss.mean()
return _weighted_loss(loss, sample_weights, mask)


def hinge_loss(positive_predictions, negative_predictions, mask=None):
def hinge_loss(
positive_predictions, negative_predictions,
sample_weights=None, mask=None):
"""
Hinge pairwise loss function.

Expand All @@ -101,6 +136,8 @@ def hinge_loss(positive_predictions, negative_predictions, mask=None):
Tensor containing predictions for known positive items.
negative_predictions: tensor
Tensor containing predictions for sampled negative items.
sample_weights: tensor, optional
Tensor containing weights to scale the loss by.
mask: tensor, optional
A binary tensor used to zero the loss from some entries
of the loss tensor.
Expand All @@ -116,22 +153,19 @@ def hinge_loss(positive_predictions, negative_predictions, mask=None):
positive_predictions +
1.0, 0.0)

if mask is not None:
mask = mask.float()
loss = loss * mask
return loss.sum() / mask.sum()

return loss.mean()
return _weighted_loss(loss, sample_weights, mask)


def adaptive_hinge_loss(positive_predictions, negative_predictions, mask=None):
def adaptive_hinge_loss(
positive_predictions, negative_predictions,
sample_weights=None, mask=None):
"""
Adaptive hinge pairwise loss function. Takes a set of predictions
for implicitly negative items, and selects those that are highest,
thus sampling those negatives that are closes to violating the
thus sampling those negatives that are closest to violating the
ranking implicit in the pattern of user interactions.

Approximates the idea of weighted approximate-rank pairwise loss
Approximates the idea of Weighted Approximate-Rank Pairwise (WARP) loss
introduced in [2]_

Parameters
Expand All @@ -143,6 +177,8 @@ def adaptive_hinge_loss(positive_predictions, negative_predictions, mask=None):
Iterable of tensors containing predictions for sampled negative items.
More tensors increase the likelihood of finding ranking-violating
pairs, but risk overfitting.
sample_weights: tensor, optional
Tensor containing weights to scale the loss by.
mask: tensor, optional
A binary tensor used to zero the loss from some entries
of the loss tensor.
Expand All @@ -163,10 +199,17 @@ def adaptive_hinge_loss(positive_predictions, negative_predictions, mask=None):

highest_negative_predictions, _ = torch.max(negative_predictions, 0)

return hinge_loss(positive_predictions, highest_negative_predictions.squeeze(), mask=mask)
return hinge_loss(
positive_predictions,
highest_negative_predictions.squeeze(),
sample_weights=sample_weights,
mask=mask
)


def regression_loss(observed_ratings, predicted_ratings):
def regression_loss(
observed_ratings, predicted_ratings,
sample_weights=None, mask=None):
"""
Regression loss.

Expand All @@ -177,6 +220,11 @@ def regression_loss(observed_ratings, predicted_ratings):
Tensor containing observed ratings.
predicted_ratings: tensor
Tensor containing rating predictions.
sample_weights: tensor, optional
Tensor containing weights to scale the loss by.
mask: tensor, optional
A binary tensor used to zero the loss from some entries
of the loss tensor.

Returns
-------
Expand All @@ -186,11 +234,14 @@ def regression_loss(observed_ratings, predicted_ratings):
"""

assert_no_grad(observed_ratings)
loss = (observed_ratings - predicted_ratings) ** 2

return ((observed_ratings - predicted_ratings) ** 2).mean()
return _weighted_loss(loss, sample_weights, mask)


def poisson_loss(observed_ratings, predicted_ratings):
def poisson_loss(
observed_ratings, predicted_ratings,
sample_weights=None, mask=None):
"""
Poisson loss.

Expand All @@ -201,6 +252,11 @@ def poisson_loss(observed_ratings, predicted_ratings):
Tensor containing observed ratings.
predicted_ratings: tensor
Tensor containing rating predictions.
sample_weights: tensor, optional
Tensor containing weights to scale the loss by.
mask: tensor, optional
A binary tensor used to zero the loss from some entries
of the loss tensor.

Returns
-------
Expand All @@ -210,11 +266,14 @@ def poisson_loss(observed_ratings, predicted_ratings):
"""

assert_no_grad(observed_ratings)
loss = predicted_ratings - observed_ratings * torch.log(predicted_ratings)

return (predicted_ratings - observed_ratings * torch.log(predicted_ratings)).mean()
return _weighted_loss(loss, sample_weights, mask)


def logistic_loss(observed_ratings, predicted_ratings):
def logistic_loss(
observed_ratings, predicted_ratings,
sample_weights=None, mask=None):
"""
Logistic loss for explicit data.

Expand All @@ -226,6 +285,11 @@ def logistic_loss(observed_ratings, predicted_ratings):
should be +1 or -1 for this loss function.
predicted_ratings: tensor
Tensor containing rating predictions.
sample_weights: tensor, optional
Tensor containing weights to scale the loss by.
mask: tensor, optional
A binary tensor used to zero the loss from some entries
of the loss tensor.

Returns
-------
Expand All @@ -239,6 +303,16 @@ def logistic_loss(observed_ratings, predicted_ratings):
# Convert target classes from (-1, 1) to (0, 1)
observed_ratings = torch.clamp(observed_ratings, 0, 1)

return F.binary_cross_entropy_with_logits(predicted_ratings,
observed_ratings,
size_average=True)
if sample_weights is not None or mask is not None:
loss = F.binary_cross_entropy_with_logits(
predicted_ratings,
observed_ratings,
size_average=False
)
return _weighted_loss(loss, sample_weights, mask)

return F.binary_cross_entropy_with_logits(
predicted_ratings,
observed_ratings,
size_average=True
)
6 changes: 3 additions & 3 deletions spotlight/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def minibatch(*tensors, **kwargs):
yield tensor[i:i + batch_size]
else:
for i in range(0, len(tensors[0]), batch_size):
yield tuple(x[i:i + batch_size] for x in tensors)
yield tuple(x[i:i + batch_size] if x is not None else None for x in tensors)


def shuffle(*arrays, **kwargs):

random_state = kwargs.get('random_state')

if len(set(len(x) for x in arrays)) != 1:
if len(set(len(x) for x in arrays if x is not None)) != 1:
raise ValueError('All inputs to shuffle must have '
'the same length.')

Expand All @@ -49,7 +49,7 @@ def shuffle(*arrays, **kwargs):
if len(arrays) == 1:
return arrays[0][shuffle_indices]
else:
return tuple(x[shuffle_indices] for x in arrays)
return tuple(x[shuffle_indices] if x is not None else None for x in arrays)


def assert_no_grad(variable):
Expand Down