From 53c664f671ed86923151bf6dbec9a66485fe7b7d Mon Sep 17 00:00:00 2001 From: working <57171759+NLGithubWP@users.noreply.github.com> Date: Thu, 31 Aug 2023 21:05:43 +0800 Subject: [PATCH] Add implementation for a single optimization step in model selection --- .../model_selection_psql/ms_mlp/train_mlp.py | 55 +++++++++++++++++-- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/examples/model_selection_psql/ms_mlp/train_mlp.py b/examples/model_selection_psql/ms_mlp/train_mlp.py index 88d6cf549..aa7bec7d1 100644 --- a/examples/model_selection_psql/ms_mlp/train_mlp.py +++ b/examples/model_selection_psql/ms_mlp/train_mlp.py @@ -129,6 +129,49 @@ def __init__(self, raise ValueError( "Nesterov momentum requires a momentum and zero dampening") + def apply(self, param_name, param_value, param_grad): + """Performs a single optimization step. + + Args: + param_name(String): the name of the param + param_value(Tensor): param values to be update in-place + grad(Tensor): param gradients; the values may be updated + in this function; cannot use it anymore + """ + assert param_value.shape == param_grad.shape, ("shape mismatch", + param_value.shape, + param_grad.shape) + self.device_check(param_value, self.step_counter, self.lr_value, + self.mom_value, self.dam_value, self.decay_value) + + # derive dtype from input + assert param_value.dtype == self.dtype + + # TODO add branch operator + # if self.decay_value != 0: + if self.weight_decay.init_value != 0: + singa.Axpy(self.decay_value.data, param_value.data, param_grad.data) + + if self.momentum.init_value != 0: + if param_name not in self.moments: + flag = param_value.device.graph_enabled() + param_value.device.EnableGraph(False) + self.moments[param_name] = tensor.zeros_like(param_value) + param_value.device.EnableGraph(flag) + + buf = self.moments[param_name] + buf *= self.mom_value + alpha = 1.0 - self.dam_value + singa.Axpy(alpha.data, param_grad.data, buf.data) + + if self.nesterov: + singa.Axpy(self.mom_value.data, buf.data, param_grad.data) + else: + param_grad = buf + + minus_lr = 0.0 - self.lr_value + singa.Axpy(minus_lr.data, param_grad.data, param_value.data) + # Data augmentation def augmentation(x, batch_size): @@ -136,8 +179,8 @@ def augmentation(x, batch_size): for data_num in range(0, batch_size): offset = np.random.randint(8, size=2) x[data_num, :, :, :] = xpad[data_num, :, - offset[0]:offset[0] + x.shape[2], - offset[1]:offset[1] + x.shape[2]] + offset[0]:offset[0] + x.shape[2], + offset[1]:offset[1] + x.shape[2]] if_flip = np.random.randint(2) if (if_flip): x[data_num, :, :, :] = x[data_num, :, :, ::-1] @@ -189,7 +232,7 @@ def resize_dataset(x, image_size): for d in range(0, dim): X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize( (image_size, image_size), Image.BILINEAR), - dtype=np.float32) + dtype=np.float32) return X @@ -251,8 +294,8 @@ def run(global_rank, sys.path.insert(0, parent) from mlp import model model = model.create_model(data_size=data_size, - num_classes=num_classes) - + num_classes=num_classes) + elif model == 'msmlp': import os, sys, inspect current = os.path.dirname( @@ -261,7 +304,7 @@ def run(global_rank, sys.path.insert(0, parent) from msmlp import model model = model.create_model(data_size=data_size, - num_classes=num_classes) + num_classes=num_classes) # For distributed training, sequential has better performance if hasattr(mssgd, "communicator"):