Skip to content

Commit fa7426e

Browse files
committed
Add ones op (tensorflow#162)
1 parent acdef3d commit fa7426e

File tree

2 files changed

+92
-1
lines changed
  • tensorflow-core/tensorflow-core-api/src

2 files changed

+92
-1
lines changed

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@
128128
import org.tensorflow.op.core.NextIteration;
129129
import org.tensorflow.op.core.NoOp;
130130
import org.tensorflow.op.core.OneHot;
131+
import org.tensorflow.op.core.Ones;
131132
import org.tensorflow.op.core.OnesLike;
132133
import org.tensorflow.op.core.OrderedMapClear;
133134
import org.tensorflow.op.core.OrderedMapIncompleteSize;
@@ -3426,6 +3427,19 @@ public <U extends TType, T extends TNumber> OneHot<U> oneHot(Operand<T> indices,
34263427
return OneHot.create(scope, indices, depth, onValue, offValue, options);
34273428
}
34283429

3430+
/**
3431+
* Creates a one valued tensor given its type and shape.
3432+
*
3433+
* @param scope is a scope used to add the underlying operation
3434+
* @param dims a 1-D operand that represents the shape of the output tensor
3435+
* @param type the output tensor datatype. Can not be TString.
3436+
* @return a constant tensor initialized with ones
3437+
* @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones.
3438+
*/
3439+
public <T extends TType, U extends TNumber> Ones<T> ones(Operand<U> dims, DataType<T> type) {
3440+
return Ones.create(scope, dims, type);
3441+
}
3442+
34293443
/**
34303444
* Returns a tensor of ones with the same shape and type as x.
34313445
*
@@ -7726,7 +7740,7 @@ public Ops withName(String opName) {
77267740
}
77277741

77287742
/**
7729-
* Returns an API that uses the provided DeviceSpec for an op.
7743+
* Returns an API that places the created operations on the device(s) matching the provided spec.
77307744
*
77317745
* @see {@link Scope#withDevice(DeviceSpec)}
77327746
*/
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
package org.tensorflow.op.core;
16+
17+
import org.tensorflow.DataType;
18+
import org.tensorflow.Operand;
19+
import org.tensorflow.Operation;
20+
import org.tensorflow.Output;
21+
import org.tensorflow.op.Op;
22+
import org.tensorflow.op.Scope;
23+
import org.tensorflow.op.annotation.Endpoint;
24+
import org.tensorflow.op.annotation.Operator;
25+
import org.tensorflow.op.dtypes.Cast;
26+
import org.tensorflow.types.TString;
27+
import org.tensorflow.types.family.TNumber;
28+
import org.tensorflow.types.family.TType;
29+
30+
/**
31+
* An operator creating a constant initialized with ones of the shape given by `dims`.
32+
*
33+
* <p>For example, the following expression
34+
* <pre>{@code tf.ones(tf.constant(shape), TFloat32.DTYPE)}</pre>
35+
* is the equivalent of
36+
* <pre>{@code tf.fill(tf.constant(shape), tf.constant(1.0f))}</pre>
37+
*
38+
* @param <T> constant type
39+
*/
40+
@Operator
41+
public final class Ones<T extends TType> implements Op, Operand<T> {
42+
43+
/**
44+
* Creates a one valued tensor given its type and shape.
45+
*
46+
* @param scope is a scope used to add the underlying operation
47+
* @param dims a 1-D operand that represents the shape of the output tensor
48+
* @param type the output tensor datatype. Can not be TString.
49+
* @return a constant tensor initialized with ones
50+
* @throws IllegalArgumentException if the tensor type or shape cannot be initialized with ones.
51+
*/
52+
@Endpoint
53+
public static <T extends TType, U extends TNumber> Ones<T> create(Scope scope, Operand<U> dims, DataType<T> type) {
54+
Scope onesScope = scope.withSubScope("Ones");
55+
if (type == TString.DTYPE) {
56+
throw new IllegalArgumentException("Can't create Ones of String DataType");
57+
}
58+
Operand<T> one = Cast.create(onesScope.withName("One"), Constant.scalarOf(onesScope, 1), type);
59+
return new Ones<>(Fill.create(onesScope.withName("Fill"), dims, one));
60+
}
61+
62+
@Override
63+
public Operation op() {
64+
return fill.op();
65+
}
66+
67+
@Override
68+
public Output<T> asOutput() {
69+
return fill.asOutput();
70+
}
71+
72+
private final Fill<T> fill;
73+
74+
private Ones(Fill<T> fill) {
75+
this.fill = fill;
76+
}
77+
}

0 commit comments

Comments
 (0)