Skip to content

Metrics Phase 1 #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 67 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
c57a2e7
Merge pull request #3 from tensorflow/master
JimClarke5 Oct 8, 2020
09fc07e
Merge pull request #4 from tensorflow/master
JimClarke5 Oct 27, 2020
a99dcb4
Merge pull request #5 from tensorflow/master
JimClarke5 Nov 17, 2020
ba294ea
Merge pull request #6 from tensorflow/master
JimClarke5 Nov 19, 2020
04f419a
Merge pull request #7 from tensorflow/master
JimClarke5 Dec 30, 2020
ad466ee
Initial checkin
JimClarke5 Jan 1, 2021
092b47d
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
4887b5b
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
04eeea6
JavaDoc cleanup
JimClarke5 Jan 1, 2021
dcb2414
Javadoc fixes
JimClarke5 Jan 3, 2021
82f18bf
Change LossInterface to LossMetric.
JimClarke5 Jan 5, 2021
9aa1511
Removed hashmap for variables, they are not needed as the variables o…
JimClarke5 Jan 7, 2021
1097722
reformat code
JimClarke5 Jan 7, 2021
bc0f468
Add tests for assertBroadcastable
JimClarke5 Jan 11, 2021
41876d5
Change type to resultType
JimClarke5 Jan 11, 2021
61af528
Added V data type for sampleWeights so that it is not forced to be th…
JimClarke5 Jan 11, 2021
c121c07
change 'type' to 'resultType'
JimClarke5 Jan 11, 2021
e9ee98f
clean up mean and fix assert assertBroadcastable
JimClarke5 Jan 11, 2021
9788983
fix error message
JimClarke5 Jan 11, 2021
8857a66
Change sampleWeights to have its own generic type <S extends TNumber>
JimClarke5 Jan 12, 2021
34a779f
Add commment about invalid tests expecting IllegalArgumentExceptions
JimClarke5 Jan 12, 2021
748f16d
Add this exception instead of the more generic IllegalArgumentExcepti…
JimClarke5 Jan 12, 2021
212541b
change IllegalArgumentException to NotBroadcastableException.
JimClarke5 Jan 12, 2021
f0d72d2
reformat code
JimClarke5 Jan 12, 2021
8b49c60
Fis=x Javadoc
JimClarke5 Jan 13, 2021
20c6e98
Fix Reduce to use boradcastWeights,
JimClarke5 Jan 17, 2021
d3d7ee9
Added comment to count to indicate that it may be weighted.
JimClarke5 Jan 17, 2021
fe86b0b
Added SetsOps and fixed AssertBroadcastable to use SetsOps methods,
JimClarke5 Jan 19, 2021
0edd114
Fixed based on various PR comments.
JimClarke5 Jan 19, 2021
7d78fd3
Deleted, no longer needed after change to Variable handling in Metrics.
JimClarke5 Jan 19, 2021
02e7ebf
Merge pull request #8 from tensorflow/master
JimClarke5 Jan 29, 2021
af1b49f
Nicer error messages for mode-forbidden ops (#169)
rnett Jan 2, 2021
7732601
Initialization imprvements (#178)
rnett Jan 7, 2021
a737334
Clairify tensorOf lifetime requirements (#190)
rnett Jan 19, 2021
253cc73
Remove extra generics from op generation (#193)
rnett Jan 26, 2021
22cb5b2
Add Java 11 support - Initial Phase (#185)
JimClarke5 Jan 26, 2021
4d1aa20
Update manual ops for new codegen (#196)
rnett Jan 26, 2021
2b7f6ed
Fix Losses to use CHANNELS_FIRST/LAST for CategoricalCrossentropy
JimClarke5 Jan 20, 2021
3800b71
Fix SetOps to properly convert sparse tensor to dense tensor using tf…
JimClarke5 Jan 30, 2021
3045999
Initial checkin
JimClarke5 Jan 1, 2021
9eb5adf
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
187c17c
Initial checkin and sync with master
JimClarke5 Jan 1, 2021
050fe28
JavaDoc cleanup
JimClarke5 Jan 1, 2021
b640406
Javadoc fixes
JimClarke5 Jan 3, 2021
3715513
Change LossInterface to LossMetric.
JimClarke5 Jan 5, 2021
a1c1976
Removed hashmap for variables, they are not needed as the variables o…
JimClarke5 Jan 7, 2021
6641fca
reformat code
JimClarke5 Jan 7, 2021
fa76043
Add tests for assertBroadcastable
JimClarke5 Jan 11, 2021
e136f4d
Change type to resultType
JimClarke5 Jan 11, 2021
e00f2ef
Added V data type for sampleWeights so that it is not forced to be th…
JimClarke5 Jan 11, 2021
bc6c64b
change 'type' to 'resultType'
JimClarke5 Jan 11, 2021
02da963
clean up mean and fix assert assertBroadcastable
JimClarke5 Jan 11, 2021
44cdc35
fix error message
JimClarke5 Jan 11, 2021
49370b9
Change sampleWeights to have its own generic type <S extends TNumber>
JimClarke5 Jan 12, 2021
24b4125
Add commment about invalid tests expecting IllegalArgumentExceptions
JimClarke5 Jan 12, 2021
43c6b7b
Add this exception instead of the more generic IllegalArgumentExcepti…
JimClarke5 Jan 12, 2021
78e9dab
change IllegalArgumentException to NotBroadcastableException.
JimClarke5 Jan 12, 2021
5508969
reformat code
JimClarke5 Jan 12, 2021
c662524
Fis=x Javadoc
JimClarke5 Jan 13, 2021
512a153
Fix Reduce to use boradcastWeights,
JimClarke5 Jan 17, 2021
0663c3c
Added comment to count to indicate that it may be weighted.
JimClarke5 Jan 17, 2021
122e06b
Added SetsOps and fixed AssertBroadcastable to use SetsOps methods,
JimClarke5 Jan 19, 2021
b7b14b1
Fixed based on various PR comments.
JimClarke5 Jan 19, 2021
13639d3
Deleted, no longer needed after change to Variable handling in Metrics.
JimClarke5 Jan 19, 2021
561322f
Fix Losses to use CHANNELS_FIRST/LAST for CategoricalCrossentropy
JimClarke5 Jan 20, 2021
2a13012
Fix SetOps to properly convert sparse tensor to dense tensor using tf…
JimClarke5 Jan 30, 2021
36f3a69
Merge remote-tracking branch 'upstream/metrics1' into metrics1
JimClarke5 Jan 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class BinaryCrossentropy<U extends TNumber, T extends TNumber>
*
* @param tf the TensorFlow Ops
* @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
* @param fromLogits Whether to interpret predictions as a tensor of logit values or not.
* @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution.
* @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When &gt; 0,
* compute the loss between the predicted labels and a smoothed version of the true labels,
* where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class CategoricalCrossentropy<U extends TNumber, T extends TNumber>
*
* @param tf the TensorFlow Ops
* @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
* @param fromLogits Whether to interpret predictions as a tensor of logit values or not.
* @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution.
* @param labelSmoothing value used to smooth labels, When &gt; 0, label values are smoothed,
* meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code>
* means that we will use a value of <code>0.1</code> for label <code>0</code> and <code>0.9
Expand All @@ -68,7 +68,7 @@ public CategoricalCrossentropy(
*
* @param tf the TensorFlow Ops
* @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
* @param fromLogits Whether to interpret predictions as a tensor of logit values or not.
* @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution.
* @param labelSmoothing value used to smooth labels, When &gt; 0, label values are smoothed,
* meaning the confidence on label values are relaxed. e.g. <code>labelSmoothing=0.2</code>
* means that we will use a value of <code>0.1</code> for label <code>0</code> and <code>0.9
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,19 @@ public static <T extends TNumber, U extends TNumber> Operand<T> topKCategoricalA
* @param tf the TensorFlow Ops
* @param labels The ground truth values.
* @param predictions The prediction values.
* @param axis The dimension along which the cosine similarity is computed.
* @param axes The dimensions along which the cosine similarity is computed.
* @param <U> the data type for the labels
* @param <T> the data type for the predictions and result
* @return Cosine similarity value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we should standardize on "similarity" or "proximity"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vote that we make Metrics CosineSimularity and for losses CosineSimularityLoss. My justification was previously mentioned in this thread.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it.

*/
public static <T extends TNumber, U extends TNumber> Operand<T> cosineProximity(
Ops tf, Operand<U> labels, Operand<T> predictions, int[] axis) {
Ops tf, Operand<U> labels, Operand<T> predictions, int[] axes) {
Operand<T> labelsNorm = CastHelper.cast(tf, labels, predictions.type());
labelsNorm = l2Normalize(tf, labelsNorm, axis);
labelsNorm = l2Normalize(tf, labelsNorm, axes);

Operand<T> predictionsNorm = l2Normalize(tf, predictions, axis);
Operand<T> predictionsNorm = l2Normalize(tf, predictions, axes);
Operand<T> mathMul = tf.math.mul(labelsNorm, predictionsNorm);
return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE));
return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,30 @@ public class SparseCategoricalCrossentropy<U extends TNumber, T extends TNumber>
extends MeanMetricWrapper<U, T> implements LossMetric<T> {

private final boolean fromLogits;
private final int axes;
private final int axis;

/**
* Creates a SparseCategoricalCrossentropy metric
*
* @param tf the TensorFlow Ops
* @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
* @param fromLogits Whether to interpret predictions as a tensor of logit values or not.
* @param axes The dimension along which the entropy is computed.
* @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution.
* @param axis The dimension along which the entropy is computed.
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the type for the variables and result
*/
public SparseCategoricalCrossentropy(
Ops tf, String name, boolean fromLogits, int axes, long seed, Class<T> type) {
Ops tf, String name, boolean fromLogits, int axis, long seed, Class<T> type) {
super(tf, name, seed, type);
setLoss(this);
this.fromLogits = fromLogits;
this.axes = axes;
this.axis = axis;
}

/** {@inheritDoc} */
@Override
public <V extends TNumber> Operand<T> call(Operand<V> labels, Operand<T> predictions) {
return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axes);
return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ protected void setLoss(LossMetric<T> loss) {
* [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of
* predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss
* functions reduce by 1 dimension, usually axis=-1.)
* @param <V> the datatype of the predictions
* @param <V> the datatype of the labels
* @param <S> the data type for sampleWeights
* @return a List of control operations that updates the Mean state variables.
*/
Expand Down

This file was deleted.

Loading