Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mxnet: fix for mxnet 2.0 #283

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,14 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0):
"as its optimizer. We have unwrapped it for you.")

param_list = []
if isinstance(params, mx.gluon.ParameterDict):

try:
from mxnet.gluon.parameter import ParameterDict
valid_types = (dict, ParameterDict)
except ImportError:
valid_types = (dict,)

if isinstance(params, valid_types):
for key in sorted(list(params.keys())):
param_list.append(params[key])

Expand All @@ -200,17 +207,31 @@ def __init__(self, params, optimizer, optimizer_params=None, root_rank=0):
# average in push_pull, has better performance.
self._scale /= size()
self.root_rank = root_rank
for i, param in enumerate(self._params):
byteps_declare_tensor("parameter_" + str(i))
for param in self._params:
# we need use the indexes in self._param2idx as the key,
# to ensure the correctness when the model uses share_parameters
if hasattr(param, '_uuid'):
param_id = param._uuid
else:
param_id = param.name
idx = self._param2idx[param_id]
byteps_declare_tensor("parameter_" + str(idx))
if param.grad_req != 'null':
byteps_declare_tensor("gradient_" + str(i))
byteps_declare_tensor("gradient_" + str(idx))


def _allreduce_grads(self):
for i, param in enumerate(self._params):
if param.grad_req != 'null':
# In MXNet 2.0, param.name is no longer unique
# and thus cannot be used as the key for the parameter.
if hasattr(param, '_uuid'):
param_id = param._uuid
else:
param_id = param.name
idx = self._param2idx[param_id]
byteps_push_pull(param.list_grad()[0], is_average=False,
name="gradient_" + str(i), priority=-i)
name="gradient_" + str(idx), priority=-i)

def _init_params(self):
tensors = []
Expand All @@ -219,7 +240,14 @@ def _init_params(self):
tensors.append(param)
else:
param_arrays = param._check_and_get(param._data, list)
idx = self._param2idx[param.name]
# In MXNet 2.0, param.name is no longer unique
# and thus cannot be used as the key for the parameter.
if hasattr(param, '_uuid'):
param_id = param._uuid
else:
param_id = param.name

idx = self._param2idx[param_id]

if rank() != self.root_rank:
param_arrays[0].__imul__(0)
Expand Down
19 changes: 11 additions & 8 deletions tests/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,22 @@ def test_byteps_trainer_param_order(self):
dims = [1]
ctx = self._current_context()
net = mx.gluon.nn.Sequential()
# layers may be added in a random order for all workers
layers = {'ones_': 1, 'zeros_': 0}
for name, init in layers.items():
net.add(mx.gluon.nn.Dense(10, in_units=10, weight_initializer=mx.init.Constant(init),
use_bias=False, prefix=name))
net.add(mx.gluon.nn.Dense(5, in_units=5, weight_initializer=mx.init.Constant(0),
use_bias=False))
net.add(mx.gluon.nn.Dense(5, in_units=10, weight_initializer=mx.init.Constant(1),
use_bias=False))
params = net.collect_params()
net.initialize()
trainer = bps.DistributedTrainer(params, 'sgd')
trainer._init_params()
# check the result of bps_broadcast
for name, init in layers.items():
weight = params[name + 'weight'].data()[0].asnumpy()
expected = np.full(shape=weight.shape, fill_value=init, dtype=weight.dtype)
for p in params.values():
weight = p.data().asnumpy()
if weight.shape[1] == 10:
init_val = 1
else:
init_val = 0
expected = np.full(shape=weight.shape, fill_value=init_val, dtype=weight.dtype)
assert np.array_equal(weight, expected), (weight, expected)

print('test_byteps_trainer_param_order passed')
Expand Down