-
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f830bc9
commit bcb9866
Showing
2 changed files
with
53 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
--- | ||
title: "Image Classification" | ||
--- | ||
|
||
```{r, include = FALSE} | ||
knitr::opts_chunk$set( | ||
collapse = TRUE, | ||
comment = "#>" | ||
) | ||
``` | ||
|
||
In the *Get Started* vignette, we have already explained how to train a simple neural network on tabular data. | ||
In this article you will learn how to work with image data. | ||
For that `mlr3torch` relies on the functionality provided by the [`torchvision`](https://github.com/mlverse/torchvision) package. | ||
As an example task, we will use the "tiny imagenet" dataset, which is a subset of the [ImageNet](http://www.image-net.org/) dataset. | ||
It consists of 200 classes with 500 training images each. | ||
The goal is to predict the class of an image from the pixels. | ||
|
||
```{r setup} | ||
set.seed(314) | ||
library(mlr3torch) | ||
task = tsk("tiny_imagenet") | ||
task | ||
``` | ||
|
||
In `mlr3torch`, image data is represented using the `mlr3torch::imageuri` class. | ||
It is essentially a character vector, containing paths to images on the file system. | ||
Note that this task has to be downloaded from the internet when accessing the data. | ||
In order to download the dataset only once, you must set the `mlr3torch.cache` option to `TRUE`, or a specific path to be used as the cache folder. | ||
|
||
```{r, eval = FALSE} | ||
options(mlr3torch.cache = TRUE) | ||
``` | ||
|
||
We can e.g. print the path to the first image as follows: | ||
|
||
```{r} | ||
task$data(1, "image") | ||
``` | ||
|
||
|
||
As a learner, we we will use the famous AlexNet classification network, which sparked the "Deep Learning revolution" in 2012. | ||
|
||
|
||
```{r} | ||
alexnet = lrn("classif.alexnet") | ||
alexnet | ||
``` | ||
|
||
|
||
|