Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save tensors in lower precision #273

Merged
merged 16 commits into from
Jul 23, 2024

Conversation

LuukSuurmeijer
Copy link
Contributor

Description

Added support for saving attributions in a lower tensor precision.

Upon saving, tensors are transformed to hugginface safetensors. Then they are optionally quantized to float16, int8 or uint8 (if there are no negative values) using zeropoint quantization. The quantization parameters are stored in the safetensor object to recover the float32 values upon loading. Safetensors are bytes objects, so they need to be base64 encoded to be written to JSON.

List of changes:

  • save has an extra parameter scores_precision with default value float32
  • FeatureAttributionSequenceOutput has two new private methods: _convert_to_safetensors and _recover_from_safetensors in order to convert the object's tensors from torch tensors to safetensors and viceversa. They are used in saving / loading respectively.
  • Two util functions in torch_utils, convert_to_safetensor and dequantize_safetensor that converts a tensor both ways respectively
  • Two new unit tests for saving and loading in float16 /float8 in test_attribution.py

This is my first PR on this project and first time properly diving into inseq, so please be critical and help me improve the feature! There are several points where I am not sure about the implementation:

  • Have to deepcopy the objects while saving, is that really necessary?
  • Is it a good idea to introduce new private methods on FeatureAttributionSequenceOutput
  • Does the output JSON preserve enough readibility?
  • Should I include unit tests for the new torch_utils functions? I saw that most of them do not have unit tests, but am happy to add them

All tests run clean with no errors.

Related Issue

issue 202

Type of Change

  • 🥂 Improvement (non-breaking change which improves an existing feature)
  • 🚀 New feature (non-breaking change which adds functionality)

Checklist

  • I've read the CODE_OF_CONDUCT.md document.
  • I've read the CONTRIBUTING.md guide.
  • I've successfully run the style checks using make fix-style.
  • I've written tests for all new methods and classes that I created and successfully ran make test.
  • I've written the docstring in Google format for all the methods and classes that I used.

LuukSuurmeijer and others added 8 commits May 10, 2024 14:16
inseq-team#202
Adds functionality for saving feature attributions objects and tensors in float16 or float8 format,
depending on `scores_precision` parameters.
Tensors are saved in huggingface safetensor format, and quantized using
zeropoint quantization. Because safetensors are bytes objects, they are
encoded with b64 to be saved in the output json and decoded upon
reloading.
* Add device_map support

* Fix device setter in HF model
@LuukSuurmeijer LuukSuurmeijer marked this pull request as draft May 10, 2024 13:27
@LuukSuurmeijer LuukSuurmeijer marked this pull request as ready for review May 10, 2024 13:30
@gsarti
Copy link
Member

gsarti commented May 11, 2024

Hey @LuukSuurmeijer, thanks a lot for this PR!

I had a look and added some very minor fixes (add a Literal type for the allowed precision strings, added a docstring for the new parameter in save). I also made sure the code works fine when compress=False, but a different precision is specified. In one of my tests, however, I had a weird issue. If you run the following code:

import torch
from inseq import load_model, FeatureAttributionOutput

saliency_mt_model = load_model("Helsinki-NLP/opus-mt-en-it", "attention")

out_path = "tmp_attr_8bit.json"
out = saliency_mt_model.attribute("This is a test.", device="cpu", show_progress=False)
out.save(out_path, scores_precision="float8", overwrite=True)
loaded_out = FeatureAttributionOutput.load(out_path)
assert torch.allclose(
        out.sequence_attributions[0].source_attributions,
        loaded_out.sequence_attributions[0].source_attributions,
        atol=1e-02,
)

You get an error in the parsing of the JSON metadata header. From a very quick exploration, it seems like this is caused by the selection of the header json.loads(safetensor[8 : (7 + header_length)])["__metadata__"] in dequantize_safetensor, which cuts the json 1 character too short. This is puzzling because for other precisions the same code works fine, and I confirm that it matches the one from the Hugging Face example you referred to. If we do not find out what's the issue, we might want to set up some error handling to either 1) brute-force extraction character-by-character until a valid JSON is formed (not ideal) or at least 2) Raise an informative error about the problem, mentioning to try another precision.

@LuukSuurmeijer
Copy link
Contributor Author

The json decode error seemed to be a one-off error with quantizing to 8bit. I managed to reproduce the error even without overwrite=True. Increasing the decoding range by 1 seemed to fix it, and did not cause any issues with decoding using float16 (all tests still run clean). I guess that fixed the issue, although I don't know the underlying cause. Do you think it's ready for merging now?

@gsarti
Copy link
Member

gsarti commented Jul 23, 2024

Hi @LuukSuurmeijer,

The code still had some issues due to FP8 conversion not handling nans by default (target-side attribution matrices contain nan for future tokens). Since torch now supports two experimental fp8 formats aimed at maximizing the expressivity of quantized tensors as described in this paper, I decided to remove the manual quantization code and simply use torch.float8_e4m3fn as a drop-in replacement for FP8 conversion.

@gsarti gsarti merged commit 979d223 into inseq-team:main Jul 23, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants