Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidwendt committed Dec 20, 2024
1 parent e6bf74e commit 2297d55
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions cpp/tests/text/subword_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cudf_test/base_fixture.hpp>
#include <cudf_test/column_utilities.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/debug_utilities.hpp>

#include <cudf/column/column.hpp>
#include <cudf/strings/strings_column_view.hpp>
Expand Down Expand Up @@ -461,14 +462,21 @@ TEST(TextSubwordTest, WP1)

TEST(TextSubwordTest, WP2)
{
cudf::test::strings_column_wrapper vocabulary({"[UNK]", "!", "a", "I", "GP", "have", "##U"});
cudf::test::strings_column_wrapper vocabulary(
{"", "[UNK]", "!", "a", "I", "G", "have", "##P", "##U"});
auto vocab = nvtext::load_wordpiece_vocabulary(cudf::strings_column_view(vocabulary));

auto input = cudf::test::strings_column_wrapper({"I have a GPU !"});
auto input =
cudf::test::strings_column_wrapper({"I have a GPU !", "do not have a gpu", "no gpu"});
auto sv = cudf::strings_column_view(input);
auto results = nvtext::wordpiece_tokenize(sv, *vocab, 10);
// cudf::test::print(results->view());

using LCW = cudf::test::lists_column_wrapper<cudf::size_type>;
LCW expected({LCW{3, 5, 2, 4, 6, 1}});
// clang-format off
LCW expected({LCW{4, 6, 3, 5, 7, 8, 2},
LCW{1, 1, 6, 3, 1},
LCW{1, 1}});
// clang-format on
CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected);
}

0 comments on commit 2297d55

Please sign in to comment.