Skip to content

Commit 9d0bb15

Browse files
committed
[TF] Add bullet operator (•) for matrix multiplication (#17173)
Add `•` operator for matmul, and remove `Tensor.dot` and `⊗` completely.
1 parent 4b6adec commit 9d0bb15

File tree

5 files changed

+18
-25
lines changed

5 files changed

+18
-25
lines changed

stdlib/public/TensorFlow/Ops.swift

+3-10
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,14 @@ public func matmul<Scalar : Numeric>(
265265
return Raw.matMul(left, right)
266266
}
267267

268-
infix operator : MultiplicationPrecedence
268+
infix operator : MultiplicationPrecedence
269269

270270
public extension Tensor where Scalar : Numeric {
271-
@_inlineable @inline(__always)
272-
@available(*, renamed: "matmul(_:_:)")
273-
func dot(_ other: Tensor) -> Tensor {
274-
return matmul(self, other)
275-
}
276-
277271
/// Performs matrix multiplication between two tensors and produces the
278272
/// result.
279273
@_inlineable @inline(__always)
280-
@available(*, renamed: "matmul(_:_:)")
281-
static func (lhs: Tensor, rhs: Tensor) -> Tensor {
282-
return lhs.dot(rhs)
274+
static func (lhs: Tensor, rhs: Tensor) -> Tensor {
275+
return matmul(lhs, rhs)
283276
}
284277
}
285278

test/TensorFlow/crashers.swift

+6-6
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public func postdom_crash1(w1: Tensor<Float>, inputBatch: Tensor<Float>) {
3333
// expected-warning @-2 {{'inputBatch' implicitly copied to the accelerator}}
3434
let iterationCount = 1000
3535
for _ in 0..<iterationCount {
36-
_ = inputBatch w1 // expected-note 2 {{value used here}}
36+
_ = inputBatch w1 // expected-note 2 {{value used here}}
3737
}
3838
}
3939

@@ -91,10 +91,10 @@ public func testStraightLineXORTraining() {
9191

9292
// Training loop
9393
for _ in 0..<iterationCount {
94-
let mmul1 = inputBatch w1
94+
let mmul1 = inputBatch w1
9595
let l1 = mmul1 + b1
9696
let o1 = sigmoid(l1)
97-
let mmul2 = o1 w2
97+
let mmul2 = o1 w2
9898
let l2 = mmul2 + b2
9999
let pred = sigmoid(l2)
100100

@@ -109,15 +109,15 @@ public func testStraightLineXORTraining() {
109109
let dL2 = dPred * pred * (1 - pred)
110110
let dMmul2 = dL2
111111
let dB2 = dL2
112-
let dO1 = dMmul2 w2.transposed(withPermutations: 1, 0)
113-
let dW2 = o1.transposed(withPermutations: 1, 0) dMmul2
112+
let dO1 = dMmul2 w2.transposed(withPermutations: 1, 0)
113+
let dW2 = o1.transposed(withPermutations: 1, 0) dMmul2
114114
let dL1 = dO1 * l1 * (1 - l1)
115115
let dMmul1 = dL1
116116
let dB1 = dL1
117117

118118
// Statically detected shape mismatch!
119119
// expected-error @+1 {{(op: 'MatMul') with input shapes: [4,2], [4,4]}}
120-
let dW1 = inputBatch dMmul1
120+
let dW1 = inputBatch dMmul1
121121

122122
// Descent
123123
w1 -= (dW1 * learningRate)

test/TensorFlow/no_copy.swift

+7-7
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ struct Classifier {
167167
var b2 = Tensor<Float>(zeros: [1, 10])
168168

169169
func prediction(for input: Tensor<Float>) -> Tensor<Float> {
170-
let h1 = sigmoid(input w1 + b1)
171-
return sigmoid(h1 w2 + b2)
170+
let h1 = sigmoid(input w1 + b1)
171+
return sigmoid(h1 w2 + b2)
172172
}
173173

174174
mutating func train(images: Tensor<Float>, labels: Tensor<Float>,
@@ -177,17 +177,17 @@ struct Classifier {
177177
var epochCount = epochCount
178178
repeat {
179179
// Forward pass
180-
let z1 = images w1 + b1
180+
let z1 = images w1 + b1
181181
let h1 = sigmoid(z1)
182-
let z2 = h1 w2 + b2
182+
let z2 = h1 w2 + b2
183183
let pred = sigmoid(z2)
184184

185185
// Backward pass
186186
let dz2 = pred - labels
187-
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
187+
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
188188
let db2 = dz2.sum(squeezingAxes: 0)
189-
let dz1 = dz2.dot(w2.transposed(withPermutations: 1, 0)) * h1 * (1 - h1)
190-
let dw1 = images.transposed(withPermutations: 1, 0) dz1
189+
let dz1 = matmul(dz2, w2.transposed(withPermutations: 1, 0)) * h1 * (1 - h1)
190+
let dw1 = images.transposed(withPermutations: 1, 0) dz1
191191
let db1 = dz1.sum(squeezingAxes: 0)
192192

193193
// Gradient descent

test/TensorFlowRuntime/tensor_debuglog.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ TensorTests.testAllBackends("XWPlusB") {
2626
// Shape: 2
2727
let b = Tensor<Float>([0.5, 0.5])
2828
// Do xW+b!
29-
let result = x w + b
29+
let result = x w + b
3030
expectEqual([1, 2], result.shape)
3131
expectEqual([12.5, 6.5], result.scalars)
3232
}

test/TensorFlowRuntime/tensor_xla_debuglog.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ XLATests.test("XWPlusB_XLA") {
2727
// Shape: 2
2828
let b = Tensor<Float>([0.5, 0.5])
2929
// Do xW+b!
30-
let result = x w + b
30+
let result = x w + b
3131
expectEqual([1, 2], result.shape)
3232
expectEqual([12.5, 6.5], result.scalars)
3333
#endif

0 commit comments

Comments
 (0)