forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FractionalMaxPool3d.cpp
431 lines (379 loc) · 13.5 KB
/
FractionalMaxPool3d.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <c10/util/irange.h>
#include <tuple>
#include <vector>
namespace at {
namespace meta {
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
const at::Tensor& input_,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples
) {
TORCH_CHECK(
pool_size.size() == 3,
"fractional_max_pool3d: kernel_size must either be a single Int or tuple of three Ints")
TORCH_CHECK(
output_size.size() == 3,
"fractional_max_pool3d: output_size must either be a single Int or tuple of three Ints")
int64_t outputT = output_size[0];
int64_t outputH = output_size[1];
int64_t outputW = output_size[2];
int64_t poolSizeT = pool_size[0];
int64_t poolSizeH = pool_size[1];
int64_t poolSizeW = pool_size[2];
int64_t numBatch = 1;
int64_t planeDim = 0;
int64_t timeDim = 1;
int64_t heightDim = 2;
int64_t widthDim = 3;
int64_t ndims = input_.ndimension();
TORCH_CHECK(ndims == 4 || ndims == 5,
"fractional_max_pool3d_out(): Expected 4D or 5D tensor, but got: ",
input_.sizes());
for (const auto i : c10::irange(1, ndims)) {
TORCH_CHECK(input_.size(i) > 0,
"fractional_max_pool3d_out(): Expected input to have non-zero size for non-batch dimensions, but got",
input_.sizes(), " with dimension ", i, " being empty.");
}
if (ndims == 5) {
numBatch = input_.size(0);
planeDim++;
timeDim++;
heightDim++;
widthDim++;
}
/* sizes */
int64_t numPlanes = input_.size(planeDim);
int64_t inputT = input_.size(timeDim);
int64_t inputH = input_.size(heightDim);
int64_t inputW = input_.size(widthDim);
TORCH_CHECK(outputT + poolSizeT - 1 < inputT,
"fractional_max_pool3d_out(): pool time ", poolSizeT,
" too large relative to input time ", inputT);
TORCH_CHECK(outputW + poolSizeW - 1 < inputW,
"fractional_max_pool3d_out(): pool width ", poolSizeW,
" too large relative to input width ", inputW);
TORCH_CHECK(outputH + poolSizeH - 1 < inputH,
"fractional_max_pool3d_out(): pool height ", poolSizeH,
" too large relative to input height ", inputH);
if (ndims == 4) {
/* resize output */
set_output(0, {numPlanes, outputT, outputH, outputW}, input_.options());
/* indices will contain the locations for each output point */
set_output(1, {numPlanes, outputT, outputH, outputW}, input_.options().dtype(kLong));
} else {
set_output(0, {numBatch, numPlanes, outputT, outputH, outputW}, input_.options());
/* indices will contain the locations for each output point */
set_output(1, {numBatch, numPlanes, outputT, outputH, outputW}, input_.options().dtype(kLong));
}
return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
.set_poolSizeT(poolSizeT).set_poolSizeH(poolSizeH).set_poolSizeW(poolSizeW)
.set_outputT(outputT).set_outputH(outputH).set_outputW(outputW);
}
} // namespace meta
namespace native {
namespace {
template<typename scalar_t>
static std::vector<int> generate_intervals(
scalar_t sample,
int64_t inputSize,
int64_t outputSize,
int64_t poolSize) {
std::vector<int> sequence(outputSize);
if (outputSize > 1) {
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
static_cast<scalar_t>(outputSize - 1);
for (const auto i : c10::irange(outputSize - 1)) {
sequence[i] =
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
}
}
if (outputSize > 0) {
sequence[outputSize - 1] = inputSize - poolSize;
}
return sequence;
}
template<typename scalar_t>
static void fractional_max_pool3d_out_single_batch_frame(
scalar_t* input,
scalar_t* output,
int64_t* indices,
scalar_t* randomSamples,
int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW,
int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
for (const auto plane : c10::irange(start, end)) {
/* each plane contains 3 random samples,
one for T, one for W, and one for H */
scalar_t* randomSamplesForPlane = randomSamples + plane * 3;
/* Generate interval sequence */
auto sequenceT = generate_intervals<scalar_t>(
randomSamplesForPlane[0], inputT, outputT, poolSizeT);
auto sequenceH = generate_intervals<scalar_t>(
randomSamplesForPlane[1], inputH, outputH, poolSizeH);
auto sequenceW = generate_intervals<scalar_t>(
randomSamplesForPlane[2], inputW, outputW, poolSizeW);
/* loop over output */
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t t, h, w;
scalar_t* inputForPlane = input + plane * inputT * inputH * inputW;
scalar_t* outputForPlane = output + plane * outputT * outputH * outputW;
int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
for (t = 0; t < outputT; ++t) {
int64_t inputTStart = sequenceT[t];
for (h = 0; h < outputH; ++h) {
int64_t inputHStart = sequenceH[h];
for (w = 0; w < outputW; ++w) {
int64_t inputWStart = sequenceW[w];
int64_t t2 = inputTStart, h2 = inputHStart, w2 = inputWStart;
scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
int64_t maxIndex = t2 * inputH * inputW + h2 * inputW + w2;
for (t2 = inputTStart; t2 < inputTStart + poolSizeT; ++t2) {
for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
AT_ASSERT(t2 >= 0 && t2 < inputT);
AT_ASSERT(h2 >= 0 && h2 < inputH);
AT_ASSERT(w2 >= 0 && w2 < inputW);
int64_t planeIndex = t2 * inputH * inputW + h2 * inputW + w2;
scalar_t val = inputForPlane[planeIndex];
if (val > maxVal || std::isnan(val)) {
maxVal = val;
maxIndex = planeIndex;
}
}
}
}
outputForPlane[t * outputH * outputW + h * outputW + w] = maxVal;
indicesForPlane[t * outputH * outputW + h * outputW + w] = maxIndex;
}
}
}
}
});
}
template<typename scalar_t>
static void fractional_max_pool3d_out_frame(
scalar_t* input,
scalar_t* output,
int64_t* indices,
scalar_t* randomSamples,
int64_t numBatch, int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW,
int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
if(numBatch == 1) {
fractional_max_pool3d_out_single_batch_frame<scalar_t>(
input, output, indices, randomSamples,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW,
poolSizeT, poolSizeH, poolSizeW
);
return;
}
at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
for (const auto batch : c10::irange(start, end)) {
fractional_max_pool3d_out_single_batch_frame<scalar_t>(
input + batch * numPlanes * inputW * inputH * inputT,
output + batch * numPlanes * outputW * outputH * outputT,
indices + batch * numPlanes * outputW * outputH * outputT,
randomSamples + batch * numPlanes * 3,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW,
poolSizeT, poolSizeH, poolSizeW
);
}
});
}
} // anonymous namespace
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
const at::Tensor& input_,
int64_t poolSizeT,
int64_t poolSizeH,
int64_t poolSizeW,
int64_t outputT,
int64_t outputH,
int64_t outputW,
const at::Tensor& randomSamples,
int64_t numBatch,
int64_t numPlanes,
int64_t inputT,
int64_t inputH,
int64_t inputW,
const at::Tensor& output,
const at::Tensor& indices) {
/* get contiguous input */
auto input = input_.contiguous();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"fractional_max_pool3d_out_frame",
[&] {
fractional_max_pool3d_out_frame<scalar_t>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
indices.data_ptr<int64_t>(),
randomSamples.data_ptr<scalar_t>(),
numBatch, numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW,
poolSizeT, poolSizeH, poolSizeW
);
}
);
}
namespace {
template<typename scalar_t>
static void fractional_max_pool3d_backward_out_single_batch_frame(
scalar_t* gradInput,
scalar_t* gradOutput,
int64_t* indices,
int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW) {
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
for (const auto plane : c10::irange(start, end)) {
scalar_t* gradInputForPlane = gradInput + plane * inputT * inputH * inputW;
scalar_t* gradOutputForPlane = gradOutput +
plane * outputT * outputH * outputW;
int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t h, w, t;
for (t = 0; t < outputT; ++t) {
for (h = 0; h < outputH; ++h) {
for (w = 0; w < outputW; ++w) {
int64_t outputIndex = t * outputH * outputW + h * outputW + w;
int64_t index = indicesForPlane[outputIndex];
AT_ASSERT(index >= 0 && index < inputT * inputH * inputW);
gradInputForPlane[index] += gradOutputForPlane[outputIndex];
}
}
}
}
});
}
template<typename scalar_t>
static void fractional_max_pool3d_backward_out_frame(
scalar_t* gradInput,
scalar_t* gradOutput,
int64_t* indices,
int64_t numBatch, int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW) {
if(numBatch == 1) {
fractional_max_pool3d_backward_out_single_batch_frame<scalar_t>(
gradInput, gradOutput, indices,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW
);
return;
}
at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
for (const auto batch : c10::irange(start, end)) {
fractional_max_pool3d_backward_out_single_batch_frame<scalar_t>(
gradInput + batch * numPlanes * inputW * inputH * inputT,
gradOutput + batch * numPlanes * outputW * outputH * outputT,
indices + batch * numPlanes * outputW * outputH * outputT,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW
);
}
});
}
void fractional_max_pool3d_backward_out_cpu_template(
const Tensor& input,
const Tensor& gradOutput_,
Tensor& gradInput,
IntArrayRef output_size,
IntArrayRef pool_size /* unused */,
const Tensor& indices) {
int64_t outputT = output_size[0];
int64_t outputH = output_size[1];
int64_t outputW = output_size[2];
int64_t numBatch = 1;
int64_t planeDim = 0;
int64_t timeDim = 1;
int64_t heightDim = 2;
int64_t widthDim = 3;
int64_t ndims = input.ndimension();
if (ndims == 5) {
numBatch = input.size(0);
planeDim = 1;
heightDim++;
widthDim++;
timeDim++;
}
/* sizes */
int64_t numPlanes = input.size(planeDim);
int64_t inputT = input.size(timeDim);
int64_t inputH = input.size(heightDim);
int64_t inputW = input.size(widthDim);
TORCH_CHECK(outputT == gradOutput_.size(timeDim),
"fractional_max_pool3d_backward_out(): gradOutput time unexpected");
TORCH_CHECK(outputH == gradOutput_.size(heightDim),
"fractional_max_pool3d_backward_out(): ",
"gradOutput height unexpected");
TORCH_CHECK(outputW == gradOutput_.size(widthDim),
"fractional_max_pool3d_backward_out(): gradOutput width unexpected");
/* get contiguous gradOutput */
auto gradOutput = gradOutput_.contiguous();
/* resize */
gradInput.resize_as_(input);
gradInput.zero_();
/* backprop */
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"fractional_max_pool3d_backward_out_frame",
[&]{
fractional_max_pool3d_backward_out_frame<scalar_t>(
gradInput.data_ptr<scalar_t>(),
gradOutput.data_ptr<scalar_t>(),
indices.data_ptr<int64_t>(),
numBatch, numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW
);
}
);
}
}// anonymous namespace
Tensor& fractional_max_pool3d_backward_out_cpu(const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices,
at::Tensor& gradInput) {
fractional_max_pool3d_backward_out_cpu_template(
input,
gradOutput_,
gradInput,
output_size,
pool_size,
indices);
return gradInput;
}
Tensor fractional_max_pool3d_backward_cpu(
const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices) {
Tensor gradInput = at::empty({0}, input.options());
fractional_max_pool3d_backward_out_cpu_template(
input,
gradOutput_,
gradInput,
output_size,
pool_size,
indices);
return gradInput;
}
}// native
}// at