Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't predict on randomForest when test set contains NA's in features #1515

Open
florianfendt opened this issue Feb 13, 2017 · 9 comments
Open

Comments

@florianfendt
Copy link
Contributor

I don't know if this is a bug in some sort or if I'm overlooking something, but this baffled @ja-thomas and me a bit this morning.
Consider a simple case where you have a missing value somewhere in your test set like in this example:

lrn.rf = makeLearner("classif.randomForest")
mod = train(lrn.rf, iris.task)
test.df = getTaskData(iris.task)
test.df[1L, 1L] = NA

mlr then throws an error when you try to predict on this set, randomForest's predict method doesn't though:

# throws error: row names contain missing values
predict(mod, newdata = test.df)
# if I'm directly using the predict method from randomForest it works
predict(mod$learner.model, test.df)

I tried printing out .newdata in predictLearner.classif.randomForest to see if we do sth unwanted with the data.frame before sending it to the learner's predict method but row names / str etc. looks fine.
Any ideas?

@larskotthoff
Copy link
Member

This sounds like a bug. Could you make a unit test that reproduces it please?

@berndbischl
Copy link
Member

before you produce a unit test:
please really show the produced output in such cases

@berndbischl
Copy link
Member

This sounds like a bug. Could you make a unit test that reproduces it please?

why are you asking him that? he already kinda posted that test....?

the problem is more: what happens here?

@berndbischl
Copy link
Member

berndbischl commented Feb 13, 2017

the real problem seems to be:
mlr claims that the rf does not handle missing values. so if you would put the NA into the task data,
neither training NOR prediction would work, and you would get a meaningful error message.

but here the prediction data frame is directly handled. and the underlying RF just creates an NA for the prediction. this also links to the issue
#1499

we should probably create the task description internally and sanity check it.

@QuayAu
Copy link
Contributor

QuayAu commented Mar 10, 2017

This problem does not only occur with the random forest. I ran every available learner on this problem and here is what i found out:

test.df = getTaskData(iris.task)
test.df[1L, 1L] = NA
learners = listLearners(obj = "classif", properties = "multiclass")$class
res = lapply(learners, function(x) {
  mod = train(makeLearner(x), iris.task)
  tryCatch(predict(mod, newdata = test.df), error = function(e) conditionMessage(e))
  }
)

So the problem here is, that only the test set contains missing values, and some learners support that, others don't. Here is a small summary of the results:

learner predict
classif.bdk setosa
classif.boosting setosa
classif.C50 setosa
classif.cforest setosa
classif.ctree setosa
classif.cvglmnet NA
classif.dbnDNN NA
classif.earth invalid subscript type 'list'
classif.evtree setosa
classif.extraTrees setosa
classif.featureless versicolor
classif.fnn no missing values are allowed
classif.gausspr arguments imply differing number of rows: 150, 149
classif.gbm setosa
classif.geoDA NA
classif.glmnet setosa
classif.h2o.deeplearning setosa
classif.h2o.gbm setosa
classif.h2o.randomForest setosa
classif.IBk setosa
classif.J48 setosa
classif.JRip setosa
classif.kknn arguments imply differing number of rows: 150, 149
classif.knn no missing values are allowed
classif.ksvm arguments imply differing number of rows: 150, 149
classif.lda NA
classif.LiblineaRL1L2SVC NA/NaN/Inf in foreign function call (arg 2)
classif.LiblineaRL1LogReg NA/NaN/Inf in foreign function call (arg 2)
classif.LiblineaRL2L1SVC NA/NaN/Inf in foreign function call (arg 2)
classif.LiblineaRL2LogReg NA/NaN/Inf in foreign function call (arg 2)
classif.LiblineaRL2SVC NA/NaN/Inf in foreign function call (arg 2)
classif.LiblineaRMultiClassSVC NA/NaN/Inf in foreign function call (arg 2)
classif.linDA NA
classif.lssvm arguments imply differing number of rows: 150, 149
classif.lvq1 no missing values are allowed
classif.mda arguments imply differing number of rows: 150, 149
classif.mlp missing values in 'x'
classif.multinom NA
classif.naiveBayes setosa
classif.nnet NA
classif.nnTrain NA
classif.OneR setosa
classif.PART setosa
classif.qda NA
classif.quaDA NA
classif.randomForest row names contain missing values
classif.randomForestSRC setosa
classif.ranger Missing data in columns: Sepal.Length.
classif.rda virginica
classif.rFerns NAs in predictors.
classif.rknn no missing values are allowed
classif.rpart setosa
classif.RRF row names contain missing values
classif.rrlda invalid subscript type 'list'
classif.saeDNN NA
classif.sda NA
classif.sparseLDA NA
classif.svm arguments imply differing number of rows: 150, 149
classif.xgboost setosa
classif.xyf setosa

@larskotthoff
Copy link
Member

Well we already have the missing learner property for that, so we just need to check that in predict().

@mb706
Copy link
Contributor

mb706 commented Dec 9, 2017

The bug for classif.randomForest in particular seems to be that it puts names on its predictions, and gives the NA prediction a NA name, which cbind in makePrediction trips over. Having a

if (is.matrix(p))
  colnames(p) = NULL
else
  names(p) = NULL

anywhere in between would fix this. There are still other learners, however, that throw errors when the prediction data set contains NAs.

@pat-s
Copy link
Member

pat-s commented Jun 6, 2019

@mb706 Did #2099 solve this particular issue for randomForest?

@pat-s
Copy link
Member

pat-s commented Dec 31, 2019

I am still getting the same error as mentioned in the OP.

Have to inspect.

library(mlr)
#> Loading required package: ParamHelpers

lrn.rf = makeLearner("classif.randomForest")
mod = train(lrn.rf, iris.task)
test.df = getTaskData(iris.task)
test.df[1L, 1L] = NA

# throws error: row names contain missing values
predict(mod, newdata = test.df)
#> Error in (function (..., row.names = NULL, check.rows = FALSE, check.names = TRUE, : row names contain missing values
# if I'm directly using the predict method from randomForest it works
predict(mod$learner.model, test.df)
#>       <NA>          2          3          4          5          6          7 
#>       <NA>     setosa     setosa     setosa     setosa     setosa     setosa 
#>          8          9         10         11         12         13         14 
#>     setosa     setosa     setosa     setosa     setosa     setosa     setosa 
#>         15         16         17         18         19         20         21 
#>     setosa     setosa     setosa     setosa     setosa     setosa     setosa 
#>         22         23         24         25         26         27         28 
#>     setosa     setosa     setosa     setosa     setosa     setosa     setosa 
#>         29         30         31         32         33         34         35 
#>     setosa     setosa     setosa     setosa     setosa     setosa     setosa 
#>         36         37         38         39         40         41         42 
#>     setosa     setosa     setosa     setosa     setosa     setosa     setosa 
#>         43         44         45         46         47         48         49 
#>     setosa     setosa     setosa     setosa     setosa     setosa     setosa 
#>         50         51         52         53         54         55         56 
#>     setosa versicolor versicolor versicolor versicolor versicolor versicolor 
#>         57         58         59         60         61         62         63 
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor 
#>         64         65         66         67         68         69         70 
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor 
#>         71         72         73         74         75         76         77 
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor 
#>         78         79         80         81         82         83         84 
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor 
#>         85         86         87         88         89         90         91 
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor 
#>         92         93         94         95         96         97         98 
#> versicolor versicolor versicolor versicolor versicolor versicolor versicolor 
#>         99        100        101        102        103        104        105 
#> versicolor versicolor  virginica  virginica  virginica  virginica  virginica 
#>        106        107        108        109        110        111        112 
#>  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
#>        113        114        115        116        117        118        119 
#>  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
#>        120        121        122        123        124        125        126 
#>  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
#>        127        128        129        130        131        132        133 
#>  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
#>        134        135        136        137        138        139        140 
#>  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
#>        141        142        143        144        145        146        147 
#>  virginica  virginica  virginica  virginica  virginica  virginica  virginica 
#>        148        149        150 
#>  virginica  virginica  virginica 
#> Levels: setosa versicolor virginica

Created on 2019-12-31 by the reprex package (v0.3.0)

Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                                      
#>  version  R version 3.6.2 Patched (2019-12-12 r77564)
#>  os       macOS Mojave 10.14.6                       
#>  system   x86_64, darwin15.6.0                       
#>  ui       X11                                        
#>  language (EN)                                       
#>  collate  en_US.UTF-8                                
#>  ctype    en_US.UTF-8                                
#>  tz       Europe/Berlin                              
#>  date     2019-12-31                                 
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version     date       lib
#>  assertthat     0.2.1       2019-03-21 [1]
#>  backports      1.1.5       2019-10-02 [1]
#>  BBmisc         1.11        2017-03-10 [1]
#>  callr          3.4.0       2019-12-09 [1]
#>  checkmate      1.9.4       2019-07-04 [1]
#>  cli            2.0.0.9000  2019-12-21 [1]
#>  colorspace     1.4-1       2019-03-18 [1]
#>  crayon         1.3.4       2017-09-16 [1]
#>  data.table     1.12.8      2019-12-09 [1]
#>  desc           1.2.0       2018-05-01 [1]
#>  devtools       2.2.1       2019-09-24 [1]
#>  digest         0.6.23      2019-11-23 [1]
#>  dplyr          0.8.3       2019-07-04 [1]
#>  ellipsis       0.3.0       2019-09-20 [1]
#>  evaluate       0.14        2019-05-28 [1]
#>  fansi          0.4.0       2018-10-05 [1]
#>  fastmatch      1.1-0       2017-01-28 [1]
#>  fs             1.3.1       2019-05-06 [1]
#>  ggplot2        3.2.1       2019-08-10 [1]
#>  glue           1.3.1       2019-03-12 [1]
#>  gtable         0.3.0       2019-03-25 [1]
#>  highr          0.8         2019-03-20 [1]
#>  htmltools      0.4.0       2019-10-04 [1]
#>  knitr          1.26        2019-11-12 [1]
#>  lattice        0.20-38     2018-11-04 [2]
#>  lazyeval       0.2.2       2019-03-15 [1]
#>  lifecycle      0.1.0       2019-08-01 [1]
#>  magrittr       1.5         2014-11-22 [1]
#>  Matrix         1.2-18      2019-11-27 [2]
#>  memoise        1.1.0       2017-04-21 [1]
#>  mlr          * 2.16.0.9000 2019-12-11 [1]
#>  munsell        0.5.0       2018-06-12 [1]
#>  parallelMap    1.4.0.9000  2019-12-19 [1]
#>  ParamHelpers * 1.13.0.9000 2019-12-11 [1]
#>  pillar         1.4.3       2019-12-20 [1]
#>  pkgbuild       1.0.6       2019-10-09 [1]
#>  pkgconfig      2.0.3       2019-09-22 [1]
#>  pkgload        1.0.2       2018-10-29 [1]
#>  prettyunits    1.0.2       2015-07-13 [1]
#>  processx       3.4.1       2019-07-18 [1]
#>  ps             1.3.0       2018-12-21 [1]
#>  purrr          0.3.3       2019-10-18 [1]
#>  R6             2.4.1       2019-11-12 [1]
#>  randomForest   4.6-14      2018-03-25 [1]
#>  Rcpp           1.0.3       2019-11-08 [1]
#>  remotes        2.1.0       2019-06-24 [1]
#>  rlang          0.4.2.9000  2019-12-25 [1]
#>  rmarkdown      2.0         2019-12-12 [1]
#>  rprojroot      1.3-2       2018-01-03 [1]
#>  scales         1.1.0       2019-11-18 [1]
#>  sessioninfo    1.1.1       2018-11-05 [1]
#>  stringi        1.4.3       2019-03-12 [1]
#>  stringr        1.4.0       2019-02-10 [1]
#>  survival       3.1-8       2019-12-03 [2]
#>  testthat       2.3.1       2019-12-01 [1]
#>  tibble         2.1.3       2019-06-06 [1]
#>  tidyselect     0.2.5       2018-10-11 [1]
#>  usethis        1.5.1.9000  2019-12-14 [1]
#>  withr          2.1.2       2018-03-15 [1]
#>  xfun           0.11        2019-11-12 [1]
#>  XML            3.98-1.20   2019-06-06 [1]
#>  yaml           2.2.0       2018-07-25 [1]
#>  source                                   
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  Github (r-lib/cli@0293ae7)               
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.0)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  local                                    
#>  CRAN (R 3.6.1)                           
#>  local                                    
#>  Github (berndbischl/ParamHelpers@c2d989c)
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.0)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  Github (r-lib/rlang@ce4f717)             
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.2)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  Github (r-lib/usethis@b2e894e)           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.1)                           
#>  CRAN (R 3.6.0)                           
#> 
#> [1] /Users/pjs/Library/R/3.6/library
#> [2] /Library/Frameworks/R.framework/Versions/3.6/Resources/library

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

9 participants