Skip to content

Commit

Permalink
Add first 4 bins test
Browse files Browse the repository at this point in the history
  • Loading branch information
rmontanana committed Jun 4, 2024
1 parent 1d02f25 commit f90fd14
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 17 deletions.
54 changes: 37 additions & 17 deletions tests/BinDisc_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@
#include "../BinDisc.h"

namespace mdlp {
const float margin = 1e-4;
class TestBinDisc3U : public BinDisc, public testing::Test {
public:
TestBinDisc3U(int n_bins = 3) : BinDisc(n_bins, strategy_t::UNIFORM) {};
float margin = 1e-4;
};
class TestBinDisc3Q : public BinDisc, public testing::Test {
public:
TestBinDisc3Q(int n_bins = 3) : BinDisc(n_bins, strategy_t::QUANTILE) {};
float margin = 1e-4;
};
class TestBinDisc4U : public BinDisc, public testing::Test {
public:
TestBinDisc4U(int n_bins = 4) : BinDisc(n_bins, strategy_t::UNIFORM) {};
};
class TestBinDisc4Q : public BinDisc, public testing::Test {
public:
TestBinDisc4Q(int n_bins = 4) : BinDisc(n_bins, strategy_t::QUANTILE) {};
};
TEST_F(TestBinDisc3U, Easy3BinsUniform)
{
Expand Down Expand Up @@ -165,21 +172,34 @@ namespace mdlp {
EXPECT_EQ(expected, labels);
EXPECT_EQ(3.0, X[0]); // X is not modified
}
// TEST(TestBinDisc_Gen, Easy4Bins)
// {
// auto disc = BinDisc(4);
// samples_t X = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0 };
// disc.fit(X);
// auto cuts = disc.getCutPoints();
// EXPECT_EQ(3.0, cuts[0]);
// EXPECT_EQ(6.0, cuts[1]);
// EXPECT_EQ(9.0, cuts[2]);
// EXPECT_EQ(numeric_limits<size_t>::max(), cuts[3]);
// EXPECT_EQ(4, cuts.size());
// auto labels = disc.transform(X);
// labels_t expected = { 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3 };
// EXPECT_EQ(expected, labels);
// }
TEST_F(TestBinDisc4U, Easy4BinsUniform)
{
samples_t X = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0 };
fit(X);
auto cuts = getCutPoints();
EXPECT_EQ(3.75, cuts[0]);
EXPECT_EQ(6.5, cuts[1]);
EXPECT_EQ(9.25, cuts[2]);
EXPECT_EQ(numeric_limits<float>::max(), cuts[3]);
EXPECT_EQ(4, cuts.size());
auto labels = transform(X);
labels_t expected = { 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3 };
EXPECT_EQ(expected, labels);
}
TEST_F(TestBinDisc4Q, Easy4BinsQuantile)
{
samples_t X = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0 };
fit(X);
auto cuts = getCutPoints();
EXPECT_EQ(3.75, cuts[0]);
EXPECT_EQ(6.5, cuts[1]);
EXPECT_EQ(9.25, cuts[2]);
EXPECT_EQ(numeric_limits<float>::max(), cuts[3]);
EXPECT_EQ(4, cuts.size());
auto labels = transform(X);
labels_t expected = { 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3 };
EXPECT_EQ(expected, labels);
}
// TEST(TestBinDisc_Gen, X13Bins)
// {
// auto disc = BinDisc(4);
Expand Down
10 changes: 10 additions & 0 deletions tests/testKbins.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def test(clf, X, expected, title):
clf3q = KBinsDiscretizer(
n_bins=3, encode="ordinal", strategy="quantile", subsample=200_000
)
clf4u = KBinsDiscretizer(
n_bins=4, encode="ordinal", strategy="uniform", subsample=200_000
)
clf4q = KBinsDiscretizer(
n_bins=4, encode="ordinal", strategy="quantile", subsample=200_000
)
#
X = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
labels = [0, 0, 0, 1, 1, 1, 2, 2, 2]
Expand Down Expand Up @@ -51,6 +57,10 @@ def test(clf, X, expected, title):
test(clf3u, X, labels, title="EasyRepeatedUniform")
test(clf3q, X, labels2, title="EasyRepeatedQuantile")
#
X = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]
labels = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
test(clf4u, X, labels, title="Easy4BinsUniform")
test(clf4q, X, labels, title="Easy4BinsQuantile")
# X = [15.0, 13.0, 12.0, 14.0, 6.0, 1.0, 8.0, 11.0, 10.0, 9.0, 7.0, 4.0, 3.0, 5.0, 2.0]
# X = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
# X = [[x] for x in X]
Expand Down

0 comments on commit f90fd14

Please sign in to comment.