Skip to content

Commit

Permalink
Embedding4bit and Embedding8bit implementation (#1292)
Browse files Browse the repository at this point in the history
* Embedding4bit and Embedding8bit implementation

* lint

* Update bitsandbytes/nn/modules.py

Co-authored-by: Matthew Douglas <[email protected]>

* Update bitsandbytes/nn/modules.py

Co-authored-by: Matthew Douglas <[email protected]>

* Update bitsandbytes/nn/modules.py

Co-authored-by: Matthew Douglas <[email protected]>

* saving -> Saving

---------

Co-authored-by: Matthew Douglas <[email protected]>
  • Loading branch information
galqiwi and matthewdouglas authored Aug 6, 2024
1 parent 4be1883 commit 6d714a5
Show file tree
Hide file tree
Showing 3 changed files with 355 additions and 13 deletions.
4 changes: 4 additions & 0 deletions bitsandbytes/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
# LICENSE file in the root directory of this source tree.
from .modules import (
Embedding,
Embedding4bit,
Embedding8bit,
EmbeddingFP4,
EmbeddingNF4,
Int8Params,
Linear4bit,
Linear8bitLt,
Expand Down
216 changes: 204 additions & 12 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,23 @@ def to(self, *args, **kwargs):
return new_param


def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None:
return

if getattr(module, "quant_state", None) is None:
warnings.warn(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)

# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert module.weight.shape[1] == 1
if not isinstance(module.weight, Params4bit):
module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True)
module.weight.quant_state = module.quant_state


class Linear4bit(nn.Linear):
"""
This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314).
Expand Down Expand Up @@ -449,22 +466,12 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
destination[prefix + "weight." + k] = v if keep_vars else v.detach()

def forward(self, x: torch.Tensor):
fix_4bit_weight_quant_state_from_module(self)

# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

if getattr(self.weight, "quant_state", None) is None:
if getattr(self, "quant_state", None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True)
self.weight.quant_state = self.quant_state
else:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
Expand Down Expand Up @@ -658,6 +665,191 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)


class Embedding8bit(nn.Embedding):
"""
This class implements [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm for embedding layer
Quantization API is similar to Linear8bitLt:
```python
import torch
import torch.nn as nn
from bitsandbytes.nn import Embedding8bit
fp16_module = nn.Embedding(128, 64)
int8_module = Embedding8bit(128, 64)
int8_module.load_state_dict(fp16_module.state_dict())
int8_module = int8_module.to(0) # Quantization happens here
```
"""

def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
self.dtype = self.weight.data.dtype

self.weight = Int8Params(self.weight.data, has_fp16_weights=False, requires_grad=False)

def _save_to_state_dict(self, destination, prefix, keep_vars):
raise NotImplementedError("Saving Embedding8bit module is not implemented")

def forward(self, input: Tensor) -> Tensor:
if not hasattr(self.weight, "SCB"):
raise RuntimeError("Embedding layer is not quantized. Please call .cuda() or .to(device) first.")

rows = self.weight.data
row_stats = self.weight.SCB

assert rows.shape == (self.num_embeddings, self.embedding_dim)
assert row_stats.shape == (self.num_embeddings,)

compressed_output = F.embedding(input, rows)
compressed_output_stats = F.embedding(input, row_stats.view(self.num_embeddings, 1))

output = compressed_output * (compressed_output_stats / 127.0)

return output.to(self.dtype)


class Embedding4bit(nn.Embedding):
"""
This is the base class similar to Linear4bit. It implements the 4-bit quantization algorithm presented in
[QLoRA](https://arxiv.org/abs/2305.14314) for embeddings.
Quantization API is similar to Linear4bit:
```python
import torch
import torch.nn as nn
from bitsandbytes.nn import Embedding4bit
fp16_module = nn.Embedding(128, 64)
quantized_module = Embedding4bit(128, 64)
quantized_module.load_state_dict(fp16_module.state_dict())
quantized_module = quantized_module.to(0) # Quantization happens here
```
"""

def __init__(
self,
num_embeddings,
embedding_dim,
dtype=None,
quant_type="fp4",
quant_storage=torch.uint8,
device=None,
):
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
self.dtype = self.weight.data.dtype

self.weight = Params4bit(
self.weight.data,
requires_grad=False,
compress_statistics=None,
quant_type=quant_type,
quant_storage=quant_storage,
module=self,
)

blocksize = self.weight.blocksize

if embedding_dim % blocksize != 0:
warnings.warn(
f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. "
"This will lead to slow inference.",
)

def _forward_with_partial_dequantize(self, input: Tensor):
assert self.embedding_dim % self.weight.quant_state.blocksize == 0

w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1)

output_4bit = torch.nn.functional.embedding(
weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2),
input=input,
).view(-1, 1)
assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1)

blocks_per_emb = self.embedding_dim // self.weight.blocksize

absmax = self.weight.quant_state.absmax
assert absmax.shape == (self.num_embeddings * blocks_per_emb,)

output_absmax = torch.nn.functional.embedding(
weight=absmax.view(self.num_embeddings, blocks_per_emb),
input=input,
).view(
-1,
)
assert output_absmax.shape == (input.numel() * blocks_per_emb,)

output_quant_state = copy.deepcopy(self.weight.quant_state)
output_quant_state.absmax = output_absmax
output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim))

output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state)
assert output.shape == (*input.shape, self.embedding_dim)

return output.to(self.dtype)

def _save_to_state_dict(self, destination, prefix, keep_vars):
raise NotImplementedError("Saving Embedding4bit module is not implemented")

def forward(self, input: Tensor) -> Tensor:
fix_4bit_weight_quant_state_from_module(self)

if self.embedding_dim % self.weight.quant_state.blocksize == 0:
return self._forward_with_partial_dequantize(input)

dequantized_weight = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state)

return torch.nn.functional.embedding(
weight=dequantized_weight,
input=input,
).to(self.dtype)


class EmbeddingFP4(Embedding4bit):
def __init__(
self,
num_embeddings,
embedding_dim,
dtype=None,
quant_storage=torch.uint8,
device=None,
):
super().__init__(
num_embeddings,
embedding_dim,
dtype=dtype,
quant_type="fp4",
quant_storage=quant_storage,
device=device,
)


class EmbeddingNF4(Embedding4bit):
def __init__(
self,
num_embeddings,
embedding_dim,
dtype=None,
quant_storage=torch.uint8,
device=None,
):
super().__init__(
num_embeddings,
embedding_dim,
dtype=dtype,
quant_type="nf4",
quant_storage=quant_storage,
device=device,
)


class Linear8bitLt(nn.Linear):
"""
This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.
Expand Down
Loading

0 comments on commit 6d714a5

Please sign in to comment.