Skip to content

Commit

Permalink
Rename label_id and kwargs
Browse files Browse the repository at this point in the history
- `label_id` was misleading since it is actually a list of token ids
  related to a label and not a scalar value. Also the general process
  of generating logits it not related to labels at all but rather just
  to tokens

- `kwargs` was named to be similar to transformers `generate`
  convention but is meant to be passed to `generate` and is therefore,
  in the context of `generate_logits` a model input. This should help
  the reader distinguish between expected input (`token_ids`) and
  model input (`model_input`)
  • Loading branch information
ottonemo committed Jul 26, 2023
1 parent 1c34aca commit da062b0
Showing 1 changed file with 39 additions and 30 deletions.
69 changes: 39 additions & 30 deletions skorch/llm/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def _extend_inputs(inputs, extra):

class _LogitsRecorder(LogitsProcessor):
"""Helper class to record logits and force the given label token ids"""
def __init__(self, label_ids, tokenizer):
def __init__(self, token_ids, tokenizer):
self.recorded_scores = []
self.label_ids = label_ids
self.token_ids = token_ids
self.tokenizer = tokenizer

def __call__(self, input_ids, scores):
Expand All @@ -180,7 +180,7 @@ def __call__(self, input_ids, scores):
# therefore there is no device mismatch and we save a bit of GPU memory
self.recorded_scores.append(scores[0].clone().cpu())
mask = torch.ones(scores.size(), dtype=torch.bool)
mask[0, self.label_ids[idx]] = False
mask[0, self.token_ids[idx]] = False
scores[mask] = -float('inf')
return scores

Expand Down Expand Up @@ -227,30 +227,34 @@ class _CFGuidance(LogitsProcessor):
"""

def __init__(self, model, tokenizer, label_ids, gamma=1.5):
def __init__(self, model, tokenizer, token_ids, gamma=1.5):
self.model = model
self.tokenizer = tokenizer
self.gamma = gamma
self.label_ids = label_ids
self.recorded_scores = []
self.token_ids = token_ids
self.token_position = 0

def __call__(self, input_ids, scores):
idx = len(self.recorded_scores)
idx = self.token_position

P_wi_wjic = scores

model_input = {
'input_ids': torch.tensor(self.label_ids)[None,:].to(self.model.device),
'attention_mask': torch.tensor([1] * len(self.label_ids))[None, :].to(self.model.device)
model_inputs = {
'input_ids': torch.tensor(self.token_ids)[None,:].to(self.model.device),
'attention_mask': torch.tensor([1] * len(self.token_ids))[None, :].to(self.model.device)
}
model_output = self.model.generate(**model_input, output_scores=True, return_dict_in_generate=True)
P_wi_wji = model_output.scores[idx]

model_output = self.model.generate(
**model_inputs,
max_new_tokens=len(self.token_ids),
output_scores=True,
return_dict_in_generate=True)
P_wi_wji = model_output.scores[idx]

# we pull the logits to CPU because they are not used as input,
# therefore there is no device mismatch and we save a bit of GPU memory
# TODO remove this by a counter since we're only using the position
self.recorded_scores.append(scores[0].clone().cpu())
# We assume that this logits processor is called in `generate_logits`
# which invokes the model for each new token in the given sequence
# (self.token_ids). Thus we need to keep track where we are.
self.token_position += 1

scores = P_wi_wji + self.gamma * (P_wi_wjic - P_wi_wji)

Expand Down Expand Up @@ -310,54 +314,59 @@ def set_cache(self, kwargs, label_id, scores):
key = str(input_id)
self.cache[key] = score

def generate_logits(self, *, label_id, **kwargs):
def generate_logits(self, *, token_ids, **model_inputs):
"""Generate logits for given token ids based on a given model input.
The model is forced to only generate the logits for the given token
ids - all other options are weighted down so that they are impossible
to be sampled.
"""
self._total_calls += 1 # mainly for debugging

recorded_logits = []
logits_cached = self.get_cache(kwargs)
logits_cached = self.get_cache(model_inputs)
while logits_cached is not None:
if label_id[0] == self.tokenizer.eos_token_id:
if token_ids[0] == self.tokenizer.eos_token_id:
# don't extend with eos_token -- it is already there at the end,
# we don't need it twice
break

recorded_logits.append(logits_cached)
kwargs = _extend_inputs(kwargs, label_id[:1])
label_id = label_id[1:]
logits_cached = self.get_cache(kwargs)
model_inputs = _extend_inputs(model_inputs, token_ids[:1])
token_ids = token_ids[1:]
logits_cached = self.get_cache(model_inputs)

if not label_id:
if not token_ids:
# the whole generation was cached
return recorded_logits

if label_id[0] == self.tokenizer.pad_token_id:
if token_ids[0] == self.tokenizer.pad_token_id:
# no need to generate on pad tokens
return recorded_logits

self._uncached_calls += 1 # mainly for debugging

recorder = _LogitsRecorder(
label_ids=label_id,
token_ids=token_ids,
tokenizer=self.tokenizer,
)
processors = [recorder]

if self.cfg_gamma is not None:
guidance = _CFGuidance(
model=self.model,
label_ids=label_id,
tokenizer=self.tokenizer,
token_ids=token_ids,
gamma=self.cfg_gamma,
)
processors.insert(0, guidance)

self.model.generate(
logits_processor=processors,
# TODO: should this be the max len of all labels?
max_new_tokens=len(label_id),
**kwargs
max_new_tokens=len(token_ids),
**model_inputs
)
self.set_cache(kwargs, label_id, recorder.recorded_scores)
self.set_cache(model_inputs, token_ids, recorder.recorded_scores)
return recorded_logits + recorder.recorded_scores[:]


Expand Down Expand Up @@ -489,7 +498,7 @@ def _predict_one(self, text):

probas_all_labels = []
for label_id in self.label_ids_:
logits = self.cached_model_.generate_logits(label_id=label_id, **inputs)
logits = self.cached_model_.generate_logits(token_ids=label_id, **inputs)
logits = torch.vstack(logits)
probas = torch.nn.functional.softmax(logits, dim=-1)

Expand Down

0 comments on commit da062b0

Please sign in to comment.