diff --git a/R/TaskClassif_tiny_imagenet.R b/R/TaskClassif_tiny_imagenet.R index fbd27f9e..e1c0e744 100644 --- a/R/TaskClassif_tiny_imagenet.R +++ b/R/TaskClassif_tiny_imagenet.R @@ -31,7 +31,7 @@ NULL # The cache_dir/datasets/tiny_imagenet folder. constructor_tiny_imagenet = function(path) { # path points to {cache_dir, tempfile}/data/tiny_imagenet - torchvision::tiny_imagenet_dataset(root = file.path(path ), download = TRUE) + torchvision::tiny_imagenet_dataset(root = file.path(path), download = TRUE) download_folder = file.path(path, "tiny-imagenet-200") lookup = fread(sprintf("%s/words.txt", download_folder), header = FALSE) diff --git a/vignettes/articles/image_classification.Rmd b/vignettes/articles/image_classification.Rmd new file mode 100644 index 00000000..ec30a59a --- /dev/null +++ b/vignettes/articles/image_classification.Rmd @@ -0,0 +1,52 @@ +--- +title: "Image Classification" +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>" +) +``` + +In the *Get Started* vignette, we have already explained how to train a simple neural network on tabular data. +In this article you will learn how to work with image data. +For that `mlr3torch` relies on the functionality provided by the [`torchvision`](https://github.com/mlverse/torchvision) package. +As an example task, we will use the "tiny imagenet" dataset, which is a subset of the [ImageNet](http://www.image-net.org/) dataset. +It consists of 200 classes with 500 training images each. +The goal is to predict the class of an image from the pixels. + +```{r setup} +set.seed(314) +library(mlr3torch) + +task = tsk("tiny_imagenet") +task +``` + +In `mlr3torch`, image data is represented using the `mlr3torch::imageuri` class. +It is essentially a character vector, containing paths to images on the file system. +Note that this task has to be downloaded from the internet when accessing the data. +In order to download the dataset only once, you must set the `mlr3torch.cache` option to `TRUE`, or a specific path to be used as the cache folder. + +```{r, eval = FALSE} +options(mlr3torch.cache = TRUE) +``` + +We can e.g. print the path to the first image as follows: + +```{r} +task$data(1, "image") +``` + + +As a learner, we we will use the famous AlexNet classification network, which sparked the "Deep Learning revolution" in 2012. + + +```{r} +alexnet = lrn("classif.alexnet") +alexnet +``` + + +