Skip to content

[Type Refactor] Use Java type system instead of custom one for typing tensors #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ class Type {
static Type IterableOf(const Type& type) {
return Interface("Iterable").add_parameter(type);
}
static Type DataTypeOf(const Type& type) {
return Class("DataType", "org.tensorflow").add_parameter(type);
}
static Type ForDataType(DataType data_type) {
switch (data_type) {
case DataType::DT_BOOL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,39 +103,55 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
}
for (const AttributeSpec& attribute : op.attributes()) {
out->push_back(attribute.var().type());
out->push_back(attribute.jni_type());
if (attribute.jni_type().name() == "DataType") {
out->push_back(Type::Class("Operands", "org.tensorflow.op"));
} else {
out->push_back(attribute.jni_type());
}
if (attribute.has_default_value() &&
attribute.type().kind() == Type::GENERIC) {
out->push_back(Type::ForDataType(attribute.default_value()->type()));
}
}
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
if (optional_attribute.jni_type().name() == "DataType") {
out->push_back(Type::Class("Operands", "org.tensorflow.op"));
} else {
out->push_back(optional_attribute.jni_type());
}
out->push_back(optional_attribute.var().type());
}
}

void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
SourceWriter* writer) {
string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
if (attr.iterable()) {
string array_name = attr.var().name() + "Array";
writer->AppendType(attr.jni_type())
.Append("[] " + array_name + " = new ")
.AppendType(attr.jni_type())
.Append("[" + var_name + ".size()];")
.EndLine()
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
.Append(array_name + "[i] = ");
writer->Append(var_name + ".get(i);");
writer->EndLine()
.EndBlock()
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
.Append(array_name + ");")
.EndLine();
} else {
if (attr.jni_type().name() == "DataType") {
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
.Append(var_name + ");")
.EndLine();
.Append(attr.iterable() ? "Operands.toDataTypes(" : "Operands.toDataType(")
.Append(attr.var().name() + "));")
.EndLine();
} else {
if (attr.iterable()) {
string array_name = attr.var().name() + "Array";
writer->AppendType(attr.jni_type())
.Append("[] " + array_name + " = new ")
.AppendType(attr.jni_type())
.Append("[" + var_name + ".size()];")
.EndLine()
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
.Append(array_name + "[i] = ");
writer->Append(var_name + ".get(i);");
writer->EndLine()
.EndBlock()
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
.Append(array_name + ");")
.EndLine();
} else {
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
.Append(var_name + ");")
.EndLine();
}
}
}

Expand Down Expand Up @@ -177,7 +193,7 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
if (attr.type().kind() == Type::GENERIC &&
default_types.find(attr.type().name()) != default_types.end()) {
factory_statement << default_types.at(attr.type().name()).name()
<< ".DTYPE";
<< ".class";
} else {
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
factory_statement << attr.var().name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,19 @@ class TypeResolver {
std::pair<Type, Type> MakeTypePair(const Type& type) {
return std::make_pair(type, type);
}
Type NextGeneric() {
Type NextGeneric(const OpDef_AttrDef& attr_def) {
char generic_letter = next_generic_letter_++;
if (next_generic_letter_ > 'Z') {
next_generic_letter_ = 'A';
}
return Type::Generic(string(1, generic_letter))
.add_supertype(Type::Class("TType", "org.tensorflow.types.family"));
return Type::Generic(string(1, generic_letter));
}
Type TypeFamilyOf(const OpDef_AttrDef& attr_def) {
// TODO(karllessard) support more type families
if (IsRealNumbers(attr_def.allowed_values())) {
return Type::Interface("TNumber", "org.tensorflow.types.family");
}
return Type::Interface("TType", "org.tensorflow.types.family");
}
};

Expand Down Expand Up @@ -155,11 +161,9 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow"));

} else if (attr_type == "type") {
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
if (IsRealNumbers(attr_def.allowed_values())) {
type.add_supertype(Type::Class("TNumber", "org.tensorflow.types.family"));
}
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
Type type = *iterable_out ? Type::Wildcard() : NextGeneric(attr_def);
type.add_supertype(TypeFamilyOf(attr_def));
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow.proto.framework"));

} else {
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
Expand Down Expand Up @@ -305,7 +309,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
Type var_type = types.first.kind() == Type::GENERIC
? Type::DataTypeOf(types.first)
? Type::ClassOf(types.first)
: types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
SourceWriter& SourceWriter::AppendType(const Type& type) {
if (type.wildcard()) {
Append("?");
WriteTypeBounds(type.supertypes());
} else {
Append(type.name());
if (!type.parameters().empty()) {
Expand Down Expand Up @@ -321,14 +322,27 @@ SourceWriter& SourceWriter::WriteGenerics(
Append(", ");
}
Append(pt->name());
if (!pt->supertypes().empty()) {
Append(" extends ").AppendType(pt->supertypes().front());
}
WriteTypeBounds(pt->supertypes());
first = false;
}
return Append(">");
}

SourceWriter& SourceWriter::WriteTypeBounds(
const std::list<Type>& bounds) {
bool first = true;
for (const Type& bound : bounds) {
if (first) {
Append(" extends ");
first = false;
} else {
Append(" & ");
}
AppendType(bound);
}
return *this;
}

SourceWriter::GenericNamespace* SourceWriter::PushGenericNamespace(
int modifiers) {
GenericNamespace* generic_namespace;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class SourceWriter {
SourceWriter& WriteJavadoc(const Javadoc& javadoc);
SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations);
SourceWriter& WriteGenerics(const std::list<const Type*>& generics);
SourceWriter& WriteTypeBounds(const std::list<Type>& bounds);
GenericNamespace* PushGenericNamespace(int modifiers);
void PopGenericNamespace();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
package org.tensorflow.op;

import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.data.experimental.DataServiceDataset;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;

/**
* An API for building {@code data.experimental} operations as {@link Op Op}s
Expand Down Expand Up @@ -57,7 +57,7 @@ public final class DataExperimentalOps {
public DataServiceDataset dataServiceDataset(Operand<TInt64> datasetId,
Operand<TString> processingMode, Operand<TString> address, Operand<TString> protocol,
Operand<TString> jobName, Operand<TInt64> maxOutstandingRequests, Operand<?> iterationCounter,
List<DataType<?>> outputTypes, List<Shape> outputShapes,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes,
DataServiceDataset.Options... options) {
return DataServiceDataset.create(scope, datasetId, processingMode, address, protocol, jobName, maxOutstandingRequests, iterationCounter, outputTypes, outputShapes, options);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.tensorflow.op;

import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.data.AnonymousIterator;
Expand Down Expand Up @@ -49,6 +48,7 @@
import org.tensorflow.types.TBool;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;

/**
* An API for building {@code data} operations as {@link Op Op}s
Expand All @@ -75,7 +75,7 @@ public final class DataOps {
* @param outputShapes
* @return a new instance of AnonymousIterator
*/
public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
public AnonymousIterator anonymousIterator(List<Class<? extends TType>> outputTypes,
List<Shape> outputShapes) {
return AnonymousIterator.create(scope, outputTypes, outputShapes);
}
Expand All @@ -93,8 +93,8 @@ public AnonymousIterator anonymousIterator(List<DataType<?>> outputTypes,
* @return a new instance of BatchDataset
*/
public BatchDataset batchDataset(Operand<?> inputDataset, Operand<TInt64> batchSize,
Operand<TBool> dropRemainder, List<DataType<?>> outputTypes, List<Shape> outputShapes,
BatchDataset.Options... options) {
Operand<TBool> dropRemainder, List<Class<? extends TType>> outputTypes,
List<Shape> outputShapes, BatchDataset.Options... options) {
return BatchDataset.create(scope, inputDataset, batchSize, dropRemainder, outputTypes, outputShapes, options);
}

Expand Down Expand Up @@ -129,7 +129,7 @@ public CSVDataset cSVDataset(Operand<TString> filenames, Operand<TString> compre
* @return a new instance of ConcatenateDataset
*/
public ConcatenateDataset concatenateDataset(Operand<?> inputDataset, Operand<?> anotherDataset,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return ConcatenateDataset.create(scope, inputDataset, anotherDataset, outputTypes, outputShapes);
}

Expand Down Expand Up @@ -164,8 +164,8 @@ public DeserializeIterator deserializeIterator(Operand<?> resourceHandle, Operan
* @param outputShapes
* @return a new instance of Iterator
*/
public Iterator iterator(String sharedName, String container, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public Iterator iterator(String sharedName, String container,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return Iterator.create(scope, sharedName, container, outputTypes, outputShapes);
}

Expand All @@ -177,8 +177,8 @@ public Iterator iterator(String sharedName, String container, List<DataType<?>>
* @param outputShapes
* @return a new instance of IteratorGetNext
*/
public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public IteratorGetNext iteratorGetNext(Operand<?> iterator,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return IteratorGetNext.create(scope, iterator, outputTypes, outputShapes);
}

Expand All @@ -191,7 +191,7 @@ public IteratorGetNext iteratorGetNext(Operand<?> iterator, List<DataType<?>> ou
* @return a new instance of IteratorGetNextAsOptional
*/
public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return IteratorGetNextAsOptional.create(scope, iterator, outputTypes, outputShapes);
}

Expand All @@ -208,8 +208,8 @@ public IteratorGetNextAsOptional iteratorGetNextAsOptional(Operand<?> iterator,
* @param outputShapes
* @return a new instance of IteratorGetNextSync
*/
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public IteratorGetNextSync iteratorGetNextSync(Operand<?> iterator,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return IteratorGetNextSync.create(scope, iterator, outputTypes, outputShapes);
}

Expand Down Expand Up @@ -255,8 +255,8 @@ public OptionalFromValue optionalFromValue(Iterable<Operand<?>> components) {
* @param outputShapes
* @return a new instance of OptionalGetValue
*/
public OptionalGetValue optionalGetValue(Operand<?> optional, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public OptionalGetValue optionalGetValue(Operand<?> optional,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return OptionalGetValue.create(scope, optional, outputTypes, outputShapes);
}

Expand Down Expand Up @@ -290,7 +290,7 @@ public OptionalNone optionalNone() {
* @return a new instance of RangeDataset
*/
public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
Operand<TInt64> step, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
Operand<TInt64> step, List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return RangeDataset.create(scope, start, stop, step, outputTypes, outputShapes);
}

Expand All @@ -305,7 +305,7 @@ public RangeDataset rangeDataset(Operand<TInt64> start, Operand<TInt64> stop,
* @return a new instance of RepeatDataset
*/
public RepeatDataset repeatDataset(Operand<?> inputDataset, Operand<TInt64> count,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return RepeatDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
}

Expand All @@ -332,7 +332,7 @@ public SerializeIterator serializeIterator(Operand<?> resourceHandle,
* @return a new instance of SkipDataset
*/
public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return SkipDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
}

Expand All @@ -348,7 +348,7 @@ public SkipDataset skipDataset(Operand<?> inputDataset, Operand<TInt64> count,
* @return a new instance of TakeDataset
*/
public TakeDataset takeDataset(Operand<?> inputDataset, Operand<TInt64> count,
List<DataType<?>> outputTypes, List<Shape> outputShapes) {
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return TakeDataset.create(scope, inputDataset, count, outputTypes, outputShapes);
}

Expand Down Expand Up @@ -409,8 +409,8 @@ public TfRecordDataset tfRecordDataset(Operand<TString> filenames,
* @param outputShapes
* @return a new instance of ZipDataset
*/
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets, List<DataType<?>> outputTypes,
List<Shape> outputShapes) {
public ZipDataset zipDataset(Iterable<Operand<?>> inputDatasets,
List<Class<? extends TType>> outputTypes, List<Shape> outputShapes) {
return ZipDataset.create(scope, inputDatasets, outputTypes, outputShapes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
//
package org.tensorflow.op;

import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.op.dtypes.AsString;
import org.tensorflow.op.dtypes.Cast;
Expand Down Expand Up @@ -73,7 +72,7 @@ public <T extends TType> AsString asString(Operand<T> input, AsString.Options...
* @param options carries optional attributes values
* @return a new instance of Cast
*/
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, DataType<U> DstT,
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, Class<U> DstT,
Cast.Options... options) {
return Cast.create(scope, x, DstT, options);
}
Expand Down Expand Up @@ -102,7 +101,7 @@ public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, DataType<U>
* @return a new instance of Complex
*/
public <U extends TType, T extends TNumber> Complex<U> complex(Operand<T> real, Operand<T> imag,
DataType<U> Tout) {
Class<U> Tout) {
return Complex.create(scope, real, imag, Tout);
}

Expand Down
Loading