Skip to content

Commit

Permalink
Merge pull request #1076 from NLGithubWP/single_opt_ms
Browse files Browse the repository at this point in the history
Add implementation for a single optimization step in model selection
  • Loading branch information
lzjpaul authored Sep 1, 2023
2 parents b56eaab + 53c664f commit 1e55cfb
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions examples/model_selection_psql/ms_mlp/train_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,58 @@ def __init__(self,
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening")

def apply(self, param_name, param_value, param_grad):
"""Performs a single optimization step.
Args:
param_name(String): the name of the param
param_value(Tensor): param values to be update in-place
grad(Tensor): param gradients; the values may be updated
in this function; cannot use it anymore
"""
assert param_value.shape == param_grad.shape, ("shape mismatch",
param_value.shape,
param_grad.shape)
self.device_check(param_value, self.step_counter, self.lr_value,
self.mom_value, self.dam_value, self.decay_value)

# derive dtype from input
assert param_value.dtype == self.dtype

# TODO add branch operator
# if self.decay_value != 0:
if self.weight_decay.init_value != 0:
singa.Axpy(self.decay_value.data, param_value.data, param_grad.data)

if self.momentum.init_value != 0:
if param_name not in self.moments:
flag = param_value.device.graph_enabled()
param_value.device.EnableGraph(False)
self.moments[param_name] = tensor.zeros_like(param_value)
param_value.device.EnableGraph(flag)

buf = self.moments[param_name]
buf *= self.mom_value
alpha = 1.0 - self.dam_value
singa.Axpy(alpha.data, param_grad.data, buf.data)

if self.nesterov:
singa.Axpy(self.mom_value.data, buf.data, param_grad.data)
else:
param_grad = buf

minus_lr = 0.0 - self.lr_value
singa.Axpy(minus_lr.data, param_grad.data, param_value.data)


# Data augmentation
def augmentation(x, batch_size):
xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
for data_num in range(0, batch_size):
offset = np.random.randint(8, size=2)
x[data_num, :, :, :] = xpad[data_num, :,
offset[0]:offset[0] + x.shape[2],
offset[1]:offset[1] + x.shape[2]]
offset[0]:offset[0] + x.shape[2],
offset[1]:offset[1] + x.shape[2]]
if_flip = np.random.randint(2)
if (if_flip):
x[data_num, :, :, :] = x[data_num, :, :, ::-1]
Expand Down Expand Up @@ -189,7 +232,7 @@ def resize_dataset(x, image_size):
for d in range(0, dim):
X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize(
(image_size, image_size), Image.BILINEAR),
dtype=np.float32)
dtype=np.float32)
return X


Expand Down Expand Up @@ -251,8 +294,8 @@ def run(global_rank,
sys.path.insert(0, parent)
from mlp import model
model = model.create_model(data_size=data_size,
num_classes=num_classes)
num_classes=num_classes)

elif model == 'msmlp':
import os, sys, inspect
current = os.path.dirname(
Expand All @@ -261,7 +304,7 @@ def run(global_rank,
sys.path.insert(0, parent)
from msmlp import model
model = model.create_model(data_size=data_size,
num_classes=num_classes)
num_classes=num_classes)

# For distributed training, sequential has better performance
if hasattr(mssgd, "communicator"):
Expand Down

0 comments on commit 1e55cfb

Please sign in to comment.