From 2ebb75d358ee8add72e3503b24fa4e5f887f8ae9 Mon Sep 17 00:00:00 2001 From: Klaus Ma Date: Wed, 17 May 2023 23:01:56 +0800 Subject: [PATCH] Add dist training. Signed-off-by: Klaus Ma --- examples/pytorch/distributed_training.py | 46 +++++++++++++++++++ examples/pytorch/inference.py | 10 ++++ .../pytorch/{main.py => local_training.py} | 0 3 files changed, 56 insertions(+) create mode 100644 examples/pytorch/distributed_training.py create mode 100644 examples/pytorch/inference.py rename examples/pytorch/{main.py => local_training.py} (100%) diff --git a/examples/pytorch/distributed_training.py b/examples/pytorch/distributed_training.py new file mode 100644 index 0000000..1d365f0 --- /dev/null +++ b/examples/pytorch/distributed_training.py @@ -0,0 +1,46 @@ +# Copyright 2023 The xflops Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" Dataset partitioning helper """ +class Partition(object): + + def __init__(self, data, index): + self.data = data + self.index = index + + def __len__(self): + return len(self.index) + + def __getitem__(self, index): + data_idx = self.index[index] + return self.data[data_idx] + + +class DataPartitioner(object): + + def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234): + self.data = data + self.partitions = [] + rng = Random() + rng.seed(seed) + data_len = len(data) + indexes = [x for x in range(0, data_len)] + rng.shuffle(indexes) + + for frac in sizes: + part_len = int(frac * data_len) + self.partitions.append(indexes[0:part_len]) + indexes = indexes[part_len:] + + def use(self, partition): + return Partition(self.data, self.partitions[partition]) + diff --git a/examples/pytorch/inference.py b/examples/pytorch/inference.py new file mode 100644 index 0000000..a6b10e7 --- /dev/null +++ b/examples/pytorch/inference.py @@ -0,0 +1,10 @@ +# Copyright 2023 The xflops Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/pytorch/main.py b/examples/pytorch/local_training.py similarity index 100% rename from examples/pytorch/main.py rename to examples/pytorch/local_training.py