diff --git a/analog/analog.py b/analog/analog.py index b95f4a9f..3f135a06 100644 --- a/analog/analog.py +++ b/analog/analog.py @@ -118,6 +118,7 @@ def add_lora( model: Optional[nn.Module] = None, watch: bool = True, clear: bool = True, + lora_state: Dict[str, Any] = None, ) -> None: """ Adds LoRA for gradient compression. @@ -140,6 +141,7 @@ def add_lora( model=model, type_filter=self.type_filter, name_filter=self.name_filter, + lora_state=lora_state, ) # Clear state and logger @@ -319,9 +321,9 @@ def initialize_from_log(self) -> None: # Load LoRA state lora_dir = os.path.join(self.log_dir, "lora") if os.path.exists(lora_dir): - if not is_lora(self.model): - self.add_lora() lora_state = torch.load(os.path.join(lora_dir, "lora_state_dict.pt")) + if not is_lora(self.model): + self.add_lora(lora_state=lora_state) for name in lora_state: assert name in self.model.state_dict(), f"{name} not in model!" self.model.load_state_dict(lora_state, strict=False) diff --git a/analog/logging/option.py b/analog/logging/option.py index 620ca9a5..088ae6c6 100644 --- a/analog/logging/option.py +++ b/analog/logging/option.py @@ -97,11 +97,21 @@ def _sanity_check(self): ) self._log["grad"] = True - def eval(self): + def eval(self, log=None): """ Enable the evaluation mode. This will turn of saving and updating statistic. """ + if log is None: + get_logger().warning( + "we automatically set 'log' to 'grad'. if this is not a desired behavior, please explicitly set your 'log' value." + ) + log = "grad" + if isinstance(log, str): + self._log[log] = True + else: + raise ValueError(f"Unsupported log type for eval: {type(log)}") + self.clear(log=False, save=True, statistic=True) def clear(self, log=True, save=True, statistic=True): diff --git a/analog/lora/lora.py b/analog/lora/lora.py index 7a575974..512b8308 100644 --- a/analog/lora/lora.py +++ b/analog/lora/lora.py @@ -2,8 +2,16 @@ import torch.nn as nn +from analog.constants import FORWARD, BACKWARD from analog.state import StatisticState from analog.lora.modules import LoraLinear, LoraConv2d, LoraEmbedding +from analog.lora.utils import ( + find_parameter_sharing_group, + _get_submodules, + find_rank_pca_compression, + find_rank_pca_covariance, + pca_rank_by_weight_shape, +) from analog.lora.utils import find_parameter_sharing_group, _get_submodules from analog.utils import get_logger, module_check @@ -23,17 +31,25 @@ def __init__(self, config: Dict[str, Any], state: StatisticState): def parse_config(self): self.init_strategy = self.config.get("init", "random") - self.rank = self.config.get("rank", 64) + self.rank_default = self.config.get("rank", 64) + self.compression_ratio_by_covariance = self.config.get( + "compression_ratio_by_covariance", None + ) + self.compression_ratio_by_memory = self.config.get( + "compression_ratio_by_memory", None + ) self.parameter_sharing = self.config.get("parameter_sharing", False) self.parameter_sharing_groups = self.config.get( "parameter_sharing_groups", None ) + self._sanity_check() def add_lora( self, model: nn.Module, type_filter: List[nn.Module], name_filter: List[str], + lora_state: Dict[str, Any] = None, ): """ Add LoRA modules to a model. @@ -69,23 +85,64 @@ def add_lora( lora_cls = LoraEmbedding psg = find_parameter_sharing_group(name, self.parameter_sharing_groups) + + rank_forward = rank_backward = self.rank_default # default rank + + if lora_state is not None: # add lora matching the rank of the lora_state + rank_forward, rank_backward = pca_rank_by_weight_shape( + lora_state[name + ".analog_lora_B.weight"].shape, module + ) + elif self.compression_ratio_by_covariance is not None: + rank_forward = find_rank_pca_covariance( + covariance_state[name][FORWARD], + self.compression_ratio_by_covariance, + ) + rank_backward = find_rank_pca_covariance( + covariance_state[name][BACKWARD], + self.compression_ratio_by_covariance, + ) + get_logger().info( + f"using adaptive rank_forward = {rank_forward}, rank_backward = {rank_backward} for {name}\n" + ) + elif self.compression_ratio_by_memory is not None: + rank_forward = rank_backward = find_rank_pca_compression( + module, + self.compression_ratio_by_memory, + ) + get_logger().info( + f"using adaptive rank_forward = {rank_forward}, rank_backward = {rank_backward} for {name}\n" + ) + if self.parameter_sharing and psg not in shared_modules: if isinstance(module, nn.Linear): - shared_module = nn.Linear(self.rank, self.rank, bias=False) + shared_module = nn.Linear(rank_forward, rank_backward, bias=False) elif isinstance(module, nn.Conv1d): shared_module = nn.Conv1d( - self.rank, self.rank, kernel_size=1, bias=False + rank_forward, rank_backward, kernel_size=1, bias=False ) elif isinstance(module, nn.Conv2d): shared_module = nn.Conv2d( - self.rank, self.rank, kernel_size=1, bias=False + rank_forward, rank_backward, kernel_size=1, bias=False ) shared_modules[psg] = shared_module - lora_module = lora_cls(self.rank, module, shared_modules.get(psg, None)) + lora_module = lora_cls( + rank_forward, rank_backward, module, shared_modules.get(psg, None) + ) if self.init_strategy == "pca": lora_module.pca_init_weight(covariance_state[name]) lora_module.to(device) parent, target, target_name = _get_submodules(model, name) setattr(parent, target_name, lora_module) + + def _sanity_check(self): + if ( + self.init_strategy == "pca" + and self.compression_ratio_by_covariance is not None + and self.compression_ratio_by_memory is not None + ): + get_logger().warning( + "compression_ratio_by_covariance and compression_ratio_by_memory are both set. " + + "compression_ratio_by_covariance will be used." + ) diff --git a/analog/lora/modules.py b/analog/lora/modules.py index 4da4f61e..bf2153a7 100644 --- a/analog/lora/modules.py +++ b/analog/lora/modules.py @@ -8,7 +8,13 @@ class LoraLinear(nn.Linear): - def __init__(self, rank: int, linear: nn.Linear, shared_module: nn.Linear = None): + def __init__( + self, + rank_forward: int, + rank_backward: int, + linear: nn.Linear, + shared_module: nn.Linear = None, + ): """Transforms a linear layer into a LoraLinear layer. Args: @@ -19,13 +25,14 @@ def __init__(self, rank: int, linear: nn.Linear, shared_module: nn.Linear = None out_features = linear.out_features super().__init__(in_features, out_features) - self.rank = min(rank, in_features, out_features) + self.rank_forward = min(rank_forward, in_features) + self.rank_backward = min(rank_backward, out_features) - self.analog_lora_A = nn.Linear(in_features, self.rank, bias=False) + self.analog_lora_A = nn.Linear(in_features, self.rank_forward, bias=False) self.analog_lora_B = shared_module or nn.Linear( - self.rank, self.rank, bias=False + self.rank_forward, self.rank_backward, bias=False ) - self.analog_lora_C = nn.Linear(self.rank, out_features, bias=False) + self.analog_lora_C = nn.Linear(self.rank_backward, out_features, bias=False) nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.analog_lora_B.weight) @@ -49,17 +56,23 @@ def pca_init_weight(self, covariance=None): ( top_r_singular_vector_forward, top_r_singular_value_forward, - ) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank) + ) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank_forward) ( top_r_singular_vector_backward, top_r_singular_value_backward, - ) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank) + ) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank_backward) self.analog_lora_A.weight.data.copy_(top_r_singular_vector_forward.T) self.analog_lora_C.weight.data.copy_(top_r_singular_vector_backward) class LoraConv2d(nn.Conv2d): - def __init__(self, rank: int, conv: nn.Conv2d, shared_module: nn.Conv2d = None): + def __init__( + self, + rank_forward: int, + rank_backward: int, + conv: nn.Conv2d, + shared_module: nn.Conv2d = None, + ): """Transforms a conv2d layer into a LoraConv2d layer. Args: @@ -76,15 +89,23 @@ def __init__(self, rank: int, conv: nn.Conv2d, shared_module: nn.Conv2d = None): in_channels, out_channels, kernel_size, stride, padding, bias=False ) - self.rank = min(rank, self.in_channels, self.out_channels) + self.rank_forward = min(rank_forward, in_channels) + self.rank_backward = min(rank_backward, out_channels) self.analog_lora_A = nn.Conv2d( - self.in_channels, self.rank, kernel_size, stride, padding, bias=False + self.in_channels, + self.rank_forward, + kernel_size, + stride, + padding, + bias=False, ) self.analog_lora_B = shared_module or nn.Conv2d( - self.rank, self.rank, 1, bias=False + self.rank_forward, self.rank_backward, 1, bias=False + ) + self.analog_lora_C = nn.Conv2d( + self.rank_backward, self.out_channels, 1, bias=False ) - self.analog_lora_C = nn.Conv2d(self.rank, self.out_channels, 1, bias=False) nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.analog_lora_B.weight) @@ -108,11 +129,11 @@ def pca_init_weight(self, covariance): ( top_r_singular_vector_forward, top_r_singular_value_forward, - ) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank) + ) = compute_top_k_singular_vectors(covariance[FORWARD], self.rank_forward) ( top_r_singular_vector_backward, top_r_singular_value_backward, - ) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank) + ) = compute_top_k_singular_vectors(covariance[BACKWARD], self.rank_backward) shape_A = self.analog_lora_A.weight.shape shape_C = self.analog_lora_C.weight.shape self.analog_lora_A.weight.data.copy_( @@ -137,13 +158,14 @@ def __init__( embedding_dim = embedding.embedding_dim super().__init__(num_embeddings, embedding_dim) - self.rank = min(rank, num_embeddings, embedding_dim) + self.rank_forward = min(rank, num_embeddings) + self.rank_backward = min(rank, embedding_dim) - self.analog_lora_A = nn.Embedding(num_embeddings, self.rank) + self.analog_lora_A = nn.Embedding(num_embeddings, self.rank_forward) self.analog_lora_B = shared_module or nn.Linear( - self.rank, self.rank, bias=False + self.rank_forward, self.rank_backward, bias=False ) - self.analog_lora_C = nn.Linear(self.rank, embedding_dim, bias=False) + self.analog_lora_C = nn.Linear(self.rank_backward, embedding_dim, bias=False) nn.init.kaiming_uniform_(self.analog_lora_A.weight, a=math.sqrt(5)) nn.init.zeros_(self.analog_lora_B.weight) diff --git a/analog/lora/utils.py b/analog/lora/utils.py index 676f3a0d..7d51135c 100644 --- a/analog/lora/utils.py +++ b/analog/lora/utils.py @@ -1,6 +1,62 @@ from typing import List +import math import torch +import torch.nn as nn + + +def find_rank_pca_covariance(matrix, threshold): + """ + Calculate the minimum principal component analysis (PCA) rank required + to explain at least the specified percentage (threshold) of the total covariance. + """ + U, S, Vh = torch.linalg.svd(matrix) + rank = 0 + cur, total = 0, sum(S) + while rank < len(S) and (cur / total) < threshold: + cur += S[rank] + rank += 1 + + return rank + + +def find_rank_pca_compression(module, ratio): + """ + Calculate the minimum principal component analysis (PCA) rank required + to reach threshold compression ratio. + """ + weight = module.weight.detach().cpu().numpy() + if isinstance(module, nn.Linear): + # r * r = m * n * ratio + in_features, out_features = weight.shape + rank = math.ceil(math.sqrt(in_features * out_features * ratio)) + elif isinstance(module, nn.Conv2d): + # r * r * 1 * 1 = in_channels * out_channels * kernel_size[0] * kernel_size[1] * ratio + in_channels, out_channels, kernel_size0, kernel_size1 = weight.shape + rank = math.ceil( + math.sqrt(in_channels * out_channels * kernel_size0 * kernel_size1 * ratio) + ) + return rank + elif isinstance(module, nn.Embedding): + # r * r = m * n * ratio + num_embeddings, embedding_dim = weight.shape + rank = math.ceil(math.sqrt(num_embeddings * embedding_dim * ratio)) + else: + raise NotImplementedError + + return rank + + +def pca_rank_by_weight_shape(shape, module): + if isinstance(module, nn.Linear): + assert len(shape) == 2 + return shape[1], shape[0] + elif isinstance(module, nn.Conv2d): + assert len(shape) == 4 + return shape[1], shape[0] + elif isinstance(module, nn.Embedding): + assert len(shape) == 2 + return shape[1], shape[0] def is_lora(model): diff --git a/examples/cifar_influence/compute_influence.py b/examples/cifar_influence/compute_influence.py index 1126e19c..9349b051 100644 --- a/examples/cifar_influence/compute_influence.py +++ b/examples/cifar_influence/compute_influence.py @@ -41,9 +41,8 @@ # Gradient & Hessian logging analog.watch(model) analog.setup({"log": "grad", "save": "grad", "statistic": "kfac"}) - +id_gen = DataIDGenerator() if not args.resume: - id_gen = DataIDGenerator() for inputs, targets in train_loader: data_id = id_gen(inputs) with analog(data_id=data_id): @@ -62,7 +61,10 @@ analog.add_analysis({"influence": InfluenceFunction}) query_iter = iter(query_loader) -with analog(log=["grad"]) as al: +test_input, test_target = next(query_iter) +test_id = id_gen(test_input) +analog.eval() +with analog(data_id=test_id) as al: test_input, test_target = next(query_iter) test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) model.zero_grad() diff --git a/examples/cifar_influence/compute_influences_pca.py b/examples/cifar_influence/compute_influences_pca.py index 924b03fc..faafee7d 100644 --- a/examples/cifar_influence/compute_influences_pca.py +++ b/examples/cifar_influence/compute_influences_pca.py @@ -16,6 +16,9 @@ parser.add_argument("--data", type=str, default="cifar10", help="cifar10/100") parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0]) parser.add_argument("--damping", type=float, default=1e-5) +parser.add_argument("--ekfac", action="store_true") +parser.add_argument("--lora", action="store_true") +parser.add_argument("--sample", action="store_true") parser.add_argument("--resume", action="store_true") args = parser.parse_args() @@ -39,13 +42,12 @@ ) analog = AnaLog(project="test", config="./config.yaml") -analog_scheduler = AnaLogScheduler(analog, lora=True) +analog_scheduler = AnaLogScheduler(analog, lora=args.lora) # Gradient & Hessian logging analog.watch(model) - -if True: - id_gen = DataIDGenerator() +id_gen = DataIDGenerator() +if not args.resume: for epoch in analog_scheduler: for inputs, targets in train_loader: data_id = id_gen(inputs) @@ -57,17 +59,17 @@ loss.backward() analog.finalize() else: - analog.add_lora() analog.initialize_from_log() # Influence Analysis log_loader = analog.build_log_dataloader() -analog.eval() analog.add_analysis({"influence": InfluenceFunction}) query_iter = iter(query_loader) -with analog(log=["grad"]) as al: - test_input, test_target = next(query_iter) +test_input, test_target = next(query_iter) +test_id = id_gen(test_input) +analog.eval() +with analog(data_id=test_id) as al: test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) model.zero_grad() test_out = model(test_input) @@ -82,6 +84,6 @@ ) # Save -if_scores = if_scores.numpy().tolist() -torch.save(if_scores, "if_analog_pca.pt") +if_scores = if_scores.numpy().tolist()[0] +torch.save(if_scores, f"if_analog_pca.pt") print("Computation time:", time.time() - start) diff --git a/examples/cifar_influence/config.yaml b/examples/cifar_influence/config.yaml index 04807799..c1b0384a 100644 --- a/examples/cifar_influence/config.yaml +++ b/examples/cifar_influence/config.yaml @@ -2,3 +2,4 @@ storage: type: default lora: init: pca + compression_ratio_by_memory: 0.1 diff --git a/examples/mnist_influence/compute_influences_scheduler.py b/examples/mnist_influence/compute_influences_scheduler.py index 06056505..2ac9357b 100644 --- a/examples/mnist_influence/compute_influences_scheduler.py +++ b/examples/mnist_influence/compute_influences_scheduler.py @@ -38,7 +38,7 @@ batch_size=1, split="valid", shuffle=False, indices=args.eval_idxs ) -analog = AnaLog(project="test", config="./config.yaml") +analog = AnaLog(project="test", config="config.yaml") al_scheduler = AnaLogScheduler( analog, ekfac=args.ekfac, lora=args.lora, sample=args.sample ) @@ -62,8 +62,6 @@ loss.backward() analog.finalize() else: - if args.lora: - analog.add_lora() analog.initialize_from_log() # Influence Analysis @@ -91,6 +89,6 @@ # Save if_scores = if_scores.numpy().tolist()[0] -torch.save(if_scores, "if_analog_scheduler.pt") +torch.save(if_scores, f"if_analog_scheduler.pt") print("Computation time:", time.time() - start) print("Top influential data indices:", top_influential_data.numpy().tolist()) diff --git a/examples/mnist_influence/config.yaml b/examples/mnist_influence/config.yaml index 27f2f85f..eb54fde7 100644 --- a/examples/mnist_influence/config.yaml +++ b/examples/mnist_influence/config.yaml @@ -1,2 +1,3 @@ lora: init: pca + compression_ratio_by_memory: 0.1 diff --git a/tests/examples/configs/lora.yaml b/tests/examples/configs/lora.yaml new file mode 100644 index 00000000..236d409e --- /dev/null +++ b/tests/examples/configs/lora.yaml @@ -0,0 +1,3 @@ +lora: + init: pca + rank: 2 diff --git a/tests/examples/data/if_analog_lora.pt b/tests/examples/data/if_analog_lora.pt new file mode 100644 index 00000000..484e5563 Binary files /dev/null and b/tests/examples/data/if_analog_lora.pt differ diff --git a/tests/examples/test_add_lora.py b/tests/examples/test_add_lora.py new file mode 100644 index 00000000..45e98760 --- /dev/null +++ b/tests/examples/test_add_lora.py @@ -0,0 +1,91 @@ +import unittest +import torch +import torchvision +import torch.nn as nn +import numpy as np +import os + +from analog import AnaLog, AnaLogScheduler +from analog.utils import DataIDGenerator +from analog.analysis import InfluenceFunction +from tests.examples.utils import get_mnist_dataloader, construct_mlp + +DEVICE = torch.device("cpu") + + +class TestAddLora(unittest.TestCase): + def test_add_lora(self): + eval_idxs = (1,) + train_idxs = [i for i in range(0, 50000, 1000)] + + model = construct_mlp().to(DEVICE) + # Get a single checkpoint (first model_id and last epoch). + model.load_state_dict( + torch.load( + f"{os.path.dirname(os.path.abspath(__file__))}/checkpoints/mnist_0_epoch_9.pt", + map_location="cpu", + ) + ) + model.eval() + + dataloader_fn = get_mnist_dataloader + train_loader = dataloader_fn( + batch_size=512, split="train", shuffle=False, indices=train_idxs + ) + query_loader = dataloader_fn( + batch_size=1, split="valid", shuffle=False, indices=eval_idxs + ) + + analog = AnaLog( + project="test", + config=f"{os.path.dirname(os.path.abspath(__file__))}/configs/lora.yaml", + ) + # Gradient & Hessian logging + al_scheduler = AnaLogScheduler(analog, ekfac=False, lora=True, sample=False) + analog.watch(model) + id_gen = DataIDGenerator() + for epoch in al_scheduler: + for inputs, targets in train_loader: + data_id = id_gen(inputs) + with analog(data_id=data_id): + inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) + model.zero_grad() + outs = model(inputs) + loss = torch.nn.functional.cross_entropy( + outs, targets, reduction="sum" + ) + loss.backward() + analog.finalize() + + log_loader = analog.build_log_dataloader() + + analog.add_analysis({"influence": InfluenceFunction}) + query_iter = iter(query_loader) + test_input, test_target = next(query_iter) + test_id = id_gen(test_input) + analog.eval() + with analog(data_id=test_id): + test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE) + model.zero_grad() + test_out = model(test_input) + test_loss = torch.nn.functional.cross_entropy( + test_out, test_target, reduction="sum" + ) + test_loss.backward() + test_log = analog.get_log() + if_scores = analog.influence.compute_influence_all( + test_log, log_loader, damping=1e-5 + ) + + # Save + if_scores = if_scores[0] + print(if_scores) + # torch.save(if_scores, f"{os.path.dirname(os.path.abspath(__file__))}/data/if_analog_lora.pt") + if_score_saved = torch.load( + f"{os.path.dirname(os.path.abspath(__file__))}/data/if_analog_lora.pt" + ) + self.assertTrue(torch.allclose(if_score_saved, if_scores)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/examples/test_compute_influences.py b/tests/examples/test_compute_influences.py index f5831bb4..f44c597c 100644 --- a/tests/examples/test_compute_influences.py +++ b/tests/examples/test_compute_influences.py @@ -5,54 +5,9 @@ import numpy as np import os -DEVICE = torch.device("cpu") - - -def construct_mlp(num_inputs=784, num_classes=10): - return torch.nn.Sequential( - nn.Flatten(), - nn.Linear(num_inputs, 4, bias=False), - nn.ReLU(), - nn.Linear(4, 2, bias=False), - nn.ReLU(), - nn.Linear(2, num_classes, bias=False), - ) - +from tests.examples.utils import get_mnist_dataloader, construct_mlp -def get_mnist_dataloader( - batch_size=128, - split="train", - shuffle=False, - subsample=False, - indices=None, - drop_last=False, -): - transforms = torchvision.transforms.Compose( - [ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize((0.1307,), (0.3081,)), - ] - ) - is_train = split == "train" - dataset = torchvision.datasets.MNIST( - root="/tmp/mnist/", download=True, train=is_train, transform=transforms - ) - - if subsample and split == "train" and indices is None: - dataset = torch.utils.data.Subset(dataset, np.arange(6_000)) - - if indices is not None: - if subsample and split == "train": - print("Overriding `subsample` argument as `indices` was provided.") - dataset = torch.utils.data.Subset(dataset, indices) - - return torch.utils.data.DataLoader( - dataset=dataset, - shuffle=shuffle, - batch_size=batch_size, - num_workers=0, - drop_last=drop_last, - ) +DEVICE = torch.device("cpu") class TestSingleCheckpointInfluence(unittest.TestCase): diff --git a/tests/examples/utils.py b/tests/examples/utils.py new file mode 100644 index 00000000..b638d9fd --- /dev/null +++ b/tests/examples/utils.py @@ -0,0 +1,51 @@ +import torch +import torchvision +import torch.nn as nn +import numpy as np + + +def construct_mlp(num_inputs=784, num_classes=10): + return torch.nn.Sequential( + nn.Flatten(), + nn.Linear(num_inputs, 4, bias=False), + nn.ReLU(), + nn.Linear(4, 2, bias=False), + nn.ReLU(), + nn.Linear(2, num_classes, bias=False), + ) + + +def get_mnist_dataloader( + batch_size=128, + split="train", + shuffle=False, + subsample=False, + indices=None, + drop_last=False, +): + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.1307,), (0.3081,)), + ] + ) + is_train = split == "train" + dataset = torchvision.datasets.MNIST( + root="/tmp/mnist/", download=True, train=is_train, transform=transforms + ) + + if subsample and split == "train" and indices is None: + dataset = torch.utils.data.Subset(dataset, np.arange(6_000)) + + if indices is not None: + if subsample and split == "train": + print("Overriding `subsample` argument as `indices` was provided.") + dataset = torch.utils.data.Subset(dataset, indices) + + return torch.utils.data.DataLoader( + dataset=dataset, + shuffle=shuffle, + batch_size=batch_size, + num_workers=0, + drop_last=drop_last, + )