diff --git a/CHANGELOG.md b/CHANGELOG.md index b4a36ee2..25dbc196 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,50 @@ - Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM` to model config. +- Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment [#282](https://github.com/inseq-team/inseq/pull/282) + +- - A new `SliceAggregator` (`"slices"`) is added to allow for slicing source (in encoder-decoder) or target (in decoder-only) tokens from a `FeatureAttributionSequenceOutput` object, using the same syntax of `ContiguousSpanAggregator`. The `__getitem__` method of the `FeatureAttributionSequenceOutput` is a shortcut for this, allowing slicing with `[start:stop]` syntax. [#282](https://github.com/inseq-team/inseq/pull/282) + +```python +import inseq +from inseq.data.aggregator import SliceAggregator + +attrib_model = inseq.load_model("gpt2", "attention") +input_prompt = """Instruction: Summarize this article. +Input_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic. +Summary:""" + +full_output_prompt = input_prompt + " Elara discovers a shimmering box under an ancient tree, unlocking a world of magic." + +out = attrib_model.attribute(input_prompt, full_output_prompt)[0] + +# These are all equivalent ways to slice only the input text contents +out_sliced = out.aggregate(SliceAggregator, target_spans=(13,73)) +out_sliced = out.aggregate("slices", target_spans=(13,73)) +out_sliced = out[13:73] +``` + +- The `__sub__` method in `FeatureAttributionSequenceOutput` is now used as a shortcut for `PairAggregator` [#282](https://github.com/inseq-team/inseq/pull/282) + + +```python +import inseq + +attrib_model = inseq.load_model("gpt2", "saliency") + +out_male = attrib_model.attribute( + "The director went home because", + "The director went home because he was tired", + step_scores=["probability"] +)[0] +out_female = attrib_model.attribute( + "The director went home because", + "The director went home because she was tired", + step_scores=["probability"] +)[0] +(out_male - out_female).show() +``` + ## 🔧 Fixes and Refactoring - Fix the issue in the attention implementation from [#268](https://github.com/inseq-team/inseq/issues/268) where non-terminal position in the tensor were set to nan if they were 0s ([#269](https://github.com/inseq-team/inseq/pull/269)). @@ -14,6 +58,9 @@ - Fix bug reported in [#266](https://github.com/inseq-team/inseq/issues/266) making `value_zeroing` unusable for SDPA attention. This enables using the method on models using SDPA attention as default (e.g. `GemmaForCausalLM`) without passing `model_kwargs={'attn_implementation': 'eager'}` ([#267](https://github.com/inseq-team/inseq/pull/267)). +- The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y [#282](https://github.com/inseq-team/inseq/pull/282) + + ## 📝 Documentation and Tutorials *No changes* diff --git a/Makefile b/Makefile index a03222c5..9df893fa 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ fix-style: .PHONY: check-safety check-safety: - $(PYTHON) -m safety check --full-report + $(PYTHON) -m safety check --full-report -i 70612 .PHONY: lint lint: fix-style check-safety diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py index ec4ad481..32265b09 100644 --- a/inseq/data/aggregator.py +++ b/inseq/data/aggregator.py @@ -357,9 +357,9 @@ def end_aggregation_hook(cls, attr: "FeatureAttributionSequenceOutput", **kwargs assert attr.target_attributions.ndim == 2, attr.target_attributions.shape except AssertionError as e: raise RuntimeError( - f"The aggregated attributions should be 2-dimensional to be visualized. Found dimensions: {e.args[0]}" - "If you're performing intermediate aggregation and don't aim to visualize the output right away, use" - "do_post_aggregation_checks=False in the aggregate method to bypass this check." + f"The aggregated attributions should be 2-dimensional to be visualized.\nFound dimensions: {e.args[0]}" + "\n\nIf you're performing intermediate aggregation and don't aim to visualize the output right away, " + "use do_post_aggregation_checks=False in the aggregate method to bypass this check." ) from e @staticmethod @@ -530,7 +530,7 @@ def format_spans(spans) -> list[tuple[int, int]]: return [spans] if isinstance(spans[0], int) else spans @classmethod - def validate_spans(cls, span_sequence: "FeatureAttributionSequenceOutput", spans: Optional[IndexSpan] = None): + def validate_spans(cls, span_sequence: list[TokenWithId], spans: Optional[IndexSpan] = None): if not spans: return allmatch = lambda l, type: all(isinstance(x, type) for x in l) @@ -545,7 +545,7 @@ def validate_spans(cls, span_sequence: "FeatureAttributionSequenceOutput", spans assert ( span[0] >= prev_span_max ), f"Spans must be postive-valued, non-overlapping and in ascending order, got {spans}" - assert span[1] < len(span_sequence), f"Span values must be indexes of the original span, got {spans}" + assert span[1] <= len(span_sequence), f"Span values must be indexes of the original span, got {spans}" prev_span_max = span[1] @staticmethod @@ -808,3 +808,136 @@ def aggregate_sequence_scores(attr, paired_attr, aggregate_fn, **kwargs): agg_fn = aggregate_fn[name] if isinstance(aggregate_fn, dict) else aggregate_fn out_dict[name] = agg_fn(sequence_scores, paired_attr.sequence_scores[name]) return out_dict + + +class SliceAggregator(ContiguousSpanAggregator): + """Slices the FeatureAttributionSequenceOutput object into a smaller one containing a subset of its elements. + + Args: + attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The starting attribution object. + source_spans (tuple of [int, int] or sequence of tuples of [int, int], optional): Spans to slice for the + source sequence. Defaults to no slicing performed. + target_spans (tuple of [int, int] or sequence of tuples of [int, int], optional): Spans to slice for the + target sequence. Defaults to no slicing performed. + """ + + aggregator_name = "slices" + default_fn = None + + @classmethod + def aggregate( + cls, + attr: "FeatureAttributionSequenceOutput", + source_spans: Optional[IndexSpan] = None, + target_spans: Optional[IndexSpan] = None, + **kwargs, + ): + """Spans can be: + + 1. A list of the form [pos_start, pos_end] including the contiguous positions of tokens that + are to be aggregated, if all values are integers and len(span) < len(original_seq) + 2. A list of the form [(pos_start_0, pos_end_0), (pos_start_1, pos_end_1)], same as above but + for multiple contiguous spans. + """ + source_spans = cls.format_spans(source_spans) + target_spans = cls.format_spans(target_spans) + + if attr.source_attributions is None: + if source_spans is not None: + logger.warn( + "Source spans are specified but no source scores are given for decoder-only models. " + "Ignoring source spans and using target spans instead." + ) + source_spans = [(s[0], min(s[1], attr.attr_pos_start)) for s in target_spans] + + # Generated tokens are always included in the slices to preserve the output scores + is_gen_added = False + new_target_spans = [] + if target_spans is not None: + for span in target_spans: + if span[1] > attr.attr_pos_start and is_gen_added: + continue + elif span[1] > attr.attr_pos_start and not is_gen_added: + new_target_spans.append((span[0], attr.attr_pos_end)) + is_gen_added = True + else: + new_target_spans.append(span) + if not is_gen_added: + new_target_spans.append((attr.attr_pos_start, attr.attr_pos_end)) + return super().aggregate(attr, source_spans=source_spans, target_spans=new_target_spans, **kwargs) + + @staticmethod + def aggregate_source(attr: "FeatureAttributionSequenceOutput", source_spans: list[tuple[int, int]], **kwargs): + sliced_source = [] + for span in source_spans: + sliced_source.extend(attr.source[span[0] : span[1]]) + return sliced_source + + @staticmethod + def aggregate_target(attr: "FeatureAttributionSequenceOutput", target_spans: list[tuple[int, int]], **kwargs): + sliced_target = [] + for span in target_spans: + sliced_target.extend(attr.target[span[0] : span[1]]) + return sliced_target + + @staticmethod + def aggregate_source_attributions(attr: "FeatureAttributionSequenceOutput", source_spans, **kwargs): + if attr.source_attributions is None: + return attr.source_attributions + return torch.cat( + tuple(attr.source_attributions[span[0] : span[1], ...] for span in source_spans), + dim=0, + ) + + @staticmethod + def aggregate_target_attributions(attr: "FeatureAttributionSequenceOutput", target_spans, **kwargs): + if attr.target_attributions is None: + return attr.target_attributions + return torch.cat( + tuple(attr.target_attributions[span[0] : span[1], ...] for span in target_spans), + dim=0, + ) + + @staticmethod + def aggregate_step_scores(attr: "FeatureAttributionSequenceOutput", **kwargs): + return attr.step_scores + + @classmethod + def aggregate_sequence_scores( + cls, + attr: "FeatureAttributionSequenceOutput", + source_spans, + target_spans, + **kwargs, + ): + if not attr.sequence_scores: + return attr.sequence_scores + out_dict = {} + for name, step_scores in attr.sequence_scores.items(): + if name.startswith("decoder"): + out_dict[name] = torch.cat( + tuple(step_scores[span[0] : span[1], ...] for span in target_spans), + dim=0, + ) + elif name.startswith("encoder"): + out_dict[name] = torch.cat( + tuple(step_scores[span[0] : span[1], span[0] : span[1], ...] for span in source_spans), + dim=0, + ) + else: + out_dict[name] = torch.cat( + tuple(step_scores[span[0] : span[1], ...] for span in source_spans), + dim=0, + ) + return out_dict + + @staticmethod + def aggregate_attr_pos_start(attr: "FeatureAttributionSequenceOutput", target_spans, **kwargs): + if not target_spans: + return attr.attr_pos_start + tot_sliced_len = sum(min(s[1], attr.attr_pos_start) - s[0] for s in target_spans) + return tot_sliced_len + + @staticmethod + def aggregate_attr_pos_end(attr: "FeatureAttributionSequenceOutput", **kwargs): + return attr.attr_pos_end diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index b83da1df..fec015f8 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -165,6 +165,16 @@ def __post_init__(self): if self.attr_pos_end is None or self.attr_pos_end > len(self.target): self.attr_pos_end = len(self.target) + def __getitem__(self, s: Union[slice, int]) -> "FeatureAttributionSequenceOutput": + source_spans = None if self.source_attributions is None else (s.start, s.stop) + target_spans = None if self.source_attributions is not None else (s.start, s.stop) + return self.aggregate("slices", source_spans=source_spans, target_spans=target_spans) + + def __sub__(self, other: "FeatureAttributionSequenceOutput") -> "FeatureAttributionSequenceOutput": + if not isinstance(other, self.__class__): + raise ValueError(f"Cannot compare {type(other)} with {type(self)}") + return self.aggregate("pair", paired_attr=other, do_post_aggregation_checks=False) + @staticmethod def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable: if attr.source_attributions is None or name.startswith("decoder"): diff --git a/inseq/data/viz.py b/inseq/data/viz.py index 5721dff1..28d7004a 100644 --- a/inseq/data/viz.py +++ b/inseq/data/viz.py @@ -209,12 +209,16 @@ def get_saliency_heatmap_html( uuid = "".join(random.choices(string.ascii_lowercase, k=20)) out = saliency_heatmap_table_header # add top row containing target tokens + out += "" + for column_idx in range(len(column_labels)): + out += f"{column_idx}" + out += "" for column_label in column_labels: out += f"{sanitize_html(column_label)}" out += "" if scores is not None: for row_index in range(scores.shape[0]): - out += f"{sanitize_html(row_labels[row_index])}" + out += f"{row_index}{sanitize_html(row_labels[row_index])}" for col_index in range(scores.shape[1]): score = "" if not np.isnan(scores[row_index, col_index]): @@ -223,7 +227,7 @@ def get_saliency_heatmap_html( out += "" if step_scores is not None: for step_score_name, step_score_values in step_scores.items(): - out += f'{step_score_name}' + out += f'{step_score_name}' if isinstance(step_scores_threshold, float): threshold = step_scores_threshold else: @@ -254,20 +258,23 @@ def get_saliency_heatmap_rich( label: str = "", step_scores_threshold: Union[float, dict[str, float]] = 0.5, ): - columns = [Column(header="", justify="right", overflow="fold")] - for column_label in column_labels: - columns.append(Column(header=escape(column_label), justify="center", overflow="fold")) + columns = [ + Column(header="", justify="right", overflow="fold"), + Column(header="", justify="right", overflow="fold"), + ] + for idx, column_label in enumerate(column_labels): + columns.append(Column(header=f"{idx}\n{escape(column_label)}", justify="center", overflow="fold")) table = Table( *columns, title=f"{label + ' ' if label else ''}Saliency Heatmap", - caption="x: Generated tokens, y: Attributed tokens", + caption="→ : Generated tokens, ↓ : Attributed tokens", padding=(0, 1, 0, 1), show_lines=False, box=box.HEAVY_HEAD, ) if scores is not None: for row_index in range(scores.shape[0]): - row = [Text(escape(row_labels[row_index]), style="bold")] + row = [Text(f"{row_index}", style="bold"), Text(escape(row_labels[row_index]), style="bold")] for col_index in range(scores.shape[1]): color = Color.from_rgb(*input_colors[row_index][col_index]) score = "" @@ -282,7 +289,7 @@ def get_saliency_heatmap_rich( else: threshold = step_scores_threshold.get(step_score_name, 0.5) style = lambda val, limit: "bold" if abs(val) >= limit and isinstance(val, float) else "" - score_row = [Text(escape(step_score_name), style="bold")] + score_row = [Text(""), Text(escape(step_score_name), style="bold")] for score in step_score_values: curr_score = round(score.item(), 2) if isinstance(score, float) else score.item() score_row.append(Text(f"{score:.2f}", justify="center", style=style(curr_score, threshold))) diff --git a/inseq/utils/viz_utils.py b/inseq/utils/viz_utils.py index 9de41d1a..69e4a94f 100644 --- a/inseq/utils/viz_utils.py +++ b/inseq/utils/viz_utils.py @@ -109,7 +109,6 @@ def get_colors( saliency_heatmap_table_header = """ - """ saliency_heatmap_html = """ @@ -117,7 +116,7 @@ def get_colors(
{label} Saliency Heatmap
- x: Generated tokens, y: Attributed tokens + → : Generated tokens, ↓ : Attributed tokens
{content} diff --git a/requirements-dev.txt b/requirements-dev.txt index d086fc10..4020d11f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ # This file was autogenerated by uv via the following command: # uv pip compile --all-extras pyproject.toml -o requirements-dev.txt -aiohttp==3.9.3 +aiohttp==3.9.5 # via # datasets # fsspec @@ -19,10 +19,12 @@ authlib==1.3.0 babel==2.14.0 # via sphinx bandit==1.7.7 + # via inseq (pyproject.toml) beautifulsoup4==4.12.3 # via furo captum==0.7.0 -certifi==2024.2.2 + # via inseq (pyproject.toml) +certifi==2024.6.2 # via requests cffi==1.16.0 # via cryptography @@ -49,6 +51,7 @@ cryptography==42.0.5 cycler==0.12.1 # via matplotlib datasets==2.17.0 + # via inseq (pyproject.toml) debugpy==1.8.1 # via ipykernel decorator==5.1.1 @@ -90,6 +93,7 @@ fsspec==2023.10.0 # huggingface-hub # torch furo==2024.1.29 + # via inseq (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.41 @@ -110,21 +114,25 @@ imagesize==1.4.1 iniconfig==2.0.0 # via pytest ipykernel==6.29.2 + # via inseq (pyproject.toml) ipython==8.18.1 # via # ipykernel # ipywidgets ipywidgets==8.1.2 + # via inseq (pyproject.toml) jaxtyping==0.2.25 + # via inseq (pyproject.toml) jedi==0.19.1 # via ipython -jinja2==3.1.3 +jinja2==3.1.4 # via # safety # sphinx # torch joblib==1.3.2 # via + # inseq (pyproject.toml) # nltk # scikit-learn jupyter-client==8.6.0 @@ -144,7 +152,9 @@ markupsafe==2.1.5 marshmallow==3.20.2 # via safety matplotlib==3.8.2 - # via captum + # via + # inseq (pyproject.toml) + # captum matplotlib-inline==0.1.6 # via # ipykernel @@ -164,10 +174,12 @@ nest-asyncio==1.6.0 networkx==3.2.1 # via torch nltk==3.8.1 + # via inseq (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.26.4 # via + # inseq (pyproject.toml) # captum # contourpy # datasets @@ -199,7 +211,7 @@ pbr==6.0.0 # via stevedore pexpect==4.9.0 # via ipython -pillow==10.3.0 +pillow==10.4.0 # via matplotlib platformdirs==4.2.0 # via @@ -208,10 +220,13 @@ platformdirs==4.2.0 pluggy==1.4.0 # via pytest pre-commit==3.6.1 + # via inseq (pyproject.toml) prompt-toolkit==3.0.43 # via ipython protobuf==4.25.2 - # via transformers + # via + # inseq (pyproject.toml) + # transformers psutil==5.9.8 # via ipykernel ptyprocess==0.7.0 @@ -229,6 +244,7 @@ pydantic==1.10.14 # safety # safety-schemas pydoclint==0.4.0 + # via inseq (pyproject.toml) pygments==2.17.2 # via # furo @@ -239,10 +255,13 @@ pyparsing==3.1.1 # via matplotlib pytest==8.0.0 # via + # inseq (pyproject.toml) # pytest-cov # pytest-xdist pytest-cov==4.1.0 + # via inseq (pyproject.toml) pytest-xdist==3.5.0 + # via inseq (pyproject.toml) python-dateutil==2.8.2 # via # jupyter-client @@ -265,7 +284,7 @@ regex==2023.12.25 # via # nltk # transformers -requests==2.31.0 +requests==2.32.3 # via # datasets # fsspec @@ -275,6 +294,7 @@ requests==2.31.0 # transformers rich==13.7.0 # via + # inseq (pyproject.toml) # bandit # safety ruamel-yaml==0.18.6 @@ -284,12 +304,15 @@ ruamel-yaml==0.18.6 ruamel-yaml-clib==0.2.8 # via ruamel-yaml ruff==0.2.1 + # via inseq (pyproject.toml) safetensors==0.4.2 # via transformers safety==3.1.0 + # via inseq (pyproject.toml) safety-schemas==0.0.2 # via safety scikit-learn==1.4.0 + # via inseq (pyproject.toml) scipy==1.12.0 # via scikit-learn sentencepiece==0.1.99 @@ -310,6 +333,7 @@ soupsieve==2.5 # via beautifulsoup4 sphinx==7.2.6 # via + # inseq (pyproject.toml) # furo # sphinx-basic-ng # sphinx-copybutton @@ -320,8 +344,11 @@ sphinx==7.2.6 sphinx-basic-ng==1.0.0b2 # via furo sphinx-copybutton==0.5.2 + # via inseq (pyproject.toml) sphinx-design==0.5.0 + # via inseq (pyproject.toml) sphinx-gitstamp==0.4.0 + # via inseq (pyproject.toml) sphinxcontrib-applehelp==1.0.8 # via sphinx sphinxcontrib-devhelp==1.0.6 @@ -335,7 +362,9 @@ sphinxcontrib-qthelp==1.0.7 sphinxcontrib-serializinghtml==1.1.10 # via sphinx sphinxemoji==0.3.1 + # via inseq (pyproject.toml) sphinxext-opengraph==0.9.1 + # via inseq (pyproject.toml) stack-data==0.6.3 # via ipython stevedore==5.1.0 @@ -347,13 +376,16 @@ threadpoolctl==3.2.0 tokenizers==0.15.2 # via transformers torch==2.2.0 - # via captum + # via + # inseq (pyproject.toml) + # captum tornado==6.4 # via # ipykernel # jupyter-client -tqdm==4.66.2 +tqdm==4.66.4 # via + # inseq (pyproject.toml) # captum # datasets # huggingface-hub @@ -369,8 +401,11 @@ traitlets==5.14.1 # jupyter-core # matplotlib-inline transformers==4.38.1 + # via inseq (pyproject.toml) typeguard==2.13.3 - # via jaxtyping + # via + # inseq (pyproject.toml) + # jaxtyping typer==0.9.0 # via safety typing-extensions==4.9.0 @@ -384,7 +419,7 @@ typing-extensions==4.9.0 # typer tzdata==2024.1 # via pandas -urllib3==2.2.0 +urllib3==2.2.2 # via # requests # safety diff --git a/requirements.txt b/requirements.txt index 05e5b4c7..9c021743 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ # This file was autogenerated by uv via the following command: # uv pip compile pyproject.toml -o requirements.txt captum==0.7.0 -certifi==2024.2.2 + # via inseq (pyproject.toml) +certifi==2024.6.2 # via requests charset-normalizer==3.3.2 # via requests @@ -27,7 +28,8 @@ huggingface-hub==0.20.3 idna==3.7 # via requests jaxtyping==0.2.25 -jinja2==3.1.3 + # via inseq (pyproject.toml) +jinja2==3.1.4 # via torch kiwisolver==1.4.5 # via matplotlib @@ -36,7 +38,9 @@ markdown-it-py==3.0.0 markupsafe==2.1.5 # via jinja2 matplotlib==3.8.2 - # via captum + # via + # inseq (pyproject.toml) + # captum mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 @@ -45,6 +49,7 @@ networkx==3.2.1 # via torch numpy==1.26.4 # via + # inseq (pyproject.toml) # captum # contourpy # jaxtyping @@ -55,10 +60,12 @@ packaging==23.2 # huggingface-hub # matplotlib # transformers -pillow==10.3.0 +pillow==10.4.0 # via matplotlib protobuf==4.25.2 - # via transformers + # via + # inseq (pyproject.toml) + # transformers pygments==2.17.2 # via rich pyparsing==3.1.1 @@ -71,11 +78,12 @@ pyyaml==6.0.1 # transformers regex==2023.12.25 # via transformers -requests==2.31.0 +requests==2.32.3 # via # huggingface-hub # transformers rich==13.7.0 + # via inseq (pyproject.toml) safetensors==0.4.2 # via transformers sentencepiece==0.1.99 @@ -87,19 +95,25 @@ sympy==1.12 tokenizers==0.15.2 # via transformers torch==2.2.0 - # via captum -tqdm==4.66.2 # via + # inseq (pyproject.toml) + # captum +tqdm==4.66.4 + # via + # inseq (pyproject.toml) # captum # huggingface-hub # transformers transformers==4.38.1 + # via inseq (pyproject.toml) typeguard==2.13.3 - # via jaxtyping + # via + # inseq (pyproject.toml) + # jaxtyping typing-extensions==4.9.0 # via # huggingface-hub # jaxtyping # torch -urllib3==2.2.0 +urllib3==2.2.2 # via requests diff --git a/tests/data/test_aggregator.py b/tests/data/test_aggregator.py index f7e7c3e5..14360c7f 100644 --- a/tests/data/test_aggregator.py +++ b/tests/data/test_aggregator.py @@ -5,6 +5,7 @@ from pytest import fixture import inseq +from inseq.data import FeatureAttributionSequenceOutput from inseq.data.aggregator import ( AggregatorPipeline, ContiguousSpanAggregator, @@ -149,6 +150,10 @@ def test_pair_aggregator(saliency_mt_model: HuggingfaceEncoderDecoderModel): diff_seqattr_other = orig_seqattr_other.aggregate("pair", paired_attr=alt_seqattr_other) assert torch.allclose(diff_seqattr_other.source_attributions, diff_seqattr.source_attributions) + # Aggregate with __sub__ + diff_seqattr_sub = orig_seqattr - alt_seqattr + assert diff_seqattr_other == diff_seqattr_sub + def test_named_aggregate_fn_aggregation(saliency_mt_model: HuggingfaceEncoderDecoderModel): out = saliency_mt_model.attribute( @@ -192,3 +197,29 @@ def test_named_aggregate_fn_aggregation(saliency_mt_model: HuggingfaceEncoderDec aggregator=["scores", "scores", "subwords"], aggregate_fn=["mean", "mean", None] ) assert out_allmean_subwords == out_allmean_subwords_expanded + + +def test_slice_aggregator_decoder_only(saliency_gpt_model: HuggingfaceDecoderOnlyModel): + out = saliency_gpt_model.attribute( + EXAMPLES["source_summary"], EXAMPLES["source_summary"] + EXAMPLES["gen_summary"], show_progress=False + )[0] + out_sliced: FeatureAttributionSequenceOutput = out.aggregate("slices", target_spans=(14, 75)) + assert [t.token for t in out_sliced.source] == EXAMPLES["source_summary_tokens"] + assert [t.token for t in out_sliced.target] == EXAMPLES["source_summary_tokens"] + EXAMPLES["gen_summary_tokens"] + + # Slice with __getitem__ + out_sliced_getitem = out[14:75] + assert out_sliced == out_sliced_getitem + + +def test_slice_aggregator_encoder_decoder(saliency_mt_model: HuggingfaceEncoderDecoderModel): + out = saliency_mt_model.attribute( + EXAMPLES["source"], EXAMPLES["target"], show_progress=False, attribute_target=True + )[0] + out_sliced: FeatureAttributionSequenceOutput = out.aggregate("slices", source_spans=(1, 4)) + assert [t.token for t in out_sliced.source] == EXAMPLES["source_subwords"][1:4] + assert [t.token for t in out_sliced.target] == EXAMPLES["target_subwords"][1:] + + # Slice with __getitem__ + out_sliced_getitem = out[1:4] + assert out_sliced == out_sliced_getitem diff --git a/tests/fixtures/aggregator.json b/tests/fixtures/aggregator.json index 53123526..51d6b20d 100644 --- a/tests/fixtures/aggregator.json +++ b/tests/fixtures/aggregator.json @@ -93,5 +93,90 @@ "\u2581models", ".", "" + ], + "source_summary": "Instruction: Summarize this article.\nInput_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic.\nSummary:", + "gen_summary": " Elara discovers a shimmering box under an ancient tree, unlocking a world of magic.", + "source_summary_tokens": [ + "\u0120In", + "\u0120a", + "\u0120quiet", + "\u0120village", + "\u0120nest", + "led", + "\u0120between", + "\u0120rolling", + "\u0120hills", + ",", + "\u0120an", + "\u0120ancient", + "\u0120tree", + "\u0120whispered", + "\u0120secrets", + "\u0120to", + "\u0120those", + "\u0120who", + "\u0120listened", + ".", + "\u0120One", + "\u0120night", + ",", + "\u0120a", + "\u0120curious", + "\u0120child", + "\u0120named", + "\u0120El", + "ara", + "\u0120leaned", + "\u0120close", + "\u0120and", + "\u0120heard", + "\u0120tales", + "\u0120of", + "\u0120hidden", + "\u0120treasures", + "\u0120beneath", + "\u0120the", + "\u0120roots", + ".", + "\u0120As", + "\u0120dawn", + "\u0120broke", + ",", + "\u0120she", + "\u0120unearthed", + "\u0120a", + "\u0120shimmer", + "ing", + "\u0120box", + ",", + "\u0120unlocking", + "\u0120a", + "\u0120forgotten", + "\u0120world", + "\u0120of", + "\u0120wonder", + "\u0120and", + "\u0120magic", + "." + ], + "gen_summary_tokens": [ + "\u0120El", + "ara", + "\u0120discovers", + "\u0120a", + "\u0120shimmer", + "ing", + "\u0120box", + "\u0120under", + "\u0120an", + "\u0120ancient", + "\u0120tree", + ",", + "\u0120unlocking", + "\u0120a", + "\u0120world", + "\u0120of", + "\u0120magic", + "." ] }