From cb07c6b19b25d643ea0086135aeb75d02ef08ae9 Mon Sep 17 00:00:00 2001 From: Kewen Zhao Date: Tue, 26 Dec 2023 17:02:27 -0500 Subject: [PATCH 1/9] add adaptive threshold --- analog/lora/lora.py | 20 +++++++++++++++----- analog/lora/utils.py | 12 ++++++++++++ examples/mnist_influence/config.yaml | 1 + 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/analog/lora/lora.py b/analog/lora/lora.py index 88d3c430..bfe96c2e 100644 --- a/analog/lora/lora.py +++ b/analog/lora/lora.py @@ -2,9 +2,10 @@ import torch.nn as nn +from analog.constants import FORWARD, BACKWARD from analog.state import AnaLogState from analog.lora.modules import LoraLinear, LoraConv2d, LoraEmbedding -from analog.lora.utils import find_parameter_sharing_group, _get_submodules +from analog.lora.utils import find_parameter_sharing_group, _get_submodules, find_rank_pca_covariance from analog.utils import get_logger @@ -24,6 +25,7 @@ def __init__(self, config: Dict[str, Any], state: AnaLogState): def parse_config(self): self.init_strategy = self.config.get("init", "random") self.rank = self.config.get("rank", 64) + self.adaptive_threshold = self.config.get("adaptive_threshold", None) self.parameter_sharing = self.config.get("parameter_sharing", False) self.parameter_sharing_groups = self.config.get( "parameter_sharing_groups", None @@ -76,20 +78,28 @@ def add_lora( lora_cls = LoraEmbedding psg = find_parameter_sharing_group(name, self.parameter_sharing_groups) + + rank = self.rank + if self.adaptive_threshold is not None: + rank_forward = find_rank_pca_covariance(hessian_state[name][FORWARD], self.adaptive_threshold) + rank_backward = find_rank_pca_covariance(hessian_state[name][BACKWARD], self.adaptive_threshold) + rank = max(rank_forward, rank_backward) + get_logger().info(f"using adaptive r = {rank} for {name}\n") + if self.parameter_sharing and psg not in shared_modules: if isinstance(module, nn.Linear): - shared_module = nn.Linear(self.rank, self.rank, bias=False) + shared_module = nn.Linear(rank, rank, bias=False) elif isinstance(module, nn.Conv1d): shared_module = nn.Conv1d( - self.rank, self.rank, kernel_size=1, bias=False + rank, rank, kernel_size=1, bias=False ) elif isinstance(module, nn.Conv2d): shared_module = nn.Conv2d( - self.rank, self.rank, kernel_size=1, bias=False + rank, rank, kernel_size=1, bias=False ) shared_modules[psg] = shared_module - lora_module = lora_cls(self.rank, module, shared_modules.get(psg, None)) + lora_module = lora_cls(rank, module, shared_modules.get(psg, None)) if self.init_strategy == "pca": lora_module.pca_init_weight(self.init_strategy, hessian_state[name]) lora_module.to(device) diff --git a/analog/lora/utils.py b/analog/lora/utils.py index 55cdea45..8f06f505 100644 --- a/analog/lora/utils.py +++ b/analog/lora/utils.py @@ -2,6 +2,18 @@ import torch +def find_rank_pca_covariance(hessian, threshold): + """ + compute the least pca rank needed for threshold covariance in hessian_state + """ + U, S, Vh = torch.linalg.svd(hessian) + rank = 0 + cur, total = 0, sum(S) + while rank < len(S) and (cur / total) < threshold: + cur += S[rank] + rank += 1 + + return rank def compute_top_k_singular_vectors(matrix, k): """ diff --git a/examples/mnist_influence/config.yaml b/examples/mnist_influence/config.yaml index 04807799..4cc25753 100644 --- a/examples/mnist_influence/config.yaml +++ b/examples/mnist_influence/config.yaml @@ -2,3 +2,4 @@ storage: type: default lora: init: pca + adaptive_threshold: 0.8 From 92910f7191c040725c26f4ac5aa0591e2bd3cfbe Mon Sep 17 00:00:00 2001 From: Kewen Zhao Date: Tue, 26 Dec 2023 21:43:49 -0500 Subject: [PATCH 2/9] format --- analog/lora/lora.py | 22 +++++++++++++--------- analog/lora/utils.py | 2 ++ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/analog/lora/lora.py b/analog/lora/lora.py index bfe96c2e..5a18117a 100644 --- a/analog/lora/lora.py +++ b/analog/lora/lora.py @@ -5,7 +5,11 @@ from analog.constants import FORWARD, BACKWARD from analog.state import AnaLogState from analog.lora.modules import LoraLinear, LoraConv2d, LoraEmbedding -from analog.lora.utils import find_parameter_sharing_group, _get_submodules, find_rank_pca_covariance +from analog.lora.utils import ( + find_parameter_sharing_group, + _get_submodules, + find_rank_pca_covariance, +) from analog.utils import get_logger @@ -81,8 +85,12 @@ def add_lora( rank = self.rank if self.adaptive_threshold is not None: - rank_forward = find_rank_pca_covariance(hessian_state[name][FORWARD], self.adaptive_threshold) - rank_backward = find_rank_pca_covariance(hessian_state[name][BACKWARD], self.adaptive_threshold) + rank_forward = find_rank_pca_covariance( + hessian_state[name][FORWARD], self.adaptive_threshold + ) + rank_backward = find_rank_pca_covariance( + hessian_state[name][BACKWARD], self.adaptive_threshold + ) rank = max(rank_forward, rank_backward) get_logger().info(f"using adaptive r = {rank} for {name}\n") @@ -90,13 +98,9 @@ def add_lora( if isinstance(module, nn.Linear): shared_module = nn.Linear(rank, rank, bias=False) elif isinstance(module, nn.Conv1d): - shared_module = nn.Conv1d( - rank, rank, kernel_size=1, bias=False - ) + shared_module = nn.Conv1d(rank, rank, kernel_size=1, bias=False) elif isinstance(module, nn.Conv2d): - shared_module = nn.Conv2d( - rank, rank, kernel_size=1, bias=False - ) + shared_module = nn.Conv2d(rank, rank, kernel_size=1, bias=False) shared_modules[psg] = shared_module lora_module = lora_cls(rank, module, shared_modules.get(psg, None)) diff --git a/analog/lora/utils.py b/analog/lora/utils.py index 8f06f505..b0717501 100644 --- a/analog/lora/utils.py +++ b/analog/lora/utils.py @@ -2,6 +2,7 @@ import torch + def find_rank_pca_covariance(hessian, threshold): """ compute the least pca rank needed for threshold covariance in hessian_state @@ -15,6 +16,7 @@ def find_rank_pca_covariance(hessian, threshold): return rank + def compute_top_k_singular_vectors(matrix, k): """ Compute the top k singular vectors of a matrix. From a3b3d7ca8371dc29efbfb7ddc75c9c0711992b24 Mon Sep 17 00:00:00 2001 From: Kewen Zhao Date: Tue, 26 Dec 2023 21:49:53 -0500 Subject: [PATCH 3/9] fix comment --- analog/lora/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/analog/lora/utils.py b/analog/lora/utils.py index b0717501..b2c23cdc 100644 --- a/analog/lora/utils.py +++ b/analog/lora/utils.py @@ -3,11 +3,12 @@ import torch -def find_rank_pca_covariance(hessian, threshold): +def find_rank_pca_covariance(matrix, threshold): """ - compute the least pca rank needed for threshold covariance in hessian_state + Calculate the minimum principal component analysis (PCA) rank required + to explain at least the specified percentage (threshold) of the total covariance. """ - U, S, Vh = torch.linalg.svd(hessian) + U, S, Vh = torch.linalg.svd(matrix) rank = 0 cur, total = 0, sum(S) while rank < len(S) and (cur / total) < threshold: From d5f47791c99ea79808fc52b79a4ccd828f757e8d Mon Sep 17 00:00:00 2001 From: Kewen Zhao Date: Fri, 5 Jan 2024 11:57:48 -0500 Subject: [PATCH 4/9] add log=grad as default option to eval --- analog/analog.py | 2 -- analog/logging/option.py | 7 ++++++- examples/mnist_influence/compute_influences_scheduler.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/analog/analog.py b/analog/analog.py index 70ec5d22..3f135a06 100644 --- a/analog/analog.py +++ b/analog/analog.py @@ -324,8 +324,6 @@ def initialize_from_log(self) -> None: lora_state = torch.load(os.path.join(lora_dir, "lora_state_dict.pt")) if not is_lora(self.model): self.add_lora(lora_state=lora_state) - if not any("analog_lora_A" in name for name in self.model.state_dict()): - self.add_lora(lora_state=lora_state) for name in lora_state: assert name in self.model.state_dict(), f"{name} not in model!" self.model.load_state_dict(lora_state, strict=False) diff --git a/analog/logging/option.py b/analog/logging/option.py index 620ca9a5..65407e18 100644 --- a/analog/logging/option.py +++ b/analog/logging/option.py @@ -97,11 +97,16 @@ def _sanity_check(self): ) self._log["grad"] = True - def eval(self): + def eval(self, log="grad"): """ Enable the evaluation mode. This will turn of saving and updating statistic. """ + if isinstance(log, str): + self._log[log] = True + else: + raise ValueError(f"Unsupported log type for eval: {type(log)}") + self.clear(log=False, save=True, statistic=True) def clear(self, log=True, save=True, statistic=True): diff --git a/examples/mnist_influence/compute_influences_scheduler.py b/examples/mnist_influence/compute_influences_scheduler.py index dcda21f7..b5f19acc 100644 --- a/examples/mnist_influence/compute_influences_scheduler.py +++ b/examples/mnist_influence/compute_influences_scheduler.py @@ -89,6 +89,6 @@ # Save if_scores = if_scores.numpy().tolist()[0] -torch.save(if_scores, "if_analog_scheduler_init_from_log_0.8.pt") +torch.save(if_scores, "examples/mnist_influence/if_analog_scheduler.pt") print("Computation time:", time.time() - start) print("Top influential data indices:", top_influential_data.numpy().tolist()) From 4ab1fb873e32de0912403b020aac979f16b4de41 Mon Sep 17 00:00:00 2001 From: Kewen Zhao Date: Mon, 8 Jan 2024 03:08:23 -0500 Subject: [PATCH 5/9] add compression by ratio --- analog/lora/lora.py | 28 +++++++++++++++++++ analog/lora/utils.py | 28 +++++++++++++++++++ examples/cifar_influence/compute_influence.py | 8 ++++-- .../cifar_influence/compute_influences_pca.py | 22 ++++++++------- examples/cifar_influence/config.yaml | 1 + .../compute_influences_scheduler.py | 4 +-- examples/mnist_influence/config.yaml | 2 +- 7 files changed, 77 insertions(+), 16 deletions(-) diff --git a/analog/lora/lora.py b/analog/lora/lora.py index 677b1624..9d7753dc 100644 --- a/analog/lora/lora.py +++ b/analog/lora/lora.py @@ -8,6 +8,7 @@ from analog.lora.utils import ( find_parameter_sharing_group, _get_submodules, + find_rank_pca_compression, find_rank_pca_covariance, pca_rank_by_weight_shape, ) @@ -34,10 +35,14 @@ def parse_config(self): self.compression_ratio_by_covariance = self.config.get( "compression_ratio_by_covariance", None ) + self.compression_ratio_by_memory = self.config.get( + "compression_ratio_by_memory", None + ) self.parameter_sharing = self.config.get("parameter_sharing", False) self.parameter_sharing_groups = self.config.get( "parameter_sharing_groups", None ) + self._sanity_check() def add_lora( self, @@ -82,6 +87,7 @@ def add_lora( psg = find_parameter_sharing_group(name, self.parameter_sharing_groups) rank_forward = rank_backward = self.rank_default # default rank + if lora_state is not None: # add lora matching the rank of the lora_state rank_forward, rank_backward = pca_rank_by_weight_shape( lora_state[name + ".analog_lora_B.weight"].shape, module @@ -101,6 +107,17 @@ def add_lora( get_logger().info( f"using adaptive rank_forward = {rank_forward}, rank_backward = {rank_backward} for {name}\n" ) + elif ( + self.init_strategy == "pca" + and self.compression_ratio_by_memory is not None + ): + rank_forward = rank_backward = find_rank_pca_compression( + module, + self.compression_ratio_by_memory, + ) + get_logger().info( + f"using adaptive rank_forward = {rank_forward}, rank_backward = {rank_backward} for {name}\n" + ) if self.parameter_sharing and psg not in shared_modules: if isinstance(module, nn.Linear): @@ -124,3 +141,14 @@ def add_lora( parent, target, target_name = _get_submodules(model, name) setattr(parent, target_name, lora_module) + + def _sanity_check(self): + if ( + self.init_strategy == "pca" + and self.compression_ratio_by_covariance is not None + and self.compression_ratio_by_memory is not None + ): + get_logger().warning( + "compression_ratio_by_covariance and compression_ratio_by_memory are both set. " + + "compression_ratio_by_covariance will be used." + ) diff --git a/analog/lora/utils.py b/analog/lora/utils.py index 2c789a30..7d51135c 100644 --- a/analog/lora/utils.py +++ b/analog/lora/utils.py @@ -1,5 +1,6 @@ from typing import List +import math import torch import torch.nn as nn @@ -19,6 +20,33 @@ def find_rank_pca_covariance(matrix, threshold): return rank +def find_rank_pca_compression(module, ratio): + """ + Calculate the minimum principal component analysis (PCA) rank required + to reach threshold compression ratio. + """ + weight = module.weight.detach().cpu().numpy() + if isinstance(module, nn.Linear): + # r * r = m * n * ratio + in_features, out_features = weight.shape + rank = math.ceil(math.sqrt(in_features * out_features * ratio)) + elif isinstance(module, nn.Conv2d): + # r * r * 1 * 1 = in_channels * out_channels * kernel_size[0] * kernel_size[1] * ratio + in_channels, out_channels, kernel_size0, kernel_size1 = weight.shape + rank = math.ceil( + math.sqrt(in_channels * out_channels * kernel_size0 * kernel_size1 * ratio) + ) + return rank + elif isinstance(module, nn.Embedding): + # r * r = m * n * ratio + num_embeddings, embedding_dim = weight.shape + rank = math.ceil(math.sqrt(num_embeddings * embedding_dim * ratio)) + else: + raise NotImplementedError + + return rank + + def pca_rank_by_weight_shape(shape, module): if isinstance(module, nn.Linear): assert len(shape) == 2 diff --git a/examples/cifar_influence/compute_influence.py b/examples/cifar_influence/compute_influence.py index 1126e19c..9349b051 100644 --- a/examples/cifar_influence/compute_influence.py +++ b/examples/cifar_influence/compute_influence.py @@ -41,9 +41,8 @@ # Gradient & Hessian logging analog.watch(model) analog.setup({"log": "grad", "save": "grad", "statistic": "kfac"}) - +id_gen = DataIDGenerator() if not args.resume: - id_gen = DataIDGenerator() for inputs, targets in train_loader: data_id = id_gen(inputs) with analog(data_id=data_id): @@ -62,7 +61,10 @@ analog.add_analysis({"influence": InfluenceFunction}) query_iter = iter(query_loader) -with analog(log=["grad"]) as al: +test_input, test_target = next(query_iter) +test_id = id_gen(test_input) +analog.eval() +with analog(data_id=test_id) as al: test_input, test_target = next(query_iter) test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) model.zero_grad() diff --git a/examples/cifar_influence/compute_influences_pca.py b/examples/cifar_influence/compute_influences_pca.py index 924b03fc..faafee7d 100644 --- a/examples/cifar_influence/compute_influences_pca.py +++ b/examples/cifar_influence/compute_influences_pca.py @@ -16,6 +16,9 @@ parser.add_argument("--data", type=str, default="cifar10", help="cifar10/100") parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0]) parser.add_argument("--damping", type=float, default=1e-5) +parser.add_argument("--ekfac", action="store_true") +parser.add_argument("--lora", action="store_true") +parser.add_argument("--sample", action="store_true") parser.add_argument("--resume", action="store_true") args = parser.parse_args() @@ -39,13 +42,12 @@ ) analog = AnaLog(project="test", config="./config.yaml") -analog_scheduler = AnaLogScheduler(analog, lora=True) +analog_scheduler = AnaLogScheduler(analog, lora=args.lora) # Gradient & Hessian logging analog.watch(model) - -if True: - id_gen = DataIDGenerator() +id_gen = DataIDGenerator() +if not args.resume: for epoch in analog_scheduler: for inputs, targets in train_loader: data_id = id_gen(inputs) @@ -57,17 +59,17 @@ loss.backward() analog.finalize() else: - analog.add_lora() analog.initialize_from_log() # Influence Analysis log_loader = analog.build_log_dataloader() -analog.eval() analog.add_analysis({"influence": InfluenceFunction}) query_iter = iter(query_loader) -with analog(log=["grad"]) as al: - test_input, test_target = next(query_iter) +test_input, test_target = next(query_iter) +test_id = id_gen(test_input) +analog.eval() +with analog(data_id=test_id) as al: test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) model.zero_grad() test_out = model(test_input) @@ -82,6 +84,6 @@ ) # Save -if_scores = if_scores.numpy().tolist() -torch.save(if_scores, "if_analog_pca.pt") +if_scores = if_scores.numpy().tolist()[0] +torch.save(if_scores, f"if_analog_pca.pt") print("Computation time:", time.time() - start) diff --git a/examples/cifar_influence/config.yaml b/examples/cifar_influence/config.yaml index 04807799..c1b0384a 100644 --- a/examples/cifar_influence/config.yaml +++ b/examples/cifar_influence/config.yaml @@ -2,3 +2,4 @@ storage: type: default lora: init: pca + compression_ratio_by_memory: 0.1 diff --git a/examples/mnist_influence/compute_influences_scheduler.py b/examples/mnist_influence/compute_influences_scheduler.py index b5f19acc..2ac9357b 100644 --- a/examples/mnist_influence/compute_influences_scheduler.py +++ b/examples/mnist_influence/compute_influences_scheduler.py @@ -38,7 +38,7 @@ batch_size=1, split="valid", shuffle=False, indices=args.eval_idxs ) -analog = AnaLog(project="test", config="examples/mnist_influence/config.yaml") +analog = AnaLog(project="test", config="config.yaml") al_scheduler = AnaLogScheduler( analog, ekfac=args.ekfac, lora=args.lora, sample=args.sample ) @@ -89,6 +89,6 @@ # Save if_scores = if_scores.numpy().tolist()[0] -torch.save(if_scores, "examples/mnist_influence/if_analog_scheduler.pt") +torch.save(if_scores, f"if_analog_scheduler.pt") print("Computation time:", time.time() - start) print("Top influential data indices:", top_influential_data.numpy().tolist()) diff --git a/examples/mnist_influence/config.yaml b/examples/mnist_influence/config.yaml index 5101af10..eb54fde7 100644 --- a/examples/mnist_influence/config.yaml +++ b/examples/mnist_influence/config.yaml @@ -1,3 +1,3 @@ lora: init: pca - compression_ratio_by_covariance: 0.8 + compression_ratio_by_memory: 0.1 From fd0f5f12c1e1ae0e1e2e6d7699600d2ff7db2d37 Mon Sep 17 00:00:00 2001 From: Kewen Zhao Date: Mon, 8 Jan 2024 23:57:20 -0500 Subject: [PATCH 6/9] add lora test, add eval warning --- analog/logging/option.py | 5 +- tests/examples/configs/lora.yaml | 3 + tests/examples/test_add_lora.py | 91 +++++++++++++++++++++++ tests/examples/test_compute_influences.py | 49 +----------- tests/examples/utils.py | 51 +++++++++++++ 5 files changed, 151 insertions(+), 48 deletions(-) create mode 100644 tests/examples/configs/lora.yaml create mode 100644 tests/examples/test_add_lora.py create mode 100644 tests/examples/utils.py diff --git a/analog/logging/option.py b/analog/logging/option.py index 65407e18..b95b7484 100644 --- a/analog/logging/option.py +++ b/analog/logging/option.py @@ -97,11 +97,14 @@ def _sanity_check(self): ) self._log["grad"] = True - def eval(self, log="grad"): + def eval(self, log=None): """ Enable the evaluation mode. This will turn of saving and updating statistic. """ + if log is None: + get_logger().warning("we automatically set 'log' to 'grad'. if this is not a desired behavior, please explicitly set your 'log' value.") + log = "grad" if isinstance(log, str): self._log[log] = True else: diff --git a/tests/examples/configs/lora.yaml b/tests/examples/configs/lora.yaml new file mode 100644 index 00000000..b37f67c7 --- /dev/null +++ b/tests/examples/configs/lora.yaml @@ -0,0 +1,3 @@ +lora: + init: random + rank: 2 diff --git a/tests/examples/test_add_lora.py b/tests/examples/test_add_lora.py new file mode 100644 index 00000000..509bf8bc --- /dev/null +++ b/tests/examples/test_add_lora.py @@ -0,0 +1,91 @@ +import unittest +import torch +import torchvision +import torch.nn as nn +import numpy as np +import os + +from analog import AnaLog, AnaLogScheduler +from analog.utils import DataIDGenerator +from analog.analysis import InfluenceFunction +from tests.examples.utils import get_mnist_dataloader, construct_mlp + +DEVICE = torch.device("cpu") + + +class TestAddLora(unittest.TestCase): + def test_add_lora(self): + eval_idxs = (1,) + train_idxs = [i for i in range(0, 50000, 1000)] + + model = construct_mlp().to(DEVICE) + # Get a single checkpoint (first model_id and last epoch). + model.load_state_dict( + torch.load( + f"{os.path.dirname(os.path.abspath(__file__))}/checkpoints/mnist_0_epoch_9.pt", + map_location="cpu", + ) + ) + model.eval() + + dataloader_fn = get_mnist_dataloader + train_loader = dataloader_fn( + batch_size=512, split="train", shuffle=False, indices=train_idxs + ) + query_loader = dataloader_fn( + batch_size=1, split="valid", shuffle=False, indices=eval_idxs + ) + + analog = AnaLog( + project="test", + config=f"{os.path.dirname(os.path.abspath(__file__))}/configs/lora.yaml", + ) + # Gradient & Hessian logging + al_scheduler = AnaLogScheduler(analog, ekfac=False, lora=False, sample=False) + analog.watch(model) + id_gen = DataIDGenerator() + for epoch in al_scheduler: + for inputs, targets in train_loader: + data_id = id_gen(inputs) + with analog(data_id=data_id): + inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) + model.zero_grad() + outs = model(inputs) + loss = torch.nn.functional.cross_entropy( + outs, targets, reduction="sum" + ) + loss.backward() + analog.finalize() + + log_loader = analog.build_log_dataloader() + + analog.add_analysis({"influence": InfluenceFunction}) + query_iter = iter(query_loader) + test_input, test_target = next(query_iter) + test_id = id_gen(test_input) + analog.eval() + with analog(data_id=test_id): + test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) + model.zero_grad() + test_out = model(test_input) + test_loss = torch.nn.functional.cross_entropy( + test_out, test_target, reduction="sum" + ) + test_loss.backward() + test_log = analog.get_log() + if_scores = analog.influence.compute_influence_all( + test_log, log_loader, damping=1e-5 + ) + + # Save + if_scores = if_scores[0] + print(if_scores) + # torch.save(if_scores, f"{os.path.dirname(os.path.abspath(__file__))}/if_analog_lora.pt") + if_score_saved = torch.load( + f"{os.path.dirname(os.path.abspath(__file__))}/if_analog_lora.pt" + ) + self.assertTrue(torch.allclose(if_score_saved, if_scores)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/examples/test_compute_influences.py b/tests/examples/test_compute_influences.py index f5831bb4..f44c597c 100644 --- a/tests/examples/test_compute_influences.py +++ b/tests/examples/test_compute_influences.py @@ -5,54 +5,9 @@ import numpy as np import os -DEVICE = torch.device("cpu") - - -def construct_mlp(num_inputs=784, num_classes=10): - return torch.nn.Sequential( - nn.Flatten(), - nn.Linear(num_inputs, 4, bias=False), - nn.ReLU(), - nn.Linear(4, 2, bias=False), - nn.ReLU(), - nn.Linear(2, num_classes, bias=False), - ) - +from tests.examples.utils import get_mnist_dataloader, construct_mlp -def get_mnist_dataloader( - batch_size=128, - split="train", - shuffle=False, - subsample=False, - indices=None, - drop_last=False, -): - transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize((0.1307,), (0.3081,)), - ] - ) - is_train = split == "train" - dataset = torchvision.datasets.MNIST( - root="/tmp/mnist/", download=True, train=is_train, transform=transforms - ) - - if subsample and split == "train" and indices is None: - dataset = torch.utils.data.Subset(dataset, np.arange(6_000)) - - if indices is not None: - if subsample and split == "train": - print("Overriding `subsample` argument as `indices` was provided.") - dataset = torch.utils.data.Subset(dataset, indices) - - return torch.utils.data.DataLoader( - dataset=dataset, - shuffle=shuffle, - batch_size=batch_size, - num_workers=0, - drop_last=drop_last, - ) +DEVICE = torch.device("cpu") class TestSingleCheckpointInfluence(unittest.TestCase): diff --git a/tests/examples/utils.py b/tests/examples/utils.py new file mode 100644 index 00000000..b638d9fd --- /dev/null +++ b/tests/examples/utils.py @@ -0,0 +1,51 @@ +import torch +import torchvision +import torch.nn as nn +import numpy as np + + +def construct_mlp(num_inputs=784, num_classes=10): + return torch.nn.Sequential( + nn.Flatten(), + nn.Linear(num_inputs, 4, bias=False), + nn.ReLU(), + nn.Linear(4, 2, bias=False), + nn.ReLU(), + nn.Linear(2, num_classes, bias=False), + ) + + +def get_mnist_dataloader( + batch_size=128, + split="train", + shuffle=False, + subsample=False, + indices=None, + drop_last=False, +): + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.1307,), (0.3081,)), + ] + ) + is_train = split == "train" + dataset = torchvision.datasets.MNIST( + root="/tmp/mnist/", download=True, train=is_train, transform=transforms + ) + + if subsample and split == "train" and indices is None: + dataset = torch.utils.data.Subset(dataset, np.arange(6_000)) + + if indices is not None: + if subsample and split == "train": + print("Overriding `subsample` argument as `indices` was provided.") + dataset = torch.utils.data.Subset(dataset, indices) + + return torch.utils.data.DataLoader( + dataset=dataset, + shuffle=shuffle, + batch_size=batch_size, + num_workers=0, + drop_last=drop_last, + ) From 122a8982161f6bb929628d95349289155c41104d Mon Sep 17 00:00:00 2001 From: Kewen Zhao Date: Tue, 9 Jan 2024 10:07:46 -0500 Subject: [PATCH 7/9] format --- analog/logging/option.py | 4 +++- tests/examples/data/if_analog_lora.pt | Bin 0 -> 1407 bytes tests/examples/test_add_lora.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) create mode 100644 tests/examples/data/if_analog_lora.pt diff --git a/analog/logging/option.py b/analog/logging/option.py index b95b7484..088ae6c6 100644 --- a/analog/logging/option.py +++ b/analog/logging/option.py @@ -103,7 +103,9 @@ def eval(self, log=None): statistic. """ if log is None: - get_logger().warning("we automatically set 'log' to 'grad'. if this is not a desired behavior, please explicitly set your 'log' value.") + get_logger().warning( + "we automatically set 'log' to 'grad'. if this is not a desired behavior, please explicitly set your 'log' value." + ) log = "grad" if isinstance(log, str): self._log[log] = True diff --git a/tests/examples/data/if_analog_lora.pt b/tests/examples/data/if_analog_lora.pt new file mode 100644 index 0000000000000000000000000000000000000000..217e502c50eeac67c6ed7b15dbdaf8dc14ee05d6 GIT binary patch literal 1407 zcmWIWW@cev;NW1u0OAbX44G;1iFt`R`RVaF`9+EPDTyVCdIi}zZcgkBQ4r9;lw6Wu zl$@a#Us{rxQ_K}#l$unUnUfMV$!<;xFmFPAi%I2VmXzeAg6wOxeJsfavVqrNbsb!~Ll3QG>;S71 zE8SraQm5m~fAH5VM`sXwL%^B+AQ}YrH(7$!X}_Cw;MzsDLm;u)UpWuj8^$?oh&FZH zV0hY5XnD+`1$y94M9%A5-0{d1M<)e zfoBtRQ;?%k6vdP Date: Tue, 9 Jan 2024 12:25:25 -0500 Subject: [PATCH 8/9] fix bug --- tests/examples/configs/lora.yaml | 2 +- tests/examples/data/if_analog_lora.pt | Bin 1407 -> 1407 bytes tests/examples/test_add_lora.py | 4 ++-- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/examples/configs/lora.yaml b/tests/examples/configs/lora.yaml index b37f67c7..236d409e 100644 --- a/tests/examples/configs/lora.yaml +++ b/tests/examples/configs/lora.yaml @@ -1,3 +1,3 @@ lora: - init: random + init: pca rank: 2 diff --git a/tests/examples/data/if_analog_lora.pt b/tests/examples/data/if_analog_lora.pt index 217e502c50eeac67c6ed7b15dbdaf8dc14ee05d6..484e5563080883633739aed259e414dc49a0dc6b 100644 GIT binary patch delta 266 zcmey*^`C3Q0mgd&^HGkMFETkaotbxt0SK%c4;%#14afF41TWt2m>tr7uy2pELxH`r z(@Ddt4i2({V09m|_gaF~x%ad=+MaW92CEz!uhC_+TGs z#{)J{++))LkQ@jEc(Zf7*e+SQc>)tRqo}Ecp{covk%5tcg^2-3O2{x(W%35*#VjB- Plfzk>S->LqSQY{RF)U4w delta 266 zcmey*^`C3Q0mgbg&FF&@vacU>QtokL00Pmy&WAztrwElp2j_)3Elg>4YT>CqxFTKG z!KFL&&|1b0usX5Q9rhq~I==h|f6a1q2C+8;oY@beL12HAC0L#IyIBXWU1U215}W;% z^Ps(9oWq7_Q^yU4ryYfs#{?c)@N9>}roVn*#ZObFf(->JexP#?ZWxHq!yxDYq94sU zad6kJQ;s0EgZE~*`T%csju}<92Ad}^aWjgV8<`ncSXvmETbfyzfTThK>!wcLz`U3R Qq-JtBOEU{t Date: Tue, 9 Jan 2024 16:58:42 -0500 Subject: [PATCH 9/9] bug fix --- analog/lora/lora.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/analog/lora/lora.py b/analog/lora/lora.py index 9d7753dc..512b8308 100644 --- a/analog/lora/lora.py +++ b/analog/lora/lora.py @@ -92,10 +92,7 @@ def add_lora( rank_forward, rank_backward = pca_rank_by_weight_shape( lora_state[name + ".analog_lora_B.weight"].shape, module ) - elif ( - self.init_strategy == "pca" - and self.compression_ratio_by_covariance is not None - ): + elif self.compression_ratio_by_covariance is not None: rank_forward = find_rank_pca_covariance( covariance_state[name][FORWARD], self.compression_ratio_by_covariance, @@ -107,10 +104,7 @@ def add_lora( get_logger().info( f"using adaptive rank_forward = {rank_forward}, rank_backward = {rank_backward} for {name}\n" ) - elif ( - self.init_strategy == "pca" - and self.compression_ratio_by_memory is not None - ): + elif self.compression_ratio_by_memory is not None: rank_forward = rank_backward = find_rank_pca_compression( module, self.compression_ratio_by_memory,