Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient compression updates #395

Open
wants to merge 308 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
308 commits
Select commit Hold shift + click to select a range
62966d7
missing semicolon...
jasperzhong Aug 30, 2020
e6a521a
fix typos
jasperzhong Aug 30, 2020
9260d32
missing semicolon...
jasperzhong Aug 30, 2020
bb93d64
add default
jasperzhong Aug 30, 2020
6346c5c
fix bugs
jasperzhong Aug 31, 2020
6491eef
fix missing s
jasperzhong Aug 31, 2020
0bfd387
use nag for dithering
jasperzhong Aug 31, 2020
8e0c85b
update train script
jasperzhong Aug 31, 2020
152ce49
update ef for dithering
jasperzhong Sep 1, 2020
6a2c22a
fix typo
jasperzhong Sep 1, 2020
7fc3d54
missing scope
jasperzhong Sep 1, 2020
91d6c2e
fix compile error
jasperzhong Sep 1, 2020
4be835d
debug
jasperzhong Sep 1, 2020
14139a8
debug
jasperzhong Sep 1, 2020
69b4117
update
jasperzhong Sep 1, 2020
a02607d
add omp for randomk
jasperzhong Sep 5, 2020
5251789
update
jasperzhong Sep 5, 2020
f5f5d58
fix typo
jasperzhong Sep 5, 2020
1d3d8e3
missing copy
jasperzhong Sep 5, 2020
60b31be
update
jasperzhong Sep 5, 2020
caadd7f
update
jasperzhong Sep 5, 2020
c3d8341
randomk with replacement
jasperzhong Sep 8, 2020
7c65805
fix missing header
jasperzhong Sep 8, 2020
5001140
fix typo
jasperzhong Sep 8, 2020
94b7ba3
use unordered_map
jasperzhong Sep 8, 2020
c3e8e26
update
jasperzhong Sep 16, 2020
2a39ff4
fix small bug
jasperzhong Sep 19, 2020
14a9d10
worker decompress use buf as input
jasperzhong Sep 20, 2020
9a34a97
use __restrict__
jasperzhong Sep 20, 2020
affde37
fix missing =default
jasperzhong Sep 20, 2020
89571c4
update tests
jasperzhong Sep 20, 2020
c9e2b7a
update randomk and tests
jasperzhong Sep 20, 2020
e08c20d
debug
jasperzhong Sep 20, 2020
f58ddfb
fix typo
jasperzhong Sep 20, 2020
26e6793
update
jasperzhong Sep 20, 2020
868cfc0
fix typo
jasperzhong Sep 20, 2020
578025d
debug
jasperzhong Sep 20, 2020
3f0a1ee
fix typo
jasperzhong Sep 20, 2020
f0284e8
update
jasperzhong Sep 20, 2020
491ea61
debug
jasperzhong Sep 20, 2020
eacda51
some optimizations in dithering
jasperzhong Sep 20, 2020
780353e
debug
jasperzhong Sep 20, 2020
0c54603
fix typo
jasperzhong Sep 20, 2020
dbdc67f
debug
jasperzhong Sep 20, 2020
8a37cb4
debug
jasperzhong Sep 20, 2020
cdba877
debug
jasperzhong Sep 20, 2020
3e84913
debug
jasperzhong Sep 20, 2020
3efcd5c
debug
jasperzhong Sep 20, 2020
7b1fdf2
debug
jasperzhong Sep 20, 2020
442a3ba
debug
jasperzhong Sep 20, 2020
f4706d9
debug
jasperzhong Sep 20, 2020
4728bac
fix bug
jasperzhong Sep 20, 2020
6be261b
debug
jasperzhong Sep 20, 2020
cc543ef
mix precision push pull
jasperzhong Sep 21, 2020
dd176a2
fix typo
jasperzhong Sep 21, 2020
ee23f59
fix missing
jasperzhong Sep 21, 2020
13136f3
fix log bug
jasperzhong Sep 23, 2020
183632b
add check
jasperzhong Sep 23, 2020
4db526a
debug
jasperzhong Sep 23, 2020
96f1e8a
debug
jasperzhong Sep 23, 2020
800359c
fix
jasperzhong Sep 23, 2020
405e01f
debug
jasperzhong Sep 23, 2020
fd34833
fix out of range
jasperzhong Sep 23, 2020
3096a06
update test
jasperzhong Sep 23, 2020
b0768d0
fix typo
jasperzhong Sep 23, 2020
48f5993
update test
jasperzhong Sep 23, 2020
b8a2a9c
fix bug
jasperzhong Sep 23, 2020
52f3e0f
fix bug
jasperzhong Sep 23, 2020
990ec9e
fix example bug
jasperzhong Sep 23, 2020
57d45e4
fix typo
jasperzhong Sep 23, 2020
3790e9c
mix precision true
jasperzhong Sep 23, 2020
9f7017a
update tests
jasperzhong Sep 23, 2020
bdf58dd
update tests
jasperzhong Sep 23, 2020
88b93e2
debug
jasperzhong Sep 23, 2020
7de2f4d
update tests
jasperzhong Sep 23, 2020
f283367
update
jasperzhong Sep 23, 2020
b4c025d
update
jasperzhong Sep 23, 2020
9741458
patch
jasperzhong Sep 24, 2020
44188ef
update test
jasperzhong Sep 24, 2020
a82caad
fix typo
jasperzhong Sep 24, 2020
388818e
hack
jasperzhong Sep 24, 2020
6fc1603
fix
jasperzhong Sep 24, 2020
ede7d9f
fuck
jasperzhong Sep 24, 2020
349eb08
fix
jasperzhong Sep 24, 2020
79f0e9f
test
jasperzhong Sep 24, 2020
cc1a1b2
fix
jasperzhong Sep 24, 2020
b970d8d
fk
jasperzhong Sep 24, 2020
2f27960
set default threshold 0
jasperzhong Sep 30, 2020
23a5ef0
restore dist_launcher
jasperzhong Sep 30, 2020
9e152f3
Merge branch 'master' into dev
jasperzhong Oct 5, 2020
7e0ff01
fix test launch server problem
jasperzhong Oct 5, 2020
ea5b85a
Squashed commit of the following:
jasperzhong Oct 18, 2020
e77ba98
update
jasperzhong Oct 25, 2020
8dfbc63
update dithering
jasperzhong Oct 26, 2020
8e6b6b7
update api caller
jasperzhong Oct 26, 2020
d3b4284
update registery
jasperzhong Oct 26, 2020
e784b3b
fix typo
jasperzhong Oct 26, 2020
ea5162f
fix typo
jasperzhong Oct 26, 2020
5becdfd
update
jasperzhong Oct 26, 2020
24e8e9b
fix
jasperzhong Oct 26, 2020
cef9324
update
jasperzhong Oct 26, 2020
9ee5c5c
update
jasperzhong Oct 26, 2020
791a645
update
jasperzhong Oct 26, 2020
7a70a2a
fix bug
jasperzhong Oct 26, 2020
f916042
fix
jasperzhong Oct 26, 2020
43a0641
fix
jasperzhong Oct 26, 2020
9d50837
fix
jasperzhong Oct 26, 2020
2d01a64
fix
jasperzhong Oct 26, 2020
af6022e
fix test
jasperzhong Oct 26, 2020
0b97b6e
update example
jasperzhong Oct 26, 2020
40058af
fix bug
jasperzhong Oct 26, 2020
9a1c659
I don't know
jasperzhong Oct 26, 2020
3d7fc1c
debug
jasperzhong Oct 26, 2020
46a1ade
update test
jasperzhong Oct 26, 2020
e417c03
update
jasperzhong Oct 26, 2020
6acbf7c
update
jasperzhong Oct 26, 2020
d0f7a65
debug
jasperzhong Oct 26, 2020
3e60eea
debug
jasperzhong Oct 26, 2020
70596d8
debug
jasperzhong Oct 26, 2020
6468339
debug
jasperzhong Oct 26, 2020
0e07d0c
debug
jasperzhong Oct 26, 2020
47ef502
fix bug
jasperzhong Oct 26, 2020
849a6a8
fix
jasperzhong Oct 26, 2020
808b4c5
debug
jasperzhong Oct 26, 2020
6318f11
add log
jasperzhong Oct 26, 2020
9edce6a
add log
jasperzhong Oct 26, 2020
6e2f824
debug
jasperzhong Oct 26, 2020
231e52d
add log
jasperzhong Oct 26, 2020
813d75a
update
jasperzhong Oct 26, 2020
0685698
update
jasperzhong Oct 26, 2020
7ecf473
add log
jasperzhong Oct 26, 2020
36f30ba
update
jasperzhong Oct 26, 2020
c854147
test
jasperzhong Oct 27, 2020
e5d7946
debug
jasperzhong Oct 27, 2020
fd30683
missing
jasperzhong Oct 27, 2020
38034bf
udp
jasperzhong Oct 27, 2020
e98ecdc
debug
jasperzhong Nov 3, 2020
9aecbcd
debug
jasperzhong Nov 3, 2020
02190fb
debug
jasperzhong Nov 3, 2020
0efd31a
debug
jasperzhong Nov 3, 2020
5dd7d9f
TEST
jasperzhong Nov 3, 2020
7adc557
test
jasperzhong Nov 3, 2020
409063b
debug
jasperzhong Nov 3, 2020
3761868
test
jasperzhong Nov 3, 2020
a4d9818
test
jasperzhong Nov 3, 2020
71e3b7d
test
jasperzhong Nov 3, 2020
00a6191
test
jasperzhong Nov 3, 2020
30c9162
fix typo
jasperzhong Nov 3, 2020
db6fbc5
test
jasperzhong Nov 3, 2020
18969df
update
jasperzhong Nov 3, 2020
ced12aa
update
jasperzhong Nov 4, 2020
1334846
test
jasperzhong Nov 4, 2020
04e437b
tset
jasperzhong Nov 4, 2020
82f05be
debug
jasperzhong Nov 4, 2020
c390a8e
bale
jasperzhong Nov 4, 2020
ee11705
debug
jasperzhong Nov 4, 2020
b45a13f
remove debug
jasperzhong Nov 14, 2020
24654ba
update
jasperzhong Nov 16, 2020
d449396
update
jasperzhong Nov 16, 2020
7061f22
missing command
jasperzhong Nov 16, 2020
69febcb
Merge branch 'master' into refactor
jasperzhong Nov 25, 2020
32057ee
delete unnecessary files
jasperzhong Nov 25, 2020
96e2d22
Squashed commit of the following:
jasperzhong Dec 1, 2020
026ac00
fix
jasperzhong Dec 1, 2020
cef6621
fix typo
jasperzhong Dec 1, 2020
c2a6f28
update script
jasperzhong Dec 1, 2020
732728a
remove persistent worker
jasperzhong Dec 2, 2020
0331f44
support pre/post scale
jasperzhong Dec 2, 2020
ffe52df
add nan check
jasperzhong Dec 5, 2020
ac70156
add test
jasperzhong Dec 5, 2020
da2355c
fix typo
jasperzhong Dec 5, 2020
14d737c
update
jasperzhong Dec 5, 2020
e122e99
Update tests
jasperzhong Dec 6, 2020
daf602f
update
jasperzhong Dec 7, 2020
5309e83
Merge branch 'master' into refactor
jasperzhong Dec 7, 2020
0dcd024
update setup
jasperzhong Dec 7, 2020
ceb1807
update
jasperzhong Dec 10, 2020
b6f4380
update
jasperzhong Dec 10, 2020
5e3e5bf
test
jasperzhong Dec 10, 2020
6356331
add apex example
jasperzhong Dec 12, 2020
7ca3c31
update
jasperzhong Dec 12, 2020
83172b0
update
jasperzhong Dec 12, 2020
0df2272
update
jasperzhong Dec 12, 2020
cdadd79
update
jasperzhong Dec 12, 2020
805efba
trace
jasperzhong Dec 12, 2020
d3ebdf6
add test
jasperzhong Dec 15, 2020
58518a5
fix typo
jasperzhong Dec 15, 2020
112a67e
update test
jasperzhong Dec 15, 2020
baa90d0
test
jasperzhong Dec 16, 2020
54a0d8c
test
jasperzhong Dec 16, 2020
db06885
try not to compress embeddings
jasperzhong Dec 24, 2020
9dbe060
restore
jasperzhong Dec 25, 2020
fb33e05
update
jasperzhong Dec 25, 2020
b7e222a
debug
jasperzhong Dec 25, 2020
000e1c0
debug
jasperzhong Dec 25, 2020
00fa90e
update debug
jasperzhong Dec 25, 2020
7734d6b
use isfinite instead of isnan
jasperzhong Dec 25, 2020
6412793
fix typo
jasperzhong Dec 25, 2020
9c2cfc0
fix
jasperzhong Dec 25, 2020
1519a15
debug
jasperzhong Dec 25, 2020
ae2a648
print max
jasperzhong Dec 25, 2020
a62f0cb
update debug
jasperzhong Dec 25, 2020
5530b8e
fix fatal bug
jasperzhong Dec 25, 2020
2f6d320
remove debug print
jasperzhong Dec 25, 2020
6d42e7d
test
jasperzhong Dec 28, 2020
5c188d9
add lr
jasperzhong Dec 28, 2020
150eaae
update
jasperzhong Dec 28, 2020
3c63977
update
jasperzhong Dec 28, 2020
92a6157
debug
jasperzhong Dec 28, 2020
39dfc19
debug
jasperzhong Dec 28, 2020
4b0a15b
update
jasperzhong Dec 29, 2020
a137f43
test
jasperzhong Dec 30, 2020
a5f841a
debug
jasperzhong Dec 30, 2020
ecc27cf
debug
jasperzhong Dec 30, 2020
7c40999
debug debug
jasperzhong Dec 30, 2020
103fbd1
debug
jasperzhong Dec 31, 2020
0c7a921
fix fatal bug
jasperzhong Dec 31, 2020
9931914
update test
jasperzhong Dec 31, 2020
1f873b5
update
jasperzhong Jan 1, 2021
76c84bc
remove nan check
jasperzhong Jan 4, 2021
d8e2f2e
keep
jasperzhong Jan 4, 2021
37dc7c0
add
jasperzhong Jan 5, 2021
4240ebd
update
jasperzhong Jan 8, 2021
c0ac6aa
update
jasperzhong Jan 10, 2021
64775f2
disable rdma
jasperzhong Jan 11, 2021
9c8d8da
try to balance workload
jasperzhong Jan 14, 2021
a7d8e91
update
jasperzhong Jan 14, 2021
684e61a
update
jasperzhong Jan 14, 2021
d28bf62
update
jasperzhong Jan 14, 2021
a0bcb30
update
jasperzhong Jan 14, 2021
1096f84
restore
jasperzhong Jan 14, 2021
8104719
add MSHADOW_USE_F16C=1
jasperzhong Jan 17, 2021
352cf47
test
jasperzhong Jan 18, 2021
6add67d
test
jasperzhong Jan 18, 2021
30005b5
update
jasperzhong Jan 18, 2021
7a95a07
add check
jasperzhong Jan 26, 2021
914b5cf
remove
jasperzhong Jan 26, 2021
4c849a9
update
jasperzhong Jan 27, 2021
f131f14
update
jasperzhong Jan 27, 2021
bdb9d41
update
jasperzhong Jan 27, 2021
aa809a7
update
jasperzhong Jan 27, 2021
b741a9d
debug
jasperzhong Jan 27, 2021
a21d46f
test topk
jasperzhong Feb 3, 2021
20e27cc
remove unnecessary tests
jasperzhong May 3, 2021
9c95ba6
Merge branch 'master' into apex
jasperzhong May 3, 2021
178232b
remove
jasperzhong May 3, 2021
7fe8f43
update
jasperzhong May 10, 2021
65db55b
update docs
jasperzhong May 10, 2021
0821569
update docs
jasperzhong May 10, 2021
78eb37e
revert
jasperzhong May 22, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,10 @@ venv.bak/
# for development
scripts/
exps/
sshlog/
hosts
example/mxnet/nohup.out
example/mxnet/run.sh
server-hosts
worker-hosts
*.sh
24 changes: 16 additions & 8 deletions byteps/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@
#include <vector>

// Add for profiling communication events
#include <stdio.h>
#include <stdlib.h>

#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <queue>
Expand Down Expand Up @@ -267,7 +266,8 @@ using TensorTable = std::unordered_map<std::string, TensorTableEntry>;
enum class RequestType {
kDefaultPushPull,
kRowSparsePushPull,
kCompressedPushPull
kCompressedPushPull,
kConfigPushPull
};

int GetCommandType(RequestType requestType, int d);
Expand All @@ -278,10 +278,18 @@ ncclDataType_t getNcclDataType(DataType dtype);

int getDataTypeLength(int dtype);

inline size_t Align(size_t size, int dtype) {
const size_t min_size =
(getDataTypeLength(dtype) * getDataTypeLength(dtype)) * 8;
return size + (min_size - size % min_size) % min_size;
inline size_t Align(size_t size) {
constexpr size_t MIN_SIZE = 512;
return size + (MIN_SIZE - size % MIN_SIZE) % MIN_SIZE;
}

// promote low-precision type into float32
inline void Promote(size_t& size, int& dtype) {
size_t ele_size = getDataTypeLength(dtype);
if (ele_size < 4) {
size *= (4 / ele_size);
dtype = BYTEPS_FLOAT32;
}
}
} // namespace common
} // namespace byteps
Expand Down
32 changes: 32 additions & 0 deletions byteps/common/compressor/cast.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2020 Amazon Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#include "cast.h"

namespace byteps {
namespace common {
namespace compressor {

void Cast::Compress(tensor_t grad, tensor_t& output) {
_cptr->Compress(CastToFP32(grad), output);
}

void Cast::Decompress(tensor_t compressed, tensor_t& output) {
// directly forward to internal compressor
_cptr->Decompress(compressed, output);
}
} // namespace compressor
} // namespace common
} // namespace byteps
62 changes: 62 additions & 0 deletions byteps/common/compressor/cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2020 Amazon Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef BYTEPS_COMPRESSOR_CAST_H
#define BYTEPS_COMPRESSOR_CAST_H

#include "../cpu_reducer.h"
#include "compressor.h"

namespace byteps {
namespace common {
namespace compressor {
/*!
* \brief Wrapper of Compressor to deal with low-precision types
*
* During summation, data in low-precision suffers from overflow problem.
* To solve the issue, we instead use float32 as our data type for inter-
* mediate buffers. When intra-node all-reduce is over, locally aggregated
* gradients are first transformed into float32 via the wrapper.
*
* The wrapper has an internel fp32 buffer to store the transformed data.
*
* \sa Compressor
*/
class Cast : public Compressor {
public:
Cast(size_t size, DataType dtype, std::unique_ptr<Compressor> cptr)
: Compressor(size, dtype),
_fp32_buf(new byte_t[size]()),
_cptr(std::move(cptr)){};
~Cast() override = default;

void Compress(tensor_t grad, tensor_t& output) final;

void Decompress(tensor_t compressed, tensor_t& output) final;

protected:
virtual tensor_t CastToFP32(tensor_t grad) = 0;

protected:
std::unique_ptr<byte_t[]> _fp32_buf;

private:
std::unique_ptr<Compressor> _cptr;
};
} // namespace compressor
} // namespace common
} // namespace byteps

#endif // BYTEPS_COMPRESSOR_CAST_H
131 changes: 117 additions & 14 deletions byteps/common/compressor/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,32 @@
#if __F16C__
#include "../half.h"
using half_t = mshadow::half::half_t;
#else
using half_t = void;
#endif

namespace byteps {
namespace common {
namespace compressor {
typedef char byte_t;
/*!
* \brief Tensor type
*/
typedef struct BPSTensor {
byte_t* data;
size_t size;
int dtype;

BPSTensor() : data(nullptr), size(0), dtype(0) {}
using byte_t = char;

struct BPSTensor {
byte_t* data{nullptr};
size_t size{0};
int dtype{0};

BPSTensor() = default;
BPSTensor(void* data, size_t size = 0, int dtype = 0)
: data(reinterpret_cast<byte_t*>(data)), size(size), dtype(dtype) {}
} tensor_t;
};
using tensor_t = BPSTensor;

using kwargs_t = std::unordered_map<std::string, std::string>;

#define COMPRESS_IMPL_SWITCH(dtype, func, dst, src, size) \
switch (dtype) { \
case BYTEPS_FLOAT16: \
return func(reinterpret_cast<uint16_t*>(dst), \
return func(reinterpret_cast<uint32_t*>(dst), \
reinterpret_cast<const half_t*>(src), \
size / sizeof(half_t)); \
case BYTEPS_FLOAT32: \
Expand All @@ -58,11 +59,45 @@ using kwargs_t = std::unordered_map<std::string, std::string>;
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
}

#define COMPRESS_IMPL_SWITCH2(dtype, func, dst, src, size) \
switch (dtype) { \
case BYTEPS_FLOAT16: \
return func(reinterpret_cast<half_t*>(dst), \
reinterpret_cast<const half_t*>(src), \
size / sizeof(half_t)); \
case BYTEPS_FLOAT32: \
return func(reinterpret_cast<float*>(dst), \
reinterpret_cast<const float*>(src), size / sizeof(float)); \
case BYTEPS_FLOAT64: \
return func(reinterpret_cast<double*>(dst), \
reinterpret_cast<const double*>(src), \
size / sizeof(double)); \
default: \
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
}

#define COMPRESS_IMPL_SCALAR_SWITCH(dtype, func, dst, src, size) \
switch (dtype) { \
case BYTEPS_FLOAT16: \
return func(reinterpret_cast<index_t*>(dst), \
reinterpret_cast<const half_t*>(src), \
size / sizeof(half_t)); \
case BYTEPS_FLOAT32: \
return func(reinterpret_cast<index_t*>(dst), \
reinterpret_cast<const float*>(src), size / sizeof(float)); \
case BYTEPS_FLOAT64: \
return func(reinterpret_cast<index_t*>(dst), \
reinterpret_cast<const double*>(src), \
size / sizeof(double)); \
default: \
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
}

#define DECOMPRESS_IMPL_SWITCH(dtype, func, dst, src, compressed_size) \
switch (dtype) { \
case BYTEPS_FLOAT16: \
return func(reinterpret_cast<half_t*>(dst), \
reinterpret_cast<const uint16_t*>(src), compressed_size); \
reinterpret_cast<const uint32_t*>(src), compressed_size); \
case BYTEPS_FLOAT32: \
return func(reinterpret_cast<float*>(dst), \
reinterpret_cast<const uint32_t*>(src), compressed_size); \
Expand All @@ -73,13 +108,43 @@ using kwargs_t = std::unordered_map<std::string, std::string>;
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
}

#define DECOMPRESS_IMPL_SWITCH2(dtype, func, dst, src, compressed_size) \
switch (dtype) { \
case BYTEPS_FLOAT16: \
return func(reinterpret_cast<half_t*>(dst), \
reinterpret_cast<const float*>(src), compressed_size); \
case BYTEPS_FLOAT32: \
return func(reinterpret_cast<float*>(dst), \
reinterpret_cast<const float*>(src), compressed_size); \
case BYTEPS_FLOAT64: \
return func(reinterpret_cast<double*>(dst), \
reinterpret_cast<const float*>(src), compressed_size); \
default: \
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
}

#define DECOMPRESS_IMPL_SCALAR_SWITCH(dtype, func, dst, src, compressed_size) \
switch (dtype) { \
case BYTEPS_FLOAT16: \
return func(reinterpret_cast<half_t*>(dst), \
reinterpret_cast<const index_t*>(src), compressed_size); \
case BYTEPS_FLOAT32: \
return func(reinterpret_cast<float*>(dst), \
reinterpret_cast<const index_t*>(src), compressed_size); \
case BYTEPS_FLOAT64: \
return func(reinterpret_cast<double*>(dst), \
reinterpret_cast<const index_t*>(src), compressed_size); \
default: \
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
}

#define FAST_UPDATE_ERROR_IMPL_SWITCH(dtype, func, dst, src1, src2, \
compressed_size) \
switch (dtype) { \
case BYTEPS_FLOAT16: \
return func(reinterpret_cast<half_t*>(dst), \
reinterpret_cast<half_t*>(src1), \
reinterpret_cast<const uint16_t*>(src2), compressed_size); \
reinterpret_cast<const uint32_t*>(src2), compressed_size); \
case BYTEPS_FLOAT32: \
return func(reinterpret_cast<float*>(dst), \
reinterpret_cast<float*>(src1), \
Expand All @@ -92,6 +157,44 @@ using kwargs_t = std::unordered_map<std::string, std::string>;
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
}

#define FAST_UPDATE_ERROR_IMPL_SWITCH2(dtype, func, dst, src1, src2, \
compressed_size) \
switch (dtype) { \
case BYTEPS_FLOAT16: \
return func(reinterpret_cast<half_t*>(dst), \
reinterpret_cast<half_t*>(src1), \
reinterpret_cast<const half_t*>(src2), compressed_size); \
case BYTEPS_FLOAT32: \
return func(reinterpret_cast<float*>(dst), \
reinterpret_cast<float*>(src1), \
reinterpret_cast<const float*>(src2), compressed_size); \
case BYTEPS_FLOAT64: \
return func(reinterpret_cast<double*>(dst), \
reinterpret_cast<double*>(src1), \
reinterpret_cast<const double*>(src2), compressed_size); \
default: \
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
}

#define FAST_UPDATE_ERROR_IMPL_SCALAR_SWITCH(dtype, func, dst, src1, src2, \
compressed_size) \
switch (dtype) { \
case BYTEPS_FLOAT16: \
return func(reinterpret_cast<half_t*>(dst), \
reinterpret_cast<half_t*>(src1), \
reinterpret_cast<const index_t*>(src2), compressed_size); \
case BYTEPS_FLOAT32: \
return func(reinterpret_cast<float*>(dst), \
reinterpret_cast<float*>(src1), \
reinterpret_cast<const index_t*>(src2), compressed_size); \
case BYTEPS_FLOAT64: \
return func(reinterpret_cast<double*>(dst), \
reinterpret_cast<double*>(src1), \
reinterpret_cast<const index_t*>(src2), compressed_size); \
default: \
BPS_CHECK(0) << "Unsupported data type:" << dtype; \
}

} // namespace compressor
} // namespace common
} // namespace byteps
Expand Down
Loading