Skip to content

Commit

Permalink
improv evignettes
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jul 27, 2023
2 parents 4403a7d + 672779b commit 77f58de
Show file tree
Hide file tree
Showing 26 changed files with 1,239 additions and 2,400 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ mlr3torch*.tgz
*~
docs
inst/doc
*.html
6 changes: 5 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ Authors@R:
role = "ctb",
email = "[email protected]",
comment = c(ORCID = "0000-0001-8867-762X")))
Description: Provides torch models as learners for the mlr3 ecosystem.
Description: Deep Learning library that extends the mlr3 framework by building
upon the 'torch' package. It allows to conveniently build, train,
and evaluate deep learning models without having to worry about low level
details. Custom architectures can be created using the graph language
defined in 'mlr3pipelines'.
License: LGPL (>= 3)
Depends:
mlr3 (>= 0.16.0),
Expand Down
6 changes: 4 additions & 2 deletions R/CallbackSetHistory.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ CallbackSetHistory = R6Class("CallbackSetHistory",
stopf("No eligible measures to plot for set '%s'.", set)
}


epoch = score = measure = NULL
if (ncol(data) == 2L) {
ggplot2::ggplot(data = data, ggplot2::aes_string(x = "epoch", y = measures)) +
ggplot2::ggplot(data = data, ggplot2::aes(x = epoch, y = !!rlang::sym(measures))) +
ggplot2::geom_line() +
ggplot2::geom_point() +
ggplot2::labs(
Expand All @@ -82,7 +84,7 @@ CallbackSetHistory = R6Class("CallbackSetHistory",
theme
} else {
data = melt(data, id.vars = "epoch", variable.name = "measure", value.name = "score")
ggplot2::ggplot(data = data, ggplot2::aes_string(x = "epoch", y = "score", color = "measure")) +
ggplot2::ggplot(data = data, ggplot2::aes_string(x = epoch, y = score, color = measure)) +
viridis::scale_color_viridis(discrete = TRUE) +
ggplot2::geom_line() +
ggplot2::geom_point() +
Expand Down
3 changes: 2 additions & 1 deletion R/ModelDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#' @family Model Configuration
#' @family Graph Network
#' @return (`ModelDescriptor`)
#' @export
ModelDescriptor = function(graph, ingress, task, optimizer = NULL, loss = NULL, callbacks = NULL, .pointer = NULL,
.pointer_shape = NULL) {
assert_r6(graph, "Graph")
Expand Down Expand Up @@ -95,7 +96,7 @@ print.ModelDescriptor = function(x, ...) {
catn(sprintf("<ModelDescriptor: %d ops>", length(x$graph$pipeops)))
catn(str_indent("* Ingress: ", ingress_shapes))
catn(str_indent("* Task: ", paste0(x$task$id, " [", x$task$task_type, "]")))
catn(str_indent("* Callbacks: ", if (!is.null(x$callbacks)) as_short_string(map_chr(x$callbacks, "label"), 100L) else "N/A")) # nolint
catn(str_indent("* Callbacks: ", if (!is.null(x$callbacks) && length(x$callbacks)) as_short_string(map_chr(x$callbacks, "label"), 100L) else "N/A")) # nolint
catn(str_indent("* Optimizer: ", if (!is.null(x$optimizer)) as_short_string(x$optimizer$label) else "N/A"))
catn(str_indent("* Loss: ", if (!is.null(x$loss)) as_short_string(x$loss$label) else "N/A"))
catn(str_indent("* .pointer: ", if (is.null(x$.pointer)) "" else { # nolint
Expand Down
11 changes: 6 additions & 5 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ lgr::get_logger("mlr3")$set_threshold("warn")



# mlr3torch
# mlr3torch <img src="man/figures/logo.png" align="right" width = "120" />

Package website: [dev](https://mlr3torch.mlr-org.com/)

Deep Learning with torch and mlr3.

<!-- badges: start -->
[![Lifecycle: experimental](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://lifecycle.r-lib.org/articles/stages.html#experimental)
Expand All @@ -30,15 +34,12 @@ lgr::get_logger("mlr3")$set_threshold("warn")
[![Mattermost](https://img.shields.io/badge/chat-mattermost-orange.svg)](https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/)
<!-- badges: end -->

Deep Learning with torch and mlr3.

## Installation

```{r eval = FALSE}
remotes::install_github("mlr-org/mlr3torch")
```


## Status

`mlr3torch` is currently still unstable and many things are missing.
Expand All @@ -49,7 +50,7 @@ Not everything will work yet and the API might change without notice.

`mlr3torch` is a deep learning framework for the [`mlr3`](https://mlr-org.com) ecosystem built on top of [`torch`](https://torch.mlverse.org/).
It allows to easily build, train and evaluate deep learning models in a few lines of codes, without needing to worry about low-level details.
Off-the-shelf learners are readily available, but custom architectures can be defined by connection `PipeOpTorch` operators in an `mlr3pipelines::Graph`.
Off-the-shelf learners are readily available, but custom architectures can be defined by connecting `PipeOpTorch` operators in an `mlr3pipelines::Graph`.

Using predefined learners such as a simple multi layer perceptron (MLP) works just like any other mlr3 `Learner`.

Expand Down
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@

<!-- README.md is generated from README.Rmd. Please edit that file -->

# mlr3torch
# mlr3torch <img src="man/figures/logo.png" align="right" width = "120" />

Package website: [dev](https://mlr3torch.mlr-org.com/)

Deep Learning with torch and mlr3.

<!-- badges: start -->

Expand All @@ -14,8 +18,6 @@ status](https://www.r-pkg.org/badges/version/mlr3torch)](https://CRAN.R-project.
[![Mattermost](https://img.shields.io/badge/chat-mattermost-orange.svg)](https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/)
<!-- badges: end -->

Deep Learning with torch and mlr3.

## Installation

``` r
Expand All @@ -34,7 +36,7 @@ everything will work yet and the API might change without notice.
[`torch`](https://torch.mlverse.org/). It allows to easily build, train
and evaluate deep learning models in a few lines of codes, without
needing to worry about low-level details. Off-the-shelf learners are
readily available, but custom architectures can be defined by connection
readily available, but custom architectures can be defined by connecting
`PipeOpTorch` operators in an `mlr3pipelines::Graph`.

Using predefined learners such as a simple multi layer perceptron (MLP)
Expand Down
581 changes: 0 additions & 581 deletions attic/vignettes/intro.html

This file was deleted.

740 changes: 0 additions & 740 deletions attic/vignettes/torchops.html

This file was deleted.

Binary file added man/figures/logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
121 changes: 121 additions & 0 deletions man/figures/logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion man/mlr3torch-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion _pkgdown.yml → pkgdown/_pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,19 @@ navbar:
mattermost:
icon: fa fa-comments
href: https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/
github:
icon: fa fa-github
href: https://github.com/mlr-org/mlr3extralearners
book:
text: mlr3book
icon: fa fa-link
href: https://mlr3book.mlr-org.com
stackoverflow:
icon: fab fa-stack-overflow
href: https://stackoverflow.com/questions/tagged/mlr3
rss:
icon: fa-rss
href: https://mlr-org.com/

articles:
- title: Get Started
Expand All @@ -41,7 +47,7 @@ articles:
- title: Image Classification
navbar: ~
contents:
- image_classification
- articles/image_classification
- title: Building a Neural Network using PipeOps
navbar: ~
contents:
Expand Down
Binary file added pkgdown/favicon/apple-touch-icon-120x120.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pkgdown/favicon/apple-touch-icon-152x152.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pkgdown/favicon/apple-touch-icon-180x180.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pkgdown/favicon/apple-touch-icon-60x60.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pkgdown/favicon/apple-touch-icon-76x76.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pkgdown/favicon/apple-touch-icon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pkgdown/favicon/favicon-16x16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pkgdown/favicon/favicon-32x32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pkgdown/favicon/favicon.ico
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/testthat/test_TorchDescriptor.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ test_that("TorchDescriptor basic checks", {
expect_identical(descriptor$generator, nn_mse_loss)
expect_identical(descriptor$id, "mse")
expect_identical(descriptor$param_set$ids(), "reduction")
expect_set_equal(descriptor$packages, c("R6", "mlr3torch", "torch"))
expect_set_equal(descriptor$packages, c("R6", "torch"))
expect_identical(descriptor$man, "torch::nn_mse_loss")

expect_class(descriptor, "TorchDescriptor")
Expand All @@ -22,7 +22,7 @@ test_that("TorchDescriptor basic checks", {
"<TorchDescriptor:mse> MSE Loss",
"* Generator: nn_mse_loss",
"* Parameters: list()",
"* Packages: R6,torch,mlr3torch"
"* Packages: R6,torch"
)
expect_identical(observed, expected)

Expand Down
14 changes: 6 additions & 8 deletions vignettes/articles/image_classification.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ In `mlr3torch`, image data is represented using the `mlr3torch::imageuri` class.
It is essentially a character vector, containing paths to images on the file system.
When listing the available task feature types (after loading the `mlr3torch` package), we can see that this class is available.

```{r}
```{r, message = FALSE}
library(mlr3torch)
mlr_reflections$task_feature_types
```

Creating a vector (in this case of length 1) is as simple as passing the image paths to the `imageuri` function.
Creating a vector (in this case of length 1) is as simple as passing the image paths to the `imageuri()` function.

```{r}
image_vec = imageuri("/path/to/your/image")
Expand All @@ -30,20 +30,17 @@ For the processing of images, `mlr3torch` relies mostly on the functionality pro
As an example task, we will use the "tiny imagenet" dataset, which is a subset of the [ImageNet](http://www.image-net.org/) dataset.
It consists of 200 classes with 500 training images each.
The goal is to predict the class of an image from the pixels.
For more information you can access the tasks's help page.
For more information you can access the tasks's help page by calling `$help()`.

```{r setup}
set.seed(314)
library(mlr3torch)
tsk_tiny = tsk("tiny_imagenet")
tsk_tiny
```

The first time this task is accessed, the data is downloaded from the internet.
In order to download the dataset only once, you can set the `mlr3torch.cache` option to either `TRUE` or a specific path to be used as the cache folder.

```{r, eval = FALSE}
```{r}
options(mlr3torch.cache = TRUE)
```

Expand All @@ -61,5 +58,6 @@ alexnet = lrn("classif.alexnet")
alexnet
```

We can now train this learner like any other learner on the task at hand, while `mlr3torch` internally creates a dataloader from the image paths.
We could now train this learner like any other learner on the task at hand, while `mlr3torch` internally creates a dataloader from the image paths.
We could also download and use predefined weights, by specifying the `pretrained` parameter to `TRUE`.
For computational reasons, we cannot demonstrate the actual training of the learner in this article.
2 changes: 1 addition & 1 deletion vignettes/get_started.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ We assume that you are familiar with the `mlr3` framework, for a detailed descri
As a first example, we will train a simple multi-layer perceptron (MLP) on the well-known "mtcars" task.
We first set a seed for reproducibility, load the library and construct the task.

```{r}
```{r, message = FALSE}
set.seed(314)
library(mlr3torch)
task = tsk("mtcars")
Expand Down
Loading

0 comments on commit 77f58de

Please sign in to comment.