Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive lora #66

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def add_lora(
model: Optional[nn.Module] = None,
watch: bool = True,
clear: bool = True,
lora_state: Dict[str, Any] = None,
) -> None:
"""
Adds LoRA for gradient compression.
Expand All @@ -140,6 +141,7 @@ def add_lora(
model=model,
type_filter=self.type_filter,
name_filter=self.name_filter,
lora_state=lora_state,
)

# Clear state and logger
Expand Down Expand Up @@ -319,9 +321,9 @@ def initialize_from_log(self) -> None:
# Load LoRA state
lora_dir = os.path.join(self.log_dir, "lora")
if os.path.exists(lora_dir):
if not is_lora(self.model):
self.add_lora()
lora_state = torch.load(os.path.join(lora_dir, "lora_state_dict.pt"))
if not is_lora(self.model):
self.add_lora(lora_state=lora_state)
for name in lora_state:
assert name in self.model.state_dict(), f"{name} not in model!"
self.model.load_state_dict(lora_state, strict=False)
Expand Down
7 changes: 6 additions & 1 deletion analog/logging/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,16 @@ def _sanity_check(self):
)
self._log["grad"] = True

def eval(self):
def eval(self, log="grad"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having "grad" as a default value, what do you think about having None as a default value, and when it's None we set it to "grad" with a warning message like:

def eval(self, log=None):
    if log is None:
        get_logger().warning("we automatically set 'log' to 'grad'. if this is not a desired behavior, please explicitly set your 'log' value.")
        log = "grad"

    if isinstance(log, str):
        ...

"""
Enable the evaluation mode. This will turn of saving and updating
statistic.
"""
if isinstance(log, str):
self._log[log] = True
else:
raise ValueError(f"Unsupported log type for eval: {type(log)}")

self.clear(log=False, save=True, statistic=True)

def clear(self, log=True, save=True, statistic=True):
Expand Down
73 changes: 68 additions & 5 deletions analog/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

import torch.nn as nn

from analog.constants import FORWARD, BACKWARD
from analog.state import StatisticState
from analog.lora.modules import LoraLinear, LoraConv2d, LoraEmbedding
from analog.lora.utils import (
find_parameter_sharing_group,
_get_submodules,
find_rank_pca_compression,
find_rank_pca_covariance,
pca_rank_by_weight_shape,
)
from analog.lora.utils import find_parameter_sharing_group, _get_submodules
from analog.utils import get_logger, module_check

Expand All @@ -23,17 +31,25 @@ def __init__(self, config: Dict[str, Any], state: StatisticState):

def parse_config(self):
self.init_strategy = self.config.get("init", "random")
self.rank = self.config.get("rank", 64)
self.rank_default = self.config.get("rank", 64)
self.compression_ratio_by_covariance = self.config.get(
"compression_ratio_by_covariance", None
)
self.compression_ratio_by_memory = self.config.get(
"compression_ratio_by_memory", None
)
self.parameter_sharing = self.config.get("parameter_sharing", False)
self.parameter_sharing_groups = self.config.get(
"parameter_sharing_groups", None
)
self._sanity_check()

def add_lora(
self,
model: nn.Module,
type_filter: List[nn.Module],
name_filter: List[str],
lora_state: Dict[str, Any] = None,
):
"""
Add LoRA modules to a model.
Expand Down Expand Up @@ -69,23 +85,70 @@ def add_lora(
lora_cls = LoraEmbedding

psg = find_parameter_sharing_group(name, self.parameter_sharing_groups)

rank_forward = rank_backward = self.rank_default # default rank

if lora_state is not None: # add lora matching the rank of the lora_state
rank_forward, rank_backward = pca_rank_by_weight_shape(
lora_state[name + ".analog_lora_B.weight"].shape, module
)
elif (
self.init_strategy == "pca"
and self.compression_ratio_by_covariance is not None
):
rank_forward = find_rank_pca_covariance(
covariance_state[name][FORWARD],
self.compression_ratio_by_covariance,
)
rank_backward = find_rank_pca_covariance(
covariance_state[name][BACKWARD],
self.compression_ratio_by_covariance,
)
get_logger().info(
f"using adaptive rank_forward = {rank_forward}, rank_backward = {rank_backward} for {name}\n"
)
elif (
self.init_strategy == "pca"
and self.compression_ratio_by_memory is not None
):
rank_forward = rank_backward = find_rank_pca_compression(
module,
self.compression_ratio_by_memory,
)
get_logger().info(
f"using adaptive rank_forward = {rank_forward}, rank_backward = {rank_backward} for {name}\n"
)

if self.parameter_sharing and psg not in shared_modules:
if isinstance(module, nn.Linear):
shared_module = nn.Linear(self.rank, self.rank, bias=False)
shared_module = nn.Linear(rank_forward, rank_backward, bias=False)
elif isinstance(module, nn.Conv1d):
shared_module = nn.Conv1d(
self.rank, self.rank, kernel_size=1, bias=False
rank_forward, rank_backward, kernel_size=1, bias=False
)
elif isinstance(module, nn.Conv2d):
shared_module = nn.Conv2d(
self.rank, self.rank, kernel_size=1, bias=False
rank_forward, rank_backward, kernel_size=1, bias=False
)
shared_modules[psg] = shared_module

lora_module = lora_cls(self.rank, module, shared_modules.get(psg, None))
lora_module = lora_cls(
rank_forward, rank_backward, module, shared_modules.get(psg, None)
)
if self.init_strategy == "pca":
lora_module.pca_init_weight(covariance_state[name])
lora_module.to(device)

parent, target, target_name = _get_submodules(model, name)
setattr(parent, target_name, lora_module)

def _sanity_check(self):
if (
self.init_strategy == "pca"
and self.compression_ratio_by_covariance is not None
and self.compression_ratio_by_memory is not None
):
get_logger().warning(
"compression_ratio_by_covariance and compression_ratio_by_memory are both set. "
+ "compression_ratio_by_covariance will be used."
)
58 changes: 40 additions & 18 deletions analog/lora/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@


class LoraLinear(nn.Linear):
def __init__(self, rank: int, linear: nn.Linear, shared_module: nn.Linear = None):
def __init__(
self,
rank_forward: int,
rank_backward: int,
linear: nn.Linear,
shared_module: nn.Linear = None,
):
"""Transforms a linear layer into a LoraLinear layer.

Args:
Expand All @@ -19,13 +25,14 @@ def __init__(self, rank: int, linear: nn.Linear, shared_module: nn.Linear = None
out_features = linear.out_features

super().__init__(in_features, out_features)
self.rank = min(rank, in_features, out_features)
self.rank_forward = min(rank_forward, in_features)
self.rank_backward = min(rank_backward, out_features)

self.analog_lora_A = nn.Linear(in_features, self.rank, bias=False)
self.analog_lora_A = nn.Linear(in_features, self.rank_forward, bias=False)
self.analog_lora_B = shared_module or nn.Linear(
self.rank, self.rank, bias=False
self.rank_forward, self.rank_backward, bias=False
)
self.analog_lora_C = nn.Linear(self.rank, out_features, bias=False)
self.analog_lora_C = nn.Linear(self.rank_backward, out_features, bias=False)

nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.analog_lora_B.weight)
Expand All @@ -49,17 +56,23 @@ def pca_init_weight(self, covariance=None):
(
top_r_singular_vector_forward,
top_r_singular_value_forward,
) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank)
) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank_forward)
(
top_r_singular_vector_backward,
top_r_singular_value_backward,
) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank)
) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank_backward)
self.analog_lora_A.weight.data.copy_(top_r_singular_vector_forward.T)
self.analog_lora_C.weight.data.copy_(top_r_singular_vector_backward)


class LoraConv2d(nn.Conv2d):
def __init__(self, rank: int, conv: nn.Conv2d, shared_module: nn.Conv2d = None):
def __init__(
self,
rank_forward: int,
rank_backward: int,
conv: nn.Conv2d,
shared_module: nn.Conv2d = None,
):
"""Transforms a conv2d layer into a LoraConv2d layer.

Args:
Expand All @@ -76,15 +89,23 @@ def __init__(self, rank: int, conv: nn.Conv2d, shared_module: nn.Conv2d = None):
in_channels, out_channels, kernel_size, stride, padding, bias=False
)

self.rank = min(rank, self.in_channels, self.out_channels)
self.rank_forward = min(rank_forward, in_channels)
self.rank_backward = min(rank_backward, out_channels)

self.analog_lora_A = nn.Conv2d(
self.in_channels, self.rank, kernel_size, stride, padding, bias=False
self.in_channels,
self.rank_forward,
kernel_size,
stride,
padding,
bias=False,
)
self.analog_lora_B = shared_module or nn.Conv2d(
self.rank, self.rank, 1, bias=False
self.rank_forward, self.rank_backward, 1, bias=False
)
self.analog_lora_C = nn.Conv2d(
self.rank_backward, self.out_channels, 1, bias=False
)
self.analog_lora_C = nn.Conv2d(self.rank, self.out_channels, 1, bias=False)

nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.analog_lora_B.weight)
Expand All @@ -108,11 +129,11 @@ def pca_init_weight(self, covariance):
(
top_r_singular_vector_forward,
top_r_singular_value_forward,
) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank)
) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank_forward)
(
top_r_singular_vector_backward,
top_r_singular_value_backward,
) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank)
) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank_backward)
shape_A = self.analog_lora_A.weight.shape
shape_C = self.analog_lora_C.weight.shape
self.analog_lora_A.weight.data.copy_(
Expand All @@ -137,13 +158,14 @@ def __init__(
embedding_dim = embedding.embedding_dim

super().__init__(num_embeddings, embedding_dim)
self.rank = min(rank, num_embeddings, embedding_dim)
self.rank_forward = min(rank, num_embeddings)
self.rank_backward = min(rank, embedding_dim)

self.analog_lora_A = nn.Embedding(num_embeddings, self.rank)
self.analog_lora_A = nn.Embedding(num_embeddings, self.rank_forward)
self.analog_lora_B = shared_module or nn.Linear(
self.rank, self.rank, bias=False
self.rank_forward, self.rank_backward, bias=False
)
self.analog_lora_C = nn.Linear(self.rank, embedding_dim, bias=False)
self.analog_lora_C = nn.Linear(self.rank_backward, embedding_dim, bias=False)

nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.analog_lora_B.weight)
Expand Down
56 changes: 56 additions & 0 deletions analog/lora/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
from typing import List

import math
import torch
import torch.nn as nn


def find_rank_pca_covariance(matrix, threshold):
"""
Calculate the minimum principal component analysis (PCA) rank required
to explain at least the specified percentage (threshold) of the total covariance.
"""
U, S, Vh = torch.linalg.svd(matrix)
rank = 0
cur, total = 0, sum(S)
while rank < len(S) and (cur / total) < threshold:
cur += S[rank]
rank += 1

return rank


def find_rank_pca_compression(module, ratio):
"""
Calculate the minimum principal component analysis (PCA) rank required
to reach threshold compression ratio.
"""
weight = module.weight.detach().cpu().numpy()
if isinstance(module, nn.Linear):
# r * r = m * n * ratio
in_features, out_features = weight.shape
rank = math.ceil(math.sqrt(in_features * out_features * ratio))
elif isinstance(module, nn.Conv2d):
# r * r * 1 * 1 = in_channels * out_channels * kernel_size[0] * kernel_size[1] * ratio
in_channels, out_channels, kernel_size0, kernel_size1 = weight.shape
rank = math.ceil(
math.sqrt(in_channels * out_channels * kernel_size0 * kernel_size1 * ratio)
)
return rank
elif isinstance(module, nn.Embedding):
# r * r = m * n * ratio
num_embeddings, embedding_dim = weight.shape
rank = math.ceil(math.sqrt(num_embeddings * embedding_dim * ratio))
else:
raise NotImplementedError

return rank


def pca_rank_by_weight_shape(shape, module):
if isinstance(module, nn.Linear):
assert len(shape) == 2
return shape[1], shape[0]
elif isinstance(module, nn.Conv2d):
assert len(shape) == 4
return shape[1], shape[0]
elif isinstance(module, nn.Embedding):
assert len(shape) == 2
return shape[1], shape[0]


def is_lora(model):
Expand Down
8 changes: 5 additions & 3 deletions examples/cifar_influence/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@
# Gradient & Hessian logging
analog.watch(model)
analog.setup({"log": "grad", "save": "grad", "statistic": "kfac"})

id_gen = DataIDGenerator()
if not args.resume:
id_gen = DataIDGenerator()
for inputs, targets in train_loader:
data_id = id_gen(inputs)
with analog(data_id=data_id):
Expand All @@ -62,7 +61,10 @@

analog.add_analysis({"influence": InfluenceFunction})
query_iter = iter(query_loader)
with analog(log=["grad"]) as al:
test_input, test_target = next(query_iter)
test_id = id_gen(test_input)
analog.eval()
with analog(data_id=test_id) as al:
test_input, test_target = next(query_iter)
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
model.zero_grad()
Expand Down
Loading
Loading