diff --git a/embetter/finetune/_contrastive_tuner.py b/embetter/finetune/_contrastive_tuner.py index 0c58cea..f109281 100644 --- a/embetter/finetune/_contrastive_tuner.py +++ b/embetter/finetune/_contrastive_tuner.py @@ -5,7 +5,6 @@ import numpy as np import torch -import torch.nn as nn from dataclasses import dataclass from ._constrastive_learn import ContrastiveLearner @@ -31,7 +30,7 @@ def generate_pairs_batch(labels, n_neg=3): single_example = {} indices = np.arange(len(labels)) for label, grouper in groupby( - ((s, l) for s, l in zip(indices, labels)), key=lambda x: x[1] + ((s, lab) for s, lab in zip(indices, labels)), key=lambda x: x[1] ): lookup[label].extend(list(i[0] for i in grouper)) single_example[label] = len(lookup[label]) == 1