diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc index 83069eb7a..6e46d9d25 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/hkv_hashtable_op_gpu.cu.cc @@ -54,6 +54,8 @@ using tensorflow::lookup::LookupInterface; template class HkvHashTableOfTensorsGpu final : public LookupInterface { private: + std::unique_ptr allocator_ptr_; + public: HkvHashTableOfTensorsGpu(OpKernelContext* ctx, OpKernel* kernel) { OP_REQUIRES_OK(ctx, @@ -112,7 +114,9 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface { if (table_) { return; } - OP_REQUIRES_OK(ctx, this->CreateTable(options, &table_)); + allocator_ptr_ = std::make_unique(ctx); + OP_REQUIRES_OK(ctx, + this->CreateTable(options, allocator_ptr_.get(), &table_)); OP_REQUIRES(ctx, (table_ != nullptr), errors::InvalidArgument("HashTable on GPU is created failed!")); LOG(INFO) << "GPU table max capacity was created on max_capacity: " @@ -129,8 +133,9 @@ class HkvHashTableOfTensorsGpu final : public LookupInterface { } Status CreateTable(gpu::TableWrapperInitOptions& options, + nv::merlin::BaseAllocator* allocator, gpu::TableWrapper** pptable) { - return gpu::CreateTableImpl(pptable, options, runtime_dim_); + return gpu::CreateTableImpl(pptable, options, allocator, runtime_dim_); } size_t size() const override { @@ -919,14 +924,12 @@ class HashTableSaveToFileSystemGpuOp : public OpKernel { std::string filepath = io::JoinPath(dirpath, file_name); FileSystem* fs; const auto env = ctx->env(); - OP_REQUIRES_OK(ctx, - env->GetFileSystemForFile(filepath, &fs)); - OP_REQUIRES_OK(ctx, - fs->RecursivelyCreateDir(std::string(fs->Dirname(filepath)))); - + OP_REQUIRES_OK(ctx, env->GetFileSystemForFile(filepath, &fs)); OP_REQUIRES_OK( - ctx, table_hkv->ExportValuesToFile(ctx, filepath, - buffer_size_)); + ctx, fs->RecursivelyCreateDir(std::string(fs->Dirname(filepath)))); + + OP_REQUIRES_OK(ctx, + table_hkv->ExportValuesToFile(ctx, filepath, buffer_size_)); } private: @@ -976,17 +979,14 @@ class HashTableLoadFromFileSystemGpuOp : public OpKernel { std::string filepath = io::JoinPath(dirpath, file_name); FileSystem* fs; const auto env = ctx->env(); - OP_REQUIRES_OK(ctx, - env->GetFileSystemForFile(filepath, &fs)); - OP_REQUIRES_OK(ctx, - fs->RecursivelyCreateDir(std::string(fs->Dirname(filepath)))); - + OP_REQUIRES_OK(ctx, env->GetFileSystemForFile(filepath, &fs)); + OP_REQUIRES_OK( + ctx, fs->RecursivelyCreateDir(std::string(fs->Dirname(filepath)))); lookup::HkvHashTableOfTensorsGpu* table_hkv = (lookup::HkvHashTableOfTensorsGpu*)table; OP_REQUIRES_OK( - ctx, table_hkv->ImportValuesFromFile(ctx, filepath, - buffer_size_)); + ctx, table_hkv->ImportValuesFromFile(ctx, filepath, buffer_size_)); } private: @@ -1063,7 +1063,6 @@ REGISTER_KERNEL(int64, int32); REGISTER_KERNEL(int64, int64); REGISTER_KERNEL(int64, Eigen::half); - #undef REGISTER_KERNEL #define SINGLE_ATTR_REGISTER_KERNEL(key_dtype, value_type) \ diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h index 03011c3b8..a44b6d542 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/lookup_impl/lookup_table_op_hkv.h @@ -16,14 +16,20 @@ limitations under the License. #ifndef TFRA_CORE_KERNELS_LOOKUP_TABLE_OP_GPU_H_ #define TFRA_CORE_KERNELS_LOOKUP_TABLE_OP_GPU_H_ -#include #include #include -#include -#include #include + #include +#include +#include +#include +#include "merlin/allocator.cuh" +#include "merlin/types.cuh" +#include "merlin/utils.cuh" +#include "merlin_hashtable.cuh" +#include "merlin_localfile.hpp" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/lookup_interface.h" #include "tensorflow/core/framework/op_kernel.h" @@ -35,10 +41,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/thread_annotations.h" -#include "merlin_hashtable.cuh" -#include "merlin_localfile.hpp" -#include "merlin/types.cuh" -#include "merlin/utils.cuh" namespace tensorflow { namespace recommenders_addons { @@ -50,9 +52,7 @@ class KVOnlyFile : public nv::merlin::BaseKVFile { public: KVOnlyFile() : keys_fp_(nullptr), values_fp_(nullptr) {} - ~KVOnlyFile() { - close(); - } + ~KVOnlyFile() { close(); } bool open(const std::string& keys_path, const std::string& values_path, const char* mode) { @@ -80,21 +80,23 @@ class KVOnlyFile : public nv::merlin::BaseKVFile { } } - size_t read(const size_t n, const size_t dim, K* keys, V* vectors, M* metas) override { + size_t read(const size_t n, const size_t dim, K* keys, V* vectors, + M* metas) override { size_t nread_keys = fread(keys, sizeof(K), static_cast(n), keys_fp_); size_t nread_vecs = fread(vectors, sizeof(V) * dim, static_cast(n), values_fp_); if (nread_keys != nread_vecs) { - LOG(INFO) << "Partially read failed. " << nread_keys << " kv pairs by KVOnlyFile."; + LOG(INFO) << "Partially read failed. " << nread_keys + << " kv pairs by KVOnlyFile."; return 0; } LOG(INFO) << "Partially read " << nread_keys << " kv pairs by KVOnlyFile."; return nread_keys; } - size_t write(const size_t n, const size_t dim, const K* keys, const V* vectors, - const M* metas) override { + size_t write(const size_t n, const size_t dim, const K* keys, + const V* vectors, const M* metas) override { size_t nwritten_keys = fwrite(keys, sizeof(K), static_cast(n), keys_fp_); size_t nwritten_vecs = @@ -102,7 +104,8 @@ class KVOnlyFile : public nv::merlin::BaseKVFile { if (nwritten_keys != nwritten_vecs) { return 0; } - LOG(INFO) << "Partially write " << nwritten_keys << " kv pairs by KVOnlyFile."; + LOG(INFO) << "Partially write " << nwritten_keys + << " kv pairs by KVOnlyFile."; return nwritten_keys; } @@ -113,7 +116,8 @@ class KVOnlyFile : public nv::merlin::BaseKVFile { // template to avoid multidef in compile time only. template -__global__ void gpu_u64_to_i64_kernel(const uint64_t* u64, int64* i64, size_t len) { +__global__ void gpu_u64_to_i64_kernel(const uint64_t* u64, int64* i64, + size_t len) { size_t tid = (blockIdx.x * blockDim.x) + threadIdx.x; if (tid < len) { i64[tid] = static_cast(u64[tid]); @@ -129,10 +133,12 @@ __global__ void broadcast_kernel(T* data, T val, size_t n) { } template -void gpu_cast_u64_to_i64(const uint64_t* u64, int64* i64, size_t len, cudaStream_t stream) { +void gpu_cast_u64_to_i64(const uint64_t* u64, int64* i64, size_t len, + cudaStream_t stream) { size_t block_size = nv::merlin::SAFE_GET_BLOCK_SIZE(1024); size_t grid_size = nv::merlin::SAFE_GET_GRID_SIZE(len, block_size); - gpu_u64_to_i64_kernel<<>>(u64, i64, len); + gpu_u64_to_i64_kernel + <<>>(u64, i64, len); } using GPUDevice = Eigen::ThreadPoolDevice; @@ -147,51 +153,139 @@ struct TableWrapperInitOptions { int io_block_size; }; -template -__global__ void gpu_fill_default_values(V* d_vals, V* d_def_val, size_t len, size_t dim) { +template +__global__ void gpu_fill_default_values(V* d_vals, V* d_def_val, size_t len, + size_t dim) { int index = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (int i = index; i < len * dim; i += stride) { - int row = i / dim; - int col = i % dim; - d_vals[row * dim + col] = *d_def_val; + int row = i / dim; + int col = i % dim; + d_vals[row * dim + col] = *d_def_val; } } +class TFOrDefaultAllocator : public nv::merlin::BaseAllocator { + private: + using NMMemType = nv::merlin::MemoryType; + // tensorflow::Allocator* tf_host_allocator_ = nullptr; + tensorflow::Allocator* tf_device_allocator_ = nullptr; + std::unique_ptr default_allocator_ = nullptr; + bool use_default_allocator_ = false; + // bool tf_async_allocator_stream_set_ = false; + static constexpr size_t kAllocatorAlignment = 8; + + public: + TFOrDefaultAllocator() : use_default_allocator_(true) { + default_allocator_ = std::make_unique(); + } + + TFOrDefaultAllocator(OpKernelContext* ctx) { + if (ctx) { + tensorflow::AllocatorAttributes tf_alloc_attrs; + tf_alloc_attrs.set_on_host(false); + tf_device_allocator_ = ctx->get_allocator(tf_alloc_attrs); + } else { + use_default_allocator_ = true; + default_allocator_ = std::make_unique(); + } + } + + ~TFOrDefaultAllocator() override {} + + void alloc(const NMMemType type, void** ptr, size_t size, + unsigned int pinned_flags = cudaHostAllocDefault) override { + if (!use_default_allocator_) { + switch (type) { + case NMMemType::Device: + *ptr = tf_device_allocator_->AllocateRaw(kAllocatorAlignment, size); + break; + case NMMemType::Pinned: + CUDA_CHECK(cudaMallocHost(ptr, size, pinned_flags)); + break; + case NMMemType::Host: + *ptr = std::malloc(size); + break; + } + } else { + default_allocator_->alloc(type, ptr, size, pinned_flags); + } + } + + void alloc_async(const NMMemType type, void** ptr, size_t size, + cudaStream_t stream) override { + if (!use_default_allocator_) { + if (NMMemType::Device == type) { + *ptr = tf_device_allocator_->AllocateRaw(kAllocatorAlignment, size); + } + } else { + default_allocator_->alloc_async(type, ptr, size, stream); + } + } + + void free(const NMMemType type, void* ptr) override { + if (!use_default_allocator_) { + switch (type) { + case NMMemType::Device: + tf_device_allocator_->DeallocateRaw(ptr); + break; + case NMMemType::Pinned: + CUDA_CHECK(cudaFreeHost(ptr)); + break; + case NMMemType::Host: + std::free(ptr); + break; + } + } else { + default_allocator_->free(type, ptr); + } + } + + void free_async(const NMMemType type, void* ptr, + cudaStream_t stream) override { + if (!use_default_allocator_) { + if (NMMemType::Device == type) { + tf_device_allocator_->DeallocateRaw(ptr); + } + } else { + default_allocator_->free_async(type, ptr, stream); + } + } +}; + template class TableWrapper { private: - //using M = uint64_t; + // using M = uint64_t; using Table = nv::merlin::HashTable; - nv::merlin::HashTableOptions mkv_options; + nv::merlin::HashTableOptions mkv_options_; public: TableWrapper(TableWrapperInitOptions& init_options, size_t dim) { max_capacity_ = init_options.max_capacity; dim_ = dim; - // nv::merlin::HashTableOptions mkv_options; - mkv_options.init_capacity = std::min(init_options.init_capacity, max_capacity_); - mkv_options.max_capacity = max_capacity_; + // nv::merlin::HashTableOptions mkv_options_; + mkv_options_.init_capacity = + std::min(init_options.init_capacity, max_capacity_); + mkv_options_.max_capacity = max_capacity_; // Since currently GPU nodes are not compatible to fast // pcie connections for D2H non-continous wirte, so just // use pure hbm mode now. - // mkv_options.max_hbm_for_vectors = std::numeric_limits::max(); - mkv_options.max_hbm_for_vectors = init_options.max_hbm_for_vectors; - mkv_options.max_load_factor = 0.5; - mkv_options.block_size = nv::merlin::SAFE_GET_BLOCK_SIZE(1024); - mkv_options.dim = dim; - mkv_options.evict_strategy = nv::merlin::EvictStrategy::kCustomized; - block_size_ = mkv_options.block_size; + // mkv_options_.max_hbm_for_vectors = std::numeric_limits::max(); + mkv_options_.max_hbm_for_vectors = init_options.max_hbm_for_vectors; + mkv_options_.max_load_factor = 0.5; + mkv_options_.block_size = nv::merlin::SAFE_GET_BLOCK_SIZE(1024); + mkv_options_.dim = dim; + mkv_options_.evict_strategy = nv::merlin::EvictStrategy::kCustomized; + block_size_ = mkv_options_.block_size; table_ = new Table(); - } - Status init() { + Status init(nv::merlin::BaseAllocator* allocator) { try { - table_->init(mkv_options); - } catch(std::runtime_error& e) - { + table_->init(mkv_options_, allocator); + } catch (std::runtime_error& e) { return Status(tensorflow::error::INTERNAL, e.what()); } return Status::OK(); @@ -203,60 +297,74 @@ class TableWrapper { cudaStream_t stream) { uint64_t t0 = (uint64_t)time(NULL); uint64_t* timestamp_metas = nullptr; - CUDA_CHECK(cudaMallocAsync(×tamp_metas, len * sizeof(uint64_t), stream)); - CUDA_CHECK(cudaMemsetAsync(timestamp_metas, 0, len * sizeof(uint64_t), stream)); + CUDA_CHECK( + cudaMallocAsync(×tamp_metas, len * sizeof(uint64_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(timestamp_metas, 0, len * sizeof(uint64_t), stream)); size_t grid_size = nv::merlin::SAFE_GET_GRID_SIZE(len, block_size_); - broadcast_kernel<<>>(timestamp_metas, t0, len); + broadcast_kernel + <<>>(timestamp_metas, t0, len); - table_->insert_or_assign(len, d_keys, d_vals, /*d_metas=*/timestamp_metas, stream); + table_->insert_or_assign(len, d_keys, d_vals, /*d_metas=*/timestamp_metas, + stream); CUDA_CHECK(cudaFreeAsync(timestamp_metas, stream)); } - void accum(const K* d_keys, const V* d_vals_or_deltas, - const bool* d_exists, size_t len, cudaStream_t stream) { + void accum(const K* d_keys, const V* d_vals_or_deltas, const bool* d_exists, + size_t len, cudaStream_t stream) { uint64_t t0 = (uint64_t)time(NULL); uint64_t* timestamp_metas = nullptr; - CUDA_CHECK(cudaMallocAsync(×tamp_metas, len * sizeof(uint64_t), stream)); - CUDA_CHECK(cudaMemsetAsync(timestamp_metas, 0, len * sizeof(uint64_t), stream)); + CUDA_CHECK( + cudaMallocAsync(×tamp_metas, len * sizeof(uint64_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(timestamp_metas, 0, len * sizeof(uint64_t), stream)); size_t grid_size = nv::merlin::SAFE_GET_GRID_SIZE(len, block_size_); - broadcast_kernel<<>>(timestamp_metas, t0, len); - table_->accum_or_assign(len, d_keys, d_vals_or_deltas, d_exists, /*d_metas=*/timestamp_metas, stream); + broadcast_kernel + <<>>(timestamp_metas, t0, len); + table_->accum_or_assign(len, d_keys, d_vals_or_deltas, d_exists, + /*d_metas=*/timestamp_metas, stream); CUDA_CHECK(cudaFreeAsync(timestamp_metas, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); } - void dump(K* d_key, V* d_val, const size_t offset, - const size_t search_length, size_t* d_dump_counter, - cudaStream_t stream) const { - table_->export_batch(search_length, offset, d_dump_counter, d_key, d_val, /*d_metas=*/nullptr, stream); + void dump(K* d_key, V* d_val, const size_t offset, const size_t search_length, + size_t* d_dump_counter, cudaStream_t stream) const { + table_->export_batch(search_length, offset, d_dump_counter, d_key, d_val, + /*d_metas=*/nullptr, stream); } - void dump_with_metas(K* d_key, V* d_val, uint64_t* d_metas, const size_t offset, - const size_t search_length, size_t* d_dump_counter, - cudaStream_t stream) const { - table_->export_batch(search_length, offset, d_dump_counter, d_key, d_val, d_metas, stream); + void dump_with_metas(K* d_key, V* d_val, uint64_t* d_metas, + const size_t offset, const size_t search_length, + size_t* d_dump_counter, cudaStream_t stream) const { + table_->export_batch(search_length, offset, d_dump_counter, d_key, d_val, + d_metas, stream); } - void dump_keys_and_metas(K* keys, int64* metas, size_t len, - size_t split_len, cudaStream_t stream) const { + void dump_keys_and_metas(K* keys, int64* metas, size_t len, size_t split_len, + cudaStream_t stream) const { V* values_buf = nullptr; size_t offset = 0; size_t real_offset = 0; size_t skip = split_len; uint64_t* metas_u64 = reinterpret_cast(metas); size_t span_len = table_->capacity(); - CUDA_CHECK(cudaMallocAsync(&values_buf, sizeof(V) * dim_ * split_len, stream)); - CUDA_CHECK(cudaMemsetAsync(values_buf, 0, sizeof(V) * dim_ * split_len, stream)); + CUDA_CHECK( + cudaMallocAsync(&values_buf, sizeof(V) * dim_ * split_len, stream)); + CUDA_CHECK( + cudaMemsetAsync(values_buf, 0, sizeof(V) * dim_ * split_len, stream)); for (; offset < span_len; offset += split_len) { if (offset + skip > span_len) { skip = span_len - offset; } // TODO: overlap the loop - size_t h_dump_counter = table_->export_batch(skip, offset, keys + real_offset, values_buf, metas_u64 + real_offset, stream); + size_t h_dump_counter = + table_->export_batch(skip, offset, keys + real_offset, values_buf, + metas_u64 + real_offset, stream); CudaCheckError(); if (h_dump_counter > 0) { - gpu_cast_u64_to_i64(metas_u64 + real_offset, metas + real_offset, h_dump_counter, stream); + gpu_cast_u64_to_i64(metas_u64 + real_offset, metas + real_offset, + h_dump_counter, stream); real_offset += h_dump_counter; } CUDA_CHECK(cudaStreamSynchronize(stream)); @@ -267,12 +375,12 @@ class TableWrapper { } // TODO (LinGeLin) support metas - bool is_valid_metas(const std::string& keyfile, const std::string& metafile) const { + bool is_valid_metas(const std::string& keyfile, + const std::string& metafile) const { return false; } - void dump_to_file(const string filepath, size_t dim, - cudaStream_t stream, + void dump_to_file(const string filepath, size_t dim, cudaStream_t stream, const size_t buffer_size) const { LOG(INFO) << "dump_to_file, filepath: " << filepath << ", dim: " << dim << ", stream: " << stream << ", buffer_size: " << buffer_size; @@ -285,79 +393,91 @@ class TableWrapper { if (is_valid_metas(keyfile, metafile)) { wfile.reset(new nv::merlin::LocalKVFile); - open_ok = reinterpret_cast*>(wfile.get())->open(keyfile, valuefile, metafile, "wb"); + open_ok = reinterpret_cast*>( + wfile.get()) + ->open(keyfile, valuefile, metafile, "wb"); has_metas = true; } else { wfile.reset(new KVOnlyFile); - open_ok = reinterpret_cast*>(wfile.get())->open(keyfile, valuefile, "wb"); + open_ok = reinterpret_cast*>(wfile.get()) + ->open(keyfile, valuefile, "wb"); } if (!open_ok) { - std::string error_msg = "Failed to dump to file to " + keyfile + ", " + valuefile + ", " + metafile; + std::string error_msg = "Failed to dump to file to " + keyfile + ", " + + valuefile + ", " + metafile; throw std::runtime_error(error_msg); } size_t n_saved = table_->save(wfile.get(), buffer_size, stream); if (has_metas) { - LOG(INFO) << "[op] Load " << n_saved << " pairs from keyfile: " - << keyfile << ", and valuefile: " << valuefile - << ", and metafile" << metafile; + LOG(INFO) << "[op] Load " << n_saved << " pairs from keyfile: " << keyfile + << ", and valuefile: " << valuefile << ", and metafile" + << metafile; } else { - LOG(INFO) << "[op] Load " << n_saved << " pairs from keyfile: " - << keyfile << ", and valuefile: " << valuefile; + LOG(INFO) << "[op] Load " << n_saved << " pairs from keyfile: " << keyfile + << ", and valuefile: " << valuefile; } CUDA_CHECK(cudaStreamSynchronize(stream)); // wfile->close(); } - void load_from_file(const string filepath, - size_t key_num, size_t dim, cudaStream_t stream, - const size_t buffer_size) { + void load_from_file(const string filepath, size_t key_num, size_t dim, + cudaStream_t stream, const size_t buffer_size) { std::unique_ptr> rfile; string keyfile = filepath + "-keys"; string valuefile = filepath + "-values"; string metafile = filepath + "-metas"; - //rfile.reset(new TimestampV1CompatFile); + // rfile.reset(new TimestampV1CompatFile); bool has_metas = false; bool open_ok = false; if (is_valid_metas(keyfile, metafile)) { rfile.reset(new nv::merlin::LocalKVFile); - open_ok = reinterpret_cast*>(rfile.get())->open(keyfile, valuefile, metafile, "rb"); + open_ok = reinterpret_cast*>( + rfile.get()) + ->open(keyfile, valuefile, metafile, "rb"); has_metas = true; } else { rfile.reset(new KVOnlyFile); - open_ok = reinterpret_cast*>(rfile.get())->open(keyfile, valuefile, "rb"); + open_ok = reinterpret_cast*>(rfile.get()) + ->open(keyfile, valuefile, "rb"); } if (!open_ok) { - std::string error_msg = "Failed to load from file " + keyfile + ", " + valuefile + ", " + metafile; + std::string error_msg = "Failed to load from file " + keyfile + ", " + + valuefile + ", " + metafile; throw std::runtime_error(error_msg); } size_t n_loaded = table_->load(rfile.get(), buffer_size, stream); if (has_metas) { - LOG(INFO) << "[op] Load " << n_loaded << " pairs from keyfile: " - << keyfile << ", and valuefile: " << valuefile - << ", and metafile" << metafile; + LOG(INFO) << "[op] Load " << n_loaded + << " pairs from keyfile: " << keyfile + << ", and valuefile: " << valuefile << ", and metafile" + << metafile; } else { - LOG(INFO) << "[op] Load " << n_loaded << " pairs from keyfile: " - << keyfile << ", and valuefile: " << valuefile; + LOG(INFO) << "[op] Load " << n_loaded + << " pairs from keyfile: " << keyfile + << ", and valuefile: " << valuefile; } CUDA_CHECK(cudaStreamSynchronize(stream)); if (has_metas) { - reinterpret_cast*>(rfile.get())->close(); + reinterpret_cast*>(rfile.get()) + ->close(); } else { reinterpret_cast*>(rfile.get())->close(); } } - void get(const K* d_keys, V* d_vals, bool* d_status, size_t len, - V* d_def_val, cudaStream_t stream, - bool is_full_size_default) const { + void get(const K* d_keys, V* d_vals, bool* d_status, size_t len, V* d_def_val, + cudaStream_t stream, bool is_full_size_default) const { if (is_full_size_default) { - CUDA_CHECK(cudaMemcpyAsync(d_vals, d_def_val, sizeof(V) * dim_ * len, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(d_vals, d_def_val, sizeof(V) * dim_ * len, + cudaMemcpyDeviceToDevice, stream)); } else { size_t grid_size = nv::merlin::SAFE_GET_GRID_SIZE(len, block_size_); - gpu_fill_default_values<<>>(d_vals, d_def_val, len, dim_); + gpu_fill_default_values + <<>>( + d_vals, d_def_val, len, dim_); } table_->find(len, d_keys, d_vals, d_status, /*d_metas=*/nullptr, stream); } @@ -373,9 +493,7 @@ class TableWrapper { CUDA_CHECK(cudaStreamSynchronize(stream)); } - size_t get_size(cudaStream_t stream) const { - return table_->size(stream); - } + size_t get_size(cudaStream_t stream) const { return table_->size(stream); } size_t get_capacity() const { return table_->capacity(); } @@ -394,10 +512,12 @@ class TableWrapper { }; template -Status CreateTableImpl(TableWrapper** pptable, TableWrapperInitOptions& options, - size_t runtime_dim) { +Status CreateTableImpl(TableWrapper** pptable, + TableWrapperInitOptions& options, + nv::merlin::BaseAllocator* allocator, + size_t runtime_dim) { *pptable = new TableWrapper(options, runtime_dim); - return (*pptable)->init(); + return (*pptable)->init(allocator); } } // namespace gpu