diff --git a/CHANGELOG.md b/CHANGELOG.md
index b4a36ee2..25dbc196 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,6 +6,50 @@
- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM` to model config.
+- Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment [#282](https://github.com/inseq-team/inseq/pull/282)
+
+- - A new `SliceAggregator` (`"slices"`) is added to allow for slicing source (in encoder-decoder) or target (in decoder-only) tokens from a `FeatureAttributionSequenceOutput` object, using the same syntax of `ContiguousSpanAggregator`. The `__getitem__` method of the `FeatureAttributionSequenceOutput` is a shortcut for this, allowing slicing with `[start:stop]` syntax. [#282](https://github.com/inseq-team/inseq/pull/282)
+
+```python
+import inseq
+from inseq.data.aggregator import SliceAggregator
+
+attrib_model = inseq.load_model("gpt2", "attention")
+input_prompt = """Instruction: Summarize this article.
+Input_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic.
+Summary:"""
+
+full_output_prompt = input_prompt + " Elara discovers a shimmering box under an ancient tree, unlocking a world of magic."
+
+out = attrib_model.attribute(input_prompt, full_output_prompt)[0]
+
+# These are all equivalent ways to slice only the input text contents
+out_sliced = out.aggregate(SliceAggregator, target_spans=(13,73))
+out_sliced = out.aggregate("slices", target_spans=(13,73))
+out_sliced = out[13:73]
+```
+
+- The `__sub__` method in `FeatureAttributionSequenceOutput` is now used as a shortcut for `PairAggregator` [#282](https://github.com/inseq-team/inseq/pull/282)
+
+
+```python
+import inseq
+
+attrib_model = inseq.load_model("gpt2", "saliency")
+
+out_male = attrib_model.attribute(
+ "The director went home because",
+ "The director went home because he was tired",
+ step_scores=["probability"]
+)[0]
+out_female = attrib_model.attribute(
+ "The director went home because",
+ "The director went home because she was tired",
+ step_scores=["probability"]
+)[0]
+(out_male - out_female).show()
+```
+
## 🔧 Fixes and Refactoring
- Fix the issue in the attention implementation from [#268](https://github.com/inseq-team/inseq/issues/268) where non-terminal position in the tensor were set to nan if they were 0s ([#269](https://github.com/inseq-team/inseq/pull/269)).
@@ -14,6 +58,9 @@
- Fix bug reported in [#266](https://github.com/inseq-team/inseq/issues/266) making `value_zeroing` unusable for SDPA attention. This enables using the method on models using SDPA attention as default (e.g. `GemmaForCausalLM`) without passing `model_kwargs={'attn_implementation': 'eager'}` ([#267](https://github.com/inseq-team/inseq/pull/267)).
+- The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y [#282](https://github.com/inseq-team/inseq/pull/282)
+
+
## 📝 Documentation and Tutorials
*No changes*
diff --git a/Makefile b/Makefile
index a03222c5..9df893fa 100644
--- a/Makefile
+++ b/Makefile
@@ -82,7 +82,7 @@ fix-style:
.PHONY: check-safety
check-safety:
- $(PYTHON) -m safety check --full-report
+ $(PYTHON) -m safety check --full-report -i 70612
.PHONY: lint
lint: fix-style check-safety
diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py
index ec4ad481..32265b09 100644
--- a/inseq/data/aggregator.py
+++ b/inseq/data/aggregator.py
@@ -357,9 +357,9 @@ def end_aggregation_hook(cls, attr: "FeatureAttributionSequenceOutput", **kwargs
assert attr.target_attributions.ndim == 2, attr.target_attributions.shape
except AssertionError as e:
raise RuntimeError(
- f"The aggregated attributions should be 2-dimensional to be visualized. Found dimensions: {e.args[0]}"
- "If you're performing intermediate aggregation and don't aim to visualize the output right away, use"
- "do_post_aggregation_checks=False in the aggregate method to bypass this check."
+ f"The aggregated attributions should be 2-dimensional to be visualized.\nFound dimensions: {e.args[0]}"
+ "\n\nIf you're performing intermediate aggregation and don't aim to visualize the output right away, "
+ "use do_post_aggregation_checks=False in the aggregate method to bypass this check."
) from e
@staticmethod
@@ -530,7 +530,7 @@ def format_spans(spans) -> list[tuple[int, int]]:
return [spans] if isinstance(spans[0], int) else spans
@classmethod
- def validate_spans(cls, span_sequence: "FeatureAttributionSequenceOutput", spans: Optional[IndexSpan] = None):
+ def validate_spans(cls, span_sequence: list[TokenWithId], spans: Optional[IndexSpan] = None):
if not spans:
return
allmatch = lambda l, type: all(isinstance(x, type) for x in l)
@@ -545,7 +545,7 @@ def validate_spans(cls, span_sequence: "FeatureAttributionSequenceOutput", spans
assert (
span[0] >= prev_span_max
), f"Spans must be postive-valued, non-overlapping and in ascending order, got {spans}"
- assert span[1] < len(span_sequence), f"Span values must be indexes of the original span, got {spans}"
+ assert span[1] <= len(span_sequence), f"Span values must be indexes of the original span, got {spans}"
prev_span_max = span[1]
@staticmethod
@@ -808,3 +808,136 @@ def aggregate_sequence_scores(attr, paired_attr, aggregate_fn, **kwargs):
agg_fn = aggregate_fn[name] if isinstance(aggregate_fn, dict) else aggregate_fn
out_dict[name] = agg_fn(sequence_scores, paired_attr.sequence_scores[name])
return out_dict
+
+
+class SliceAggregator(ContiguousSpanAggregator):
+ """Slices the FeatureAttributionSequenceOutput object into a smaller one containing a subset of its elements.
+
+ Args:
+ attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The starting attribution object.
+ source_spans (tuple of [int, int] or sequence of tuples of [int, int], optional): Spans to slice for the
+ source sequence. Defaults to no slicing performed.
+ target_spans (tuple of [int, int] or sequence of tuples of [int, int], optional): Spans to slice for the
+ target sequence. Defaults to no slicing performed.
+ """
+
+ aggregator_name = "slices"
+ default_fn = None
+
+ @classmethod
+ def aggregate(
+ cls,
+ attr: "FeatureAttributionSequenceOutput",
+ source_spans: Optional[IndexSpan] = None,
+ target_spans: Optional[IndexSpan] = None,
+ **kwargs,
+ ):
+ """Spans can be:
+
+ 1. A list of the form [pos_start, pos_end] including the contiguous positions of tokens that
+ are to be aggregated, if all values are integers and len(span) < len(original_seq)
+ 2. A list of the form [(pos_start_0, pos_end_0), (pos_start_1, pos_end_1)], same as above but
+ for multiple contiguous spans.
+ """
+ source_spans = cls.format_spans(source_spans)
+ target_spans = cls.format_spans(target_spans)
+
+ if attr.source_attributions is None:
+ if source_spans is not None:
+ logger.warn(
+ "Source spans are specified but no source scores are given for decoder-only models. "
+ "Ignoring source spans and using target spans instead."
+ )
+ source_spans = [(s[0], min(s[1], attr.attr_pos_start)) for s in target_spans]
+
+ # Generated tokens are always included in the slices to preserve the output scores
+ is_gen_added = False
+ new_target_spans = []
+ if target_spans is not None:
+ for span in target_spans:
+ if span[1] > attr.attr_pos_start and is_gen_added:
+ continue
+ elif span[1] > attr.attr_pos_start and not is_gen_added:
+ new_target_spans.append((span[0], attr.attr_pos_end))
+ is_gen_added = True
+ else:
+ new_target_spans.append(span)
+ if not is_gen_added:
+ new_target_spans.append((attr.attr_pos_start, attr.attr_pos_end))
+ return super().aggregate(attr, source_spans=source_spans, target_spans=new_target_spans, **kwargs)
+
+ @staticmethod
+ def aggregate_source(attr: "FeatureAttributionSequenceOutput", source_spans: list[tuple[int, int]], **kwargs):
+ sliced_source = []
+ for span in source_spans:
+ sliced_source.extend(attr.source[span[0] : span[1]])
+ return sliced_source
+
+ @staticmethod
+ def aggregate_target(attr: "FeatureAttributionSequenceOutput", target_spans: list[tuple[int, int]], **kwargs):
+ sliced_target = []
+ for span in target_spans:
+ sliced_target.extend(attr.target[span[0] : span[1]])
+ return sliced_target
+
+ @staticmethod
+ def aggregate_source_attributions(attr: "FeatureAttributionSequenceOutput", source_spans, **kwargs):
+ if attr.source_attributions is None:
+ return attr.source_attributions
+ return torch.cat(
+ tuple(attr.source_attributions[span[0] : span[1], ...] for span in source_spans),
+ dim=0,
+ )
+
+ @staticmethod
+ def aggregate_target_attributions(attr: "FeatureAttributionSequenceOutput", target_spans, **kwargs):
+ if attr.target_attributions is None:
+ return attr.target_attributions
+ return torch.cat(
+ tuple(attr.target_attributions[span[0] : span[1], ...] for span in target_spans),
+ dim=0,
+ )
+
+ @staticmethod
+ def aggregate_step_scores(attr: "FeatureAttributionSequenceOutput", **kwargs):
+ return attr.step_scores
+
+ @classmethod
+ def aggregate_sequence_scores(
+ cls,
+ attr: "FeatureAttributionSequenceOutput",
+ source_spans,
+ target_spans,
+ **kwargs,
+ ):
+ if not attr.sequence_scores:
+ return attr.sequence_scores
+ out_dict = {}
+ for name, step_scores in attr.sequence_scores.items():
+ if name.startswith("decoder"):
+ out_dict[name] = torch.cat(
+ tuple(step_scores[span[0] : span[1], ...] for span in target_spans),
+ dim=0,
+ )
+ elif name.startswith("encoder"):
+ out_dict[name] = torch.cat(
+ tuple(step_scores[span[0] : span[1], span[0] : span[1], ...] for span in source_spans),
+ dim=0,
+ )
+ else:
+ out_dict[name] = torch.cat(
+ tuple(step_scores[span[0] : span[1], ...] for span in source_spans),
+ dim=0,
+ )
+ return out_dict
+
+ @staticmethod
+ def aggregate_attr_pos_start(attr: "FeatureAttributionSequenceOutput", target_spans, **kwargs):
+ if not target_spans:
+ return attr.attr_pos_start
+ tot_sliced_len = sum(min(s[1], attr.attr_pos_start) - s[0] for s in target_spans)
+ return tot_sliced_len
+
+ @staticmethod
+ def aggregate_attr_pos_end(attr: "FeatureAttributionSequenceOutput", **kwargs):
+ return attr.attr_pos_end
diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py
index b83da1df..fec015f8 100644
--- a/inseq/data/attribution.py
+++ b/inseq/data/attribution.py
@@ -165,6 +165,16 @@ def __post_init__(self):
if self.attr_pos_end is None or self.attr_pos_end > len(self.target):
self.attr_pos_end = len(self.target)
+ def __getitem__(self, s: Union[slice, int]) -> "FeatureAttributionSequenceOutput":
+ source_spans = None if self.source_attributions is None else (s.start, s.stop)
+ target_spans = None if self.source_attributions is not None else (s.start, s.stop)
+ return self.aggregate("slices", source_spans=source_spans, target_spans=target_spans)
+
+ def __sub__(self, other: "FeatureAttributionSequenceOutput") -> "FeatureAttributionSequenceOutput":
+ if not isinstance(other, self.__class__):
+ raise ValueError(f"Cannot compare {type(other)} with {type(self)}")
+ return self.aggregate("pair", paired_attr=other, do_post_aggregation_checks=False)
+
@staticmethod
def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable:
if attr.source_attributions is None or name.startswith("decoder"):
diff --git a/inseq/data/viz.py b/inseq/data/viz.py
index 5721dff1..28d7004a 100644
--- a/inseq/data/viz.py
+++ b/inseq/data/viz.py
@@ -209,12 +209,16 @@ def get_saliency_heatmap_html(
uuid = "".join(random.choices(string.ascii_lowercase, k=20))
out = saliency_heatmap_table_header
# add top row containing target tokens
+ out += "
| | "
+ for column_idx in range(len(column_labels)):
+ out += f"{column_idx} | "
+ out += "
---|
| | "
for column_label in column_labels:
out += f"{sanitize_html(column_label)} | "
out += "
"
if scores is not None:
for row_index in range(scores.shape[0]):
- out += f"{sanitize_html(row_labels[row_index])} | "
+ out += f"
---|
{row_index} | {sanitize_html(row_labels[row_index])} | "
for col_index in range(scores.shape[1]):
score = ""
if not np.isnan(scores[row_index, col_index]):
@@ -223,7 +227,7 @@ def get_saliency_heatmap_html(
out += "
"
if step_scores is not None:
for step_score_name, step_score_values in step_scores.items():
- out += f'{step_score_name} | '
+ out += f'
---|
| {step_score_name} | '
if isinstance(step_scores_threshold, float):
threshold = step_scores_threshold
else:
@@ -254,20 +258,23 @@ def get_saliency_heatmap_rich(
label: str = "",
step_scores_threshold: Union[float, dict[str, float]] = 0.5,
):
- columns = [Column(header="", justify="right", overflow="fold")]
- for column_label in column_labels:
- columns.append(Column(header=escape(column_label), justify="center", overflow="fold"))
+ columns = [
+ Column(header="", justify="right", overflow="fold"),
+ Column(header="", justify="right", overflow="fold"),
+ ]
+ for idx, column_label in enumerate(column_labels):
+ columns.append(Column(header=f"{idx}\n{escape(column_label)}", justify="center", overflow="fold"))
table = Table(
*columns,
title=f"{label + ' ' if label else ''}Saliency Heatmap",
- caption="x: Generated tokens, y: Attributed tokens",
+ caption="→ : Generated tokens, ↓ : Attributed tokens",
padding=(0, 1, 0, 1),
show_lines=False,
box=box.HEAVY_HEAD,
)
if scores is not None:
for row_index in range(scores.shape[0]):
- row = [Text(escape(row_labels[row_index]), style="bold")]
+ row = [Text(f"{row_index}", style="bold"), Text(escape(row_labels[row_index]), style="bold")]
for col_index in range(scores.shape[1]):
color = Color.from_rgb(*input_colors[row_index][col_index])
score = ""
@@ -282,7 +289,7 @@ def get_saliency_heatmap_rich(
else:
threshold = step_scores_threshold.get(step_score_name, 0.5)
style = lambda val, limit: "bold" if abs(val) >= limit and isinstance(val, float) else ""
- score_row = [Text(escape(step_score_name), style="bold")]
+ score_row = [Text(""), Text(escape(step_score_name), style="bold")]
for score in step_score_values:
curr_score = round(score.item(), 2) if isinstance(score, float) else score.item()
score_row.append(Text(f"{score:.2f}", justify="center", style=style(curr_score, threshold)))
diff --git a/inseq/utils/viz_utils.py b/inseq/utils/viz_utils.py
index 9de41d1a..69e4a94f 100644
--- a/inseq/utils/viz_utils.py
+++ b/inseq/utils/viz_utils.py
@@ -109,7 +109,6 @@ def get_colors(
saliency_heatmap_table_header = """
- |
"""
saliency_heatmap_html = """
@@ -117,7 +116,7 @@ def get_colors(
{label} Saliency Heatmap
- x: Generated tokens, y: Attributed tokens
+ → : Generated tokens, ↓ : Attributed tokens
{content}
diff --git a/requirements-dev.txt b/requirements-dev.txt
index d086fc10..4020d11f 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,6 +1,6 @@
# This file was autogenerated by uv via the following command:
# uv pip compile --all-extras pyproject.toml -o requirements-dev.txt
-aiohttp==3.9.3
+aiohttp==3.9.5
# via
# datasets
# fsspec
@@ -19,10 +19,12 @@ authlib==1.3.0
babel==2.14.0
# via sphinx
bandit==1.7.7
+ # via inseq (pyproject.toml)
beautifulsoup4==4.12.3
# via furo
captum==0.7.0
-certifi==2024.2.2
+ # via inseq (pyproject.toml)
+certifi==2024.6.2
# via requests
cffi==1.16.0
# via cryptography
@@ -49,6 +51,7 @@ cryptography==42.0.5
cycler==0.12.1
# via matplotlib
datasets==2.17.0
+ # via inseq (pyproject.toml)
debugpy==1.8.1
# via ipykernel
decorator==5.1.1
@@ -90,6 +93,7 @@ fsspec==2023.10.0
# huggingface-hub
# torch
furo==2024.1.29
+ # via inseq (pyproject.toml)
gitdb==4.0.11
# via gitpython
gitpython==3.1.41
@@ -110,21 +114,25 @@ imagesize==1.4.1
iniconfig==2.0.0
# via pytest
ipykernel==6.29.2
+ # via inseq (pyproject.toml)
ipython==8.18.1
# via
# ipykernel
# ipywidgets
ipywidgets==8.1.2
+ # via inseq (pyproject.toml)
jaxtyping==0.2.25
+ # via inseq (pyproject.toml)
jedi==0.19.1
# via ipython
-jinja2==3.1.3
+jinja2==3.1.4
# via
# safety
# sphinx
# torch
joblib==1.3.2
# via
+ # inseq (pyproject.toml)
# nltk
# scikit-learn
jupyter-client==8.6.0
@@ -144,7 +152,9 @@ markupsafe==2.1.5
marshmallow==3.20.2
# via safety
matplotlib==3.8.2
- # via captum
+ # via
+ # inseq (pyproject.toml)
+ # captum
matplotlib-inline==0.1.6
# via
# ipykernel
@@ -164,10 +174,12 @@ nest-asyncio==1.6.0
networkx==3.2.1
# via torch
nltk==3.8.1
+ # via inseq (pyproject.toml)
nodeenv==1.8.0
# via pre-commit
numpy==1.26.4
# via
+ # inseq (pyproject.toml)
# captum
# contourpy
# datasets
@@ -199,7 +211,7 @@ pbr==6.0.0
# via stevedore
pexpect==4.9.0
# via ipython
-pillow==10.3.0
+pillow==10.4.0
# via matplotlib
platformdirs==4.2.0
# via
@@ -208,10 +220,13 @@ platformdirs==4.2.0
pluggy==1.4.0
# via pytest
pre-commit==3.6.1
+ # via inseq (pyproject.toml)
prompt-toolkit==3.0.43
# via ipython
protobuf==4.25.2
- # via transformers
+ # via
+ # inseq (pyproject.toml)
+ # transformers
psutil==5.9.8
# via ipykernel
ptyprocess==0.7.0
@@ -229,6 +244,7 @@ pydantic==1.10.14
# safety
# safety-schemas
pydoclint==0.4.0
+ # via inseq (pyproject.toml)
pygments==2.17.2
# via
# furo
@@ -239,10 +255,13 @@ pyparsing==3.1.1
# via matplotlib
pytest==8.0.0
# via
+ # inseq (pyproject.toml)
# pytest-cov
# pytest-xdist
pytest-cov==4.1.0
+ # via inseq (pyproject.toml)
pytest-xdist==3.5.0
+ # via inseq (pyproject.toml)
python-dateutil==2.8.2
# via
# jupyter-client
@@ -265,7 +284,7 @@ regex==2023.12.25
# via
# nltk
# transformers
-requests==2.31.0
+requests==2.32.3
# via
# datasets
# fsspec
@@ -275,6 +294,7 @@ requests==2.31.0
# transformers
rich==13.7.0
# via
+ # inseq (pyproject.toml)
# bandit
# safety
ruamel-yaml==0.18.6
@@ -284,12 +304,15 @@ ruamel-yaml==0.18.6
ruamel-yaml-clib==0.2.8
# via ruamel-yaml
ruff==0.2.1
+ # via inseq (pyproject.toml)
safetensors==0.4.2
# via transformers
safety==3.1.0
+ # via inseq (pyproject.toml)
safety-schemas==0.0.2
# via safety
scikit-learn==1.4.0
+ # via inseq (pyproject.toml)
scipy==1.12.0
# via scikit-learn
sentencepiece==0.1.99
@@ -310,6 +333,7 @@ soupsieve==2.5
# via beautifulsoup4
sphinx==7.2.6
# via
+ # inseq (pyproject.toml)
# furo
# sphinx-basic-ng
# sphinx-copybutton
@@ -320,8 +344,11 @@ sphinx==7.2.6
sphinx-basic-ng==1.0.0b2
# via furo
sphinx-copybutton==0.5.2
+ # via inseq (pyproject.toml)
sphinx-design==0.5.0
+ # via inseq (pyproject.toml)
sphinx-gitstamp==0.4.0
+ # via inseq (pyproject.toml)
sphinxcontrib-applehelp==1.0.8
# via sphinx
sphinxcontrib-devhelp==1.0.6
@@ -335,7 +362,9 @@ sphinxcontrib-qthelp==1.0.7
sphinxcontrib-serializinghtml==1.1.10
# via sphinx
sphinxemoji==0.3.1
+ # via inseq (pyproject.toml)
sphinxext-opengraph==0.9.1
+ # via inseq (pyproject.toml)
stack-data==0.6.3
# via ipython
stevedore==5.1.0
@@ -347,13 +376,16 @@ threadpoolctl==3.2.0
tokenizers==0.15.2
# via transformers
torch==2.2.0
- # via captum
+ # via
+ # inseq (pyproject.toml)
+ # captum
tornado==6.4
# via
# ipykernel
# jupyter-client
-tqdm==4.66.2
+tqdm==4.66.4
# via
+ # inseq (pyproject.toml)
# captum
# datasets
# huggingface-hub
@@ -369,8 +401,11 @@ traitlets==5.14.1
# jupyter-core
# matplotlib-inline
transformers==4.38.1
+ # via inseq (pyproject.toml)
typeguard==2.13.3
- # via jaxtyping
+ # via
+ # inseq (pyproject.toml)
+ # jaxtyping
typer==0.9.0
# via safety
typing-extensions==4.9.0
@@ -384,7 +419,7 @@ typing-extensions==4.9.0
# typer
tzdata==2024.1
# via pandas
-urllib3==2.2.0
+urllib3==2.2.2
# via
# requests
# safety
diff --git a/requirements.txt b/requirements.txt
index 05e5b4c7..9c021743 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,8 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml -o requirements.txt
captum==0.7.0
-certifi==2024.2.2
+ # via inseq (pyproject.toml)
+certifi==2024.6.2
# via requests
charset-normalizer==3.3.2
# via requests
@@ -27,7 +28,8 @@ huggingface-hub==0.20.3
idna==3.7
# via requests
jaxtyping==0.2.25
-jinja2==3.1.3
+ # via inseq (pyproject.toml)
+jinja2==3.1.4
# via torch
kiwisolver==1.4.5
# via matplotlib
@@ -36,7 +38,9 @@ markdown-it-py==3.0.0
markupsafe==2.1.5
# via jinja2
matplotlib==3.8.2
- # via captum
+ # via
+ # inseq (pyproject.toml)
+ # captum
mdurl==0.1.2
# via markdown-it-py
mpmath==1.3.0
@@ -45,6 +49,7 @@ networkx==3.2.1
# via torch
numpy==1.26.4
# via
+ # inseq (pyproject.toml)
# captum
# contourpy
# jaxtyping
@@ -55,10 +60,12 @@ packaging==23.2
# huggingface-hub
# matplotlib
# transformers
-pillow==10.3.0
+pillow==10.4.0
# via matplotlib
protobuf==4.25.2
- # via transformers
+ # via
+ # inseq (pyproject.toml)
+ # transformers
pygments==2.17.2
# via rich
pyparsing==3.1.1
@@ -71,11 +78,12 @@ pyyaml==6.0.1
# transformers
regex==2023.12.25
# via transformers
-requests==2.31.0
+requests==2.32.3
# via
# huggingface-hub
# transformers
rich==13.7.0
+ # via inseq (pyproject.toml)
safetensors==0.4.2
# via transformers
sentencepiece==0.1.99
@@ -87,19 +95,25 @@ sympy==1.12
tokenizers==0.15.2
# via transformers
torch==2.2.0
- # via captum
-tqdm==4.66.2
# via
+ # inseq (pyproject.toml)
+ # captum
+tqdm==4.66.4
+ # via
+ # inseq (pyproject.toml)
# captum
# huggingface-hub
# transformers
transformers==4.38.1
+ # via inseq (pyproject.toml)
typeguard==2.13.3
- # via jaxtyping
+ # via
+ # inseq (pyproject.toml)
+ # jaxtyping
typing-extensions==4.9.0
# via
# huggingface-hub
# jaxtyping
# torch
-urllib3==2.2.0
+urllib3==2.2.2
# via requests
diff --git a/tests/data/test_aggregator.py b/tests/data/test_aggregator.py
index f7e7c3e5..14360c7f 100644
--- a/tests/data/test_aggregator.py
+++ b/tests/data/test_aggregator.py
@@ -5,6 +5,7 @@
from pytest import fixture
import inseq
+from inseq.data import FeatureAttributionSequenceOutput
from inseq.data.aggregator import (
AggregatorPipeline,
ContiguousSpanAggregator,
@@ -149,6 +150,10 @@ def test_pair_aggregator(saliency_mt_model: HuggingfaceEncoderDecoderModel):
diff_seqattr_other = orig_seqattr_other.aggregate("pair", paired_attr=alt_seqattr_other)
assert torch.allclose(diff_seqattr_other.source_attributions, diff_seqattr.source_attributions)
+ # Aggregate with __sub__
+ diff_seqattr_sub = orig_seqattr - alt_seqattr
+ assert diff_seqattr_other == diff_seqattr_sub
+
def test_named_aggregate_fn_aggregation(saliency_mt_model: HuggingfaceEncoderDecoderModel):
out = saliency_mt_model.attribute(
@@ -192,3 +197,29 @@ def test_named_aggregate_fn_aggregation(saliency_mt_model: HuggingfaceEncoderDec
aggregator=["scores", "scores", "subwords"], aggregate_fn=["mean", "mean", None]
)
assert out_allmean_subwords == out_allmean_subwords_expanded
+
+
+def test_slice_aggregator_decoder_only(saliency_gpt_model: HuggingfaceDecoderOnlyModel):
+ out = saliency_gpt_model.attribute(
+ EXAMPLES["source_summary"], EXAMPLES["source_summary"] + EXAMPLES["gen_summary"], show_progress=False
+ )[0]
+ out_sliced: FeatureAttributionSequenceOutput = out.aggregate("slices", target_spans=(14, 75))
+ assert [t.token for t in out_sliced.source] == EXAMPLES["source_summary_tokens"]
+ assert [t.token for t in out_sliced.target] == EXAMPLES["source_summary_tokens"] + EXAMPLES["gen_summary_tokens"]
+
+ # Slice with __getitem__
+ out_sliced_getitem = out[14:75]
+ assert out_sliced == out_sliced_getitem
+
+
+def test_slice_aggregator_encoder_decoder(saliency_mt_model: HuggingfaceEncoderDecoderModel):
+ out = saliency_mt_model.attribute(
+ EXAMPLES["source"], EXAMPLES["target"], show_progress=False, attribute_target=True
+ )[0]
+ out_sliced: FeatureAttributionSequenceOutput = out.aggregate("slices", source_spans=(1, 4))
+ assert [t.token for t in out_sliced.source] == EXAMPLES["source_subwords"][1:4]
+ assert [t.token for t in out_sliced.target] == EXAMPLES["target_subwords"][1:]
+
+ # Slice with __getitem__
+ out_sliced_getitem = out[1:4]
+ assert out_sliced == out_sliced_getitem
diff --git a/tests/fixtures/aggregator.json b/tests/fixtures/aggregator.json
index 53123526..51d6b20d 100644
--- a/tests/fixtures/aggregator.json
+++ b/tests/fixtures/aggregator.json
@@ -93,5 +93,90 @@
"\u2581models",
".",
""
+ ],
+ "source_summary": "Instruction: Summarize this article.\nInput_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic.\nSummary:",
+ "gen_summary": " Elara discovers a shimmering box under an ancient tree, unlocking a world of magic.",
+ "source_summary_tokens": [
+ "\u0120In",
+ "\u0120a",
+ "\u0120quiet",
+ "\u0120village",
+ "\u0120nest",
+ "led",
+ "\u0120between",
+ "\u0120rolling",
+ "\u0120hills",
+ ",",
+ "\u0120an",
+ "\u0120ancient",
+ "\u0120tree",
+ "\u0120whispered",
+ "\u0120secrets",
+ "\u0120to",
+ "\u0120those",
+ "\u0120who",
+ "\u0120listened",
+ ".",
+ "\u0120One",
+ "\u0120night",
+ ",",
+ "\u0120a",
+ "\u0120curious",
+ "\u0120child",
+ "\u0120named",
+ "\u0120El",
+ "ara",
+ "\u0120leaned",
+ "\u0120close",
+ "\u0120and",
+ "\u0120heard",
+ "\u0120tales",
+ "\u0120of",
+ "\u0120hidden",
+ "\u0120treasures",
+ "\u0120beneath",
+ "\u0120the",
+ "\u0120roots",
+ ".",
+ "\u0120As",
+ "\u0120dawn",
+ "\u0120broke",
+ ",",
+ "\u0120she",
+ "\u0120unearthed",
+ "\u0120a",
+ "\u0120shimmer",
+ "ing",
+ "\u0120box",
+ ",",
+ "\u0120unlocking",
+ "\u0120a",
+ "\u0120forgotten",
+ "\u0120world",
+ "\u0120of",
+ "\u0120wonder",
+ "\u0120and",
+ "\u0120magic",
+ "."
+ ],
+ "gen_summary_tokens": [
+ "\u0120El",
+ "ara",
+ "\u0120discovers",
+ "\u0120a",
+ "\u0120shimmer",
+ "ing",
+ "\u0120box",
+ "\u0120under",
+ "\u0120an",
+ "\u0120ancient",
+ "\u0120tree",
+ ",",
+ "\u0120unlocking",
+ "\u0120a",
+ "\u0120world",
+ "\u0120of",
+ "\u0120magic",
+ "."
]
}
---|
---|