Skip to content

Commit

Permalink
Test ellpack categorical feature with missing values. (#10906)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 18, 2024
1 parent d5f9886 commit 8af7062
Showing 1 changed file with 35 additions and 4 deletions.
39 changes: 35 additions & 4 deletions tests/cpp/data/test_ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

#include <utility>

#include "../../../src/common/categorical.h"
#include "../../../src/common/categorical.h" // for AsCat
#include "../../../src/common/compressed_iterator.h" // for CompressedByteT
#include "../../../src/common/hist_util.h"
#include "../../../src/common/ref_resource_view.cuh" // for MakeCudaGrowOnly
#include "../../../src/data/device_adapter.cuh" // for CupyAdapter
#include "../../../src/data/device_adapter.cuh" // for CupyAdapter
#include "../../../src/data/ellpack_page.cuh"
#include "../../../src/data/ellpack_page.h"
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
Expand Down Expand Up @@ -98,7 +98,7 @@ TEST(EllpackPage, FromCategoricalBasic) {
auto& h_ft = m->Info().feature_types.HostVector();
h_ft.resize(kCols, FeatureType::kCategorical);

Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()};
auto ellpack = EllpackPage(&ctx, m.get(), p);
auto accessor = ellpack.Impl()->GetDeviceAccessor(&ctx);
Expand Down Expand Up @@ -128,6 +128,37 @@ TEST(EllpackPage, FromCategoricalBasic) {
}
}

TEST(EllpackPage, FromCategoricalMissing) {
auto ctx = MakeCUDACtx(0);

std::shared_ptr<common::HistogramCuts> cuts;
auto nan = std::numeric_limits<float>::quiet_NaN();
// 2 rows and 3 columns. The second column is nan, row_stride is 2.
std::vector<float> data{{0.1, nan, 1, 0.2, nan, 0}};
auto p_fmat = GetDMatrixFromData(data, 2, 3);
p_fmat->Info().feature_types.HostVector() = {FeatureType::kNumerical, FeatureType::kNumerical,
FeatureType::kCategorical};
p_fmat->Info().feature_types.SetDevice(ctx.Device());

auto p = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
for (auto const& page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, p)) {
cuts = std::make_shared<common::HistogramCuts>(page.Cuts());
}
cuts->SetDevice(ctx.Device());
for (auto const& page : p_fmat->GetBatches<EllpackPage>(&ctx, p)) {
std::vector<common::CompressedByteT> h_buffer;
auto h_acc = page.Impl()->GetHostAccessor(&ctx, &h_buffer,
p_fmat->Info().feature_types.ConstDeviceSpan());
ASSERT_EQ(h_acc.n_rows, 2);
ASSERT_EQ(cuts->NumFeatures(), 3);
ASSERT_EQ(h_acc.row_stride, 2);
ASSERT_EQ(h_acc.gidx_iter[0], 0);
ASSERT_EQ(h_acc.gidx_iter[1], 4); // cat 1
ASSERT_EQ(h_acc.gidx_iter[2], 1);
ASSERT_EQ(h_acc.gidx_iter[3], 3); // cat 0
}
}

struct ReadRowFunction {
EllpackDeviceAccessor matrix;
int row;
Expand Down

0 comments on commit 8af7062

Please sign in to comment.