Skip to content

Add option to run TensorFlow job on the preferred device (via Scope) #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
writer->EndLine();
}
}
// Add Device String
writer->Append("opBuilder.setDevice(scope.makeDeviceString());");
writer->EndLine();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @zaleslaw , I have another change to propose you that I just thought of. I think if we merge together scope.applyControlDependencies and this new step for adding the device string to the opBuilder would make the whole process of building an op simpler, especially with non-generated ones.

So what about that:

  1. We rename scope.applyControlDependencies(opBuilder) to scope.apply(opBuilder)
  2. In Scope.apply, we apply both control dependencies, the device string (and eventually more data if needed)

By doing so, we also won't need scope.makeDeviceString() anymore. Let me know what you think

// Add control dependencies, if any.
writer->Append("opBuilder = scope.applyControlDependencies(opBuilder);");
writer->EndLine();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@

import java.nio.charset.Charset;
import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.EagerSession;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;

import org.tensorflow.*;
import org.tensorflow.ndarray.BooleanNdArray;
import org.tensorflow.ndarray.ByteNdArray;
import org.tensorflow.ndarray.DoubleNdArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public final class Abs<T extends TNumber> extends RawOp implements Operand<T> {
@Endpoint(describeByClass = true)
public static <T extends TNumber> Abs<T> create(Scope scope, Operand<T> x) {
OperationBuilder opBuilder = scope.env().opBuilder("Abs", scope.makeOpName("Abs"));
opBuilder.setDevice(scope.makeDeviceString());
opBuilder.addInput(x.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
return new Abs<T>(opBuilder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ private static String EmptyOrWithPrefix(Integer i, String prefix) {
}

/** A Builder class for building {@link DeviceSpec} class. */
static class Builder {
public static class Builder {
private String job = null;
private Integer replica = null;
private Integer task = null;
Expand All @@ -123,27 +123,27 @@ static class Builder {

private Builder() {}

Builder job(String job) {
public Builder job(String job) {
this.job = job;
return this;
}

Builder replica(Integer replica) {
public Builder replica(Integer replica) {
this.replica = replica;
return this;
}

Builder task(Integer task) {
public Builder task(Integer task) {
this.task = task;
return this;
}

Builder deviceIndex(Integer deviceIndex) {
public Builder deviceIndex(Integer deviceIndex) {
this.deviceIndex = deviceIndex;
return this;
}

Builder deviceType(DeviceSpec.DeviceType deviceType) {
public Builder deviceType(DeviceSpec.DeviceType deviceType) {
this.deviceType = deviceType;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package org.tensorflow.op;

import java.util.ArrayList;

import org.tensorflow.DeviceSpec;
import org.tensorflow.ExecutionEnvironment;
import org.tensorflow.OperationBuilder;

Expand Down Expand Up @@ -83,7 +85,7 @@ public final class Scope {
* @param env The execution environment used by the scope.
*/
public Scope(ExecutionEnvironment env) {
this(env, new NameScope(), new ArrayList<>());
this(env, new NameScope(), new ArrayList<>(), DeviceSpec.newBuilder().build());
}

/** Returns the execution environment used by this scope. */
Expand All @@ -105,7 +107,7 @@ public ExecutionEnvironment env() {
* @throws IllegalArgumentException if the name is invalid
*/
public Scope withSubScope(String childScopeName) {
return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies);
return new Scope(env, nameScope.withSubScope(childScopeName), controlDependencies, deviceSpec);
}

/**
Expand All @@ -121,7 +123,12 @@ public Scope withSubScope(String childScopeName) {
* @throws IllegalArgumentException if the name is invalid
*/
public Scope withName(String opName) {
return new Scope(env, nameScope.withName(opName), controlDependencies);
return new Scope(env, nameScope.withName(opName), controlDependencies, deviceSpec);
}

// TODO: add docs
public Scope withDevice(DeviceSpec deviceSpec) {
return new Scope(env, nameScope, controlDependencies, deviceSpec);
}

/**
Expand Down Expand Up @@ -149,10 +156,11 @@ public String makeOpName(String defaultName) {
}

private Scope(
ExecutionEnvironment env, NameScope nameScope, Iterable<Op> controlDependencies) {
ExecutionEnvironment env, NameScope nameScope, Iterable<Op> controlDependencies, DeviceSpec deviceSpec) {
this.env = env;
this.nameScope = nameScope;
this.controlDependencies = controlDependencies;
this.deviceSpec = deviceSpec;
}

/**
Expand All @@ -165,7 +173,7 @@ private Scope(
* @return a new scope with the provided control dependencies
*/
public Scope withControlDependencies(Iterable<Op> controls) {
return new Scope(env, nameScope, controls);
return new Scope(env, nameScope, controls, deviceSpec);
}

/**
Expand All @@ -183,4 +191,9 @@ public OperationBuilder applyControlDependencies(OperationBuilder builder) {
private final ExecutionEnvironment env;
private final Iterable<Op> controlDependencies;
private final NameScope nameScope;
private final DeviceSpec deviceSpec;

public String makeDeviceString() {
return deviceSpec.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Collection;

import org.junit.jupiter.api.Test;
import org.tensorflow.AutoCloseableList;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.*;
import org.tensorflow.op.Ops;
import org.tensorflow.op.Scope;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffer;
Expand All @@ -42,6 +42,7 @@
import org.tensorflow.ndarray.LongNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
Expand Down Expand Up @@ -70,6 +71,46 @@ public void createInts() {
}
}

@Test
public void absDeviceSpec() {
ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance())
.setLogDevicePlacement(true)
.build();

try (Graph g = new Graph();
Session sess = new Session(g, config)) {

Ops tf = Ops.create(g).withSubScope("anotherJob");

Tensor<TInt32> a = TInt32.scalarOf(-1);

Output<TInt32> aOps = g
.opBuilder("Const", "aOps")
.setAttr("dtype", a.dataType())
.setAttr("value", a)
.setDevice("/job:localhost/replica:0/task:0/device:CPU:0")
.build()
.output(0);
DeviceSpec deviceSpec = DeviceSpec.newBuilder()
.job("localhost")
.replica(0)
.task(1)
.deviceType(DeviceSpec.DeviceType.CPU)
.build();

//DeviceSpec deviceSpec = DeviceSpec.newBuilder().build();
Output<TInt32> absOps = tf.withName("ABS_OPS")
.withDevice(deviceSpec)
.math.sub(aOps, aOps).asOutput();


try (AutoCloseableList<Tensor<?>> t =
new AutoCloseableList<>(sess.runner().fetch(absOps).run())) {
assertEquals(1, t.get(0).rawData().asInts().getObject(0));
}
}
}

@Test
public void createFloats() {
FloatDataBuffer buffer = DataBuffers.of(1.0f, 2.0f, 3.0f, 4.0f);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ private static class OpsSpec {
ClassName.get("org.tensorflow", "ExecutionEnvironment");
private static final TypeName T_EAGER_SESSION = ClassName.get("org.tensorflow", "EagerSession");
private static final TypeName T_STRING = ClassName.get(String.class);
private static final TypeName T_DEVICE_SPEC = ClassName.get("org.tensorflow", "DeviceSpec");

private static final String LICENSE =
"Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n"
Expand Down Expand Up @@ -537,6 +538,18 @@ private static TypeSpec buildTopClass(OpsSpec spec) {
T_SCOPE)
.build());

opsBuilder.addMethod(
MethodSpec.methodBuilder("withDevice")
.addModifiers(Modifier.PUBLIC)
.addParameter(T_DEVICE_SPEC, "deviceSpec")
.returns(T_OPS)
.addStatement("return new Ops(scope.withDevice(deviceSpec))")
.addJavadoc(
"Returns an API that uses the provided DeviceSpec for an op.\n\n"
+ "@see {@link $T#withDevice(DeviceSpec)}\n",
T_SCOPE)
.build());

opsBuilder.addMethod(
MethodSpec.methodBuilder("withControlDependencies")
.addModifiers(Modifier.PUBLIC)
Expand Down