Skip to content

Commit e013353

Browse files
authored
Add Regularizers 1 (#216)
1 parent 62fa275 commit e013353

File tree

11 files changed

+861
-47
lines changed

11 files changed

+861
-47
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.regularizers;
16+
17+
import org.tensorflow.op.Ops;
18+
19+
/**
20+
* A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator)
21+
* Regression, regularization penalty.
22+
*
23+
* <p>The L1 regularization penalty is computed as: <code>loss = l1 * reduceSum(abs(x))</code>
24+
*/
25+
public class L1 extends L1L2 {
26+
27+
/**
28+
* Create a regularizer that applies an L1 regularization penalty of {@link
29+
* #DEFAULT_REGULARIZATION_PENALTY}
30+
*
31+
* @param tf the TensorFlow Ops
32+
*/
33+
public L1(Ops tf) {
34+
this(tf, DEFAULT_REGULARIZATION_PENALTY);
35+
}
36+
37+
/**
38+
* Create a regularizer that applies an L1 regularization penalty
39+
*
40+
* @param tf the TensorFlow Ops
41+
* @param l1 the L1 regularization penalty
42+
* @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite.
43+
*/
44+
public L1(Ops tf, float l1) {
45+
super(tf, l1, 0f);
46+
}
47+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.regularizers;
16+
17+
import org.tensorflow.Operand;
18+
import org.tensorflow.framework.losses.impl.LossesHelper;
19+
import org.tensorflow.op.Ops;
20+
import org.tensorflow.types.family.TNumber;
21+
22+
/**
23+
* A regularizer that applies both L1 and L2 regularization penalties.
24+
*
25+
* <p>The L1 regularization penalty is computed as:
26+
*
27+
* <pre>loss = l1 * reduceSum(abs(x))</pre>
28+
*
29+
* <p>The L2 regularization penalty is computed as
30+
*
31+
* <pre>loss = l2 * reduceSum(square(x))</pre>
32+
*
33+
*/
34+
public class L1L2 extends Regularizer {
35+
36+
private final float l1;
37+
private final float l2;
38+
39+
/**
40+
* Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty
41+
*
42+
* @param tf the TensorFlow Ops
43+
*/
44+
public L1L2(Ops tf) {
45+
this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY);
46+
}
47+
48+
/**
49+
* Creates an L1L2 regularizer
50+
*
51+
* @param tf the TensorFlow Ops
52+
* @param l1 L1 regularization factor, if null it is set to 0.
53+
* @param l2 L2 regularization factor, if null it is set to 0.
54+
* @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN}
55+
* of {@link Float#isInfinite}
56+
*/
57+
public L1L2(Ops tf, float l1, float l2) {
58+
super(tf);
59+
if (Float.isNaN(l1) || Float.isInfinite(l1)) {
60+
throw new IllegalArgumentException(
61+
String.format(
62+
"L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value",
63+
l1));
64+
}
65+
this.l1 = l1;
66+
67+
if (Float.isNaN(l2) || Float.isInfinite(l2)) {
68+
throw new IllegalArgumentException(
69+
String.format(
70+
"L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value",
71+
l2));
72+
}
73+
this.l2 = l2;
74+
}
75+
76+
77+
/** {@inheritDoc} */
78+
@Override
79+
public <R extends TNumber> Operand<R> call(Operand<R> input) {
80+
Ops tf = getTF();
81+
if (this.getL1() == 0f && this.getL2() == 0f) {
82+
return tf.dtypes.cast(tf.constant(0), input.type());
83+
}
84+
Operand<R> regularization = tf.dtypes.cast(tf.constant(0), input.type());
85+
86+
if (this.getL1() != 0.f) {
87+
Operand<R> l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type());
88+
Operand<R> abs = tf.math.abs(input);
89+
Operand<R> reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input));
90+
regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum));
91+
}
92+
93+
if (this.getL2() != 0.f) {
94+
Operand<R> l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type());
95+
Operand<R> sqr = tf.math.square(input);
96+
Operand<R> reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input));
97+
regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum));
98+
}
99+
100+
return regularization;
101+
}
102+
103+
/**
104+
* Gets the L1 regularization factor
105+
*
106+
* @return the L1 regularization factor
107+
*/
108+
public float getL1() {
109+
return l1;
110+
}
111+
112+
/**
113+
* Gets the L2 regularization factor
114+
*
115+
* @return the L2 regularization factor
116+
*/
117+
public float getL2() {
118+
return l2;
119+
}
120+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.regularizers;
16+
17+
import org.tensorflow.op.Ops;
18+
19+
/**
20+
* A regularizer that applies a L2 (Ridge Regression) regularization penalty.
21+
*
22+
* <p>The L2 regularization penalty is computed as: <code>loss = l2 * reduceSum(square(x))</code>
23+
*/
24+
public class L2 extends L1L2 {
25+
26+
/**
27+
* Create a regularizer that applies an L2 regularization penalty of {@link
28+
* #DEFAULT_REGULARIZATION_PENALTY}
29+
*
30+
* @param tf the TensorFlow Ops
31+
*/
32+
public L2(Ops tf) {
33+
this(tf, DEFAULT_REGULARIZATION_PENALTY);
34+
}
35+
36+
/**
37+
* Create a regularizer that applies an L1 regularization penalty
38+
*
39+
* @param tf the TensorFlow Ops
40+
* @param l2 the L2 regularization penalty
41+
* @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite.
42+
*/
43+
public L2(Ops tf, float l2) {
44+
super(tf, 0f, l2);
45+
}
46+
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.regularizers;
16+
17+
import org.tensorflow.Operand;
18+
import org.tensorflow.framework.losses.Loss;
19+
import org.tensorflow.op.Ops;
20+
import org.tensorflow.types.family.TNumber;
21+
22+
/**
23+
* Base class for Regularizers
24+
*
25+
* <p>Regularizers allow you to apply penalties on layer parameters or layer activity during
26+
* optimization. These penalties are summed into the loss function that the network optimizes.
27+
*/
28+
public abstract class Regularizer {
29+
30+
public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f;
31+
32+
private final Ops tf;
33+
private final String name;
34+
35+
/**
36+
* Creates a Regularizer, using {@link Class#getSimpleName()} for the name
37+
*
38+
* @param tf the TensorFlow ops.
39+
*/
40+
protected Regularizer(Ops tf) {
41+
this(tf, null);
42+
}
43+
/**
44+
* Creates a Regularizer
45+
*
46+
* @param tf the TensorFlow ops.
47+
* @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the
48+
* name.
49+
*/
50+
protected Regularizer(Ops tf, String name) {
51+
this.tf = tf;
52+
this.name = name == null ? this.getClass().getSimpleName() : name;
53+
}
54+
55+
/**
56+
* Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only
57+
* sampleWeights are applied to the regularizer.
58+
*
59+
* @return this Regularizer as a Loss
60+
*/
61+
public Loss asLoss() {
62+
return new RegularizerLoss(this.tf, this);
63+
}
64+
65+
/**
66+
* Computes a regularization penalty from an input.
67+
*
68+
* @param input the weighted input
69+
* @return the result of computing the regularization penalty
70+
* @param <R> the data type of the input and result
71+
*/
72+
public abstract <R extends TNumber> Operand<R> call(Operand<R> input);
73+
74+
/**
75+
* Gets the TensorFlow Ops
76+
*
77+
* @return the TensorFlow Ops
78+
*/
79+
public Ops getTF() {
80+
return tf;
81+
}
82+
83+
/**
84+
* Gets the name for this regularizer
85+
*
86+
* @return the name for this regularizer
87+
*/
88+
public String getName() {
89+
return name;
90+
}
91+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.regularizers;
16+
17+
import org.tensorflow.Operand;
18+
import org.tensorflow.framework.losses.Loss;
19+
import org.tensorflow.op.Ops;
20+
import org.tensorflow.types.family.TNumber;
21+
22+
/**
23+
* A Regularizer call wrapped as a Loss instance
24+
*
25+
* <p>This class facilitates using a regularizer as a loss, only <code>sampleWeights</code> are
26+
* regularized.
27+
*/
28+
class RegularizerLoss extends Loss {
29+
30+
private final Regularizer regularizer;
31+
32+
/**
33+
* Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link
34+
* Loss#REDUCTION_DEFAULT}
35+
*
36+
* @param tf the TensorFlow Ops
37+
* @param regularizer the regularizer used to calculate the loss
38+
*/
39+
public RegularizerLoss(Ops tf, Regularizer regularizer) {
40+
this(tf, null, regularizer);
41+
}
42+
43+
/**
44+
* Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}
45+
*
46+
* @param tf the TensorFlow Ops
47+
* @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}.
48+
* @param regularizer the regularizer used to calculate the loss
49+
*/
50+
public RegularizerLoss(Ops tf, String name, Regularizer regularizer) {
51+
super(tf, name);
52+
this.regularizer = regularizer;
53+
}
54+
55+
/** {@inheritDoc} */
56+
@Override
57+
public <T extends TNumber> Operand<T> call(
58+
Operand<? extends TNumber> labels, Operand<T> predictions, Operand<T> sampleWeights) {
59+
if (sampleWeights == null) {
60+
throw new IllegalArgumentException("sampleWeights cannot be null");
61+
}
62+
return regularizer.call(sampleWeights);
63+
}
64+
}

0 commit comments

Comments
 (0)