From 6d714a5cce3db5bd7f577bc447becc7a92d5ccc7 Mon Sep 17 00:00:00 2001 From: Vladimir Malinovskii Date: Tue, 6 Aug 2024 22:00:58 +0300 Subject: [PATCH] Embedding4bit and Embedding8bit implementation (#1292) * Embedding4bit and Embedding8bit implementation * lint * Update bitsandbytes/nn/modules.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update bitsandbytes/nn/modules.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update bitsandbytes/nn/modules.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * saving -> Saving --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> --- bitsandbytes/nn/__init__.py | 4 + bitsandbytes/nn/modules.py | 216 ++++++++++++++++++++++++++++++++++-- tests/test_modules.py | 148 +++++++++++++++++++++++- 3 files changed, 355 insertions(+), 13 deletions(-) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 96f4359bf..20aff67a3 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -4,6 +4,10 @@ # LICENSE file in the root directory of this source tree. from .modules import ( Embedding, + Embedding4bit, + Embedding8bit, + EmbeddingFP4, + EmbeddingNF4, Int8Params, Linear4bit, Linear8bitLt, diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index f113b3648..6c78494aa 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -347,6 +347,23 @@ def to(self, *args, **kwargs): return new_param +def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]): + if getattr(module.weight, "quant_state", None) is not None: + return + + if getattr(module, "quant_state", None) is None: + warnings.warn( + "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", + ) + + # the quant state got lost when the parameter got converted. This happens for example for fsdp + # since we registered the module, we can recover the state here + assert module.weight.shape[1] == 1 + if not isinstance(module.weight, Params4bit): + module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True) + module.weight.quant_state = module.quant_state + + class Linear4bit(nn.Linear): """ This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314). @@ -449,22 +466,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): destination[prefix + "weight." + k] = v if keep_vars else v.detach() def forward(self, x: torch.Tensor): + fix_4bit_weight_quant_state_from_module(self) + # weights are cast automatically as Int8Params, but the bias has to be cast manually if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) - if getattr(self.weight, "quant_state", None) is None: - if getattr(self, "quant_state", None) is not None: - # the quant state got lost when the parameter got converted. This happens for example for fsdp - # since we registered the module, we can recover the state here - assert self.weight.shape[1] == 1 - if not isinstance(self.weight, Params4bit): - self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True) - self.weight.quant_state = self.quant_state - else: - print( - "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.", - ) if not self.compute_type_is_set: self.set_compute_type(x) self.compute_type_is_set = True @@ -658,6 +665,191 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) +class Embedding8bit(nn.Embedding): + """ + This class implements [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm for embedding layer + + Quantization API is similar to Linear8bitLt: + ```python + import torch + import torch.nn as nn + + from bitsandbytes.nn import Embedding8bit + + fp16_module = nn.Embedding(128, 64) + int8_module = Embedding8bit(128, 64) + + int8_module.load_state_dict(fp16_module.state_dict()) + + int8_module = int8_module.to(0) # Quantization happens here + ``` + """ + + def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None): + super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype) + self.dtype = self.weight.data.dtype + + self.weight = Int8Params(self.weight.data, has_fp16_weights=False, requires_grad=False) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + raise NotImplementedError("Saving Embedding8bit module is not implemented") + + def forward(self, input: Tensor) -> Tensor: + if not hasattr(self.weight, "SCB"): + raise RuntimeError("Embedding layer is not quantized. Please call .cuda() or .to(device) first.") + + rows = self.weight.data + row_stats = self.weight.SCB + + assert rows.shape == (self.num_embeddings, self.embedding_dim) + assert row_stats.shape == (self.num_embeddings,) + + compressed_output = F.embedding(input, rows) + compressed_output_stats = F.embedding(input, row_stats.view(self.num_embeddings, 1)) + + output = compressed_output * (compressed_output_stats / 127.0) + + return output.to(self.dtype) + + +class Embedding4bit(nn.Embedding): + """ + This is the base class similar to Linear4bit. It implements the 4-bit quantization algorithm presented in + [QLoRA](https://arxiv.org/abs/2305.14314) for embeddings. + + Quantization API is similar to Linear4bit: + ```python + import torch + import torch.nn as nn + + from bitsandbytes.nn import Embedding4bit + + fp16_module = nn.Embedding(128, 64) + quantized_module = Embedding4bit(128, 64) + + quantized_module.load_state_dict(fp16_module.state_dict()) + + quantized_module = quantized_module.to(0) # Quantization happens here + ``` + """ + + def __init__( + self, + num_embeddings, + embedding_dim, + dtype=None, + quant_type="fp4", + quant_storage=torch.uint8, + device=None, + ): + super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype) + self.dtype = self.weight.data.dtype + + self.weight = Params4bit( + self.weight.data, + requires_grad=False, + compress_statistics=None, + quant_type=quant_type, + quant_storage=quant_storage, + module=self, + ) + + blocksize = self.weight.blocksize + + if embedding_dim % blocksize != 0: + warnings.warn( + f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. " + "This will lead to slow inference.", + ) + + def _forward_with_partial_dequantize(self, input: Tensor): + assert self.embedding_dim % self.weight.quant_state.blocksize == 0 + + w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1) + + output_4bit = torch.nn.functional.embedding( + weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2), + input=input, + ).view(-1, 1) + assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1) + + blocks_per_emb = self.embedding_dim // self.weight.blocksize + + absmax = self.weight.quant_state.absmax + assert absmax.shape == (self.num_embeddings * blocks_per_emb,) + + output_absmax = torch.nn.functional.embedding( + weight=absmax.view(self.num_embeddings, blocks_per_emb), + input=input, + ).view( + -1, + ) + assert output_absmax.shape == (input.numel() * blocks_per_emb,) + + output_quant_state = copy.deepcopy(self.weight.quant_state) + output_quant_state.absmax = output_absmax + output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim)) + + output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state) + assert output.shape == (*input.shape, self.embedding_dim) + + return output.to(self.dtype) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + raise NotImplementedError("Saving Embedding4bit module is not implemented") + + def forward(self, input: Tensor) -> Tensor: + fix_4bit_weight_quant_state_from_module(self) + + if self.embedding_dim % self.weight.quant_state.blocksize == 0: + return self._forward_with_partial_dequantize(input) + + dequantized_weight = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + + return torch.nn.functional.embedding( + weight=dequantized_weight, + input=input, + ).to(self.dtype) + + +class EmbeddingFP4(Embedding4bit): + def __init__( + self, + num_embeddings, + embedding_dim, + dtype=None, + quant_storage=torch.uint8, + device=None, + ): + super().__init__( + num_embeddings, + embedding_dim, + dtype=dtype, + quant_type="fp4", + quant_storage=quant_storage, + device=device, + ) + + +class EmbeddingNF4(Embedding4bit): + def __init__( + self, + num_embeddings, + embedding_dim, + dtype=None, + quant_storage=torch.uint8, + device=None, + ): + super().__init__( + num_embeddings, + embedding_dim, + dtype=dtype, + quant_type="nf4", + quant_storage=quant_storage, + device=device, + ) + + class Linear8bitLt(nn.Linear): """ This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm. diff --git a/tests/test_modules.py b/tests/test_modules.py index 9d507c6b4..2176f1d48 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,3 +1,4 @@ +import inspect import math import einops @@ -616,7 +617,97 @@ def test_fp8linear(): assert bgraderr < 0.00002 -def test_4bit_warnings(requires_cuda): +@pytest.mark.parametrize("embedding_dim", [64, 65]) +@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str) +@pytest.mark.parametrize( + "embedding_class,quant_storage", + [ + (bnb.nn.Embedding8bit, None), + (bnb.nn.EmbeddingFP4, torch.uint8), + (bnb.nn.EmbeddingFP4, torch.float32), + (bnb.nn.EmbeddingNF4, torch.uint8), + (bnb.nn.EmbeddingNF4, torch.float32), + ], + ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), +) +def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage): + num_embeddings = 128 + + src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to( + torch.float32 + ) * 2 - 1 # Embeddings filled with {-1, 1} values. It should compress losslessly + + emb_base = nn.Embedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + _freeze=True, + _weight=src_weight, + ) + if embedding_class is bnb.nn.Embedding8bit: + e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + else: + e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage) + + e.load_state_dict(emb_base.state_dict()) + + emb_base.cuda() + e.cuda() + + input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda") + + torch.testing.assert_close( + actual=e(input_tokens), + expected=emb_base(input_tokens), + ) + + +@pytest.mark.parametrize("embedding_dim", [64, 65]) +@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str) +@pytest.mark.parametrize( + "embedding_class,quant_storage", + [ + (bnb.nn.Embedding8bit, None), + (bnb.nn.EmbeddingFP4, torch.uint8), + (bnb.nn.EmbeddingFP4, torch.float32), + (bnb.nn.EmbeddingNF4, torch.uint8), + (bnb.nn.EmbeddingNF4, torch.float32), + ], + ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), +) +def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_storage): + is_8bit = embedding_class is bnb.nn.Embedding8bit + + num_embeddings = 128 + + src_weight = torch.rand((num_embeddings, embedding_dim), dtype=torch.float32) + + emb_base = nn.Embedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + _freeze=True, + _weight=src_weight, + ) + if is_8bit: + e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + else: + e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage) + + e.load_state_dict(emb_base.state_dict()) + + emb_base.cuda() + e.cuda() + + input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda") + + torch.testing.assert_close( + actual=e(input_tokens), + expected=emb_base(input_tokens), + atol=0.05 if is_8bit else 0.20, + rtol=0.0, + ) + + +def test_4bit_linear_warnings(): dim1 = 64 with pytest.warns(UserWarning, match=r"inference or training"): @@ -642,3 +733,58 @@ def test_4bit_warnings(requires_cuda): net(inp) assert len(record) == 2 + + +def test_4bit_embedding_warnings(): + num_embeddings = 128 + default_block_size = 64 + + with pytest.warns(UserWarning, match=r"inference."): + net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1) + net.cuda() + inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda") + net(inp) + + +def test_4bit_embedding_weight_fsdp_fix(): + num_embeddings = 64 + embedding_dim = 32 + + module = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + + module.cuda() + + module.weight.quant_state = None + + input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda") + + module(input_tokens) + + assert module.weight.quant_state is not None + + +def test_4bit_linear_weight_fsdp_fix(): + inp_size = 64 + out_size = 32 + + module = bnb.nn.Linear4bit(inp_size, out_size) + + module.cuda() + + module.weight.quant_state = None + + input_tensor = torch.randn((1, inp_size), device="cuda") + + module(input_tensor) + + assert module.weight.quant_state is not None + + +def test_embedding_not_implemented_error(): + with pytest.raises(NotImplementedError): + emb = bnb.nn.Embedding4bit(32, 32) + emb.state_dict() + + with pytest.raises(NotImplementedError): + emb = bnb.nn.Embedding8bit(32, 32) + emb.state_dict()