-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REVIEW] Add tfidf bm25 #2353
Open
jperez999
wants to merge
97
commits into
rapidsai:branch-24.12
Choose a base branch
from
jperez999:add-tfidf-bm25
base: branch-24.12
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
[REVIEW] Add tfidf bm25 #2353
Changes from 79 commits
Commits
Show all changes
97 commits
Select commit
Hold shift + click to select a range
a6677ca
update master references
ajschmidt8 ad2d7d7
REL DOC Updates for main branch switch
mike-wendt e3c9344
Merge pull request #272 from rapidsai/branch-21.06
ajschmidt8 3b0a6d2
Merge pull request #321 from rapidsai/branch-21.08
ajschmidt8 309ea1a
REL v21.08.00 release
GPUtester 3740998
Merge pull request #612 from rapidsai/branch-22.04
raydouglass e987ec8
REL v22.04.00 release
GPUtester 0b55c32
Add `conda` compilers (#702)
ajschmidt8 229b9f8
update changelog
raydouglass 0eded98
Merge pull request #708 from rapidsai/branch-22.06
raydouglass 3e5a625
FIX update-version.sh
raydouglass ad50a7f
Merge pull request #709 from rapidsai/branch-22.06
raydouglass ed2c529
REL v22.06.00 release
GPUtester aae5e34
Merge pull request #782 from rapidsai/branch-22.08
raydouglass 87a7d16
REL v22.08.00 release
GPUtester 1de93ba
Merge pull request #908 from rapidsai/branch-22.10
raydouglass 31ae597
REL v22.10.00 release
GPUtester 08abc72
[HOTFIX] Update cuda-python dependency to 11.7.1 (#963)
cjnolet c6e6ce8
Merge pull request #988 from rapidsai/branch-22.10
raydouglass f7d2335
REL v22.10.01 release
GPUtester c16fa56
Merge pull request #1063 from rapidsai/branch-22.12
raydouglass 9a716b7
REL v22.12.00 release
GPUtester 60936ba
Merge pull request #1101 from rapidsai/branch-22.12
raydouglass a655c9a
REL v22.12.01 release
GPUtester 9a66f42
Merge pull request #1250 from rapidsai/branch-23.02
raydouglass 69dce2d
REL v23.02.00 release
raydouglass 1467154
Merge pull request #1405 from rapidsai/branch-23.04
raydouglass 7d1057e
REL v23.04.00 release
raydouglass dc800d6
REL v23.04.01 release
raydouglass 520e12c
REL Merge pull request #1486 from rapidsai/branch-23.04
raydouglass f626bf1
Merge pull request #1549 from rapidsai/branch-23.06
raydouglass c931b61
REL v23.06.00 release
raydouglass af1515d
Merge pull request #1589 from rapidsai/branch-23.06
raydouglass 9147c90
REL v23.06.01 release
raydouglass 59ae9d6
Merge pull request #1636 from rapidsai/branch-23.06
raydouglass 7dd2f6d
REL v23.06.02 release
raydouglass 5797ef5
Merge pull request #1692 from rapidsai/branch-23.08
raydouglass e588d7b
REL v23.08.00 release
raydouglass 51f52c1
Merge pull request #1863 from rapidsai/branch-23.10
raydouglass afdddfb
REL v23.10.00 release
raydouglass e9f9aa8
Merge pull request #2020 from rapidsai/branch-23.12
raydouglass 599651e
REL v23.12.00 release
raydouglass 9e2d627
REL Revert update-version.sh changes for release
raydouglass 1143113
Merge pull request #2134 from rapidsai/branch-24.02
raydouglass 698d6c7
REL v24.02.00 release
raydouglass e0d40e5
Merge pull request #2240 from rapidsai/branch-24.04
raydouglass fa44bcc
REL v24.04.00 release
raydouglass 41938c4
Merge pull request #2341 from rapidsai/branch-24.06
raydouglass 63a506d
REL v24.06.00 release
raydouglass 427ea26
add in support for preprocessing with bm25 and tfidf
jperez999 ffbfbc7
add in test cases and header file
jperez999 2d82aca
add tfidf coo support
jperez999 dc01bc1
add in header for coo tfidf
jperez999 6f4745d
add bm25 test support coo in and refactor tfidf support
jperez999 987ff5e
add in long test for coo to csr convert test
jperez999 c46008c
remove unneeded print statement
jperez999 81bb89d
remove unneeded test
jperez999 ff1991f
add csr and coo matrix bfknn apis
jperez999 c593f4e
add knn to preprocess tests
jperez999 0febb55
all tests in place and refactor code
jperez999 6477cd4
add in cmake for test files
jperez999 c836ba8
adjust tests, coo now passes all checks
jperez999 ce8253e
csr and coo tests passing, refactor feature preprocessing
jperez999 442cd7a
refactor names to make more generic
jperez999 b1720c7
further refactor to feature and id variable names
jperez999 3365ec3
add documentation and refactor to use num rows and num cols from matrix
jperez999 06b6df2
update tests to reflect values given refactor
jperez999 034d2c5
add documentation
jperez999 04bb007
removed unnecessary imports and variables
jperez999 3747291
fix function docs to reflect behavior more correctly
jperez999 281a029
Merge branch 'branch-24.08' into add-tfidf-bm25
jperez999 3d66d4b
Update docs/source/contributing.md
jperez999 2b70436
Update .github/PULL_REQUEST_TEMPLATE.md
jperez999 84ffc8b
Update .github/PULL_REQUEST_TEMPLATE.md
jperez999 63607bd
Merge branch 'branch-24.08' into add-tfidf-bm25
jperez999 0f462a9
Merge branch 'branch-24.08' into add-tfidf-bm25
jperez999 1fc27f3
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 82cfb1f
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 1155609
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 6302957
Merge branch 'branch-24.10' into add-tfidf-bm25
cjnolet 05f4af2
fix preprocessing and make tests run on r random at generation
jperez999 a1e3a48
remove unnecessary imports
jperez999 44f3e1c
remove log for tf
jperez999 e25e2de
added more template changes
jperez999 187e148
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 ec4e4a2
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 e6d2c1c
remove excess thrust calls
jperez999 5120c97
add better comment on inputs for tests
jperez999 81e2a41
Merge branch 'add-tfidf-bm25' of https://github.com/jperez999/raft in…
jperez999 90373ab
Merge branch 'branch-24.10' into add-tfidf-bm25
jperez999 87a729c
fixed scale errors
jperez999 63576b0
remove vector based public apis
jperez999 c123acb
add in bfknn tests for csr and coo sparse matrices
jperez999 29f14d9
Merge branch 'branch-24.12' into add-tfidf-bm25
rhdong 0ca6e10
Merge branch 'branch-24.12' into add-tfidf-bm25
jperez999 b000065
remove unused functions
jperez999 3507771
Merge branch 'add-tfidf-bm25' of https://github.com/jperez999/raft in…
jperez999 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
463 changes: 463 additions & 0 deletions
463
cpp/include/raft/sparse/matrix/detail/preprocessing.cuh
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/core/device_csr_matrix.hpp> | ||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/core/resource/cuda_stream.hpp> | ||
#include <raft/core/resources.hpp> | ||
#include <raft/sparse/matrix/detail/preprocessing.cuh> | ||
|
||
#include <optional> | ||
|
||
namespace raft::sparse::matrix { | ||
|
||
/** | ||
* @brief Use BM25 algorithm to encode features in COO sparse matrix | ||
* @param handle: raft resource handle | ||
* @param rows: Input COO rows array | ||
* @param columns: Input COO columns array | ||
* @param values: Input COO values array | ||
* @param values_out: Output COO values array | ||
* @param n_rows: Number of rows in matrix | ||
* @param n_cols: Number of columns in matrix | ||
* @param k_param: K value to use for BM25 algorithm | ||
* @param b_param: B value to use for BM25 algorithm | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_bm25(raft::resources& handle, | ||
raft::device_vector_view<T1, IdxT> rows, | ||
raft::device_vector_view<T1, IdxT> columns, | ||
raft::device_vector_view<T2, IdxT> values, | ||
raft::device_vector_view<T2, IdxT> values_out, | ||
int n_rows, | ||
int n_cols, | ||
float k_param = 1.6f, | ||
float b_param = 0.75) | ||
{ | ||
return matrix::detail::base_encode_bm25<T1, T2, IdxT>( | ||
handle, rows, columns, values, values_out, n_rows, n_cols, k_param, b_param); | ||
} | ||
|
||
/** | ||
* @brief Use BM25 algorithm to encode features in COO sparse matrix | ||
* @param handle: raft resource handle | ||
* @param coo_in: Input COO matrix | ||
* @param values_out: Output values array | ||
* @param k_param: K value to use for BM25 algorithm | ||
* @param b_param: B value to use for BM25 algorithm | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_bm25(raft::resources& handle, | ||
raft::device_coo_matrix_view<T2, T1, T1, T1> coo_in, | ||
raft::device_vector_view<T2, IdxT> values_out, | ||
float k_param = 1.6f, | ||
float b_param = 0.75) | ||
{ | ||
return matrix::detail::encode_bm25<T1, T2, IdxT>(handle, coo_in, values_out, k_param, b_param); | ||
} | ||
|
||
/** | ||
* @brief Use BM25 algorithm to encode features in CSR sparse matrix | ||
* @param handle: raft resource handle | ||
* @param csr_in: Input CSR matrix | ||
* @param values_out: Output values array | ||
* @param k_param: K value to use for BM25 algorithm | ||
* @param b_param: B value to use for BM25 algorithm | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_bm25(raft::resources& handle, | ||
raft::device_csr_matrix_view<T2, T1, T1, T1> csr_in, | ||
raft::device_vector_view<T2, IdxT> values_out, | ||
float k_param = 1.6f, | ||
float b_param = 0.75) | ||
{ | ||
return matrix::detail::encode_bm25<T1, T2, IdxT>(handle, csr_in, values_out, k_param, b_param); | ||
} | ||
|
||
/** | ||
* @brief Use TFIDF algorithm to encode features in COO sparse matrix | ||
* @param handle: raft resource handle | ||
* @param rows: Input COO rows array | ||
* @param columns: Input COO columns array | ||
* @param values: Input COO values array | ||
* @param values_out: Output COO values array | ||
* @param n_rows: Number of rows in matrix | ||
* @param n_cols: Number of columns in matrix | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_tfidf(raft::resources& handle, | ||
raft::device_vector_view<T1, IdxT> rows, | ||
raft::device_vector_view<T1, IdxT> columns, | ||
jperez999 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raft::device_vector_view<T2, IdxT> values, | ||
raft::device_vector_view<T2, IdxT> values_out, | ||
int n_rows, | ||
int n_cols) | ||
{ | ||
return matrix::detail::base_encode_tfidf<T1, T2, IdxT>( | ||
handle, rows, columns, values, values_out, n_rows, n_cols); | ||
} | ||
|
||
/** | ||
* @brief Use TFIDF algorithm to encode features in COO sparse matrix | ||
* @param handle: raft resource handle | ||
* @param coo_in: Input COO matrix | ||
* @param values_out: Output COO values array | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_tfidf(raft::resources& handle, | ||
raft::device_coo_matrix_view<T2, T1, T1, T1> coo_in, | ||
raft::device_vector_view<T2, IdxT> values_out) | ||
{ | ||
return matrix::detail::encode_tfidf<T1, T2, IdxT>(handle, coo_in, values_out); | ||
} | ||
|
||
/** | ||
* @brief Use TFIDF algorithm to encode features in CSR sparse matrix | ||
* @param handle: raft resource handle | ||
* @param csr_in: Input CSR matrix | ||
* @param values_out: Output values array | ||
*/ | ||
template <typename T1, typename T2, typename IdxT> | ||
void encode_tfidf(raft::resources& handle, | ||
raft::device_csr_matrix_view<T2, T1, T1, T1> csr_in, | ||
raft::device_vector_view<T2, IdxT> values_out) | ||
{ | ||
return matrix::detail::encode_tfidf<T1, T2, IdxT>(handle, csr_in, values_out); | ||
} | ||
|
||
} // namespace raft::sparse::matrix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a COO format and I'd prefer to see it accept our raft::sparse::coo_matrix_view instead of accepting these arrays individually. You'll notice this is not widespread yet because we just created these new formats but new sparse functions should use them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we support the sparse formats below, we should remove these ones now that accept the raw arrays.