diff --git a/dev/articles/callbacks.html b/dev/articles/callbacks.html index c55d112d..17260a19 100644 --- a/dev/articles/callbacks.html +++ b/dev/articles/callbacks.html @@ -225,7 +225,7 @@
torch
Primerinput = torch_randn(2, 3)
input
#> torch_tensor
-#> 0.9197 0.6295 -0.9055
-#> -2.5884 0.7595 1.2294
+#> -1.3766 -0.5136 0.3212
+#> -0.1381 0.5962 0.2744
#> [ CPUFloatType{2,3} ]
A nn_module
is constructed from a
nn_module_generator
. nn_linear
is one of the
@@ -117,8 +117,8 @@
torch
Primeroutput = module_1(input)
output
#> torch_tensor
-#> 0.6356 -0.0022 0.3491 -1.0918
-#> -1.5718 1.8213 -1.6298 0.2341
+#> 0.2026 -0.6605 0.1249 0.6521
+#> 0.4501 -0.2476 0.1827 0.3494
#> [ CPUFloatType{2,4} ][ grad_fn = <AddmmBackward0> ]
A neural network with one (4-unit) hidden layer and two outputs needs the following ingredients
@@ -134,8 +134,8 @@torch
Primeroutput = softmax(output)
output
#> torch_tensor
-#> 0.2569 0.4082 0.3350
-#> 0.3488 0.3890 0.2623
+#> 0.3464 0.1966 0.4570
+#> 0.3344 0.1942 0.4714
#> [ CPUFloatType{2,3} ][ grad_fn = <SoftmaxBackward0> ]
We will now continue with showing how such a neural network can be
represented in mlr3torch
.
Note we only use the $train()
, since torch modules do
not have anything that maps to the state
(it is filled by
@@ -196,8 +196,8 @@
While this object allows to easily perform a forward pass, it does
not inherit from nn_module
, which is useful for various
@@ -245,8 +245,8 @@
graph_module(input)
#> torch_tensor
-#> 0.2569 0.4082 0.3350
-#> 0.3488 0.3890 0.2623
+#> 0.3464 0.1966 0.4570
+#> 0.3344 0.1942 0.4714
#> [ CPUFloatType{2,3} ][ grad_fn = <SoftmaxBackward0> ]
ModelDescriptor
to
small_module(batch$x[[1]])
#> torch_tensor
-#> 1.7664 4.3770 -3.4375 -0.1343
-#> 1.7621 4.0245 -3.1681 -0.1304
-#> 1.6213 4.0500 -3.1483 -0.0796
+#> 2.5036 3.6613 -0.2827 -0.3227
+#> 2.2548 3.3180 -0.4614 -0.4464
+#> 2.2830 3.3477 -0.3081 -0.3362
#> [ CPUFloatType{3,4} ][ grad_fn = <AddmmBackward0> ]The first linear layer that takes “Sepal” input
("linear1"
) creates a 2x4 tensor (batch size 2, 4 units),
@@ -690,14 +689,14 @@
We observe that the po("nn_merge_cat")
concatenates
these, as expected:
The printed output of the data descriptor informs us about:
What happens during materialize(lt[1])
is the
following:
We see that the $graph
has a new pipeop with id
"poly.x"
and the output pointer
points to
poly.x
. Also we see that the shape of the tensor is now
diff --git a/dev/pkgdown.yml b/dev/pkgdown.yml
index a84a9cac..540cf350 100644
--- a/dev/pkgdown.yml
+++ b/dev/pkgdown.yml
@@ -7,7 +7,7 @@ articles:
articles/internals_pipeop_torch: internals_pipeop_torch.html
articles/lazy_tensor: lazy_tensor.html
articles/pipeop_torch: pipeop_torch.html
-last_built: 2024-10-29T11:35Z
+last_built: 2024-11-08T09:47Z
urls:
reference: https://mlr3torch.mlr-org.com/reference
article: https://mlr3torch.mlr-org.com/articles
diff --git a/dev/reference/DataDescriptor.html b/dev/reference/DataDescriptor.html
index 7ddde268..e032eba3 100644
--- a/dev/reference/DataDescriptor.html
+++ b/dev/reference/DataDescriptor.html
@@ -263,14 +263,14 @@
lt1 = as_lazy_tensor(torch_randn(10, 3))
materialize(lt1, rbind = TRUE)
#> torch_tensor
-#> -1.3420 -0.7548 -0.2027
-#> 0.8642 -0.6203 -1.1163
-#> -0.1052 0.5600 -0.3974
-#> -0.3943 -0.0705 -1.7041
-#> 1.5337 0.1588 -0.8624
-#> 1.4815 -0.1104 -0.3846
-#> 2.3301 1.1064 0.2867
-#> -1.1371 -1.4983 0.6197
-#> 0.3968 1.8850 0.5676
-#> 1.4147 1.3346 -1.3396
+#> 0.2337 0.5821 -0.4541
+#> -0.0929 -0.0327 -0.0007
+#> 0.7647 1.1990 1.1332
+#> -1.2069 0.9962 -0.2510
+#> 0.4641 -0.2855 -0.6539
+#> -0.2911 0.3456 -0.2388
+#> -1.2894 0.6156 1.3607
+#> -0.2922 -2.0550 1.3359
+#> -0.3164 0.3312 -1.1026
+#> -1.4172 -0.1749 1.6354
#> [ CPUFloatType{10,3} ]
materialize(lt1, rbind = FALSE)
#> [[1]]
#> torch_tensor
-#> -1.3420
-#> -0.7548
-#> -0.2027
+#> 0.2337
+#> 0.5821
+#> -0.4541
#> [ CPUFloatType{3} ]
#>
#> [[2]]
#> torch_tensor
-#> 0.8642
-#> -0.6203
-#> -1.1163
+#> 0.01 *
+#> -9.2944
+#> -3.2728
+#> -0.0712
#> [ CPUFloatType{3} ]
#>
#> [[3]]
#> torch_tensor
-#> -0.1052
-#> 0.5600
-#> -0.3974
+#> 0.7647
+#> 1.1990
+#> 1.1332
#> [ CPUFloatType{3} ]
#>
#> [[4]]
#> torch_tensor
-#> -0.3943
-#> -0.0705
-#> -1.7041
+#> -1.2069
+#> 0.9962
+#> -0.2510
#> [ CPUFloatType{3} ]
#>
#> [[5]]
#> torch_tensor
-#> 1.5337
-#> 0.1588
-#> -0.8624
+#> 0.4641
+#> -0.2855
+#> -0.6539
#> [ CPUFloatType{3} ]
#>
#> [[6]]
#> torch_tensor
-#> 1.4815
-#> -0.1104
-#> -0.3846
+#> -0.2911
+#> 0.3456
+#> -0.2388
#> [ CPUFloatType{3} ]
#>
#> [[7]]
#> torch_tensor
-#> 2.3301
-#> 1.1064
-#> 0.2867
+#> -1.2894
+#> 0.6156
+#> 1.3607
#> [ CPUFloatType{3} ]
#>
#> [[8]]
#> torch_tensor
-#> -1.1371
-#> -1.4983
-#> 0.6197
+#> -0.2922
+#> -2.0550
+#> 1.3359
#> [ CPUFloatType{3} ]
#>
#> [[9]]
#> torch_tensor
-#> 0.3968
-#> 1.8850
-#> 0.5676
+#> -0.3164
+#> 0.3312
+#> -1.1026
#> [ CPUFloatType{3} ]
#>
#> [[10]]
#> torch_tensor
-#> 1.4147
-#> 1.3346
-#> -1.3396
+#> -1.4172
+#> -0.1749
+#> 1.6354
#> [ CPUFloatType{3} ]
#>
lt2 = as_lazy_tensor(torch_randn(10, 4))
@@ -219,184 +220,185 @@ Examplesmaterialize(d, rbind = TRUE)
#> $lt1
#> torch_tensor
-#> -1.3420 -0.7548 -0.2027
-#> 0.8642 -0.6203 -1.1163
-#> -0.1052 0.5600 -0.3974
-#> -0.3943 -0.0705 -1.7041
-#> 1.5337 0.1588 -0.8624
-#> 1.4815 -0.1104 -0.3846
-#> 2.3301 1.1064 0.2867
-#> -1.1371 -1.4983 0.6197
-#> 0.3968 1.8850 0.5676
-#> 1.4147 1.3346 -1.3396
+#> 0.2337 0.5821 -0.4541
+#> -0.0929 -0.0327 -0.0007
+#> 0.7647 1.1990 1.1332
+#> -1.2069 0.9962 -0.2510
+#> 0.4641 -0.2855 -0.6539
+#> -0.2911 0.3456 -0.2388
+#> -1.2894 0.6156 1.3607
+#> -0.2922 -2.0550 1.3359
+#> -0.3164 0.3312 -1.1026
+#> -1.4172 -0.1749 1.6354
#> [ CPUFloatType{10,3} ]
#>
#> $lt2
#> torch_tensor
-#> -0.1006 1.3168 0.1388 0.5232
-#> 0.2336 -0.0717 -1.0036 0.6683
-#> 0.7921 1.4544 0.1958 -0.2346
-#> 0.6385 -0.0318 0.6482 -0.1380
-#> 1.3797 0.7916 -0.2373 -1.0011
-#> 0.3849 -0.2163 -0.0237 -0.0863
-#> 0.4032 0.1983 -1.6133 -2.2197
-#> 0.2452 0.7073 0.3545 0.0833
-#> -0.8829 0.2502 -0.3007 -0.1676
-#> 1.9140 -0.7256 -0.7844 -1.4341
+#> -0.3209 -0.9323 -0.4262 0.4701
+#> -0.6289 -1.5827 1.4104 -0.9415
+#> -0.6591 -1.3246 -0.3064 0.8200
+#> -1.0953 -0.7571 2.1173 -0.8233
+#> 0.2794 -1.1853 -1.6339 0.4039
+#> 1.4075 1.4175 -0.1898 -0.2829
+#> -0.2797 -0.2720 0.2012 0.6381
+#> -0.1676 2.0197 -0.6833 1.2002
+#> 1.7206 -0.0778 0.1583 -1.5801
+#> 0.1183 -0.4630 -0.5408 -0.7640
#> [ CPUFloatType{10,4} ]
#>
materialize(d, rbind = FALSE)
#> $lt1
#> $lt1[[1]]
#> torch_tensor
-#> -1.3420
-#> -0.7548
-#> -0.2027
+#> 0.2337
+#> 0.5821
+#> -0.4541
#> [ CPUFloatType{3} ]
#>
#> $lt1[[2]]
#> torch_tensor
-#> 0.8642
-#> -0.6203
-#> -1.1163
+#> 0.01 *
+#> -9.2944
+#> -3.2728
+#> -0.0712
#> [ CPUFloatType{3} ]
#>
#> $lt1[[3]]
#> torch_tensor
-#> -0.1052
-#> 0.5600
-#> -0.3974
+#> 0.7647
+#> 1.1990
+#> 1.1332
#> [ CPUFloatType{3} ]
#>
#> $lt1[[4]]
#> torch_tensor
-#> -0.3943
-#> -0.0705
-#> -1.7041
+#> -1.2069
+#> 0.9962
+#> -0.2510
#> [ CPUFloatType{3} ]
#>
#> $lt1[[5]]
#> torch_tensor
-#> 1.5337
-#> 0.1588
-#> -0.8624
+#> 0.4641
+#> -0.2855
+#> -0.6539
#> [ CPUFloatType{3} ]
#>
#> $lt1[[6]]
#> torch_tensor
-#> 1.4815
-#> -0.1104
-#> -0.3846
+#> -0.2911
+#> 0.3456
+#> -0.2388
#> [ CPUFloatType{3} ]
#>
#> $lt1[[7]]
#> torch_tensor
-#> 2.3301
-#> 1.1064
-#> 0.2867
+#> -1.2894
+#> 0.6156
+#> 1.3607
#> [ CPUFloatType{3} ]
#>
#> $lt1[[8]]
#> torch_tensor
-#> -1.1371
-#> -1.4983
-#> 0.6197
+#> -0.2922
+#> -2.0550
+#> 1.3359
#> [ CPUFloatType{3} ]
#>
#> $lt1[[9]]
#> torch_tensor
-#> 0.3968
-#> 1.8850
-#> 0.5676
+#> -0.3164
+#> 0.3312
+#> -1.1026
#> [ CPUFloatType{3} ]
#>
#> $lt1[[10]]
#> torch_tensor
-#> 1.4147
-#> 1.3346
-#> -1.3396
+#> -1.4172
+#> -0.1749
+#> 1.6354
#> [ CPUFloatType{3} ]
#>
#>
#> $lt2
#> $lt2[[1]]
#> torch_tensor
-#> -0.1006
-#> 1.3168
-#> 0.1388
-#> 0.5232
+#> -0.3209
+#> -0.9323
+#> -0.4262
+#> 0.4701
#> [ CPUFloatType{4} ]
#>
#> $lt2[[2]]
#> torch_tensor
-#> 0.2336
-#> -0.0717
-#> -1.0036
-#> 0.6683
+#> -0.6289
+#> -1.5827
+#> 1.4104
+#> -0.9415
#> [ CPUFloatType{4} ]
#>
#> $lt2[[3]]
#> torch_tensor
-#> 0.7921
-#> 1.4544
-#> 0.1958
-#> -0.2346
+#> -0.6591
+#> -1.3246
+#> -0.3064
+#> 0.8200
#> [ CPUFloatType{4} ]
#>
#> $lt2[[4]]
#> torch_tensor
-#> 0.6385
-#> -0.0318
-#> 0.6482
-#> -0.1380
+#> -1.0953
+#> -0.7571
+#> 2.1173
+#> -0.8233
#> [ CPUFloatType{4} ]
#>
#> $lt2[[5]]
#> torch_tensor
-#> 1.3797
-#> 0.7916
-#> -0.2373
-#> -1.0011
+#> 0.2794
+#> -1.1853
+#> -1.6339
+#> 0.4039
#> [ CPUFloatType{4} ]
#>
#> $lt2[[6]]
#> torch_tensor
-#> 0.3849
-#> -0.2163
-#> -0.0237
-#> -0.0863
+#> 1.4075
+#> 1.4175
+#> -0.1898
+#> -0.2829
#> [ CPUFloatType{4} ]
#>
#> $lt2[[7]]
#> torch_tensor
-#> 0.4032
-#> 0.1983
-#> -1.6133
-#> -2.2197
+#> -0.2797
+#> -0.2720
+#> 0.2012
+#> 0.6381
#> [ CPUFloatType{4} ]
#>
#> $lt2[[8]]
#> torch_tensor
-#> 0.2452
-#> 0.7073
-#> 0.3545
-#> 0.0833
+#> -0.1676
+#> 2.0197
+#> -0.6833
+#> 1.2002
#> [ CPUFloatType{4} ]
#>
#> $lt2[[9]]
#> torch_tensor
-#> -0.8829
-#> 0.2502
-#> -0.3007
-#> -0.1676
+#> 1.7206
+#> -0.0778
+#> 0.1583
+#> -1.5801
#> [ CPUFloatType{4} ]
#>
#> $lt2[[10]]
#> torch_tensor
-#> 1.9140
-#> -0.7256
-#> -0.7844
-#> -1.4341
+#> 0.1183
+#> -0.4630
+#> -0.5408
+#> -0.7640
#> [ CPUFloatType{4} ]
#>
#>
diff --git a/dev/reference/mlr_learners.torchvision.html b/dev/reference/mlr_learners.torchvision.html
index f81e4292..4717a642 100644
--- a/dev/reference/mlr_learners.torchvision.html
+++ b/dev/reference/mlr_learners.torchvision.html
@@ -133,9 +133,9 @@ ArgumentsUsage
optimizer = NULL,
loss = NULL,
callbacks = list(),
- packages = c("torchvision", "magick"),
+ packages = c("torchvision"),
man,
properties = NULL,
predict_types = NULL
diff --git a/dev/reference/mlr_learners_torch_model.html b/dev/reference/mlr_learners_torch_model.html
index 9dcd7ab2..710171ea 100644
--- a/dev/reference/mlr_learners_torch_model.html
+++ b/dev/reference/mlr_learners_torch_model.html
@@ -243,14 +243,14 @@ Exampleslearner$train(task, ids$train)
learner$predict(task, ids$test)
#> <PredictionClassif> for 50 observations:
-#> row_ids truth response
-#> 3 setosa setosa
-#> 9 setosa setosa
-#> 10 setosa setosa
-#> --- --- ---
-#> 147 virginica setosa
-#> 149 virginica setosa
-#> 150 virginica setosa
+#> row_ids truth response
+#> 3 setosa versicolor
+#> 9 setosa versicolor
+#> 10 setosa versicolor
+#> --- --- ---
+#> 147 virginica versicolor
+#> 149 virginica versicolor
+#> 150 virginica versicolor