Skip to content

Commit

Permalink
use tf allocator
Browse files Browse the repository at this point in the history
  • Loading branch information
LinGeLin committed Sep 22, 2023
1 parent acf8f0d commit 48a0d1e
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ using tensorflow::lookup::LookupInterface;
template <class K, class V>
class HkvHashTableOfTensorsGpu final : public LookupInterface {
private:
std::unique_ptr<nv::merlin::BaseAllocator> allocator_ptr_;

public:
HkvHashTableOfTensorsGpu(OpKernelContext* ctx, OpKernel* kernel) {
OP_REQUIRES_OK(ctx,
Expand Down Expand Up @@ -112,7 +114,9 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
if (table_) {
return;
}
OP_REQUIRES_OK(ctx, this->CreateTable(options, &table_));
allocator_ptr_ = std::make_unique<gpu::TFOrDefaultAllocator>(ctx);
OP_REQUIRES_OK(ctx,
this->CreateTable(options, allocator_ptr_.get(), &table_));
OP_REQUIRES(ctx, (table_ != nullptr),
errors::InvalidArgument("HashTable on GPU is created failed!"));
LOG(INFO) << "GPU table max capacity was created on max_capacity: "
Expand All @@ -129,8 +133,9 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface {
}

Status CreateTable(gpu::TableWrapperInitOptions& options,
nv::merlin::BaseAllocator* allocator,
gpu::TableWrapper<K, V>** pptable) {
return gpu::CreateTableImpl(pptable, options, runtime_dim_);
return gpu::CreateTableImpl(pptable, options, allocator, runtime_dim_);
}

size_t size() const override {
Expand Down Expand Up @@ -919,14 +924,12 @@ class HashTableSaveToFileSystemGpuOp : public OpKernel {
std::string filepath = io::JoinPath(dirpath, file_name);
FileSystem* fs;
const auto env = ctx->env();
OP_REQUIRES_OK(ctx,
env->GetFileSystemForFile(filepath, &fs));
OP_REQUIRES_OK(ctx,
fs->RecursivelyCreateDir(std::string(fs->Dirname(filepath))));

OP_REQUIRES_OK(ctx, env->GetFileSystemForFile(filepath, &fs));
OP_REQUIRES_OK(
ctx, table_hkv->ExportValuesToFile(ctx, filepath,
buffer_size_));
ctx, fs->RecursivelyCreateDir(std::string(fs->Dirname(filepath))));

OP_REQUIRES_OK(ctx,
table_hkv->ExportValuesToFile(ctx, filepath, buffer_size_));
}

private:
Expand Down Expand Up @@ -976,17 +979,14 @@ class HashTableLoadFromFileSystemGpuOp : public OpKernel {
std::string filepath = io::JoinPath(dirpath, file_name);
FileSystem* fs;
const auto env = ctx->env();
OP_REQUIRES_OK(ctx,
env->GetFileSystemForFile(filepath, &fs));
OP_REQUIRES_OK(ctx,
fs->RecursivelyCreateDir(std::string(fs->Dirname(filepath))));

OP_REQUIRES_OK(ctx, env->GetFileSystemForFile(filepath, &fs));
OP_REQUIRES_OK(
ctx, fs->RecursivelyCreateDir(std::string(fs->Dirname(filepath))));

lookup::HkvHashTableOfTensorsGpu<K, V>* table_hkv =
(lookup::HkvHashTableOfTensorsGpu<K, V>*)table;
OP_REQUIRES_OK(
ctx, table_hkv->ImportValuesFromFile(ctx, filepath,
buffer_size_));
ctx, table_hkv->ImportValuesFromFile(ctx, filepath, buffer_size_));
}

private:
Expand Down Expand Up @@ -1063,7 +1063,6 @@ REGISTER_KERNEL(int64, int32);
REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, Eigen::half);


#undef REGISTER_KERNEL

#define SINGLE_ATTR_REGISTER_KERNEL(key_dtype, value_type) \
Expand Down
Loading

0 comments on commit 48a0d1e

Please sign in to comment.