Skip to content

Add option to run TensorFlow job on the preferred device (via Scope) #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
writer->EndLine();
}
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @zaleslaw , I have another change to propose you that I just thought of. I think if we merge together scope.applyControlDependencies and this new step for adding the device string to the opBuilder would make the whole process of building an op simpler, especially with non-generated ones.

So what about that:

  1. We rename scope.applyControlDependencies(opBuilder) to scope.apply(opBuilder)
  2. In Scope.apply, we apply both control dependencies, the device string (and eventually more data if needed)

By doing so, we also won't need scope.makeDeviceString() anymore. Let me know what you think

// Add control dependencies, if any.
writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);");
writer->Append("opBuilder = scope.apply(opBuilder);");
writer->EndLine();

for (const AttributeSpec& attribute : op.attributes()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.nio.charset.Charset;
import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.DeviceSpec;
import org.tensorflow.EagerSession;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Operand;
Expand Down Expand Up @@ -7724,6 +7725,15 @@ public Ops withName(String opName) {
return new Ops(scope.withName(opName));
}

/**
* Returns an API that uses the provided DeviceSpec for an op.
*
* @see {@link Scope#withDevice(DeviceSpec)}
*/
public Ops withDevice(DeviceSpec deviceSpec) {
return new Ops(scope.withDevice(deviceSpec));
}

/**
* Returns an API that adds operations to the graph with the provided control dependencies.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private Options() {
public static AudioSpectrogram create(Scope scope, Operand<TFloat32> input, Long windowSize, Long stride, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("AudioSpectrogram", scope.makeOpName("AudioSpectrogram"));
opBuilder.addInput(input.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("window_size", windowSize);
opBuilder.setAttr("stride", stride);
if (options != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private Options() {
public static DecodeWav create(Scope scope, Operand<TString> contents, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("DecodeWav", scope.makeOpName("DecodeWav"));
opBuilder.addInput(contents.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.desiredChannels != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static EncodeWav create(Scope scope, Operand<TFloat32> audio, Operand<TIn
OperationBuilder opBuilder = scope.env().opBuilder("EncodeWav", scope.makeOpName("EncodeWav"));
opBuilder.addInput(audio.asOutput());
opBuilder.addInput(sampleRate.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new EncodeWav(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public static Mfcc create(Scope scope, Operand<TFloat32> spectrogram, Operand<TI
OperationBuilder opBuilder = scope.env().opBuilder("Mfcc", scope.makeOpName("Mfcc"));
opBuilder.addInput(spectrogram.asOutput());
opBuilder.addInput(sampleRate.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.upperFrequencyLimit != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public static <T extends TNumber> BitwiseAnd<T> create(Scope scope, Operand<T> x
OperationBuilder opBuilder = scope.env().opBuilder("BitwiseAnd", scope.makeOpName("BitwiseAnd"));
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new BitwiseAnd<T>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public static <T extends TNumber> BitwiseOr<T> create(Scope scope, Operand<T> x,
OperationBuilder opBuilder = scope.env().opBuilder("BitwiseOr", scope.makeOpName("BitwiseOr"));
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new BitwiseOr<T>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public static <T extends TNumber> BitwiseXor<T> create(Scope scope, Operand<T> x
OperationBuilder opBuilder = scope.env().opBuilder("BitwiseXor", scope.makeOpName("BitwiseXor"));
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new BitwiseXor<T>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public final class Invert<T extends TNumber> extends RawOp implements Operand<T>
public static <T extends TNumber> Invert<T> create(Scope scope, Operand<T> x) {
OperationBuilder opBuilder = scope.env().opBuilder("Invert", scope.makeOpName("Invert"));
opBuilder.addInput(x.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new Invert<T>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public static <T extends TNumber> LeftShift<T> create(Scope scope, Operand<T> x,
OperationBuilder opBuilder = scope.env().opBuilder("LeftShift", scope.makeOpName("LeftShift"));
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new LeftShift<T>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public static <T extends TNumber> RightShift<T> create(Scope scope, Operand<T> x
OperationBuilder opBuilder = scope.env().opBuilder("RightShift", scope.makeOpName("RightShift"));
opBuilder.addInput(x.asOutput());
opBuilder.addInput(y.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new RightShift<T>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static KMC2ChainInitialization create(Scope scope, Operand<TFloat32> dist
OperationBuilder opBuilder = scope.env().opBuilder("KMC2ChainInitialization", scope.makeOpName("KMC2ChainInitialization"));
opBuilder.addInput(distances.asOutput());
opBuilder.addInput(seed.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new KMC2ChainInitialization(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public static KmeansPlusPlusInitialization create(Scope scope, Operand<TFloat32>
opBuilder.addInput(numToSample.asOutput());
opBuilder.addInput(seed.asOutput());
opBuilder.addInput(numRetriesPerSample.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new KmeansPlusPlusInitialization(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private Options() {
public static <T extends TNumber> AllReduce<T> create(Scope scope, Operand<T> input, Long groupSize, Long groupKey, Long instanceKey, String mergeOp, String finalOp, List<Long> subdivOffsets, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("CollectiveReduce", scope.makeOpName("AllReduce"));
opBuilder.addInput(input.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("group_size", groupSize);
opBuilder.setAttr("group_key", groupKey);
opBuilder.setAttr("instance_key", instanceKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private Options() {
@Endpoint(describeByClass = true)
public static <T extends TType> BroadcastRecv<T> create(Scope scope, DataType<T> T, Long groupSize, Long groupKey, Long instanceKey, Shape shape, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("CollectiveBcastRecv", scope.makeOpName("BroadcastRecv"));
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("T", T);
opBuilder.setAttr("group_size", groupSize);
opBuilder.setAttr("group_key", groupKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ private Options() {
public static <T extends TType> BroadcastSend<T> create(Scope scope, Operand<T> input, Long groupSize, Long groupKey, Long instanceKey, Shape shape, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("CollectiveBcastSend", scope.makeOpName("BroadcastSend"));
opBuilder.addInput(input.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("group_size", groupSize);
opBuilder.setAttr("group_key", groupKey);
opBuilder.setAttr("instance_key", instanceKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ private Options() {
@Endpoint(describeByClass = true)
public static Abort create(Scope scope, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("Abort", scope.makeOpName("Abort"));
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.errorMsg != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public static <T extends TNumber> All create(Scope scope, Operand<TBool> input,
OperationBuilder opBuilder = scope.env().opBuilder("All", scope.makeOpName("All"));
opBuilder.addInput(input.asOutput());
opBuilder.addInput(axis.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.keepDims != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public static <T extends TNumber> Any create(Scope scope, Operand<TBool> input,
OperationBuilder opBuilder = scope.env().opBuilder("Any", scope.makeOpName("Any"));
opBuilder.addInput(input.asOutput());
opBuilder.addInput(axis.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.keepDims != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public static AssertThat create(Scope scope, Operand<TBool> condition, Iterable<
OperationBuilder opBuilder = scope.env().opBuilder("Assert", scope.makeOpName("AssertThat"));
opBuilder.addInput(condition.asOutput());
opBuilder.addInputList(Operands.asOutputs(data));
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.summarize != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public static <T extends TType> Assign<T> create(Scope scope, Operand<T> ref, Op
OperationBuilder opBuilder = scope.env().opBuilder("Assign", scope.makeOpName("Assign"));
opBuilder.addInput(ref.asOutput());
opBuilder.addInput(value.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.validateShape != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public static <T extends TType> AssignAdd<T> create(Scope scope, Operand<T> ref,
OperationBuilder opBuilder = scope.env().opBuilder("AssignAdd", scope.makeOpName("AssignAdd"));
opBuilder.addInput(ref.asOutput());
opBuilder.addInput(value.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.useLocking != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public static <T extends TType> AssignAddVariableOp create(Scope scope, Operand<
OperationBuilder opBuilder = scope.env().opBuilder("AssignAddVariableOp", scope.makeOpName("AssignAddVariableOp"));
opBuilder.addInput(resource.asOutput());
opBuilder.addInput(value.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new AssignAddVariableOp(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public static <T extends TType> AssignSub<T> create(Scope scope, Operand<T> ref,
OperationBuilder opBuilder = scope.env().opBuilder("AssignSub", scope.makeOpName("AssignSub"));
opBuilder.addInput(ref.asOutput());
opBuilder.addInput(value.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.useLocking != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public static <T extends TType> AssignSubVariableOp create(Scope scope, Operand<
OperationBuilder opBuilder = scope.env().opBuilder("AssignSubVariableOp", scope.makeOpName("AssignSubVariableOp"));
opBuilder.addInput(resource.asOutput());
opBuilder.addInput(value.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new AssignSubVariableOp(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public static <T extends TType> AssignVariableOp create(Scope scope, Operand<?>
OperationBuilder opBuilder = scope.env().opBuilder("AssignVariableOp", scope.makeOpName("AssignVariableOp"));
opBuilder.addInput(resource.asOutput());
opBuilder.addInput(value.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new AssignVariableOp(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ private Options() {
@Endpoint(describeByClass = true)
public static Barrier create(Scope scope, List<DataType<?>> componentTypes, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("Barrier", scope.makeOpName("Barrier"));
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
DataType[] componentTypesArray = new DataType[componentTypes.size()];
for (int i = 0; i < componentTypesArray.length; ++i) {
componentTypesArray[i] = componentTypes.get(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private Options() {
public static BarrierClose create(Scope scope, Operand<TString> handle, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("BarrierClose", scope.makeOpName("BarrierClose"));
opBuilder.addInput(handle.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
if (options != null) {
for (Options opts : options) {
if (opts.cancelPendingEnqueues != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public final class BarrierIncompleteSize extends RawOp implements Operand<TInt32
public static BarrierIncompleteSize create(Scope scope, Operand<TString> handle) {
OperationBuilder opBuilder = scope.env().opBuilder("BarrierIncompleteSize", scope.makeOpName("BarrierIncompleteSize"));
opBuilder.addInput(handle.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new BarrierIncompleteSize(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static <T extends TType> BarrierInsertMany create(Scope scope, Operand<TS
opBuilder.addInput(handle.asOutput());
opBuilder.addInput(keys.asOutput());
opBuilder.addInput(values.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("component_index", componentIndex);
return new BarrierInsertMany(opBuilder.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public final class BarrierReadySize extends RawOp implements Operand<TInt32> {
public static BarrierReadySize create(Scope scope, Operand<TString> handle) {
OperationBuilder opBuilder = scope.env().opBuilder("BarrierReadySize", scope.makeOpName("BarrierReadySize"));
opBuilder.addInput(handle.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new BarrierReadySize(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public static BarrierTakeMany create(Scope scope, Operand<TString> handle, Opera
OperationBuilder opBuilder = scope.env().opBuilder("BarrierTakeMany", scope.makeOpName("BarrierTakeMany"));
opBuilder.addInput(handle.asOutput());
opBuilder.addInput(numElements.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
DataType[] componentTypesArray = new DataType[componentTypes.size()];
for (int i = 0; i < componentTypesArray.length; ++i) {
componentTypesArray[i] = componentTypes.get(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ private Options() {
public static Batch create(Scope scope, Iterable<Operand<?>> inTensors, Long numBatchThreads, Long maxBatchSize, Long batchTimeoutMicros, Long gradTimeoutMicros, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("Batch", scope.makeOpName("Batch"));
opBuilder.addInputList(Operands.asOutputs(inTensors));
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("num_batch_threads", numBatchThreads);
opBuilder.setAttr("max_batch_size", maxBatchSize);
opBuilder.setAttr("batch_timeout_micros", batchTimeoutMicros);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public static <T extends TType, U extends TNumber> BatchToSpace<T> create(Scope
OperationBuilder opBuilder = scope.env().opBuilder("BatchToSpace", scope.makeOpName("BatchToSpace"));
opBuilder.addInput(input.asOutput());
opBuilder.addInput(crops.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("block_size", blockSize);
return new BatchToSpace<T>(opBuilder.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public static <T extends TType, U extends TNumber, V extends TNumber> BatchToSpa
opBuilder.addInput(input.asOutput());
opBuilder.addInput(blockShape.asOutput());
opBuilder.addInput(crops.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new BatchToSpaceNd<T>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public final class Bitcast<U extends TType> extends RawOp implements Operand<U>
public static <U extends TType, T extends TType> Bitcast<U> create(Scope scope, Operand<T> input, DataType<U> type) {
OperationBuilder opBuilder = scope.env().opBuilder("Bitcast", scope.makeOpName("Bitcast"));
opBuilder.addInput(input.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
opBuilder.setAttr("type", type);
return new Bitcast<U>(opBuilder.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static <T extends TNumber> BroadcastDynamicShape<T> create(Scope scope, O
OperationBuilder opBuilder = scope.env().opBuilder("BroadcastArgs", scope.makeOpName("BroadcastDynamicShape"));
opBuilder.addInput(s0.asOutput());
opBuilder.addInput(s1.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new BroadcastDynamicShape<T>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static <T extends TNumber> BroadcastGradientArgs<T> create(Scope scope, O
OperationBuilder opBuilder = scope.env().opBuilder("BroadcastGradientArgs", scope.makeOpName("BroadcastGradientArgs"));
opBuilder.addInput(s0.asOutput());
opBuilder.addInput(s1.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new BroadcastGradientArgs<T>(opBuilder.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public static <T extends TType, U extends TNumber> BroadcastTo<T> create(Scope s
OperationBuilder opBuilder = scope.env().opBuilder("BroadcastTo", scope.makeOpName("BroadcastTo"));
opBuilder.addInput(input.asOutput());
opBuilder.addInput(shape.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder = scope.apply(opBuilder);
return new BroadcastTo<T>(opBuilder.build());
}

Expand Down
Loading