-
Notifications
You must be signed in to change notification settings - Fork 513
/
Copy pathtensor.h
1361 lines (1022 loc) · 59.5 KB
/
tensor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#pragma once
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_client/async_task.h"
#include "tensorflow/compiler/xla/xla_client/cache.h"
#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "tensorflow/compiler/xla/xla_client/multi_wait.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch/csrc/autograd/variable.h"
#include "torch_xla/csrc/computation.h"
#include "torch_xla/csrc/cross_replica_reduces.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/view.h"
namespace torch_xla {
class XLATensor {
class DeviceContextArena;
struct Data;
public:
static XLATensor Create(const at::Tensor& tensor, const Device& device);
static XLATensor Create(
xla::ComputationClient::DataPtr xla_data,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor Create(
ir::Value ir_value, const Device& device,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
// Creates an empty/null tensor.
XLATensor() = default;
bool is_null() const { return data_ptr() == nullptr; }
size_t generation() const { return data()->generation; }
XLATensor alias() const { return XLATensor(data_ptr()); }
xla::int64 size(xla::int64 dim) const;
at::Tensor ToTensor(bool detached);
void ShallowCopyTo(XLATensor* dest) const;
// Assigns the tensor value to the XLA tensor.
void SetTensor(at::Tensor tensor);
void UpdateFromTensor(at::Tensor tensor, bool sync);
void UpdateFromTensorOut(at::Tensor tensor);
void UpdateFromTensorOut(const XLATensor& tensor);
at::ScalarType dtype() const;
c10::optional<at::ScalarType> dtype_optional() const;
// Set logical_element_type which is visible to upstream PyTorch.
void SetScalarType(c10::optional<at::ScalarType> logical_element_type);
xla::util::MaybeRef<xla::Shape> shape() const;
xla::Shape shape_with_layout() const;
const Device& GetDevice() const;
xla::int64 GetUniqueId() const;
// Retrieves an opaque ID of the alias object upon which the tensor's view is
// rooted, or 0 if this tensor is not a view.
std::ptrdiff_t GetViewAliasId() const;
// Fetches the XLA data behind the tensor. If the tensor has a graph defining
// its current value, executes the graph and fetches the XLA data result.
xla::ComputationClient::DataPtr GetXlaData();
// Fetches the current value of the XLA data, which can be missing (nullptr)
// in case the tensor has a graph defining its current value,
xla::ComputationClient::DataPtr CurrentXlaData() const;
void SetXlaData(xla::ComputationClient::DataPtr xla_data);
// Retrieves the current IR Node, or nullptr in case no active IR Node is
// available.
ir::Value CurrentIrValue() const;
// Retrieves the IR Node representing this XLATensor. One will be created if
// missing. Note that although this is a const API, it actually changes the
// internal state ofthe object.
ir::Value GetIrValue() const;
c10::optional<at::Tensor> CurrentTensorData() const;
// Applies the queue of operations in preparation for using the data.
void ApplyPendingGraph();
static ir::Value GetDeviceDataIrValue(const at::Scalar& value,
xla::PrimitiveType type,
const Device& device);
static ir::Value GetIrValueForScalar(const at::Scalar& value,
xla::PrimitiveType type,
const Device& device);
static ir::Value GetIrValueForScalar(const at::Scalar& value,
const Device& device);
static ir::Value GetIrValueForScalar(const at::Scalar& value,
xla::PrimitiveType type,
absl::Span<const xla::int64> dimensions,
const Device& device);
static ir::Value GetIrValueForScalar(const at::Scalar& value,
const xla::Shape& shape,
const Device& device);
static ir::Value GetIrValueForScalar(
const at::Scalar& value, const xla::Shape& shape,
c10::optional<at::ScalarType> logical_element_type, const Device& device);
static ir::Value GetRngSeed(const Device& device);
static void SetRngSeed(const Device& device, xla::uint64 seed);
static xla::uint64 GetRunningSeed(const Device& device);
// Dispatches a comparison operator, setting the logical type of the result
// appropriately.
static XLATensor DispatchComparisonOp(c10::Symbol kind,
const XLATensor& input,
const at::Scalar& other);
// Same as above, with the second input a tensor as well.
static XLATensor DispatchComparisonOp(c10::Symbol kind,
const XLATensor& input,
const XLATensor& other);
// Dumps the XLA HLO text of the computation accumulated in the graph which is
// attached the tensors.
static std::string DumpHloComputation(const std::vector<XLATensor>& tensors);
// Retrieves the set of XLA tensors which are currently live in the system,
// for the given device. If device is nullptr, the live tensors for all
// devices will be returned. Returned tensors are sorted by device as primary
// key, and by unique ID as secondary key.
static std::vector<XLATensor> GetLiveTensors(const Device* device);
// Applies all the pending IR operations queued over the input tensors. All
// the tensors must be on the same device. If wait is true, the sync operation
// will be run synchronously. The devices argument, if not empty, tells the
// devices which should be partecipating into the replicated computation.
static void SyncTensorsGraph(std::vector<XLATensor>* tensors,
absl::Span<const std::string> devices, bool wait,
bool sync_xla_data);
// Makes sure that any outstanding IR operation accumulated over live tensors,
// gets turned into device data. If wait is true, the sync operation will be
// run synchronously. The devices argument, if not empty, tells the devices
// which should be partecipating into the replicated computation.
static void SyncLiveTensorsGraph(const Device* device,
absl::Span<const std::string> devices,
bool wait);
// Marks an execution step, which allows the tensor framework to understand
// the computation boundaries.
static void MarkStep(const Device& device);
// Waits for all the outstanding operations on all the supplied devices.
// If devices is empty, the wait will happen for all local devices.
static void WaitDeviceOps(absl::Span<const std::string> devices);
// Retrieves the PyTorch CPU tensors behind the XLA tensors IR operations.
// All the tensors must be on the same device.
static std::vector<at::Tensor> GetTensors(std::vector<XLATensor>* tensors);
// Operation which creates XLA tensors out of PyTorch CPU tensors by batching
// the requests to the computation servers.
static std::vector<XLATensor> CreateTensors(
const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices);
//////////////////////////////////////////////////////////////////////////////
// XLA dedicated operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
static std::pair<XLATensor, ir::Value> all_reduce(
const XLATensor& input, const ir::Value& token, AllReduceType reduce_type,
double scale, std::vector<std::vector<xla::int64>> groups);
static ir::Value all_reduce_(XLATensor& input, const ir::Value& token,
AllReduceType reduce_type, double scale,
std::vector<std::vector<xla::int64>> groups);
static ir::Value all_reduce(std::vector<XLATensor>* inputs,
const ir::Value& token, AllReduceType reduce_type,
double scale,
std::vector<std::vector<xla::int64>> groups);
static std::pair<XLATensor, ir::Value> all_to_all(
const XLATensor& input, const ir::Value& token,
xla::int64 split_dimension, xla::int64 concat_dimension,
xla::int64 split_count, std::vector<std::vector<xla::int64>> groups);
static std::pair<XLATensor, ir::Value> collective_permute(
const XLATensor& input, const ir::Value& token,
std::vector<std::pair<xla::int64, xla::int64>> source_target_pairs);
static XLATensor get_dimensions_size(const XLATensor& input,
std::vector<xla::int64> dimensions);
static std::vector<XLATensor> user_computation(
const std::string& opname, absl::Span<const XLATensor> inputs,
ComputationPtr computation);
//////////////////////////////////////////////////////////////////////////////
// ATEN operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
static void __ilshift__(XLATensor& input, const at::Scalar& other);
static void __ilshift__(XLATensor& input, const XLATensor& other);
static void __irshift__(XLATensor& input, const at::Scalar& other);
static void __irshift__(XLATensor& input, const XLATensor& other);
static XLATensor __lshift__(
const XLATensor& input, const at::Scalar& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor __lshift__(
const XLATensor& input, const XLATensor& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor __rshift__(
const XLATensor& input, const at::Scalar& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor __rshift__(
const XLATensor& input, const XLATensor& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor adaptive_avg_pool3d(const XLATensor& input,
std::vector<xla::int64> output_size);
static XLATensor adaptive_avg_pool3d_backward(const XLATensor& grad_output,
const XLATensor& input);
static XLATensor _adaptive_avg_pool2d(const XLATensor& input,
std::vector<xla::int64> output_size);
static XLATensor _adaptive_avg_pool2d_backward(const XLATensor& grad_output,
const XLATensor& input);
static void _amp_foreach_non_finite_check_and_unscale_(
std::vector<XLATensor> self, XLATensor& found_inf,
const XLATensor& inv_scale);
static void _amp_update_scale_(XLATensor& current_scale,
XLATensor& growth_tracker,
const XLATensor& found_inf,
double scale_growth_factor,
double scale_backoff_factor,
int growth_interval);
static XLATensor abs(const XLATensor& input);
static XLATensor acos(const XLATensor& input);
static XLATensor acosh(const XLATensor& input);
static XLATensor add(
const XLATensor& input, const XLATensor& other, const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor add(
const XLATensor& input, const at::Scalar& other, const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor addcdiv(const XLATensor& input, const at::Scalar& value,
const XLATensor& tensor1, const XLATensor& tensor2);
static void addcdiv_(XLATensor& input, const at::Scalar& value,
const XLATensor& tensor1, const XLATensor& tensor2);
static XLATensor addcmul(const XLATensor& input, const at::Scalar& value,
const XLATensor& tensor1, const XLATensor& tensor2);
static XLATensor addmm(const XLATensor& input, const XLATensor& weight,
const XLATensor& bias);
static XLATensor all(const XLATensor& input,
std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions);
static XLATensor any(const XLATensor& input,
std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions);
static void arange_out(XLATensor& out, const at::Scalar& start,
const at::Scalar& end, const at::Scalar& step,
at::ScalarType scalar_type);
static XLATensor argmax(const XLATensor& input, xla::int64 dim, bool keepdim);
static XLATensor argmax(const XLATensor& input);
static XLATensor argmin(const XLATensor& input, xla::int64 dim, bool keepdim);
static XLATensor argmin(const XLATensor& input);
// Takes a slice from the input as R1 at the specified offset and reshapes it
// into the provided size.
static XLATensor as_strided(const XLATensor& input,
std::vector<xla::int64> size,
std::vector<xla::int64> stride,
c10::optional<xla::int64> storage_offset);
// In-place version of the method above.
static void as_strided_(XLATensor& input, std::vector<xla::int64> size,
std::vector<xla::int64> stride,
c10::optional<xla::int64> storage_offset);
static XLATensor asin(const XLATensor& input);
static XLATensor asinh(const XLATensor& input);
static XLATensor atan(const XLATensor& input);
static XLATensor atanh(const XLATensor& input);
static XLATensor atan2(
const XLATensor& input, const XLATensor& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor avg_pool_nd(const XLATensor& input,
xla::int64 spatial_dim_count,
std::vector<xla::int64> kernel_size,
std::vector<xla::int64> stride,
std::vector<xla::int64> padding, bool ceil_mode,
bool count_include_pad);
static XLATensor avg_pool_nd_backward(const XLATensor& out_backprop,
const XLATensor& input,
xla::int64 spatial_dim_count,
std::vector<xla::int64> kernel_size,
std::vector<xla::int64> stride,
std::vector<xla::int64> padding,
bool ceil_mode, bool count_include_pad);
static XLATensor baddbmm(const XLATensor& input, const XLATensor& batch1,
const XLATensor& batch2, const at::Scalar& beta,
const at::Scalar& alpha);
static XLATensor bernoulli(const XLATensor& input, double probability);
static XLATensor bernoulli(const XLATensor& input);
static void bernoulli_(XLATensor& input, double probability);
static void bernoulli_(XLATensor& input, const XLATensor& probability);
static XLATensor binary_cross_entropy(const XLATensor& input,
const XLATensor& target,
const XLATensor& weight,
xla::int64 reduction);
static XLATensor binary_cross_entropy_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
const XLATensor& weight,
xla::int64 reduction);
static void bitwise_and_out(XLATensor& out, const XLATensor& input,
const at::Scalar& other);
static void bitwise_and_out(XLATensor& out, const XLATensor& input,
const XLATensor& other);
static void bitwise_not_out(XLATensor& out, const XLATensor& input);
static void bitwise_or_out(XLATensor& out, const XLATensor& input,
const at::Scalar& other);
static void bitwise_or_out(XLATensor& out, const XLATensor& input,
const XLATensor& other);
static void bitwise_xor_out(XLATensor& out, const XLATensor& input,
const at::Scalar& other);
static void bitwise_xor_out(XLATensor& out, const XLATensor& input,
const XLATensor& other);
// Batch matrix multiplication. Both tensors must be 3D, the batch size must
// match and the remaining two dimensions must be compatible for matrix
// multiplication.
static XLATensor bmm(const XLATensor& batch1, const XLATensor& batch2);
// Broadcasts the given tensors according to broadcasting semantics.
static std::vector<XLATensor> broadcast_tensors(
absl::Span<const XLATensor> tensors);
static XLATensor cat(absl::Span<const XLATensor> tensors, xla::int64 dim);
static XLATensor ceil(const XLATensor& input);
static XLATensor cholesky(const XLATensor& input, bool upper);
static XLATensor clamp(const XLATensor& input,
const c10::optional<at::Scalar>& min,
const c10::optional<at::Scalar>& max);
static XLATensor clamp(const XLATensor& input,
const c10::optional<at::Tensor>& min,
const c10::optional<at::Tensor>& max);
static void clamp_out(XLATensor& out, const XLATensor& input,
const c10::optional<at::Tensor>& min,
const c10::optional<at::Tensor>& max);
static XLATensor clone(const XLATensor& input);
// Pad with the given value and size specified by the given list of low and
// high paddings.
static XLATensor constant_pad_nd(const XLATensor& input,
absl::Span<const xla::int64> pad,
const at::Scalar& value);
static XLATensor convolution_overrideable(
const XLATensor& input, const XLATensor& weight, const XLATensor& bias,
std::vector<xla::int64> stride, std::vector<xla::int64> padding,
std::vector<xla::int64> dilation, bool transposed,
std::vector<xla::int64> output_padding, xla::int64 groups);
static std::tuple<XLATensor, XLATensor, XLATensor>
convolution_backward_overrideable(
const XLATensor& out_backprop, const XLATensor& input,
const XLATensor& weight, std::vector<xla::int64> stride,
std::vector<xla::int64> padding, std::vector<xla::int64> dilation,
bool transposed, std::vector<xla::int64> output_padding,
xla::int64 groups);
static XLATensor convolution_overrideable(
const XLATensor& input, const XLATensor& weight,
std::vector<xla::int64> stride, std::vector<xla::int64> padding,
std::vector<xla::int64> dilation, bool transposed,
std::vector<xla::int64> output_padding, xla::int64 groups);
static XLATensor cos(const XLATensor& input);
static XLATensor cosh(const XLATensor& input);
// Returns the cross product of the two input tensors in the given dimension.
// If the dimension is not given, it defaults to the first dimension found
// with the size 3.
static XLATensor cross(const XLATensor& input, const XLATensor& other,
c10::optional<xla::int64> dim);
// Returns the cumulative product of elements of input in the given dimension.
static XLATensor cumprod(const XLATensor& input, xla::int64 dim,
c10::optional<at::ScalarType> dtype);
// Returns the cumulative sum of elements of input in the given dimension.
static XLATensor cumsum(const XLATensor& input, xla::int64 dim,
c10::optional<at::ScalarType> dtype);
// If the input is a matrix (2-D tensor), returns a 1-D tensor with the
// diagonal elements of the input. If the input is a vector (1-D tensor),
// returns a 2-D square tensor with the elements of input as the diagonal.
static XLATensor diag(const XLATensor& input, xla::int64 offset);
// Returns the diagonal of a matrix (2-D tensor) or batch of matrices. The
// matrix dimensions are specified by dim1 and dim2, the diagonal by offset.
static XLATensor diagonal(const XLATensor& input, xla::int64 offset,
xla::int64 dim1, xla::int64 dim2);
static XLATensor div(
const XLATensor& input, const XLATensor& other,
const c10::optional<c10::string_view>& rounding_mode = c10::nullopt,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor div(const XLATensor& input, const at::Scalar& other);
// A generalized contraction between tensors of arbitrary dimension defined by
// the given equation and applied to the input tensors.
static XLATensor einsum(const std::string& equation,
absl::Span<const XLATensor> tensors);
static XLATensor elu(const XLATensor& input, const at::Scalar& alpha,
const at::Scalar& scale, const at::Scalar& input_scale);
static void elu_(XLATensor& input, const at::Scalar& alpha,
const at::Scalar& scale, const at::Scalar& input_scale);
static XLATensor elu_backward(const XLATensor& grad_output,
const at::Scalar& alpha,
const at::Scalar& scale,
const at::Scalar& input_scale,
const XLATensor& output);
static XLATensor embedding_dense_backward(const XLATensor& grad_output,
const XLATensor& indices,
xla::int64 num_weights,
xla::int64 padding_idx,
bool scale_grad_by_freq);
static XLATensor eq(const XLATensor& input, const at::Scalar& other);
static XLATensor eq(const XLATensor& input, const XLATensor& other);
static XLATensor erf(const XLATensor& input);
static XLATensor erfc(const XLATensor& input);
static XLATensor erfinv(const XLATensor& input);
static XLATensor exp(const XLATensor& input);
static XLATensor expand(const XLATensor& input, std::vector<xla::int64> size);
static XLATensor expm1(const XLATensor& input);
static void exponential_(XLATensor& input, double lambd);
// Returns a 2-D tensor with ones on the diagonal and zeros elsewhere.
static XLATensor eye(xla::int64 lines, xla::int64 cols, const Device& device,
at::ScalarType element_type);
static void eye_out(XLATensor& out, xla::int64 lines, xla::int64 cols);
// Fills the input with the given value.
static void fill_(XLATensor& input, const at::Scalar& value);
// Flips (reverses) the values in the dimensions of the input tensor.
static XLATensor flip(const XLATensor& input,
absl::Span<const xla::int64> dims);
static XLATensor floor(const XLATensor& input);
static XLATensor fmod(
const XLATensor& input, const XLATensor& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor fmod(
const XLATensor& input, const at::Scalar& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor frac(const XLATensor& input);
static XLATensor full(absl::Span<const xla::int64> size,
const at::Scalar& fill_value, const Device& device,
at::ScalarType scalar_type);
static XLATensor full_like(const XLATensor& input,
const at::Scalar& fill_value, const Device& device,
c10::optional<at::ScalarType> scalar_type);
static XLATensor gather(const XLATensor& input, xla::int64 dim,
const XLATensor& index);
static XLATensor ge(const XLATensor& input, const at::Scalar& other);
static XLATensor ge(const XLATensor& input, const XLATensor& other);
static XLATensor gelu(const XLATensor& input);
static XLATensor gelu_backward(const XLATensor& grad, const XLATensor& input);
static XLATensor ger(const XLATensor& input, const XLATensor& vec2);
static XLATensor gt(const XLATensor& input, const at::Scalar& other);
static XLATensor gt(const XLATensor& input, const XLATensor& other);
// Gather slices from input into a result with shape specified by indices. The
// shape of the indices are first made consistent using broadcast semantics.
// For input of shape d1 x d2 x ... x dn and p indices of shape i1 x i2 x ...
// x ik, the output shape is d1 x ... x d(start_dim) x i1 x ... x ik x
// d(start_dim+p+1) x ... x dn.
static XLATensor index(const XLATensor& input,
absl::Span<const XLATensor> indices,
xla::int64 start_dim);
static XLATensor index_add(const XLATensor& input, xla::int64 dim,
const XLATensor& index, const XLATensor& source);
static void index_add_(XLATensor& input, xla::int64 dim,
const XLATensor& index, const XLATensor& source);
static XLATensor index_copy(const XLATensor& input, xla::int64 dim,
const XLATensor& index, const XLATensor& source);
static void index_copy_(XLATensor& input, xla::int64 dim,
const XLATensor& index, const XLATensor& source);
// Fills the elements of the base tensor with the given value in the given
// dimension, at positions given by the index. The index must be a rank-1
// tensor.
static XLATensor index_fill(const XLATensor& input, xla::int64 dim,
const XLATensor& index, const at::Scalar& value);
// Same as above, but the value is wrapped as a rank-0 tensor.
static XLATensor index_fill(const XLATensor& input, xla::int64 dim,
const XLATensor& index, const XLATensor& value);
static void index_fill_(XLATensor& input, xla::int64 dim,
const XLATensor& index, const XLATensor& value);
static void index_fill_(XLATensor& input, xla::int64 dim,
const XLATensor& index, const at::Scalar& value);
// Puts values into the input tensor using the given indices (a tuple of
// tensors) and returns the result.
static XLATensor index_put(const XLATensor& input,
absl::Span<const XLATensor> indices,
xla::int64 start_dim, const XLATensor& values,
bool accumulate,
absl::Span<const xla::int64> result_permutation);
static void index_put_(XLATensor& input, const XLATensor& canonical_base,
absl::Span<const XLATensor> indices,
xla::int64 start_dim, const XLATensor& values,
bool accumulate,
absl::Span<const xla::int64> result_permutation);
static XLATensor index_select(const XLATensor& input, xla::int64 dim,
const XLATensor& index);
static XLATensor inverse(const XLATensor& input);
static XLATensor isnan(const XLATensor& input);
static XLATensor kl_div_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
xla::int64 reduction, bool log_target);
static std::tuple<XLATensor, XLATensor> kthvalue(const XLATensor& input,
xla::int64 k, xla::int64 dim,
bool keepdim);
static XLATensor l1_loss(const XLATensor& input, const XLATensor& target,
xla::int64 reduction);
static XLATensor l1_loss_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
xla::int64 reduction);
static XLATensor le(const XLATensor& input, const at::Scalar& other);
static XLATensor le(const XLATensor& input, const XLATensor& other);
static XLATensor hardshrink(const XLATensor& input, const at::Scalar& lambda);
static XLATensor hardshrink_backward(const XLATensor& grad_out,
const XLATensor& input,
const at::Scalar& lambda);
static XLATensor hardsigmoid(const XLATensor& input);
static XLATensor hardsigmoid_backward(const XLATensor& grad_output,
const XLATensor& input);
static XLATensor hardtanh_backward(const XLATensor& grad_output,
const XLATensor& input,
const at::Scalar& min_val,
const at::Scalar& max_val);
static XLATensor leaky_relu(const XLATensor& input, double negative_slope);
static XLATensor leaky_relu_backward(const XLATensor& grad_output,
const XLATensor& input,
double negative_slope);
static XLATensor lerp(const XLATensor& input, const XLATensor& end,
const XLATensor& weight);
static XLATensor lerp(const XLATensor& input, const XLATensor& end,
const at::Scalar& weight);
static XLATensor log(const XLATensor& input);
static XLATensor log_base(const XLATensor& input, ir::OpKind op, double base);
static XLATensor log_sigmoid(const XLATensor& input);
static std::tuple<XLATensor, XLATensor> log_sigmoid_forward(
const XLATensor& input);
static XLATensor log_sigmoid_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& buffer);
static XLATensor log_softmax(const XLATensor& input, xla::int64 dim,
c10::optional<at::ScalarType> dtype);
static XLATensor log_softmax_backward(const XLATensor& grad_output,
const XLATensor& output,
xla::int64 dim);
static XLATensor log1p(const XLATensor& input);
static void log1p_(XLATensor& input);
static XLATensor logdet(const XLATensor& input);
static XLATensor logsumexp(const XLATensor& input,
std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions);
static XLATensor lt(const XLATensor& input, const at::Scalar& other);
static XLATensor lt(const XLATensor& input, const XLATensor& other);
// In-place version of the method above.
static void masked_fill_(XLATensor& input, const XLATensor& mask,
const at::Scalar& value);
static void masked_scatter_(XLATensor& input, const XLATensor& mask,
const XLATensor& source);
static XLATensor masked_select(const XLATensor& input, const XLATensor& mask);
static XLATensor matmul(const XLATensor& input, const XLATensor& other);
static XLATensor max(
const XLATensor& input, const XLATensor& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor max(const XLATensor& input);
static std::tuple<XLATensor, XLATensor> max(const XLATensor& input,
xla::int64 dim, bool keepdim);
static void max_out(XLATensor& max, XLATensor& max_values,
const XLATensor& input, xla::int64 dim, bool keepdim);
static std::tuple<XLATensor, XLATensor> max_pool_nd(
const XLATensor& input, xla::int64 spatial_dim_count,
std::vector<xla::int64> kernel_size, std::vector<xla::int64> stride,
std::vector<xla::int64> padding, bool ceil_mode);
static XLATensor max_pool_nd_backward(const XLATensor& out_backprop,
const XLATensor& input,
xla::int64 spatial_dim_count,
std::vector<xla::int64> kernel_size,
std::vector<xla::int64> stride,
std::vector<xla::int64> padding,
bool ceil_mode);
static XLATensor max_unpool(const XLATensor& input, const XLATensor& indices,
std::vector<xla::int64> output_size);
static XLATensor max_unpool_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& indices,
std::vector<xla::int64> output_size);
static XLATensor mean(const XLATensor& input,
std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions,
c10::optional<at::ScalarType> dtype);
static XLATensor min(
const XLATensor& input, const XLATensor& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor min(const XLATensor& input);
static std::tuple<XLATensor, XLATensor> min(const XLATensor& input,
xla::int64 dim, bool keepdim);
static void min_out(XLATensor& min, XLATensor& min_indices,
const XLATensor& input, xla::int64 dim, bool keepdim);
static XLATensor mm(const XLATensor& input, const XLATensor& weight);
static XLATensor mse_loss(const XLATensor& input, const XLATensor& target,
xla::int64 reduction);
static XLATensor mse_loss_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
xla::int64 reduction);
static XLATensor mul(
const XLATensor& input, const XLATensor& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor mul(
const XLATensor& input, const at::Scalar& other,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor mv(const XLATensor& input, const XLATensor& vec);
static void mv_out(XLATensor& out, const XLATensor& input,
const XLATensor& vec);
// Returns a new tensor that is a narrowed view of the input in the given
// dimension.
static XLATensor narrow(const XLATensor& input, xla::int64 dim,
xla::int64 start, xla::int64 length);
// Like batch_norm, but returns additional save_mean and save_invstd used by
// the backward pass.
static std::tuple<XLATensor, XLATensor, XLATensor> native_batch_norm(
const XLATensor& input, const XLATensor& weight, const XLATensor& bias,
XLATensor& running_mean, XLATensor& running_var, bool training,
double momentum, double eps);
// Returns the input, weight and bias gradients.
static std::tuple<XLATensor, XLATensor, XLATensor> native_batch_norm_backward(
const XLATensor& grad_out, const XLATensor& input,
const XLATensor& weight, const XLATensor& save_mean,
const XLATensor& save_invstd, bool training, double eps);
static XLATensor ne(const XLATensor& input, const at::Scalar& other);
static XLATensor ne(const XLATensor& input, const XLATensor& other);
static XLATensor neg(const XLATensor& input);
static XLATensor nll_loss(const XLATensor& input, const XLATensor& target,
const XLATensor& weight, xla::int64 reduction,
int ignore_index);
static XLATensor nll_loss2d(const XLATensor& input, const XLATensor& target,
const XLATensor& weight, xla::int64 reduction,
int ignore_index);
static XLATensor nll_loss2d_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
const XLATensor& weight,
xla::int64 reduction, int ignore_index,
const XLATensor& total_weight);
static XLATensor nll_loss_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
const XLATensor& weight,
xla::int64 reduction, int ignore_index,
const XLATensor& total_weight);
static std::pair<XLATensor, XLATensor> nms(const XLATensor& boxes,
const XLATensor& scores,
const XLATensor& score_threshold,
const XLATensor& iou_threshold,
xla::int64 output_size);
static XLATensor nonzero(const XLATensor& input);
static XLATensor norm(const XLATensor& input,
const c10::optional<at::Scalar>& p,
c10::optional<at::ScalarType> dtype,
at::IntArrayRef dim, bool keepdim);
static XLATensor normal(double mean, const XLATensor& std);
static XLATensor normal(const XLATensor& mean, double std);
static XLATensor normal(const XLATensor& mean, const XLATensor& std);
static void normal_(XLATensor& input, double mean, double std);
static XLATensor not_supported(std::string description, xla::Shape shape,
const Device& device);
// Permute the dimensions of this tensor according to the given permutation.
static XLATensor permute(const XLATensor& input,
absl::Span<const xla::int64> dims);
static XLATensor pow(const XLATensor& input, const at::Scalar& exponent);
static XLATensor pow(const XLATensor& input, const XLATensor& exponent);
static XLATensor pow(const at::Scalar& input, const XLATensor& exponent);
static XLATensor prod(const XLATensor& input,
std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions,
c10::optional<at::ScalarType> dtype);
static void put_(XLATensor& input, const XLATensor& index,
const XLATensor& source, bool accumulate);
static std::tuple<XLATensor, XLATensor> qr(const XLATensor& input, bool some);
static void random_(XLATensor& input, int64_t from, int64_t to);
static XLATensor randperm(xla::int64 n, const Device& device,
at::ScalarType scalar_type);
static XLATensor reciprocal(const XLATensor& input);
static XLATensor reflection_pad2d(const XLATensor& input,
std::vector<xla::int64> padding);
static XLATensor reflection_pad2d_backward(const XLATensor& grad_output,
const XLATensor& input,
std::vector<xla::int64> padding);
static XLATensor relu(const XLATensor& input);
static void relu_(XLATensor& input);
static XLATensor remainder(const XLATensor& input, const XLATensor& other);
static XLATensor remainder(const XLATensor& input, const at::Scalar& other);
// Repeats the input tensor along each dimension by the given number of
// repeats.
static XLATensor repeat(const XLATensor& input,
std::vector<xla::int64> repeats);
static XLATensor replication_pad1d(const XLATensor& input,
std::vector<xla::int64> padding);
static XLATensor replication_pad1d_backward(const XLATensor& grad_output,
const XLATensor& input,
std::vector<xla::int64> padding);
static XLATensor replication_pad2d(const XLATensor& input,
std::vector<xla::int64> padding);
static XLATensor replication_pad2d_backward(const XLATensor& grad_output,
const XLATensor& input,
std::vector<xla::int64> padding);
static void resize_(XLATensor& input, std::vector<xla::int64> size);
static XLATensor round(const XLATensor& input);
static XLATensor rrelu_with_noise(const XLATensor& input, XLATensor& noise,
const at::Scalar& lower,
const at::Scalar& upper, bool training);
static XLATensor rrelu_with_noise_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& noise,
const at::Scalar& lower,
const at::Scalar& upper,
bool training);
static XLATensor rsqrt(const XLATensor& input);
static XLATensor rsub(
const XLATensor& input, const XLATensor& other, const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor rsub(
const XLATensor& input, const at::Scalar& other, const at::Scalar& alpha,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static void copy_(XLATensor& input, XLATensor& src);
static void scatter_(XLATensor& input, xla::int64 dim, const XLATensor& index,
const XLATensor& src);
static void scatter_(XLATensor& input, xla::int64 dim, const XLATensor& index,
const at::Scalar& value);
static void scatter_add_(XLATensor& input, xla::int64 dim,
const XLATensor& index, const XLATensor& src);
static XLATensor select(const XLATensor& input, xla::int64 dim,
xla::int64 index);
static void silu_out(XLATensor& input, XLATensor& out);
static XLATensor sigmoid(const XLATensor& input);
static XLATensor sigmoid_backward(const XLATensor& grad_output,
const XLATensor& output);
static XLATensor sign(const XLATensor& input);
static XLATensor sin(const XLATensor& input);
static XLATensor sinh(const XLATensor& input);
static XLATensor slice(const XLATensor& input, xla::int64 dim,
xla::int64 start, xla::int64 end, xla::int64 step);
// Computes a loss that uses a squared term if the absolute element-wise error
// falls below 1 and an L1 term otherwise.
static XLATensor smooth_l1_loss(const XLATensor& input,
const XLATensor& target, xla::int64 reduction,
double beta);
// Returns the gradient of the input of a smooth_l1_loss operation.
static XLATensor smooth_l1_loss_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
xla::int64 reduction, double beta);
static XLATensor softmax(const XLATensor& input, xla::int64 dim,
c10::optional<at::ScalarType> dtype);
static XLATensor softmax_backward(const XLATensor& grad_output,
const XLATensor& output, xla::int64 dim);
static XLATensor softplus(const XLATensor& input, const at::Scalar& beta,
const at::Scalar& threshold);
static XLATensor softplus_backward(const XLATensor& grad_output,
const XLATensor& input,
const at::Scalar& beta,
const at::Scalar& threshold,
const XLATensor& output);
static XLATensor softshrink(const XLATensor& input, const at::Scalar& lambda);
static XLATensor softshrink_backward(const XLATensor& grad_out,
const XLATensor& input,
const at::Scalar& lambda);
static std::vector<XLATensor> split(const XLATensor& input,
xla::int64 split_size, xla::int64 dim);
static std::vector<XLATensor> split_with_sizes(
const XLATensor& input, std::vector<xla::int64> split_size,
xla::int64 dim);
static XLATensor sqrt(const XLATensor& input);
// Squeeze out all trivial (size 1) dimensions.
static XLATensor squeeze(const XLATensor& input);
// Squeeze out the specified dimension index, if trivial (size 1). Returns
// unchanged input otherwise.
static XLATensor squeeze(const XLATensor& input, xla::int64 dim);
// In-place versions of the methods above.
static void squeeze_(XLATensor& input);
static void squeeze_(XLATensor& input, xla::int64 dim);
static XLATensor stack(absl::Span<const XLATensor> tensors, xla::int64 dim);
static XLATensor std(const XLATensor& input,