Skip to content

Commit

Permalink
task_filter_ex tests
Browse files Browse the repository at this point in the history
  • Loading branch information
advieser committed Nov 2, 2024
1 parent 8701d9d commit 27d883c
Showing 1 changed file with 50 additions and 8 deletions.
58 changes: 50 additions & 8 deletions tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
@@ -1,42 +1,84 @@
context("utils")

# test changed row_roles$use (some removed, duplicates)
# test with task with trailing rows filtered out

test_that("task_filter_ex - Basic functionality", {
task = mlr_tasks$get("iris")

rowidx = as.integer(c(1, 2, 3, 2, 1, 2, 3, 2, 1)) # annoying and unnecessary mlr3 type strictness

# Equal to task$filter() in case of no duplicates
tfiltered_ex = task_filter_ex(task$clone(), unique(rowidx))
tfiltered = task$clone()$filter(unique(rowidx))
expect_equal(tfiltered_ex$data(), tfiltered$data())

# With duplicates
tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))

# After selecting columns
task$select(c("Petal.Length", "Petal.Width"))
tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))
})

test_that("task_filter_ex - task with col role group", {
test_that("task_filter_ex - filtered trailing rows", {
task = as_task_classif(rbind(iris, iris, iris), target = "Species", id = "test")
task$filter(301:450)

# task = mlr_tasks$get("iris")
rowidx = as.integer(300 + c(1, 2, 3, 2, 1, 2, 3, 2, 1, 4))

tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))
})

test_that("task_filter_ex - task with column role group", {
task = mlr_tasks$get("iris")
task$cbind(data.frame(grp = rep(c("A", "A", "B", "C", "D"), 30)))
task$set_col_roles("grp", "group")

rowidx = as.integer(300 + c(1, 2, 3, 2, 1, 2, 3, 2, 1, 4))
rowidx = as.integer(c(1, 2, 3, 2, 1, 2, 3, 2, 1, 4))

# Basic test
tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))
expect_equal(
table(tfiltered$groups$group),
table(c("A", "A", "A_1", "A_1", "A_2", "A_2", "A_3", "B", "B_1", "C"))
table(c("A", "A", "B", "A_1", "A_1", "A_2", "B_1", "A_3", "A_2", "C"))
)

# Name collision
task$cbind(data.frame(grp = rep(c("A", "A_1", "B", "C", "D"), 30)))
task$set_col_roles("grp", "group")

tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))
expect_equal(
tfiltered$groups$group,
c("A", "A_1", "B", "A_1_1", "A_2", "A_1_2", "B_1", "A_1_3", "A_3", "C")
)
})

test_that("task_filter_ex - changed row_roles$use", {
task = mlr_tasks$get("iris")

rowidx = as.integer(c(1, 2, 3, 2, 1, 2, 3, 2, 1))

task$row_roles$use = seq(1, 50)
tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))

task$row_roles$use = c(seq(1, 50), seq(1, 20))
tfiltered = task_filter_ex(task$clone(), 50L + rowidx)
expect_equal(tfiltered$data(), task$data(rows = 50L + rowidx))

# Need to define primary key "..row_id" explicitly because mlr3 fills it otherwise with task$row_ids
# and asserts uniqueness.
task$cbind(data.frame("..row_id" = seq(1, 70), grp = rep(c("A", "A", "B", "C", "D", "C", "A"), 10)))
task$set_col_roles("grp", "group")

tfiltered = task_filter_ex(task$clone(), rowidx)
expect_equal(tfiltered$data(), task$data(rows = rowidx))

})

# test name colision
# test for large data_set (allow_cartesian)

0 comments on commit 27d883c

Please sign in to comment.