Skip to content

Commit

Permalink
pad before writing to rocksDB (#3245)
Browse files Browse the repository at this point in the history
Summary:
before writing embeddings to rocksDB, first pad the embedding dim to max_D then write.

Pull Request resolved: #3245

X-link: facebookresearch/FBGEMM#344

Reviewed By: xunnanxu, duduyi2013

Differential Revision: D64272137

fbshipit-source-id: 9c6b49c28bc033d2e642cea62f495ef9f874e903
  • Loading branch information
Yulu Jia authored and facebook-github-bot committed Oct 13, 2024
1 parent 27c9382 commit b260e98
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {

at::Tensor narrow(int64_t dim, int64_t start, int64_t length) {
CHECK_EQ(dim, 0) << "Only narrow on dim 0 is supported";
CHECK_EQ(db_->get_max_D(), shape_[1]);
CHECK_GE(db_->get_max_D(), shape_[1]);
CHECK_TRUE(snapshot_handle_ != nullptr);
auto t = at::empty(c10::IntArrayRef({length, db_->get_max_D()}), options_);
db_->get_range_from_snapshot(
Expand All @@ -430,8 +430,16 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
const int64_t length,
const at::Tensor& weights) {
CHECK_EQ(dim, 0) << "Only set_range on dim 0 is supported";
CHECK_EQ(db_->get_max_D(), shape_[1]);
db_->set_range(weights, start + row_offset_, length);
CHECK_GE(db_->get_max_D(), shape_[1]);
int pad_right = db_->get_max_D() - weights.size(1);
if (pad_right == 0) {
db_->set_range(weights, start + row_offset_, length);
} else {
std::vector<int64_t> padding = {0, pad_right, 0, 0};
auto padded_weights = torch::constant_pad_nd(weights, padding, 0);
CHECK_EQ(db_->get_max_D(), padded_weights.size(1));
db_->set_range(padded_weights, start + row_offset_, length);
}
}

c10::IntArrayRef size() {
Expand Down
17 changes: 10 additions & 7 deletions fbgemm_gpu/test/tbe/ssd/kv_tensor_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class KvTensorWrapperTest(TestCase):
def test_read_tensor_using_wrapper_from_db(self) -> None:
E = int(1e4)
D = 128
max_D = 256 # max emb dimension seen by rocksDB
N = 1000
weights_precision = SparseType.FP32
weights_dtype = weights_precision.as_dtype()
Expand All @@ -49,7 +50,7 @@ def test_read_tensor_using_wrapper_from_db(self) -> None:
0, # ssd_memtable_flush_period,
0, # ssd_memtable_flush_offset,
4, # ssd_l0_files_per_compact,
D, # embedding_dim
max_D, # embedding_dim
0, # ssd_rate_limit_mbps,
1, # ssd_size_ratio,
8, # ssd_compaction_trigger,
Expand All @@ -65,13 +66,14 @@ def test_read_tensor_using_wrapper_from_db(self) -> None:
indices = torch.randperm(N)
# insert the weights with the corresponding indices into the table
weights = torch.arange(N * D, dtype=weights_dtype).view(N, D)
output_weights = torch.empty_like(weights)
padded_weights = torch.nn.functional.pad(weights, (0, max_D - D))
output_weights = torch.empty_like(padded_weights)
count = torch.tensor([N])
ssd_db.set(indices, weights, count)
ssd_db.set(indices, padded_weights, count)

# force waiting for set to complete
ssd_db.get(indices, output_weights, torch.tensor(indices.shape[0]))
torch.testing.assert_close(weights, output_weights)
torch.testing.assert_close(padded_weights, output_weights)

# create a view tensor wrapper
snapshot = ssd_db.create_snapshot()
Expand Down Expand Up @@ -104,6 +106,7 @@ def test_read_tensor_using_wrapper_from_db(self) -> None:
def test_write_tensor_to_db(self) -> None:
E = int(1e4) # num total rows
D = 128 # emb dimension
max_D = 256 # max emb dimension seen by rocksDB
N = 1000 # window size
weights_precision = SparseType.FP32
weights_dtype = weights_precision.as_dtype()
Expand All @@ -117,7 +120,7 @@ def test_write_tensor_to_db(self) -> None:
0, # ssd_memtable_flush_period,
0, # ssd_memtable_flush_offset,
4, # ssd_l0_files_per_compact,
D, # embedding_dim
max_D, # embedding_dim
0, # ssd_rate_limit_mbps,
1, # ssd_size_ratio,
8, # ssd_compaction_trigger,
Expand All @@ -128,9 +131,9 @@ def test_write_tensor_to_db(self) -> None:
32, # row_storage_bitwidth
10 * (2**20), # block cache size
)

weights = torch.arange(N * D, dtype=weights_dtype).view(N, D)
output_weights = torch.empty_like(weights)
padded_weights = torch.nn.functional.pad(weights, (0, max_D - D))
output_weights = torch.empty_like(padded_weights)

# no snapshot needed for writing to rocksdb
tensor_wrapper0 = torch.classes.fbgemm.KVTensorWrapper(
Expand Down

0 comments on commit b260e98

Please sign in to comment.