Skip to content

Commit

Permalink
Feat: adaptive avg pool
Browse files Browse the repository at this point in the history
  • Loading branch information
cxzhang4 authored Oct 18, 2024
1 parent c6def4f commit a7cfd2a
Show file tree
Hide file tree
Showing 66 changed files with 943 additions and 9 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Collate:
'PipeOpTorch.R'
'PipeOpTaskPreprocTorch.R'
'PipeOpTorchActivation.R'
'PipeOpTorchAdaptiveAvgPool.R'
'PipeOpTorchAvgPool.R'
'PipeOpTorchBatchNorm.R'
'PipeOpTorchBlock.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ export(ModelDescriptor)
export(PipeOpModule)
export(PipeOpTaskPreprocTorch)
export(PipeOpTorch)
export(PipeOpTorchAdaptiveAvgPool1D)
export(PipeOpTorchAdaptiveAvgPool2D)
export(PipeOpTorchAdaptiveAvgPool3D)
export(PipeOpTorchAvgPool1D)
export(PipeOpTorchAvgPool2D)
export(PipeOpTorchAvgPool3D)
Expand Down
126 changes: 126 additions & 0 deletions R/PipeOpTorchAdaptiveAvgPool.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
PipeOpTorchAdaptiveAvgPool = R6Class("PipeOpTorchAdaptiveAvgPool",
inherit = PipeOpTorch,
public = list(
initialize = function(id, d, param_vals = list()) {
private$.d = assert_int(d, lower = 1, upper = 3)
module_generator = switch(d, nn_adaptive_avg_pool1d, nn_adaptive_avg_pool2d, nn_adaptive_avg_pool3d)
check_vector = make_check_vector(private$.d)
param_set = ps(
output_size = p_uty(custom_check = check_vector, tags = c("required", "train"))
)

super$initialize(
id = id,
param_set = param_set,
param_vals = param_vals,
module_generator = module_generator
)
}
),
private = list(
.shapes_out = function(shapes_in, param_vals, task) {
list(adaptive_avg_output_shape(
shape_in = shapes_in[[1]],
conv_dim = private$.d,
output_size = param_vals$output_size
))
},
.d = NULL
)
)

adaptive_avg_output_shape = function(shape_in, conv_dim, output_size) {
shape_in = assert_integerish(shape_in, min.len = conv_dim, coerce = TRUE)

if (length(output_size) == 1) output_size = rep(output_size, conv_dim)

shape_head = utils::head(shape_in, -conv_dim)
if (length(shape_head) <= 1) warningf("Input tensor does not have batch dimension.")

shape_tail = output_size

c(shape_head, shape_tail)
}

#' @title 1D Adaptive Average Pooling
#'
#' @templateVar id nn_adaptive_avg_pool1d
#' @template pipeop_torch_channels_default
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool1d description
#'
#' @section Parameters:
#' * `output_size` :: `integer(1)`\cr
#' The target output size. A single number.
#'
#' @section Internals:
#' Calls [`nn_adaptive_avg_pool1d()`][torch::nn_adaptive_avg_pool1d] during training.
#' @export
PipeOpTorchAdaptiveAvgPool1D = R6Class("PipeOpTorchAdaptiveAvgPool1D", inherit = PipeOpTorchAdaptiveAvgPool,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @template params_pipelines
initialize = function(id = "nn_adaptive_avg_pool1d", param_vals = list()) {
super$initialize(id = id, d = 1, param_vals = param_vals)
}
)
)

#' @title 2D Adaptive Average Pooling
#'
#' @templateVar id nn_adaptive_avg_pool2d
#' @template pipeop_torch_channels_default
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool2d description
#'
#' @section Parameters:
#' * `output_size` :: `integer()`\cr
#' The target output size. Can be a single number or a vector.
#'
#' @section Internals:
#' Calls [`nn_adaptive_avg_pool2d()`][torch::nn_adaptive_avg_pool2d] during training.
#' @export
PipeOpTorchAdaptiveAvgPool2D = R6Class("PipeOpTorchAdaptiveAvgPool2D", inherit = PipeOpTorchAdaptiveAvgPool,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @template params_pipelines
initialize = function(id = "nn_adaptive_avg_pool2d", param_vals = list()) {
super$initialize(id = id, d = 2, param_vals = param_vals)
}
)
)

#' @title 3D Adaptive Average Pooling
#'
#' @templateVar id nn_adaptive_avg_pool3d
#' @template pipeop_torch_channels_default
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool3d description
#'
#' @section Parameters:
#' * `output_size` :: `integer()`\cr
#' The target output size. Can be a single number or a vector.
#'
#' @section Internals:
#' Calls [`nn_adaptive_avg_pool3d()`][torch::nn_adaptive_avg_pool3d] during training.
#' @export
PipeOpTorchAdaptiveAvgPool3D = R6Class("PipeOpTorchAdaptiveAvgPool3D", inherit = PipeOpTorchAdaptiveAvgPool,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
#' @template params_pipelines
initialize = function(id = "nn_adaptive_avg_pool3d", param_vals = list()) {
super$initialize(id = id, d = 3, param_vals = param_vals)
}
)
)

#' @include zzz.R
register_po("nn_adaptive_avg_pool1d", PipeOpTorchAdaptiveAvgPool1D)
register_po("nn_adaptive_avg_pool2d", PipeOpTorchAdaptiveAvgPool2D)
register_po("nn_adaptive_avg_pool3d", PipeOpTorchAdaptiveAvgPool3D)
6 changes: 3 additions & 3 deletions R/PipeOpTorchAvgPool.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ avg_output_shape = function(shape_in, conv_dim, padding, stride, kernel_size, ce
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool1d description
#' @inherit torch::nnf_avg_pool1d description
#'
#' @section Parameters:
#' * `kernel_size` :: (`integer()`)\cr
Expand Down Expand Up @@ -104,7 +104,7 @@ PipeOpTorchAvgPool1D = R6Class("PipeOpTorchAvgPool1D", inherit = PipeOpTorchAvgP
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool2d description
#' @inherit torch::nnf_avg_pool2d description
#'
#' @inheritSection mlr_pipeops_nn_avg_pool1d Parameters
#'
Expand All @@ -128,7 +128,7 @@ PipeOpTorchAvgPool2D = R6Class("PipeOpTorchAvgPool2D", inherit = PipeOpTorchAvgP
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#' @inherit torch::nnf_adaptive_avg_pool3d description
#' @inherit torch::nnf_avg_pool3d description
#'
#' @inheritSection mlr_pipeops_nn_avg_pool1d Parameters
#'
Expand Down
176 changes: 176 additions & 0 deletions man/mlr_pipeops_nn_adaptive_avg_pool1d.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a7cfd2a

Please sign in to comment.