Skip to content

Commit

Permalink
[BP] Test ellpack categorical feature with missing values. (#10906) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 18, 2024
1 parent 40742a9 commit d6059f4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ __global__ void CompressBinEllpackKernel(
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
const float* feature_cuts = &cuts[cut_ptrs[feature]];
int ncuts = cut_ptrs[feature + 1] - cut_ptrs[feature];
bool is_cat = common::IsCat(feature_types, ifeature);
bool is_cat = common::IsCat(feature_types, feature);
// Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin]
if (is_cat) {
Expand Down
39 changes: 36 additions & 3 deletions tests/cpp/data/test_ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

#include <utility>

#include "../../../src/common/categorical.h"
#include "../../../src/common/categorical.h" // for AsCat
#include "../../../src/common/compressed_iterator.h" // for CompressedByteT
#include "../../../src/common/hist_util.h"
#include "../../../src/data/ellpack_page.cuh"
#include "../../../src/data/ellpack_page.h"
#include "../../../src/tree/param.h" // TrainParam
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
#include "../../../src/tree/param.h" // TrainParam
#include "../helpers.h"
#include "../histogram_helpers.h"
#include "gtest/gtest.h"
Expand Down Expand Up @@ -91,7 +93,7 @@ TEST(EllpackPage, FromCategoricalBasic) {
auto& h_ft = m->Info().feature_types.HostVector();
h_ft.resize(kCols, FeatureType::kCategorical);

Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()};
auto ellpack = EllpackPage(&ctx, m.get(), p);
auto accessor = ellpack.Impl()->GetDeviceAccessor(FstCU());
Expand Down Expand Up @@ -122,6 +124,37 @@ TEST(EllpackPage, FromCategoricalBasic) {
}
}

TEST(EllpackPage, FromCategoricalMissing) {
auto ctx = MakeCUDACtx(0);

std::shared_ptr<common::HistogramCuts> cuts;
auto nan = std::numeric_limits<float>::quiet_NaN();
// 2 rows and 3 columns. The second column is nan, row_stride is 2.
std::vector<float> data{{0.1, nan, 1, 0.2, nan, 0}};
auto p_fmat = GetDMatrixFromData(data, 2, 3);
p_fmat->Info().feature_types.HostVector() = {FeatureType::kNumerical, FeatureType::kNumerical,
FeatureType::kCategorical};
p_fmat->Info().feature_types.SetDevice(ctx.Device());

auto p = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
for (auto const& page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, p)) {
cuts = std::make_shared<common::HistogramCuts>(page.Cuts());
}
cuts->cut_ptrs_.SetDevice(ctx.Device());
cuts->cut_values_.SetDevice(ctx.Device());
cuts->min_vals_.SetDevice(ctx.Device());
for (auto const& page : p_fmat->GetBatches<EllpackPage>(&ctx, p)) {
std::vector<common::CompressedByteT> h_buffer;
auto h_acc = page.Impl()->GetHostAccessor(p_fmat->Info().feature_types.ConstDeviceSpan());
ASSERT_EQ(h_acc.n_rows, 2);
ASSERT_EQ(h_acc.row_stride, 2);
ASSERT_EQ(h_acc.gidx_iter[0], 0);
ASSERT_EQ(h_acc.gidx_iter[1], 4); // cat 1
ASSERT_EQ(h_acc.gidx_iter[2], 1);
ASSERT_EQ(h_acc.gidx_iter[3], 3); // cat 0
}
}

struct ReadRowFunction {
EllpackDeviceAccessor matrix;
int row;
Expand Down

0 comments on commit d6059f4

Please sign in to comment.