forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Linear.cpp
809 lines (744 loc) · 30.6 KB
/
Linear.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/Resize.h>
#include <ATen/native/xnnpack/Engine.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/TensorOperators.h>
#include <ATen/native/xnnpack/Engine.h>
#include <c10/util/irange.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/TensorSubclassLikeUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_trilinear.h>
#include <ATen/ops/_trilinear_native.h>
#include <ATen/ops/add.h>
#include <ATen/ops/addmm.h>
#include <ATen/ops/bilinear_native.h>
#include <ATen/ops/bmm.h>
#include <ATen/ops/einsum_native.h>
#include <ATen/ops/linear_native.h>
#include <ATen/ops/matmul.h>
#include <ATen/ops/mkldnn_linear.h>
#include <ATen/ops/mm.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/tensordot_native.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like_ops.h>
#endif
#include <cctype>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace at { namespace native {
// Parse environment variable "TORCH_LINEAR_FLATTEN_3D"
static inline bool parseLinearFlatten3d() {
// Uninitialized value
static int value = -1;
if (value == -1) {
const char* env_str = std::getenv("TORCH_LINEAR_FLATTEN_3D");
if (env_str != nullptr && strcmp(env_str, "1") == 0) {
value = 1;
} else {
value = 0;
}
}
return bool(value);
}
// `_flatten_3d_linear` flattens the first two dimensions of the input tensor
// before passing it to linear operation, if the input tensor is 3D
static inline Tensor _flatten_3d_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
const auto input_sizes = input.sym_sizes();
const auto result = at::addmm(bias, input.view_symint({input_sizes[0] * input_sizes[1], input_sizes[2]}), weight.t());
return result.view_symint({input_sizes[0], input_sizes[1], result.sym_size(1)});
}
Tensor linear(const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt) {
// See [Note: hacky wrapper removal for optional tensor]
auto bias = bias_opt.has_value()
? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
: c10::MaybeOwned<Tensor>::owned(c10::in_place);
if (input.is_mkldnn()) {
return at::mkldnn_linear(input, weight, *bias);
}
#if defined(C10_MOBILE)
if (xnnpack::use_linear(input, weight, *bias)) {
return xnnpack::linear(input, weight, *bias);
}
#endif
if (input.dim() == 2 && bias->defined()) {
// Fused op is marginally faster.
return at::addmm(*bias, input, weight.t());
}
if (input.dim() == 3 && bias->defined() && !input.is_xla()) {
// Also hit the fused path for contiguous 3D input, if not using xla
// backend. Reshaping/flattening has some performance implications on xla.
if (input.is_contiguous()) {
return _flatten_3d_linear(input, weight, *bias);
} else if (parseLinearFlatten3d()) {
// If user forces flattening via env var
const Tensor input_cont = input.contiguous();
return _flatten_3d_linear(input_cont, weight, *bias);
}
}
auto output = at::matmul(input, weight.t());
if (bias->defined()) {
// for composite compliance use out-of-place version of `add`
if (isTensorSubclassLike(*bias) ||
bias->_fw_grad(/*level*/ 0).defined()) {
output = at::add(output, *bias);
} else {
output.add_(*bias);
}
}
return output;
}
Tensor& linear_out(const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt, Tensor& output) {
TORCH_CHECK(!input.is_mkldnn(), "linear doesn't support out for MKLDNN tensors");
// See [Note: hacky wrapper removal for optional tensor]
auto bias = bias_opt.has_value()
? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
: c10::MaybeOwned<Tensor>::owned(c10::in_place);
if (input.dim() == 2 && bias->defined()) {
// Fused op is marginally faster.
return at::addmm_out(output, *bias, input, weight.t());
}
output = at::matmul_out(output, input, weight.t());
if (bias->defined()) {
output.add_(*bias);
}
return output;
}
// sumproduct_pair computes `(left*right).sum(sumdims)` by means of permutation and
// batch matrix multiplication
// its main purpose is to provide a pairwise reduction for einsum
static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArrayRef sum_dims_, bool keepdim) {
// assumes that tensors have been pre-unsqueezed (so that all dimensions match - after broadcasting)
// but makes no other assumptions on the order of dimensions
TORCH_CHECK(left_.dim()==right_.dim(), "number of dimensions must match");
if (sum_dims_.empty())
return at::mul(left_, right_);
int64_t dim = left_.dim();
auto sum_dims = at::dim_list_to_bitset(sum_dims_, dim);
// dimensions that will be part of the output (i.e. not summed over) in three vectors:
// dims in lro appear in left, right and output, similarly, lo: left and output, ro: right and output
// also the sizes are kept track of for reshaping
std::vector<int64_t> lro, lo, ro;
SymInt lro_size = 1, lo_size = 1, ro_size = 1, sum_size = 1;
Tensor left = left_;
Tensor right = right_;
for (const auto i : c10::irange(dim)) {
auto sl = left.sym_size(i)!=1;
auto sr = right.sym_size(i)!=1;
if (sum_dims[i]) { // first dimensions that will be summed over after multiplication
if (sl && sr) { // dimensions nontrivially in both left and right must be of the same size
TORCH_CHECK(left.sym_size(i)==right.sym_size(i), "non-broadcast dimensions must match");
sum_size *= left.sym_size(i);
} else if (sl) { // if it is only in one of left and right, we can sum right away
left = left.sum(i, true);
} else if (sr) {
right = right.sum(i, true);
}
} else if (sl && sr) { // now deal with dimensions that will be in the output
// dimensions nontrivially in both left and right must be of the same size
TORCH_CHECK(left.sym_size(i)==right.sym_size(i), "non-broadcast dimensions must match");
lro.push_back(i);
lro_size *= left.sym_size(i);
} else if (sl) { // keep track of dimensions appearing only once
lo.push_back(i);
lo_size *= left.sym_size(i);
} else {
ro.push_back(i);
ro_size *= right.sym_size(i);
}
}
// we now work with the following permutations / shapes.
// the pipeline is permute inputs -> reshape inputs -> batch matrix mul -> reshape(view) output -> permute output
// output: "lro, lo, 1-for-summed-dims, ro" with original shape dimensions
// left: "lro, lo, summed" permuted with lpermutation and the three flattened
// right: "lro, summed, ro" permuted with rpermutation and the three flattened
// then the permuted output is a view of bmm(left, right)
// finally, opermutation reverts the permutation to the original order of dimensions
auto out_num_dim = lro.size() + lo.size() + sum_dims_.size() + ro.size();
std::vector<SymInt> out_size;
out_size.reserve(out_num_dim);
for (auto& d : lro) out_size.push_back(left.sym_size(d));
for (auto& d : lo) out_size.push_back(left.sym_size(d));
for (auto& d : sum_dims_) { out_size.emplace_back(1); (void)(d); }; // avoid warning about not using d
for (auto& d : ro) out_size.push_back(right.sym_size(d));
std::vector<int64_t> lpermutation(lro);
lpermutation.insert(lpermutation.end(), lo.begin(), lo.end());
lpermutation.insert(lpermutation.end(), sum_dims_.begin(), sum_dims_.end());
lpermutation.insert(lpermutation.end(), ro.begin(), ro.end());
std::vector<int64_t> rpermutation(lro);
rpermutation.insert(rpermutation.end(), sum_dims_.begin(), sum_dims_.end());
rpermutation.insert(rpermutation.end(), ro.begin(), ro.end());
rpermutation.insert(rpermutation.end(), lo.begin(), lo.end());
std::vector<int64_t> opermutation(out_num_dim, -1);
{
int64_t i = 0;
for (auto it = lro.cbegin(); it != lro.cend(); i++, it++) {
opermutation[*it] = i;
}
for (auto it = lo.cbegin(); it != lo.cend(); i++, it++) {
opermutation[*it] = i;
}
for (auto it = sum_dims_.cbegin(); it != sum_dims_.cend(); i++, it++) {
opermutation[*it] = i;
}
for (auto it = ro.cbegin(); it != ro.cend(); i++, it++) {
opermutation[*it] = i;
}
}
// now we can execute the operations above
left = left.permute(lpermutation).reshape_symint({lro_size, std::move(lo_size), sum_size});
right = right.permute(rpermutation).reshape_symint({std::move(lro_size), std::move(sum_size), std::move(ro_size)});
Tensor result = at::bmm(left, right);
result = result.view_symint(out_size).permute(opermutation);
// finally squeeze summed dimensions if desired
if (! keepdim) {
auto sizes = result.sizes().vec();
for (auto i = dim-1; i>=0; i--) {
if (sum_dims[i]) {
sizes.erase(sizes.begin() + i);
}
}
result = result.view(sizes);
}
return result;
}
// There are roughly three parts to computing einsum:
// 1. Parse equation to extract the labels for each input operand and output
// 2. Unsqueeze missing dimensions from input operands and permute to align them
// 3. Compute result by multiplying input operands and summing contraction
// dimensions. We do the last part by reducing to bmm.
// If a path is specified, we reduce in the order specified by the path, else we
// default to going left => right. The path is a list of indices processed the same
// way as opt-einsum: https://optimized-einsum.readthedocs.io/en/stable/path_finding.html#format-of-the-path
Tensor einsum(c10::string_view equation, TensorList operands, at::OptionalIntArrayRef path) {
TORCH_CHECK(!operands.empty(), "einsum(): must provide at least one operand");
const auto num_ops = operands.size();
if (path.has_value()) {
const auto path_size = num_ops == 1 ? 1 : (num_ops - 1) * 2;
TORCH_CHECK(
path->size() == path_size,
"einsum(): expected contraction path given in path parameter to have size ",
path_size,
" but got ",
path->size());
}
// Labels must be in range [A-Za-z]
constexpr uint8_t NUM_OF_LETTERS = 'z' - 'a' + 1;
constexpr uint8_t TOTAL_LABELS = NUM_OF_LETTERS * 2;
// Code used to identify ELLIPSIS ("...")
constexpr uint8_t ELLIPSIS = TOTAL_LABELS;
// Convert label in [A-Za-z] to subscript in [0, TOTAL_LABELS)
auto label_to_subscript = [=](unsigned char label) -> uint8_t {
return std::isupper(label) ? label - 'A' : label - 'a' + NUM_OF_LETTERS;
};
// Convert subscript in [0, TOTAL_LABELS) to label in [A-Za-z]
auto subscript_to_label = [=](uint8_t s) -> unsigned char {
return s < NUM_OF_LETTERS ? s + 'A' : s + 'a' - NUM_OF_LETTERS;
};
// Find arrow (->) to split equation into lhs and rhs
const auto arrow_pos = equation.find("->");
const auto lhs = equation.substr(0, arrow_pos);
// Convert labels for input operands into an index in [0, 52) and store
// them in op_labels for each operand along with ELLIPSIS if present.
std::vector<std::vector<uint8_t>> op_labels(num_ops);
bool ell_in_input = false;
std::size_t curr_op = 0;
for (std::size_t i = 0; i < lhs.length(); ++i) {
const unsigned char label = lhs[i];
switch (label) {
case ' ':
// Ignore spaces
break;
case '.':
TORCH_CHECK(
// Only one ellipsis per operand can be given
!ell_in_input,
"einsum(): found \'.\' for operand ",
curr_op,
" for which an ellipsis was already found");
TORCH_CHECK(
// Ensure it's a valid ellipsis
i + 2 < lhs.length() && lhs[++i] == '.' && lhs[++i] == '.',
"einsum(): found \'.\' for operand ",
curr_op,
" that is not part of any ellipsis");
op_labels[curr_op].push_back(ELLIPSIS);
ell_in_input = true;
break;
case ',':
// Move onto next operand
++curr_op;
TORCH_CHECK(
curr_op < num_ops,
"einsum(): fewer operands were provided than specified in the equation");
ell_in_input = false;
break;
default:
// Parse label
TORCH_CHECK(
std::isalpha(label),
"einsum(): invalid subscript given at index ",
i,
" in the equation string, subscripts must be in [a-zA-Z]");
op_labels[curr_op].push_back(label_to_subscript(label));
}
}
TORCH_CHECK(
curr_op == num_ops - 1,
"einsum(): more operands were provided than specified in the equation");
std::vector<int64_t> label_count(TOTAL_LABELS, 0);
// The maximum number of dimensions covered by any ellipsis, needed when
// unsqueezing missing dimensions from operands to permute and broadcast
int64_t ell_num_dim = 0;
// Compute label frequency and number of dimensions covered by ellipsis
// We do this after parsing labels to make it more readable and simpler
// to compute the number of dimensions covered by ellipsis.
for(const auto i : c10::irange(num_ops)) {
const auto& operand = operands[i];
const auto labels = op_labels[i];
const auto ndims = operand.dim();
int64_t nlabels = static_cast<int64_t>(labels.size());
bool has_ellipsis = false;
for (const auto& label : labels) {
if (label == ELLIPSIS) {
--nlabels;
has_ellipsis = true;
ell_num_dim = std::max(ell_num_dim, ndims - nlabels);
} else {
++label_count[label];
}
}
TORCH_CHECK(
has_ellipsis ? nlabels <= ndims : nlabels == ndims,
"einsum(): the number of subscripts in the equation (",
nlabels,
has_ellipsis ? ") is more than the number of dimensions ("
: ") does not match the number of dimensions (",
ndims,
") for operand ",
i,
has_ellipsis ? "" : " and no ellipsis was given");
}
// We want to align the dimensions of every input tensor to have
// shape out_dims + sum_dims. For this, we create a mapping of label
// to index into the permuted shape.
std::vector<int64_t> label_perm_index(TOTAL_LABELS, -1);
// Current index in the permuted shape
int64_t perm_index = 0;
// Start index of ellipsis dimensions in the permuted shape
int64_t ell_index = 0;
bool ell_in_output = false;
if (arrow_pos == std::string::npos) {
// Implicit output is ellipsis (...) + labels seen only once
perm_index = ell_num_dim;
// ell_in_output is used to stop us from reducing ellipses dims later
ell_in_output = true;
for (const auto label : c10::irange(TOTAL_LABELS)) {
if (label_count[label] == 1) {
label_perm_index[label] = perm_index++;
}
}
} else {
// Parse explicit output
const auto rhs = equation.substr(arrow_pos + 2);
for (std::size_t i = 0; i < rhs.length(); ++i) {
const unsigned char label = rhs[i];
switch (label) {
case ' ':
// Ignore spaces
break;
case '.':
TORCH_CHECK(
// There can only be one ellipsis in the output
!ell_in_output,
"einsum(): found \'.\' for output but an ellipsis (...) was already found");
TORCH_CHECK(
// Ensure ellipsis is correct
i + 2 < rhs.length() && rhs[++i] == '.' && rhs[++i] == '.',
"einsum(): found \'.\' for output that is not part of any ellipsis (...)");
ell_index = perm_index;
perm_index += ell_num_dim;
ell_in_output = true;
break;
default:
TORCH_CHECK(
std::isalpha(label),
"einsum(): invalid subscript given at index ",
lhs.size() + 2 + i,
" in the equation string, subscripts must be in [a-zA-Z]");
const auto index = label_to_subscript(label);
TORCH_CHECK(
// Ensure label appeared at least once for some input operand and at
// most once for the output
label_count[index] > 0 && label_perm_index[index] == -1,
"einsum(): output subscript ",
label,
label_perm_index[index] > -1
? " appears more than once in the output"
: " does not appear in the equation for any input operand");
label_perm_index[index] = perm_index++;
}
}
}
// Save number of dimensions in output before adding contraction dims (dims to sum out)
const int64_t out_num_dim = perm_index;
// If ellipsis is not part of the output, add to contraction dimensions
if (!ell_in_output) {
ell_index = perm_index;
perm_index += ell_num_dim;
}
// Add contraction labels (labels not present in output)
for (const auto label : c10::irange(TOTAL_LABELS)) {
if (label_count[label] > 0 && label_perm_index[label] == -1) {
label_perm_index[label] = perm_index++;
}
}
// Next: we check the sizes, take diagonals for repeated labels, unsqueeze
// missing dimensions so all operands have the same dimensions and permute
// the operands to align the dimensions following the indices computed above.
// We also count how many operands have dimension with size != 1 for each
// label used to identify which dimensions can be contracted.
std::vector<SymInt> label_size(TOTAL_LABELS, 1);
std::vector<SymInt> ell_sizes(ell_num_dim, 1);
std::vector<uint64_t> dim_counts(perm_index, 0);
std::deque<Tensor> ops;
for (const auto i : irange(num_ops)) {
auto op = operands[i];
std::vector<int64_t> permutation(perm_index, -1);
std::int64_t dim = 0;
for (const auto s : op_labels[i]) {
if (s == ELLIPSIS) {
// Iterate over each dimension covered by ellipsis
const auto ndim = operands[i].ndimension() - (static_cast<int64_t>(op_labels[i].size()) - 1);
for (auto j = ell_num_dim - ndim; j < ell_num_dim; ++j) {
if (op.sym_size(dim) != 1) {
// Update ellipsis size
TORCH_CHECK(
ell_sizes[j] == 1 || ell_sizes[j] == op.sym_size(dim),
"einsum(): dimension ",
dim,
" covered by ellipsis in operand ",
i,
"has size ",
op.size(dim),
" which does not broadcast with previously seen ellipsis with size ",
ell_sizes[j],
" for the respective dimension");
ell_sizes[j] = op.sym_size(dim);
++dim_counts[ell_index + j];
}
permutation[ell_index + j] = dim++;
}
} else if (permutation[label_perm_index[s]] == -1) {
if (op.sym_size(dim) != 1) {
// Update subscript
TORCH_CHECK(
label_size[s] == 1 || label_size[s] == op.sym_size(dim),
"einsum(): subscript ",
subscript_to_label(s),
" has size ",
op.sym_size(dim),
" for operand ",
i,
" which does not broadcast with previously seen size ",
label_size[s]);
label_size[s] = op.sym_size(dim);
++dim_counts[label_perm_index[s]];
}
permutation[label_perm_index[s]] = dim++;
} else {
// Repeated label, take diagonal
const auto prev_dim = permutation[label_perm_index[s]];
TORCH_CHECK(
op.sym_size(dim) == op.sym_size(prev_dim),
"einsum(): subscript ",
subscript_to_label(s),
" is repeated for operand ",
i,
" but the sizes don't match, ",
op.sym_size(dim),
" != ",
op.sym_size(prev_dim));
op = op.diagonal(0, prev_dim, dim).movedim(-1, prev_dim);
}
}
// Add dimensions for missing labels
for (auto& val : permutation) {
if (val == -1) {
op = op.unsqueeze(dim);
val = dim++;
}
}
ops.emplace_back(op.permute(permutation));
}
const auto contract_path = path.value_or(std::vector<int64_t>{});
auto it = contract_path.begin();
// Contract
while (ops.size() > 1) {
int64_t i = 0;
int64_t j = 1;
if (path.has_value()) {
i = *it++;
j = *it++;
if (j < i) {
std::swap(i, j);
}
TORCH_CHECK(
i != j && i >= 0 && j < static_cast<int64_t>(ops.size()),
"einsum(): invalid contraction (",
i,
", ",
j,
i == j ? ") cannot contract an operand with itself"
: ") operand index is out of bounds");
}
auto a = ops[i];
auto b = ops[j];
ops.erase(ops.begin() + j);
ops.erase(ops.begin() + i);
// Collect dimensions that can be summed now
std::vector<int64_t> sum_dims;
SmallVector<int64_t, 5> a_dims_to_sum;
SmallVector<int64_t, 5> b_dims_to_sum;
for (auto dim = out_num_dim; dim < perm_index; ++dim) {
if (a.sym_size(dim) != 1 && b.sym_size(dim) != 1) {
if (--dim_counts[dim] == 1) {
sum_dims.push_back(dim);
dim_counts[dim] = 0;
}
} else if (dim_counts[dim] == 1) {
if (a.sym_size(dim) != 1) {
a_dims_to_sum.push_back(dim);
dim_counts[dim] = 0;
} else if (b.sym_size(dim) != 1) {
b_dims_to_sum.push_back(dim);
dim_counts[dim] = 0;
}
}
}
// Sum multiple dims at a time to minimize the number of kernel calls to sum
if (!a_dims_to_sum.empty()) {
a = a.sum(a_dims_to_sum, true);
}
if (!b_dims_to_sum.empty()) {
b = b.sum(b_dims_to_sum, true);
}
if (path.has_value()) {
ops.emplace_back(sumproduct_pair(a, b, sum_dims, true));
} else {
ops.emplace_front(sumproduct_pair(a, b, sum_dims, true));
}
}
// Sum out contraction dims
if (perm_index - out_num_dim > 0) {
// if there were ops to contract, we would have already done so
// in the previous loop and all the dims to sum are now 1
// NB: use view instead of squeeze (or sum) for faster (mps) performance
if (num_ops > 1) {
auto sizes = ops[0].sym_sizes().vec();
for (auto dim = perm_index - 1; dim >= out_num_dim; --dim) {
sizes.erase(sizes.begin() + dim);
}
return ops[0].view_symint(sizes);
} else {
std::vector<int64_t> sum_dims(perm_index - out_num_dim);
std::iota(sum_dims.begin(), sum_dims.end(), out_num_dim);
return ops[0].sum(sum_dims);
}
}
return ops[0];
}
// _trilinear computes a trilinear einstein sum with an unrolled dimension
// the result is `(i1.unsqueeze(expand1)*i2.unsqueeze(expand2)*i2.unsqueeze(expand3)).sum(sumdim)`
// the computation is unrolled in the unroll_dim dimension
// its main purpose is to unify the computations in bilinear and bilinear_backward
Tensor _trilinear(const Tensor& i1_, const Tensor& i2_, const Tensor& i3_,
IntArrayRef expand1_, IntArrayRef expand2_, IntArrayRef expand3_,
IntArrayRef sumdim_, int64_t unroll_dim) {
int64_t total_dim = i1_.dim()+expand1_.size();
TORCH_CHECK((unroll_dim >= 0) && (unroll_dim < total_dim), "unroll_dim must be in [0,", total_dim-1, "]");
auto expand1 = at::dim_list_to_bitset(expand1_, total_dim);
auto expand2 = at::dim_list_to_bitset(expand2_, total_dim);
auto expand3 = at::dim_list_to_bitset(expand3_, total_dim);
auto sumdim = at::dim_list_to_bitset(sumdim_, total_dim);
Tensor i1 = i1_;
Tensor i2 = i2_;
Tensor i3 = i3_;
std::vector<int64_t> output_size;
std::vector<int64_t> sum_dims_12, sum_dims_23;
int64_t unroll_size = -1;
// asserts...
for (const auto i : c10::irange(total_dim)) {
int64_t s = 0;
if (expand1[i]) {
i1 = i1.unsqueeze(i);
} else {
s = i1.size(i);
}
if (expand2[i]) {
i2 = i2.unsqueeze(i);
} else {
s = i2.size(i);
}
if (expand3[i]) {
i3 = i3.unsqueeze(i);
if (sumdim[i] && (i != unroll_dim))
sum_dims_12.push_back(i);
} else {
s = i3.size(i);
if (sumdim[i] && (i != unroll_dim))
sum_dims_23.push_back(i);
}
output_size.push_back(sumdim[i] ? 1 : s);
if (i == unroll_dim)
unroll_size = s;
}
int64_t slicemul1 = (expand1[unroll_dim] ? 0 : 1);
int64_t slicemul2 = (expand2[unroll_dim] ? 0 : 1);
int64_t slicemul3 = (expand3[unroll_dim] ? 0 : 1);
auto output = at::zeros(output_size, i1.options());
// Three conditionals are necessary since this function is meant to work for both
// forward and backward, which changes the dimensions of the inputs.
// Note that if output has zero elems is because (at least) one of i1, i2, i3 has zero elems.
if (i1.numel() != 0 && i2.numel() != 0 && i3.numel() != 0) {
if (! sumdim[unroll_dim]) {
for (const auto k : c10::irange(unroll_size)) {
Tensor buf = at::native::sumproduct_pair(i1.narrow(unroll_dim, k * slicemul1, 1),
i2.narrow(unroll_dim, k * slicemul2, 1),
sum_dims_12, true);
buf = at::native::sumproduct_pair(buf, i3.narrow(unroll_dim, k * slicemul3, 1), sum_dims_23, true);
output.narrow(unroll_dim, k, 1).add_(buf);
}
}
else {
for (const auto k : c10::irange(unroll_size)) {
Tensor buf = at::native::sumproduct_pair(i1.narrow(unroll_dim, k*slicemul1, 1),
i2.narrow(unroll_dim, k*slicemul2, 1), sum_dims_12, true);
buf = at::native::sumproduct_pair(buf, i3.narrow(unroll_dim, k*slicemul3, 1), sum_dims_23, true);
output.add_(buf);
}
}
}
for (int64_t i = output.dim()-1; i >= 0; i--)
if (sumdim[i])
output.squeeze_(i);
return output;
}
Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight, const c10::optional<Tensor>& bias_opt) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
TORCH_CHECK(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got ", input1.dim(), " and ", input2.dim());
for (const auto i : c10::irange(input1.dim() - 1)) {
TORCH_CHECK(input1.size(i) == input2.size(i),
"bilinear(): input batch dimensions do not match at dim ", i, ": got ", input1.size(i), " and ", input2.size(i));
}
TORCH_CHECK(input1.size(input1.dim() - 1) == weight.size(1),
"bilinear(): input1 size does not match weight size: got ",
input1.size(input1.dim() - 1), " but expected ", weight.size(1));
TORCH_CHECK(input2.size(input2.dim() - 1) == weight.size(2),
"bilinear(): input2 size does not match weight size: got ",
input2.size(input2.dim() - 1), " but expected ", weight.size(2));
TORCH_CHECK(!bias.defined() || bias.size(0) == weight.size(0),
"bilinear(): bias size does not match weight size: got ",
bias.size(0), " but expected ", weight.size(0));
std::vector<int64_t> output_size;
auto size1 = input1.sizes();
output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
output_size.push_back(weight.size(0));
auto input1_flattened = input1.reshape({-1, input1.size(-1)});
auto input2_flattened = input2.reshape({-1, input2.size(-1)});
Tensor output = at::_trilinear(input1_flattened, weight, input2_flattened, {1,3}, {0}, {1,2}, {2,3}).reshape(output_size);
if (bias.defined()) {
output = output + bias;
}
return output;
}
// implements tensordot, a matrix-multiplication-like contraction, but the dimensions given
// in the two dimension lists
Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2) {
TORCH_CHECK(dims1.size() == dims2.size(), "both dimension lists should have same length");
TORCH_CHECK(input1.scalar_type() == input2.scalar_type(), "both inputs should have same dtype");
SymInt csize = 1; // total size of the contracted dimensions
Tensor t1 = input1;
Tensor t2 = input2;
for (const auto i : c10::irange(dims1.size())) {
SymInt s1 = input1.sym_size(dims1[i]);
SymInt s2 = input2.sym_size(dims2[i]);
if (s2 == 1) { // broadcasted dimensions can be summed right away
t1 = t1.sum(dims1[i], true);
} else if (s1 == 1) {
t2 = t2.sum(dims2[i], true);
} else {
TORCH_CHECK(s1 == s2, "contracted dimensions need to match, but first has size ", s1, " in dim ", dims1[i],
" and second has size ", s2, " in dim ", dims2[i]);
csize *= s1;
}
}
auto cdims1 = at::dim_list_to_bitset(dims1, input1.dim());
auto cdims2 = at::dim_list_to_bitset(dims2, input2.dim());
std::vector<int64_t> p1, p2; // p1, p2: input permutations
std::vector<SymInt> rsizes; // rsizes: sizes of the result
p1.reserve(input1.dim());
p2.reserve(input2.dim());
rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
SymInt size1 = 1; // number of non-contracted elements in input1
SymInt size2 = 1; // number of non-contracted elements in input2
// fill the permutations and compute sizes
for (const auto i : c10::irange(input1.dim())) {
if (! cdims1[i]) {
p1.emplace_back(i);
size1 *= t1.sym_size(i);
rsizes.emplace_back(t1.sym_size(i));
}
}
for (const auto x : dims1) {
p1.emplace_back(x);
}
for (const auto x : dims2) {
p2.emplace_back(x);
}
for (const auto i : c10::irange(input2.dim())) {
if (! cdims2[i]) {
p2.emplace_back(i);
size2 *= t2.sym_size(i);
rsizes.emplace_back(t2.sym_size(i));
}
}
// permut and reshape for matrix multiplication
t1 = t1.permute(p1).reshape_symint({size1, csize});
t2 = t2.permute(p2).reshape_symint({csize, size2});
// multiply and reshape to target size
return at::mm(t1, t2).reshape_symint(rsizes);
}
Tensor &tensordot_out(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {
Tensor result_tmp = at::native::tensordot(input1, input2, dims1, dims2);
auto result_dtype = result_tmp.scalar_type();
auto output_tensor_dtype = result.scalar_type();
auto output_device = result.device();
auto input1_device = input1.device();
auto input2_device = input2.device();
// check if the input & output tensors are on the same device.
TORCH_CHECK(
(output_device == input1_device) && (input1_device == input2_device),
"tensordot: Expected the output and input tensors to be on the "
"same device, but got the output tensor on ", output_device,
", input tensor a on ", input1_device, ", and input tensor b on ", input2_device);
// check if the computed result has the same dtype as the out tensor
// (because tensordot does not support type promotion)
TORCH_CHECK(
result_dtype == output_tensor_dtype, "tensordot",
": Expected the output tensor to have dtype ", result_dtype,
", but got an output tensor with dtype ", output_tensor_dtype);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
return result;
}
}} // namespace at::native