Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove AdapterSpec from metrics #2244

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yifanmai
Copy link
Collaborator

This removes the coupling between the adapter and the metrics, allowing the metrics to be computed only using the requests and results from the model clients.

@yifanmai yifanmai force-pushed the yifanmai/fix-remove-adapter-spec-from-metrics branch from ba53f57 to e23220b Compare January 17, 2024 01:14
@yifanmai yifanmai marked this pull request as draft January 17, 2024 06:09
@yifanmai
Copy link
Collaborator Author

Converting to draft because this requires some manual testing.

Copy link
Contributor

@brianwgoldman brianwgoldman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is also removing ScenarioState from metrics. Consider updating the name/description to mention that.

reference_stats: Dict[ReferenceKey, ReferenceStat] = {}
for request_state in reference_request_states:
assert request_state.reference_index is not None and request_state.request_mode is not None
reference_key = ReferenceKey(request_state.reference_index, request_state.request_mode)
reference_stats[reference_key] = compute_logprob_and_length(request_state, window_service)

if adapter_spec.method in [ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL, ADAPT_RANKING_BINARY]:
is_calibrated = any([request_state.request_mode == "calibration" for request_state in reference_request_states])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "any" here but using reference_request_states[0] to decide model_deployment?

If we are asserting in both cases that they are universal values, maybe we should write a helper to do that assertion?

@@ -294,20 +280,14 @@ def compute_request_state_metrics(
stats: List[Stat] = []

stats.append(Stat(MetricName("num_references")).add(len(request_state.instance.references)))

# Copy from adapter spec
stats.append(Stat(MetricName("num_train_trials")).add(adapter_spec.num_train_trials))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this Stat not needed?

for context, request_states in grouped_request_states.items():
for stat in self.evaluate_instances(request_states):
for request_state in trial_request_states:
grouped_request_states[MetricContext.from_instance(request_state.instance)].append(request_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has potential behavior change since it can include request_states that have non-None reference_index.

if request_state.reference_index is None:
instance_to_request_state_set[instance].generation_states.append(request_state)
else:
instance_to_request_state_set[instance].references_states.append(request_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the reference_states were ordered by reference_index. Is that still guaranteed? Does it matter if the order changes?

@@ -166,7 +149,7 @@ def evaluate(

# Compute per-instance stats
per_instance_stats: List[PerInstanceStats] = []
for instance, stats in zip(scenario_state.instances, results):
for instance, stats in zip(instances, results):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think switching this to zip(request_state_sets, results) would make it less fragile and more clear that we are putting the input and output of the parallel map back together.

@@ -352,3 +333,19 @@ def add_context(stat: Stat, context: MetricContext) -> Stat:
return Stat(
replace(stat.name, split=context.split, sub_split=context.sub_split, perturbation=context.perturbation)
).merge(stat)


def get_num_train_trials(request_states: List[RequestState]) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There appears to be no method calling this? Is it left over from a previous iteration?

instance_to_request_state_set[instance].generation_states.append(request_state)
else:
instance_to_request_state_set[instance].references_states.append(request_state)
request_state_sets: List[RequestStateSet] = list(instance_to_request_state_set.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Order here can also change. Maybe we want an OrderedDict?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants