From 62298a2ef3bc1d4a0004ff10edebd714569f49cd Mon Sep 17 00:00:00 2001 From: Luuk Suurmeijer Date: Fri, 10 May 2024 12:43:17 +0200 Subject: [PATCH] Save attribution tensors with a lower precision. https://github.com/inseq-team/inseq/issues/202 Adds functionality for saving feature attributions objects and tensors in float16 or float8 format, depending on `scores_precision` parameters. Tensors are saved in huggingface safetensor format, and quantized using zeropoint quantization. Because safetensors are bytes objects, they are encoded with b64 to be saved in the output json and decoded upon reloading. --- .pre-commit-config.yaml | 28 ++++++------- inseq/data/attribution.py | 75 +++++++++++++++++++++++++++++++--- inseq/utils/__init__.py | 4 ++ inseq/utils/serialization.py | 7 ++-- inseq/utils/torch_utils.py | 55 +++++++++++++++++++++++++ tests/data/test_attribution.py | 28 ++++++++++++- 6 files changed, 173 insertions(+), 24 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd5d07ba..03479203 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,20 +24,20 @@ repos: - id: ruff-format - id: ruff - - repo: local - hooks: - - id: fast-test - name: fast-test - entry: make - args: ["fast-test"] - language: system - pass_filenames: false - - id: clean - name: clean - entry: make - args: ["clean"] - language: system - pass_filenames: false + # - repo: local + # hooks: + # - id: fast-test + # name: fast-test + # entry: make + # args: ["fast-test"] + # language: system + # pass_filenames: false + # - id: clean + # name: clean + # entry: make + # args: ["clean"] + # language: system + # pass_filenames: false - repo: local hooks: diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index 7841cf7c..5101952d 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -1,3 +1,4 @@ +import base64 import logging from copy import deepcopy from dataclasses import dataclass, field @@ -8,6 +9,8 @@ import torch from ..utils import ( + convert_to_safetensor, + dequantize_safetensor, drop_padding, get_sequences_from_batched_steps, json_advanced_dump, @@ -159,6 +162,59 @@ def __post_init__(self): if self.attr_pos_end is None or self.attr_pos_end > len(self.target): self.attr_pos_end = len(self.target) + def _convert_to_safetensors(self, scores_precision="float32"): + """ + Converts tensor attributes within the class to the specified precision. + The conversion is based on the specified `scores_precision`. + If the input tensor is already of the desired precision, no conversion occurs. + For float8, the function performs scaling and converts to uint8, which can be later converted back to float16 upon reloading. + + Args: + scores_precision (str, optional): Desired output data type precision.Defaults to "float32". + Returns: + self: The function modifies the class attributes in-place. + """ + + if self.source_attributions is not None: + self.source_attributions = convert_to_safetensor( + self.source_attributions.contiguous(), quantization=scores_precision + ) + if self.target_attributions is not None: + self.target_attributions = convert_to_safetensor( + self.target_attributions.contiguous(), quantization=scores_precision + ) + if self.step_scores is not None: + self.step_scores = { + k: convert_to_safetensor(v.contiguous(), quantization=scores_precision) + for k, v in self.step_scores.items() + } + if self.sequence_scores is not None: + self.sequence_scores = { + k: convert_to_safetensor(v.contiguous(), quantization=scores_precision) + for k, v in self.sequence_scores.items() + } + return self + + def _recover_from_safetensors(self): + """ + Converts tensor attributes within the class from b64-encoded safetensors to torch tensors.`. + Args: + self + Returns: + self + """ + if self.source_attributions is not None: + self.source_attributions = dequantize_safetensor(base64.b64decode(self.source_attributions)) + if self.target_attributions is not None: + self.target_attributions = dequantize_safetensor(base64.b64decode(self.target_attributions)) + if self.step_scores is not None: + self.step_scores = {k: dequantize_safetensor(base64.b64decode(v)) for k, v in self.step_scores.items()} + if self.sequence_scores is not None: + self.sequence_scores = { + k: dequantize_safetensor(base64.b64decode(v)) for k, v in self.sequence_scores.items() + } + return self + @staticmethod def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable: if attr.source_attributions is None or name.startswith("decoder"): @@ -546,6 +602,7 @@ def save( ndarray_compact: bool = True, use_primitives: bool = False, split_sequences: bool = False, + scores_precision: str = "float32", ) -> None: """Save class contents to a JSON file. @@ -572,17 +629,25 @@ def save( raise ValueError(f"{path} already exists. Override with overwrite=True.") save_outs = [] paths = [] + self_out = deepcopy(self) if split_sequences: - for i, seq in enumerate(self.sequence_attributions): + for i, seq in enumerate(self_out.sequence_attributions): attr_out = deepcopy(self) - attr_out.sequence_attributions = [seq] + attr_out.sequence_attributions = [ + seq._convert_to_safetensors(scores_precision=scores_precision) + ] # this overwrites the original attr_out.step_attributions = None attr_out.info["input_texts"] = [attr_out.info["input_texts"][i]] attr_out.info["generated_texts"] = [attr_out.info["generated_texts"][i]] save_outs.append(attr_out) paths.append(f"{str(path).split('.json')[0]}_{i}.json{'.gz' if compress else ''}") else: - save_outs.append(self) + self_out = deepcopy(self) + self_out.sequence_attributions = [ + seq._convert_to_safetensors(scores_precision=scores_precision) + for seq in self_out.sequence_attributions + ] + save_outs.append(self_out) paths.append(path) for attr_out, path_out in zip(save_outs, paths): with open(path_out, f"w{'b' if compress else ''}") as f: @@ -615,9 +680,9 @@ def load( :class:`~inseq.data.FeatureAttributionOutput`: Loaded attribution output """ out = json_advanced_load(path, decompression=decompress) - out.sequence_attributions = [seq.torch() for seq in out.sequence_attributions] + out.sequence_attributions = [seq._recover_from_safetensors() for seq in out.sequence_attributions] if out.step_attributions is not None: - out.step_attributions = [step.torch() for step in out.step_attributions] + out.step_attributions = [step._recover_from_safetensors() for step in out.step_attributions] return out def aggregate( diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index f632ba32..53c39c5e 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -49,6 +49,8 @@ from .torch_utils import ( aggregate_contiguous, check_device, + convert_to_safetensor, + dequantize_safetensor, euclidean_distance, filter_logits, find_block_stack, @@ -69,6 +71,8 @@ "UnknownAttributionMethodError", "MissingAlignmentsError", "cache_results", + "convert_to_safetensor", + "dequantize_safetensor", "optional", "pad", "pretty_list", diff --git a/inseq/utils/serialization.py b/inseq/utils/serialization.py index b45f8580..f7966191 100644 --- a/inseq/utils/serialization.py +++ b/inseq/utils/serialization.py @@ -29,6 +29,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import base64 import json from collections import OrderedDict from json import JSONEncoder @@ -59,6 +60,8 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k """ if isinstance(obj, (list, dict)): return obj + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("UTF8") if hasattr(obj, "__class__") and hasattr(obj, "__dict__"): if not hasattr(obj, "__new__"): raise TypeError(f"class '{obj.__class__}' does not have a __new__ method; ") @@ -84,9 +87,7 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k dct["attributes"] = hashodict(obj.__dict__) if use_primitives: attrs = dct.get("attributes", {}) - return attrs - else: - return dct + return attrs if use_primitives else dct return obj diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py index 86acd635..7d9fb9fc 100644 --- a/inseq/utils/torch_utils.py +++ b/inseq/utils/torch_utils.py @@ -1,7 +1,10 @@ +import json import logging +import struct from collections.abc import Sequence from typing import TYPE_CHECKING, Callable, Literal, Optional, Union +import safetensors import torch import torch.nn.functional as F from jaxtyping import Int, Num @@ -38,6 +41,58 @@ def remap_from_filtered( return new_source.scatter(0, index, filtered) +def convert_to_safetensor(tensor: torch.Tensor, quantization="float32") -> bytes: + """ + Converts a torch tensor to a safetensor, and optionally quantizes the weights with zero-point quantization. + Quantization parameters are saved in the safetensor to be used on reloading. + Adapted from https://towardsdatascience.com/introduction-to-weight-quantization-2494701b9c0c + + Args: + tensor (torch.Tensor): some torch tensor + quantization (str): format to quantize weights to [float32, float16, float8] + Returns: + bytes: A safetensor in bytes format + Raises: + ValueError if `quantization` doesn't match the possible options + + """ + metadata_dict = {"quantization": quantization} + if quantization == "float32": + return safetensors.torch.save({"attribution": tensor}, metadata=metadata_dict) + + negatives = torch.any(tensor < 0) + if quantization == "float16": + return safetensors.torch.save({"attribution": tensor.to(torch.float16)}, metadata=metadata_dict) + elif quantization == "float8": + xrange = torch.max(tensor) - torch.min(tensor) + scale = 255 / xrange + if negatives: + zeropoint = (-scale * torch.min(tensor)).round() - 128 + quant_tensor = torch.clip((tensor * scale + zeropoint).round(), -128, 127).to(torch.int8) + else: + zeropoint = (-scale * torch.min(tensor)).round() + quant_tensor = torch.clip((tensor * scale + zeropoint).round(), 0, 255).to(torch.uint8) + + metadata_dict["scale"], metadata_dict["zeropoint"] = f"{scale}", f"{zeropoint}" + return safetensors.torch.save({"attribution": quant_tensor}, metadata=metadata_dict) + else: + raise ValueError("`quantization` has to be one of [float32, float16, float8]") + + +def dequantize_safetensor(safetensor: bytes) -> torch.Tensor: + """ + Convert a safetensor to a torch tensor and dequantize weights to float32. + Adapted from https://huggingface.co/docs/safetensors/metadata_parsing + """ + header_length = struct.unpack("