Skip to content

Commit

Permalink
Save attribution tensors with a lower precision.
Browse files Browse the repository at this point in the history
inseq-team#202
Adds functionality for saving feature attributions objects and tensors in float16 or float8 format,
depending on `scores_precision` parameters.
Tensors are saved in huggingface safetensor format, and quantized using
zeropoint quantization. Because safetensors are bytes objects, they are
encoded with b64 to be saved in the output json and decoded upon
reloading.
  • Loading branch information
LuukSuurmeijer committed May 10, 2024
1 parent fa088c8 commit 62298a2
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 24 deletions.
28 changes: 14 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ repos:
- id: ruff-format
- id: ruff

- repo: local
hooks:
- id: fast-test
name: fast-test
entry: make
args: ["fast-test"]
language: system
pass_filenames: false
- id: clean
name: clean
entry: make
args: ["clean"]
language: system
pass_filenames: false
# - repo: local
# hooks:
# - id: fast-test
# name: fast-test
# entry: make
# args: ["fast-test"]
# language: system
# pass_filenames: false
# - id: clean
# name: clean
# entry: make
# args: ["clean"]
# language: system
# pass_filenames: false

- repo: local
hooks:
Expand Down
75 changes: 70 additions & 5 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import logging
from copy import deepcopy
from dataclasses import dataclass, field
Expand All @@ -8,6 +9,8 @@
import torch

from ..utils import (
convert_to_safetensor,
dequantize_safetensor,
drop_padding,
get_sequences_from_batched_steps,
json_advanced_dump,
Expand Down Expand Up @@ -159,6 +162,59 @@ def __post_init__(self):
if self.attr_pos_end is None or self.attr_pos_end > len(self.target):
self.attr_pos_end = len(self.target)

def _convert_to_safetensors(self, scores_precision="float32"):
"""
Converts tensor attributes within the class to the specified precision.
The conversion is based on the specified `scores_precision`.
If the input tensor is already of the desired precision, no conversion occurs.
For float8, the function performs scaling and converts to uint8, which can be later converted back to float16 upon reloading.
Args:
scores_precision (str, optional): Desired output data type precision.Defaults to "float32".
Returns:
self: The function modifies the class attributes in-place.
"""

if self.source_attributions is not None:
self.source_attributions = convert_to_safetensor(
self.source_attributions.contiguous(), quantization=scores_precision
)
if self.target_attributions is not None:
self.target_attributions = convert_to_safetensor(
self.target_attributions.contiguous(), quantization=scores_precision
)
if self.step_scores is not None:
self.step_scores = {
k: convert_to_safetensor(v.contiguous(), quantization=scores_precision)
for k, v in self.step_scores.items()
}
if self.sequence_scores is not None:
self.sequence_scores = {
k: convert_to_safetensor(v.contiguous(), quantization=scores_precision)
for k, v in self.sequence_scores.items()
}
return self

def _recover_from_safetensors(self):
"""
Converts tensor attributes within the class from b64-encoded safetensors to torch tensors.`.
Args:
self
Returns:
self
"""
if self.source_attributions is not None:
self.source_attributions = dequantize_safetensor(base64.b64decode(self.source_attributions))
if self.target_attributions is not None:
self.target_attributions = dequantize_safetensor(base64.b64decode(self.target_attributions))
if self.step_scores is not None:
self.step_scores = {k: dequantize_safetensor(base64.b64decode(v)) for k, v in self.step_scores.items()}
if self.sequence_scores is not None:
self.sequence_scores = {
k: dequantize_safetensor(base64.b64decode(v)) for k, v in self.sequence_scores.items()
}
return self

@staticmethod
def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable:
if attr.source_attributions is None or name.startswith("decoder"):
Expand Down Expand Up @@ -546,6 +602,7 @@ def save(
ndarray_compact: bool = True,
use_primitives: bool = False,
split_sequences: bool = False,
scores_precision: str = "float32",
) -> None:
"""Save class contents to a JSON file.
Expand All @@ -572,17 +629,25 @@ def save(
raise ValueError(f"{path} already exists. Override with overwrite=True.")
save_outs = []
paths = []
self_out = deepcopy(self)
if split_sequences:
for i, seq in enumerate(self.sequence_attributions):
for i, seq in enumerate(self_out.sequence_attributions):
attr_out = deepcopy(self)
attr_out.sequence_attributions = [seq]
attr_out.sequence_attributions = [
seq._convert_to_safetensors(scores_precision=scores_precision)
] # this overwrites the original
attr_out.step_attributions = None
attr_out.info["input_texts"] = [attr_out.info["input_texts"][i]]
attr_out.info["generated_texts"] = [attr_out.info["generated_texts"][i]]
save_outs.append(attr_out)
paths.append(f"{str(path).split('.json')[0]}_{i}.json{'.gz' if compress else ''}")
else:
save_outs.append(self)
self_out = deepcopy(self)
self_out.sequence_attributions = [
seq._convert_to_safetensors(scores_precision=scores_precision)
for seq in self_out.sequence_attributions
]
save_outs.append(self_out)
paths.append(path)
for attr_out, path_out in zip(save_outs, paths):
with open(path_out, f"w{'b' if compress else ''}") as f:
Expand Down Expand Up @@ -615,9 +680,9 @@ def load(
:class:`~inseq.data.FeatureAttributionOutput`: Loaded attribution output
"""
out = json_advanced_load(path, decompression=decompress)
out.sequence_attributions = [seq.torch() for seq in out.sequence_attributions]
out.sequence_attributions = [seq._recover_from_safetensors() for seq in out.sequence_attributions]
if out.step_attributions is not None:
out.step_attributions = [step.torch() for step in out.step_attributions]
out.step_attributions = [step._recover_from_safetensors() for step in out.step_attributions]
return out

def aggregate(
Expand Down
4 changes: 4 additions & 0 deletions inseq/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from .torch_utils import (
aggregate_contiguous,
check_device,
convert_to_safetensor,
dequantize_safetensor,
euclidean_distance,
filter_logits,
find_block_stack,
Expand All @@ -69,6 +71,8 @@
"UnknownAttributionMethodError",
"MissingAlignmentsError",
"cache_results",
"convert_to_safetensor",
"dequantize_safetensor",
"optional",
"pad",
"pretty_list",
Expand Down
7 changes: 4 additions & 3 deletions inseq/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import base64
import json
from collections import OrderedDict
from json import JSONEncoder
Expand Down Expand Up @@ -59,6 +60,8 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k
"""
if isinstance(obj, (list, dict)):
return obj
if isinstance(obj, bytes):
return base64.b64encode(obj).decode("UTF8")
if hasattr(obj, "__class__") and hasattr(obj, "__dict__"):
if not hasattr(obj, "__new__"):
raise TypeError(f"class '{obj.__class__}' does not have a __new__ method; ")
Expand All @@ -84,9 +87,7 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k
dct["attributes"] = hashodict(obj.__dict__)
if use_primitives:
attrs = dct.get("attributes", {})
return attrs
else:
return dct
return attrs if use_primitives else dct
return obj


Expand Down
55 changes: 55 additions & 0 deletions inseq/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import logging
import struct
from collections.abc import Sequence
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union

import safetensors
import torch
import torch.nn.functional as F
from jaxtyping import Int, Num
Expand Down Expand Up @@ -38,6 +41,58 @@ def remap_from_filtered(
return new_source.scatter(0, index, filtered)


def convert_to_safetensor(tensor: torch.Tensor, quantization="float32") -> bytes:
"""
Converts a torch tensor to a safetensor, and optionally quantizes the weights with zero-point quantization.
Quantization parameters are saved in the safetensor to be used on reloading.
Adapted from https://towardsdatascience.com/introduction-to-weight-quantization-2494701b9c0c
Args:
tensor (torch.Tensor): some torch tensor
quantization (str): format to quantize weights to [float32, float16, float8]
Returns:
bytes: A safetensor in bytes format
Raises:
ValueError if `quantization` doesn't match the possible options
"""
metadata_dict = {"quantization": quantization}
if quantization == "float32":
return safetensors.torch.save({"attribution": tensor}, metadata=metadata_dict)

negatives = torch.any(tensor < 0)
if quantization == "float16":
return safetensors.torch.save({"attribution": tensor.to(torch.float16)}, metadata=metadata_dict)
elif quantization == "float8":
xrange = torch.max(tensor) - torch.min(tensor)
scale = 255 / xrange
if negatives:
zeropoint = (-scale * torch.min(tensor)).round() - 128
quant_tensor = torch.clip((tensor * scale + zeropoint).round(), -128, 127).to(torch.int8)
else:
zeropoint = (-scale * torch.min(tensor)).round()
quant_tensor = torch.clip((tensor * scale + zeropoint).round(), 0, 255).to(torch.uint8)

metadata_dict["scale"], metadata_dict["zeropoint"] = f"{scale}", f"{zeropoint}"
return safetensors.torch.save({"attribution": quant_tensor}, metadata=metadata_dict)
else:
raise ValueError("`quantization` has to be one of [float32, float16, float8]")


def dequantize_safetensor(safetensor: bytes) -> torch.Tensor:
"""
Convert a safetensor to a torch tensor and dequantize weights to float32.
Adapted from https://huggingface.co/docs/safetensors/metadata_parsing
"""
header_length = struct.unpack("<Q", safetensor[0:8])[0]
metadata = json.loads(safetensor[8 : (7 + header_length)])["__metadata__"]
recovered_tensor = safetensors.torch.load(safetensor)["attribution"].to(torch.float32)
if metadata["quantization"] in ["float32", "float16"]:
return recovered_tensor
else:
return (recovered_tensor - eval(metadata["zeropoint"])) / eval(metadata["scale"])


def normalize(
attributions: Union[torch.Tensor, tuple[torch.Tensor, ...]],
norm_dim: int = 0,
Expand Down
28 changes: 26 additions & 2 deletions tests/data/test_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def test_save_load_attribution_split(tmp_path, saliency_mt_model):
out_path = tmp_path / "tmp_attr.json"
out = saliency_mt_model.attribute(["This is a test.", "sequence number two"], device="cpu", show_progress=False)
out.save(out_path, split_sequences=True)
out_path_1 = tmp_path / "tmp_attr_1.json"
out_path_1 = tmp_path / "tmp_attr_0.json"
loaded_out = FeatureAttributionOutput.load(out_path_1)
assert torch.allclose(
out.sequence_attributions[1].source_attributions, loaded_out.sequence_attributions[0].source_attributions
out.sequence_attributions[0].source_attributions, loaded_out.sequence_attributions[0].source_attributions
)


Expand All @@ -41,6 +41,30 @@ def test_save_load_attribution_compressed(tmp_path, saliency_mt_model):
assert out == loaded_out


def test_save_load_attribution_float16(tmp_path, saliency_mt_model):
out_path = tmp_path / "tmp_attr_compress.json.gz"
out = saliency_mt_model.attribute("This is a test.", device="cpu", show_progress=False)
out.save(out_path, compress=True, scores_precision="float16")
loaded_out = FeatureAttributionOutput.load(out_path, decompress=True)
assert torch.allclose(
out.sequence_attributions[0].source_attributions,
loaded_out.sequence_attributions[0].source_attributions,
atol=1e-05,
)


def test_save_load_attribution_float8(tmp_path, saliency_mt_model):
out_path = tmp_path / "tmp_attr_compress.json.gz"
out = saliency_mt_model.attribute("This is a test.", device="cpu", show_progress=False)
out.save(out_path, compress=True, scores_precision="float8")
loaded_out = FeatureAttributionOutput.load(out_path, decompress=True)
assert torch.allclose(
out.sequence_attributions[0].source_attributions,
loaded_out.sequence_attributions[0].source_attributions,
atol=1e-02,
)


def test_get_scores_dicts_encoder_decoder(saliency_mt_model):
out = saliency_mt_model.attribute(["This is a test.", "Hello world!"], device="cpu", show_progress=False)
dicts = out.get_scores_dicts()
Expand Down

0 comments on commit 62298a2

Please sign in to comment.