forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Conv_miopen.cpp
1278 lines (1069 loc) · 47.3 KB
/
Conv_miopen.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#include <ATen/native/ConvUtils.h>
// TODO: Remove the condition on AT_ROCM_ENABLED entirely,
// don't build this file as part of CPU build.
#include <ATen/cuda/CUDAConfig.h>
#if !AT_ROCM_ENABLED()
namespace at { namespace native {
// See Note [ATen preprocessor philosophy]
at::Tensor miopen_convolution(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt /* optional */,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic) {
AT_ERROR("miopen_convolution: ATen not compiled with MIOpen support");
}
at::Tensor miopen_convolution_backward_input(
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
AT_ERROR("miopen_convolution_backward_input: ATen not compiled with MIOpen support");
}
at::Tensor miopen_convolution_backward_weight(
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
AT_ERROR("miopen_convolution_backward_weight: ATen not compiled with MIOpen support");
}
at::Tensor miopen_convolution_backward_bias(
const at::Tensor& grad_output) {
AT_ERROR("miopen_convolution_backward_bias: ATen not compiled with MIOpen support");
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
AT_ERROR("miopen_convolution_backward: ATen not compiled with MIOpen support");
}
at::Tensor miopen_convolution_transpose(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt /* optional */,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic) {
AT_ERROR("miopen_convolution_transpose: ATen not compiled with MIOpen support");
}
at::Tensor miopen_convolution_transpose_backward_input(
const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic) {
AT_ERROR("miopen_convolution_transpose_backward: ATen not compiled with MIOpen support");
}
at::Tensor miopen_convolution_transpose_backward_weight(
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
AT_ERROR("miopen_convolution_transpose_backward_weight: ATen not compiled with MIOpen support");
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_convolution_transpose_backward(
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
AT_ERROR("miopen_convolution_transpose_backward: ATen not compiled with MIOpen support");
}
at::Tensor miopen_depthwise_convolution(
const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt /* optional */,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic) {
AT_ERROR("miopen_depthwise_convolution: ATen not compiled with MIOpen support");
}
at::Tensor miopen_depthwise_convolution_backward_input(
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
AT_ERROR("miopen_depthwise_convolution_backward_input: ATen not compiled with MIOpen support");
}
at::Tensor miopen_depthwise_convolution_backward_weight(
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
AT_ERROR("miopen_depthwise_convolution_backward_weight: ATen not compiled with MIOpen support");
}
std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_depthwise_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
AT_ERROR("miopen_depthwise_convolution_backward: ATen not compiled with MIOpen support");
}
}}
#else // AT_ROCM_ENABLED
#include <ATen/miopen/miopen-wrapper.h>
#include <ATen/miopen/Descriptors.h>
#include <ATen/miopen/Types.h>
#include <ATen/miopen/Utils.h>
#include <ATen/hip/EmptyTensor.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/ConvUtils.h>
#include <c10/util/irange.h>
#include <c10/hip/HIPCachingAllocator.h>
#include <functional>
#include <iterator>
#include <sstream>
#include <algorithm>
#include <memory>
#include <mutex>
#include <stdint.h>
#include <unordered_map>
namespace at { namespace native {
Tensor narrowGroup(const Tensor& t, int dim, int group_idx, int64_t groups) {
auto group_size = t.size(dim) / groups;
return t.narrow(dim, group_idx * group_size, group_size);
}
// ---------------------------------------------------------------------
//
// Checking
//
// ---------------------------------------------------------------------
// Used on pad, stride and dilation
static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
{
TORCH_CHECK(args.size() <= expected_size,
"Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");
TORCH_CHECK(args.size() >= expected_size,
"Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
expected_size, " (while checking arguments for ", c, ")");
auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
if (num_negative_values > 0){
std::stringstream ss;
ss << arg_name << " should be greater than zero but got (";
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
AT_ERROR(ss.str());
}
}
// see NOTE [ Convolution checks] in src/Aten/native/cudnn/Conv.cpp
static void convolution_shape_check(
CheckedFrom c,
const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
{
check_args(c, padding, input->dim() - 2, "padding");
check_args(c, stride, padding.size(), "stride");
check_args(c, dilation, padding.size(), "dilation");
// Input
checkDimRange(c, input, 3, 6 /* exclusive */);
checkSize(c, input, input_channels_dim, weight->size(1) * groups);
// Weight
checkSameDim(c, input, weight);
checkSameDim(c, input, output);
}
// This POD struct is used to let us easily compute hashes of the
// parameters
struct ConvolutionParams
{
miopenHandle_t handle;
miopenDataType_t dataType;
int input_size[2 + max_dim];
int input_stride[2 + max_dim];
int weight_size[2 + max_dim];
int padding[max_dim];
int stride[max_dim];
int dilation[max_dim];
int64_t groups;
bool deterministic;
int device_id; //This is needed to distinguish between miopen handles of multiple gpus.
// NB: transposed purposely omitted: transposed just swaps
// forward and backward, so you can reuse the benchmark entry,
};
// ConvolutionParams must be a POD because we read out its memory
// contenst as char* when hashing
static_assert(std::is_pod<ConvolutionParams>::value, "ConvolutionParams not POD");
void setConvolutionParams(
ConvolutionParams* params, miopenHandle_t handle,
const at::Tensor& input, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool deterministic) {
miopenDataType_t dataType = getMiopenDataType(input);
memset(params, 0, sizeof(ConvolutionParams));
params->dataType = dataType;
params->handle = handle;
// ASSERT(weight.dim() == input.dim())
for (int i = 0; i != input.dim(); ++i) {
params->input_size[i] = (int) input.size(i);
params->input_stride[i] = (int) input.stride(i);
params->weight_size[i] = (int) weight.size(i);
}
// ASSERT(padding.size() == stride.size())
// ASSERT(padding.size() == dilation.size())
for (size_t i = 0; i != padding.size(); ++i) {
params->padding[i] = padding[i];
params->stride[i] = stride[i];
params->dilation[i] = dilation[i];
}
params->groups = groups;
params->deterministic = deterministic;
int device_id;
HIP_CHECK(hipGetDevice(&device_id));
params->device_id = device_id;
}
// Convenience struct for passing around descriptors and data
// pointers
struct ConvolutionArgs {
miopenHandle_t handle;
ConvolutionParams params;
TensorDescriptor idesc, odesc;
FilterDescriptor wdesc;
const Tensor& input, output, weight;
ConvolutionDescriptor cdesc;
ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) {
}
};
// ---------------------------------------------------------------------
//
// Benchmarking
//
// ---------------------------------------------------------------------
// Hashing machinery for ConvolutionParams
struct ParamsHash {
std::size_t operator()(const ConvolutionParams& params) const {
auto ptr = reinterpret_cast<const uint8_t*>(¶ms);
uint32_t value = 0x811C9DC5;
for (const auto i : c10::irange((int)sizeof(ConvolutionParams))) {
value ^= ptr[i];
value *= 0x01000193;
}
return (size_t)value;
}
};
struct ParamsEqual {
bool operator()(const ConvolutionParams& a, const ConvolutionParams& b) const {
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
return memcmp(ptr1, ptr2, sizeof(ConvolutionParams)) == 0;
}
};
template <typename T>
struct BenchmarkCache {
std::mutex mutex;
std::unordered_map<ConvolutionParams, T, ParamsHash, ParamsEqual> map;
bool find(const ConvolutionParams& params, T* results) {
std::lock_guard<std::mutex> guard(mutex);
auto it = map.find(params);
if (it == map.end()) {
return false;
}
*results = it->second;
return true;
}
void insert(const ConvolutionParams& params, const T& results) {
std::lock_guard<std::mutex> guard(mutex);
map[params] = results;
}
};
BenchmarkCache<miopenConvFwdAlgorithm_t> fwd_algos;
BenchmarkCache<miopenConvBwdDataAlgorithm_t> bwd_data_algos;
BenchmarkCache<miopenConvBwdWeightsAlgorithm_t> bwd_filter_algos;
BenchmarkCache<size_t> fwd_wssizes;
BenchmarkCache<size_t> bwd_data_wssizes;
BenchmarkCache<size_t> bwd_filter_wssizes;
struct Workspace {
Workspace(size_t size) : size(size), data(NULL) {
data = c10::hip::HIPCachingAllocator::raw_alloc(size);
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
c10::hip::HIPCachingAllocator::raw_delete(data);
}
}
size_t size;
void* data;
};
template<typename algo_t>
struct algorithm_search {
};
size_t getWorkspaceSize(
const ConvolutionArgs& args, const miopenConvFwdAlgorithm_t)
{
size_t sz = 0;
miopenConvolutionForwardGetWorkSpaceSize(
args.handle,
args.wdesc.desc(),
args.idesc.desc(),
args.cdesc.desc(),
args.odesc.desc(),
&sz);
return sz;
}
size_t getWorkspaceSize(
const ConvolutionArgs& args, const miopenConvBwdDataAlgorithm_t)
{
size_t sz = 0;
miopenConvolutionBackwardDataGetWorkSpaceSize(
args.handle,
args.odesc.desc(),
args.wdesc.desc(),
args.cdesc.desc(),
args.idesc.desc(),
&sz);
return sz;
}
size_t getWorkspaceSize(
const ConvolutionArgs& args, const miopenConvBwdWeightsAlgorithm_t)
{
size_t sz = 0;
miopenConvolutionBackwardWeightsGetWorkSpaceSize(
args.handle,
args.odesc.desc(),
args.idesc.desc(),
args.cdesc.desc(),
args.wdesc.desc(),
&sz);
return sz;
}
template<typename perf_t>
perf_t getBestAlgorithm(perf_t *perfResults, bool deterministic, int n_algo) {
return perfResults[0];
}
template<>
struct algorithm_search<miopenConvFwdAlgorithm_t> {
using perf_t = miopenConvAlgoPerf_t;
using algo_t = miopenConvFwdAlgorithm_t;
static constexpr auto DEFAULT_ALGO = miopenConvolutionFwdAlgoGEMM;
static BenchmarkCache<algo_t>& cache() { return fwd_algos; }
static BenchmarkCache<size_t>& wsscache() { return fwd_wssizes; }
static perf_t findAlgorithm(const ConvolutionArgs& args) {
int perf_count;
perf_t perf_results;
size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO);
Workspace ws(max_ws_size);
MIOPEN_CHECK(miopenFindConvolutionForwardAlgorithm(
args.handle,
args.idesc.desc(), args.input.data_ptr(),
args.wdesc.desc(), args.weight.data_ptr(),
args.cdesc.desc(),
args.odesc.desc(), args.output.data_ptr(),
1, // just return the fastest
&perf_count,
&perf_results,
ws.data,
ws.size,
false));
return perf_results;
}
};
template<>
struct algorithm_search<miopenConvBwdDataAlgorithm_t> {
using perf_t = miopenConvAlgoPerf_t;
using algo_t = miopenConvBwdDataAlgorithm_t;
static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdDataAlgoGEMM;
static BenchmarkCache<algo_t>& cache() { return bwd_data_algos; }
static BenchmarkCache<size_t>& wsscache() { return bwd_data_wssizes; }
static perf_t findAlgorithm(const ConvolutionArgs& args) {
int perf_count;
perf_t perf_results;
size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO);
Workspace ws(max_ws_size);
MIOPEN_CHECK(miopenFindConvolutionBackwardDataAlgorithm(
args.handle,
args.odesc.desc(), args.output.data_ptr(),
args.wdesc.desc(), args.weight.data_ptr(),
args.cdesc.desc(),
args.idesc.desc(), args.input.data_ptr(),
1, // just return the fastest
&perf_count,
&perf_results,
ws.data,
ws.size,
false));
return perf_results;
}
};
template<>
struct algorithm_search<miopenConvBwdWeightsAlgorithm_t> {
using perf_t = miopenConvAlgoPerf_t;
using algo_t = miopenConvBwdWeightsAlgorithm_t;
static constexpr auto DEFAULT_ALGO = miopenConvolutionBwdWeightsAlgoGEMM;
static BenchmarkCache<algo_t>& cache() { return bwd_filter_algos; }
static BenchmarkCache<size_t>& wsscache() { return bwd_filter_wssizes; }
static perf_t findAlgorithm(const ConvolutionArgs& args) {
int perf_count;
perf_t perf_results;
size_t max_ws_size = getWorkspaceSize(args, DEFAULT_ALGO);
Workspace ws(max_ws_size);
MIOPEN_CHECK(miopenFindConvolutionBackwardWeightsAlgorithm(
args.handle,
args.odesc.desc(), args.output.data_ptr(),
args.idesc.desc(), args.input.data_ptr(),
args.cdesc.desc(),
args.wdesc.desc(), args.weight.data_ptr(),
1, // just return the fastest
&perf_count,
&perf_results,
ws.data,
ws.size,
false));
return perf_results;
}
};
template<typename algo_t>
void findAlgorithm(const ConvolutionArgs& args, bool benchmark, algo_t* algo) {
using search = algorithm_search<algo_t>;
auto& cache = search::cache();
auto& wsscache = search::wsscache();
if (cache.find(args.params, algo)) {
return;
}
if (args.params.deterministic && !benchmark) {
*algo = search::DEFAULT_ALGO;
}
if (cache.find(args.params, algo)) {
// re-check cache since another thread may have benchmarked the algorithm
return;
}
auto perfResults = search::findAlgorithm(args);
*algo = reinterpret_cast<algo_t&>(perfResults);
cache.insert(args.params, *algo);
wsscache.insert(args.params, perfResults.memory);
c10::hip::HIPCachingAllocator::emptyCache();
}
template<typename algo_t>
Workspace chooseAlgorithm(
const ConvolutionArgs& args,
bool benchmark,
algo_t* algo)
{
findAlgorithm(args, benchmark, algo);
using search = algorithm_search<algo_t>;
size_t workspace_size;
search::wsscache().find(args.params, &workspace_size);
try {
return Workspace(workspace_size);
} catch (const std::exception& e) {
hipGetLastError(); // clear OOM error
// switch to default algorithm and record it in the cache to prevent
// further OOM errors
*algo = search::DEFAULT_ALGO;
workspace_size = getWorkspaceSize(args, *algo);
search::cache().insert(args.params, *algo);
search::wsscache().insert(args.params, workspace_size);
return Workspace(workspace_size);
}
}
// ---------------------------------------------------------------------
//
// Bias addition
//
// ---------------------------------------------------------------------
// In-place!
void miopen_convolution_add_bias_(CheckedFrom c, const TensorArg& output, const TensorArg& bias)
{
checkAllSameType(c, {output, bias});
checkAllSameGPU(c, {output, bias});
checkSize(c, bias, { output->size(output_channels_dim) });
TensorDescriptor bdesc, odesc;
auto memory_format = output->suggest_memory_format();
std::vector<int64_t> shape( output->dim(), 1);
shape[output_channels_dim] = -1;
at::Tensor bias_contig = bias->reshape(shape).contiguous(memory_format);
// Make sure that NC11 strides follow formula
bias_contig.resize_(bias_contig.sizes(), memory_format );
// TODO: Workaround since MIOpen does not support NHWC bias
// See #64426
output->add_( bias_contig );
/* MIOpen does not support NHWC bias; Activate once support is added.
bdesc.set( bias_contig );
odesc.set(*output);
auto handle = getMiopenHandle();
auto dataType = getMiopenDataType(*bias);
Constant one(dataType, 1);
Constant zero(dataType, 0);
MIOPEN_CHECK(miopenConvolutionForwardBias(handle, &one, bdesc.desc(), bias->data_ptr(),
&zero, odesc.desc(), output->data_ptr()));
*/
}
// see NOTE [ Convolution design ] in src/Aten/native/cudnn/Conv.cpp
// ---------------------------------------------------------------------
//
// Convolution forward / Transposed convolution backward
//
// ---------------------------------------------------------------------
// The raw API directly invokes MIOpen.
//
// There are a few reasons this should never be directly exposed
// via ATen:
//
// - It takes output as a parameter (this should be computed!)
// - It doesn't do input checking
// - It doesn't resize output (it is assumed to be correctly sized)
//
void raw_miopen_convolution_forward_out(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
auto dataType = getMiopenDataType(input);
miopenConvolutionMode_t c_mode = miopenConvolution;
ConvolutionArgs args{ input, output, weight };
args.handle = getMiopenHandle();
setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic);
args.idesc.set(input);
args.wdesc.set(weight, input.suggest_memory_format(), 0);
args.odesc.set(output);
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
miopenConvFwdAlgorithm_t fwdAlg;
Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlg);
Constant one(dataType, 1);
Constant zero(dataType, 0);
MIOPEN_CHECK(miopenConvolutionForward(
args.handle,
&one, args.idesc.desc(), input.data_ptr(),
args.wdesc.desc(), weight.data_ptr(),
args.cdesc.desc(), fwdAlg, &zero,
args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size));
}
Tensor miopen_convolution_forward(
CheckedFrom c,
const TensorArg& input, const TensorArg& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic)
{
checkAllSameType(c, {input, weight});
checkAllSameGPU(c, {input, weight});
auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*input, *weight)) {
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
}
Tensor output_t = at::detail::empty_cuda(
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation),
input->options().memory_format(memory_format));
if (output_t.numel() == 0) {
return output_t;
}
// Avoid ambiguity of "output" when this is being used as backwards
TensorArg output{ output_t, "result", 0 };
convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
// See #4500
Tensor weight_contig = weight->contiguous(memory_format);
// Make sure that NC11 strides follow formula
weight_contig.resize_(weight_contig.sizes(), memory_format);
Tensor input_contig = input->contiguous(memory_format);
input_contig.resize_(input_contig.sizes(), memory_format);
raw_miopen_convolution_forward_out(
*output, input_contig, weight_contig,
padding, stride, dilation, groups, benchmark, deterministic);
return *output;
}
Tensor miopen_convolution(
const Tensor& input_t, const Tensor& weight_t, const c10::optional<Tensor>& bias_t_opt,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic)
{
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt);
const Tensor& bias_t = *bias_t_maybe_owned;
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 },
bias { bias_t, "bias", 3 };
CheckedFrom c = "miopen_convolution";
auto output_t = miopen_convolution_forward(
c, input, weight, padding, stride, dilation, groups, benchmark, deterministic);
if (bias->defined()) {
miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias);
}
return output_t;
}
//Depthwise Convolutions
void raw_miopen_depthwise_convolution_forward_out(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
auto dataType = getMiopenDataType(input);
miopenConvolutionMode_t c_mode = miopenDepthwise;
ConvolutionArgs args{ input, output, weight };
args.handle = getMiopenHandle();
setConvolutionParams(&args.params, args.handle, input, weight, padding, stride, dilation, groups, deterministic);
args.idesc.set(input);
args.wdesc.set(weight, input.suggest_memory_format(), 0);
args.odesc.set(output);
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
miopenConvFwdAlgorithm_t fwdAlg;
Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlg);
Constant one(dataType, 1);
Constant zero(dataType, 0);
MIOPEN_CHECK(miopenConvolutionForward(
args.handle,
&one, args.idesc.desc(), input.data_ptr(),
args.wdesc.desc(), weight.data_ptr(),
args.cdesc.desc(), fwdAlg, &zero,
args.odesc.desc(), output.data_ptr(), workspace.data, workspace.size));
}
Tensor miopen_depthwise_convolution_forward(
CheckedFrom c,
const TensorArg& input, const TensorArg& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic)
{
checkAllSameType(c, {input, weight});
checkAllSameGPU(c, {input, weight});
auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*input, *weight)) {
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
}
Tensor output_t = at::detail::empty_cuda(
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation),
input->options().memory_format(memory_format));
TensorArg output{ output_t, "result", 0 };
convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
// See #4500
Tensor weight_contig = weight->contiguous(memory_format);
// Make sure that NC11 strides follow formula
weight_contig.resize_(weight_contig.sizes(), memory_format);
Tensor input_contig = input->contiguous(memory_format);
input_contig.resize_(input_contig.sizes(), memory_format);
raw_miopen_depthwise_convolution_forward_out(
*output, input_contig, weight_contig,
padding, stride, dilation, groups, benchmark, deterministic);
return *output;
}
Tensor miopen_depthwise_convolution(
const Tensor& input_t, const Tensor& weight_t, const c10::optional<Tensor>& bias_t_opt,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic)
{
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt);
const Tensor& bias_t = *bias_t_maybe_owned;
TensorArg input { input_t, "input", 1 },
weight { weight_t, "weight", 2 },
bias { bias_t, "bias", 3 };
CheckedFrom c = "miopen_depthwise_convolution";
auto output_t = miopen_depthwise_convolution_forward(
c, input, weight, padding, stride, dilation, groups, benchmark, deterministic);
if (bias->defined()) {
miopen_convolution_add_bias_(c, { output_t, "result", 0 }, bias);
}
return output_t;
}
// ---------------------------------------------------------------------
//
// Convolution backward (bias)
//
// ---------------------------------------------------------------------
Tensor miopen_convolution_backward_bias(
const Tensor& grad_output_t)
{
TensorArg grad_output{ grad_output_t, "grad_output", 1 };
// TODO: Workaround since MIOpen does not support NHWC bias
// See #64426
std::vector<int64_t> discard_dims;
for( int i = 0; i < grad_output_t.dim(); i++ ) {
if(i != output_channels_dim ) {
discard_dims.push_back(i);
}
}
Tensor outputBias = at::squeeze( at::sum(grad_output_t, discard_dims, true) );
if( outputBias.dim() == 0 ) {
// always return a tensor of shape [_]
return outputBias.unsqueeze(0);
}
else {
return outputBias;
}
/* MIOpen does not support NHWC bias. Activate once support is added.
auto grad_bias_t = at::empty( { grad_output->size(output_channels_dim) }, grad_output->options());
TensorArg grad_bias{ grad_bias_t, "result", 0 };
TensorDescriptor bdesc{grad_bias->expand({1, grad_bias->size(0)}),
static_cast<size_t>(grad_output->dim())};
TensorDescriptor odesc{*grad_output};
auto handle = getMiopenHandle();
auto dataType = getMiopenDataType(*grad_bias);
Constant one(dataType, 1);
Constant zero(dataType, 0);
MIOPEN_CHECK(miopenConvolutionBackwardBias(handle, &one, odesc.desc(), grad_output->data_ptr(),
&zero, bdesc.desc(), grad_bias->data_ptr()));
return *grad_bias;
*/
}
// ---------------------------------------------------------------------
//
// Convolution backward (weight)
//
// ---------------------------------------------------------------------
void raw_miopen_convolution_backward_weight_out(
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
auto dataType = getMiopenDataType(input);
miopenConvolutionMode_t c_mode = miopenConvolution;
ConvolutionArgs args{ input, grad_output, grad_weight };
args.handle = getMiopenHandle();
setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic);
args.idesc.set(input);
args.wdesc.set(grad_weight, input.suggest_memory_format(), 0);
args.odesc.set(grad_output);
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
miopenConvBwdWeightsAlgorithm_t bwdFilterAlg;
Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg);
Constant one(dataType, 1);
Constant zero(dataType, 0);
MIOPEN_CHECK(miopenConvolutionBackwardWeights(
args.handle,
&one, args.odesc.desc(), grad_output.data_ptr(),
args.idesc.desc(), input.data_ptr(),
args.cdesc.desc(), bwdFilterAlg, &zero,
args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size));
}
//Depthwise backward weights.
void raw_miopen_depthwise_convolution_backward_weight_out(
const Tensor& grad_weight, const Tensor& grad_output, const Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic) {
auto dataType = getMiopenDataType(input);
miopenConvolutionMode_t c_mode = miopenDepthwise;
ConvolutionArgs args{ input, grad_output, grad_weight };
args.handle = getMiopenHandle();
setConvolutionParams(&args.params, args.handle, input, grad_weight, padding, stride, dilation, groups, deterministic);
args.idesc.set(input);
args.wdesc.set(grad_weight, input.suggest_memory_format(), 0);
args.odesc.set(grad_output);
args.cdesc.set(dataType, c_mode, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
miopenConvBwdWeightsAlgorithm_t bwdFilterAlg;
Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlg);
Constant one(dataType, 1);
Constant zero(dataType, 0);
MIOPEN_CHECK(miopenConvolutionBackwardWeights(
args.handle,
&one, args.odesc.desc(), grad_output.data_ptr(),
args.idesc.desc(), input.data_ptr(),
args.cdesc.desc(), bwdFilterAlg, &zero,
args.wdesc.desc(), grad_weight.data_ptr(), workspace.data, workspace.size));
}
Tensor miopen_depthwise_convolution_backward_weight(
CheckedFrom c,
IntArrayRef weight_size, const TensorArg& grad_output, const TensorArg& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic)
{
checkAllSameType(c, {grad_output, input});
checkAllSameGPU(c, {grad_output, input});
auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*input, *grad_output)) {
memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
}
Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
// Make sure that NC11 strides follow formula
grad_output_contig_t.resize_(grad_output_contig_t.sizes(), memory_format);
TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 };
Tensor input_contig_t = input->contiguous(memory_format);
input_contig_t.resize_(input_contig_t.sizes(), memory_format);
TensorArg input_contig{ input_contig_t, "input", 2};
auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), memory_format);
// For uniformity with everything else, although it seems grad_weight
// would be unambiguous too.
TensorArg grad_weight{ grad_weight_t, "result", 0 };
convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups);
raw_miopen_depthwise_convolution_backward_weight_out(
*grad_weight, *grad_output_contig, *input_contig,
padding, stride, dilation, groups, benchmark, deterministic);
return grad_weight_t;
}
Tensor miopen_depthwise_convolution_backward_weight(
IntArrayRef weight_size,
const Tensor& grad_output_t,
const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic)
{
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
input{ input_t, "input", 2 };
return miopen_depthwise_convolution_backward_weight(
"miopen_depthwise_convolution_backward_weight",
weight_size, grad_output, input,
padding, stride, dilation, groups, benchmark, deterministic);
}
Tensor miopen_convolution_backward_weight(
CheckedFrom c,
IntArrayRef weight_size, const TensorArg& grad_output, const TensorArg& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic)
{
checkAllSameType(c, {grad_output, input});
checkAllSameGPU(c, {grad_output, input});
auto memory_format = at::MemoryFormat::Contiguous;
if (miopen_conv_use_channels_last(*input, *grad_output)) {
memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
}
Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
// Make sure that NC11 strides follow formula
grad_output_contig_t.resize_(grad_output_contig_t.sizes(), memory_format);
TensorArg grad_output_contig{ grad_output_contig_t, "grad_output", 1 };
Tensor input_contig_t = input->contiguous(memory_format);
input_contig_t.resize_(input_contig_t.sizes(), memory_format);
TensorArg input_contig{ input_contig_t, "input", 2};
auto grad_weight_t = at::empty(weight_size, grad_output_contig->options(), memory_format);
// For uniformity with everything else, although it seems grad_weight
// would be unambiguous too.
TensorArg grad_weight{ grad_weight_t, "result", 0 };
convolution_shape_check(c, input, grad_weight, grad_output_contig, padding, stride, dilation, groups);
raw_miopen_convolution_backward_weight_out(
*grad_weight, *grad_output_contig, *input_contig,
padding, stride, dilation, groups, benchmark, deterministic);
return grad_weight_t;
}
Tensor miopen_convolution_backward_weight(
IntArrayRef weight_size,
const Tensor& grad_output_t,
const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic)
{
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
input{ input_t, "input", 2 };
return miopen_convolution_backward_weight(
"miopen_convolution_backward_weight",
weight_size, grad_output, input,
padding, stride, dilation, groups, benchmark, deterministic);
}
Tensor miopen_convolution_transpose_backward_input(
const Tensor& grad_output_t, const Tensor& weight_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic)
{
TensorArg grad_output { grad_output_t, "grad_output", 1 },
weight { weight_t, "weight", 2 };
return miopen_convolution_forward(
"miopen_convolution_transpose_backward_input",
grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
}
Tensor miopen_convolution_transpose_backward_weight(
IntArrayRef weight_size,
const Tensor& grad_output_t,
const Tensor& input_t,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic)
{
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
input{ input_t, "input", 2 };
return miopen_convolution_backward_weight(