Skip to content

Commit f101e15

Browse files
vpaprotskSandhya Viswanathan
authored and
Sandhya Viswanathan
committed
8333583: Crypto-XDH.generateSecret regression after JDK-8329538
Reviewed-by: sviswanathan, kvn, ascarpino
1 parent b3bf31a commit f101e15

File tree

9 files changed

+73
-90
lines changed

9 files changed

+73
-90
lines changed

make/jdk/src/classes/build/tools/intpoly/FieldGen.java

+2-8
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,7 @@ private String generate(FieldParams params) throws IOException {
778778
result.appendLine("}");
779779

780780
result.appendLine("@Override");
781-
result.appendLine("protected int mult(long[] a, long[] b, long[] r) {");
781+
result.appendLine("protected void mult(long[] a, long[] b, long[] r) {");
782782
result.incrIndent();
783783
for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) {
784784
result.appendIndent();
@@ -804,9 +804,6 @@ private String generate(FieldParams params) throws IOException {
804804
}
805805
}
806806
result.append(");\n");
807-
result.appendIndent();
808-
result.append("return 0;");
809-
result.appendLine();
810807
result.decrIndent();
811808
result.appendLine("}");
812809

@@ -836,7 +833,7 @@ private String generate(FieldParams params) throws IOException {
836833
// }
837834
// }
838835
result.appendLine("@Override");
839-
result.appendLine("protected int square(long[] a, long[] r) {");
836+
result.appendLine("protected void square(long[] a, long[] r) {");
840837
result.incrIndent();
841838
for (int i = 0; i < 2 * params.getNumLimbs() - 1; i++) {
842839
result.appendIndent();
@@ -877,9 +874,6 @@ private String generate(FieldParams params) throws IOException {
877874
}
878875
}
879876
result.append(");\n");
880-
result.appendIndent();
881-
result.append("return 0;");
882-
result.appendLine();
883877
result.decrIndent();
884878
result.appendLine("}");
885879

src/hotspot/cpu/x86/stubGenerator_x86_64_poly_mont.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,6 @@ address StubGenerator::generate_intpoly_montgomeryMult_P256() {
249249
const Register tmp = r9;
250250

251251
montgomeryMultiply(aLimbs, bLimbs, rLimbs, tmp, _masm);
252-
__ mov64(rax, 0x1); // Return 1 (Fig. 5, Step 6 [1] skipped in montgomeryMultiply)
253252

254253
__ leave();
255254
__ ret(0);

src/hotspot/share/classfile/vmIntrinsics.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,8 @@ class methodHandle;
529529
/* support for sun.security.util.math.intpoly.MontgomeryIntegerPolynomialP256 */ \
530530
do_class(sun_security_util_math_intpoly_MontgomeryIntegerPolynomialP256, "sun/security/util/math/intpoly/MontgomeryIntegerPolynomialP256") \
531531
do_intrinsic(_intpoly_montgomeryMult_P256, sun_security_util_math_intpoly_MontgomeryIntegerPolynomialP256, intPolyMult_name, intPolyMult_signature, F_R) \
532-
do_name(intPolyMult_name, "mult") \
533-
do_signature(intPolyMult_signature, "([J[J[J)I") \
532+
do_name(intPolyMult_name, "multImpl") \
533+
do_signature(intPolyMult_signature, "([J[J[J)V") \
534534
\
535535
do_class(sun_security_util_math_intpoly_IntegerPolynomial, "sun/security/util/math/intpoly/IntegerPolynomial") \
536536
do_intrinsic(_intpoly_assign, sun_security_util_math_intpoly_IntegerPolynomial, intPolyAssign_name, intPolyAssign_signature, F_S) \

src/hotspot/share/opto/library_call.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -7580,8 +7580,6 @@ bool LibraryCallKit::inline_intpoly_montgomeryMult_P256() {
75807580
OptoRuntime::intpoly_montgomeryMult_P256_Type(),
75817581
stubAddr, stubName, TypePtr::BOTTOM,
75827582
a_start, b_start, r_start);
7583-
Node* result = _gvn.transform(new ProjNode(call, TypeFunc::Parms));
7584-
set_result(result);
75857583
return true;
75867584
}
75877585

src/hotspot/share/opto/runtime.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1435,8 +1435,8 @@ const TypeFunc* OptoRuntime::intpoly_montgomeryMult_P256_Type() {
14351435

14361436
// result type needed
14371437
fields = TypeTuple::fields(1);
1438-
fields[TypeFunc::Parms + 0] = TypeInt::INT; // carry bits in output
1439-
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms+1, fields);
1438+
fields[TypeFunc::Parms + 0] = nullptr; // void
1439+
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms, fields);
14401440
return TypeFunc::make(domain, range);
14411441
}
14421442

@@ -1455,7 +1455,7 @@ const TypeFunc* OptoRuntime::intpoly_assign_Type() {
14551455

14561456
// result type needed
14571457
fields = TypeTuple::fields(1);
1458-
fields[TypeFunc::Parms + 0] = NULL; // void
1458+
fields[TypeFunc::Parms + 0] = nullptr; // void
14591459
const TypeTuple* range = TypeTuple::make(TypeFunc::Parms, fields);
14601460
return TypeFunc::make(domain, range);
14611461
}

src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial.java

+13-11
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,11 @@ public abstract sealed class IntegerPolynomial implements IntegerFieldModuloP
9090
* store the result in an IntegerPolynomial representation in a. Requires
9191
* that a.length == numLimbs.
9292
*/
93-
protected int multByInt(long[] a, long b) {
93+
protected void multByInt(long[] a, long b) {
9494
for (int i = 0; i < a.length; i++) {
9595
a[i] *= b;
9696
}
9797
reduce(a);
98-
return 0;
9998
}
10099

101100
/**
@@ -104,15 +103,15 @@ protected int multByInt(long[] a, long b) {
104103
* a.length == b.length == r.length == numLimbs. It is allowed for a and r
105104
* to be the same array.
106105
*/
107-
protected abstract int mult(long[] a, long[] b, long[] r);
106+
protected abstract void mult(long[] a, long[] b, long[] r);
108107

109108
/**
110109
* Multiply an IntegerPolynomial representation (a) with itself and store
111110
* the result in an IntegerPolynomialRepresentation (r). Requires that
112111
* a.length == r.length == numLimbs. It is allowed for a and r
113112
* to be the same array.
114113
*/
115-
protected abstract int square(long[] a, long[] r);
114+
protected abstract void square(long[] a, long[] r);
116115

117116
IntegerPolynomial(int bitsPerLimb,
118117
int numLimbs,
@@ -622,8 +621,8 @@ public ImmutableElement multiply(IntegerModuloP genB) {
622621
}
623622

624623
long[] newLimbs = new long[limbs.length];
625-
int numAdds = mult(limbs, b.limbs, newLimbs);
626-
return new ImmutableElement(newLimbs, numAdds);
624+
mult(limbs, b.limbs, newLimbs);
625+
return new ImmutableElement(newLimbs, 0);
627626
}
628627

629628
@Override
@@ -635,8 +634,8 @@ public ImmutableElement square() {
635634
}
636635

637636
long[] newLimbs = new long[limbs.length];
638-
int numAdds = IntegerPolynomial.this.square(limbs, newLimbs);
639-
return new ImmutableElement(newLimbs, numAdds);
637+
IntegerPolynomial.this.square(limbs, newLimbs);
638+
return new ImmutableElement(newLimbs, 0);
640639
}
641640

642641
public void addModPowerTwo(IntegerModuloP arg, byte[] result) {
@@ -751,7 +750,8 @@ public MutableElement setProduct(IntegerModuloP genB) {
751750
b.numAdds = 0;
752751
}
753752

754-
numAdds = mult(limbs, b.limbs, limbs);
753+
mult(limbs, b.limbs, limbs);
754+
numAdds = 0;
755755
return this;
756756
}
757757

@@ -764,7 +764,8 @@ public MutableElement setProduct(SmallValue v) {
764764
}
765765

766766
int value = ((Limb)v).value;
767-
numAdds += multByInt(limbs, value);
767+
multByInt(limbs, value);
768+
numAdds = 0;
768769
return this;
769770
}
770771

@@ -824,7 +825,8 @@ public MutableElement setSquare() {
824825
numAdds = 0;
825826
}
826827

827-
numAdds = IntegerPolynomial.this.square(limbs, limbs);
828+
IntegerPolynomial.this.square(limbs, limbs);
829+
numAdds = 0;
828830
return this;
829831
}
830832

src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomial1305.java

+2-4
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ private IntegerPolynomial1305() {
5050
super(BITS_PER_LIMB, NUM_LIMBS, 1, MODULUS);
5151
}
5252

53-
protected int mult(long[] a, long[] b, long[] r) {
53+
protected void mult(long[] a, long[] b, long[] r) {
5454

5555
// Use grade-school multiplication into primitives to avoid the
5656
// temporary array allocation. This is equivalent to the following
@@ -73,7 +73,6 @@ protected int mult(long[] a, long[] b, long[] r) {
7373
long c8 = (a[4] * b[4]);
7474

7575
carryReduce(r, c0, c1, c2, c3, c4, c5, c6, c7, c8);
76-
return 0;
7776
}
7877

7978
private void carryReduce(long[] r, long c0, long c1, long c2, long c3,
@@ -100,7 +99,7 @@ private void carryReduce(long[] r, long c0, long c1, long c2, long c3,
10099
}
101100

102101
@Override
103-
protected int square(long[] a, long[] r) {
102+
protected void square(long[] a, long[] r) {
104103
// Use grade-school multiplication with a simple squaring optimization.
105104
// Multiply into primitives to avoid the temporary array allocation.
106105
// This is equivalent to the following code:
@@ -123,7 +122,6 @@ protected int square(long[] a, long[] r) {
123122
long c8 = (a[4] * a[4]);
124123

125124
carryReduce(r, c0, c1, c2, c3, c4, c5, c6, c7, c8);
126-
return 0;
127125
}
128126

129127
@Override

src/java.base/share/classes/sun/security/util/math/intpoly/IntegerPolynomialModBinP.java

+2-4
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,11 @@ private void multOnly(long[] a, long[] b, long[] c) {
131131
}
132132

133133
@Override
134-
protected int mult(long[] a, long[] b, long[] r) {
134+
protected void mult(long[] a, long[] b, long[] r) {
135135

136136
long[] c = new long[2 * numLimbs];
137137
multOnly(a, b, c);
138138
carryReduce(c, r);
139-
return 0;
140139
}
141140

142141
private void modReduceInBits(long[] limbs, int index, int bits, long x) {
@@ -189,7 +188,7 @@ protected void reduce(long[] a) {
189188
}
190189

191190
@Override
192-
protected int square(long[] a, long[] r) {
191+
protected void square(long[] a, long[] r) {
193192

194193
long[] c = new long[2 * numLimbs];
195194
for (int i = 0; i < numLimbs; i++) {
@@ -200,7 +199,6 @@ protected int square(long[] a, long[] r) {
200199
}
201200

202201
carryReduce(c, r);
203-
return 0;
204202
}
205203

206204
/**

src/java.base/share/classes/sun/security/util/math/intpoly/MontgomeryIntegerPolynomialP256.java

+49-55
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import sun.security.util.math.IntegerFieldModuloP;
3232
import java.math.BigInteger;
3333
import jdk.internal.vm.annotation.IntrinsicCandidate;
34+
import jdk.internal.vm.annotation.ForceInline;
3435

3536
// Reference:
3637
// - [1] Shay Gueron and Vlad Krasnov "Fast Prime Field Elliptic Curve
@@ -103,8 +104,8 @@ public ImmutableElement getElement(BigInteger v) {
103104
setLimbsValuePositive(v, vLimbs);
104105

105106
// Convert to Montgomery domain
106-
int numAdds = mult(vLimbs, h, montLimbs);
107-
return new ImmutableElement(montLimbs, numAdds);
107+
mult(vLimbs, h, montLimbs);
108+
return new ImmutableElement(montLimbs, 0);
108109
}
109110

110111
@Override
@@ -114,24 +115,6 @@ public SmallValue getSmallValue(int value) {
114115
return super.getSmallValue(value);
115116
}
116117

117-
/*
118-
* This function is used by IntegerPolynomial.setProduct(SmallValue v) to
119-
* multiply by a small constant (i.e. (int) 1,2,3,4). Instead of doing a
120-
* montgomery conversion followed by a montgomery multiplication, just use
121-
* the spare top (64-BITS_PER_LIMB) bits to multiply by a constant. (See [1]
122-
* Section 4 )
123-
*
124-
* Will return an unreduced value
125-
*/
126-
@Override
127-
protected int multByInt(long[] a, long b) {
128-
assert (b < (1 << BITS_PER_LIMB));
129-
for (int i = 0; i < a.length; i++) {
130-
a[i] *= b;
131-
}
132-
return (int) (b - 1);
133-
}
134-
135118
@Override
136119
public ImmutableIntegerModuloP fromMontgomery(ImmutableIntegerModuloP n) {
137120
assert n.getField() == MontgomeryIntegerPolynomialP256.ONE;
@@ -163,19 +146,27 @@ private void halfLimbs(long[] a, long[] r) {
163146
}
164147

165148
@Override
166-
protected int square(long[] a, long[] r) {
167-
return mult(a, a, r);
149+
protected void square(long[] a, long[] r) {
150+
mult(a, a, r);
168151
}
169152

153+
170154
/**
171155
* Unrolled Word-by-Word Montgomery Multiplication r = a * b * 2^-260 (mod P)
172156
*
173157
* See [1] Figure 5. "Algorithm 2: Word-by-Word Montgomery Multiplication
174158
* for a Montgomery Friendly modulus p". Note: Step 6. Skipped; Instead use
175159
* numAdds to reuse existing overflow logic.
176160
*/
161+
@Override
162+
protected void mult(long[] a, long[] b, long[] r) {
163+
multImpl(a, b, r);
164+
reducePositive(r);
165+
}
166+
167+
@ForceInline
177168
@IntrinsicCandidate
178-
protected int mult(long[] a, long[] b, long[] r) {
169+
private void multImpl(long[] a, long[] b, long[] r) {
179170
long aa0 = a[0];
180171
long aa1 = a[1];
181172
long aa2 = a[2];
@@ -408,36 +399,16 @@ protected int mult(long[] a, long[] b, long[] r) {
408399
d4 += n4 & LIMB_MASK;
409400

410401
c5 += d1 + dd0 + (d0 >>> BITS_PER_LIMB);
411-
c6 += d2 + dd1 + (c5 >>> BITS_PER_LIMB);
412-
c7 += d3 + dd2 + (c6 >>> BITS_PER_LIMB);
413-
c8 += d4 + dd3 + (c7 >>> BITS_PER_LIMB);
414-
c9 = dd4 + (c8 >>> BITS_PER_LIMB);
415-
416-
c5 &= LIMB_MASK;
417-
c6 &= LIMB_MASK;
418-
c7 &= LIMB_MASK;
419-
c8 &= LIMB_MASK;
420-
421-
// At this point, the result could overflow by one modulus.
422-
c0 = c5 - modulus[0];
423-
c1 = c6 - modulus[1] + (c0 >> BITS_PER_LIMB);
424-
c0 &= LIMB_MASK;
425-
c2 = c7 - modulus[2] + (c1 >> BITS_PER_LIMB);
426-
c1 &= LIMB_MASK;
427-
c3 = c8 - modulus[3] + (c2 >> BITS_PER_LIMB);
428-
c2 &= LIMB_MASK;
429-
c4 = c9 - modulus[4] + (c3 >> BITS_PER_LIMB);
430-
c3 &= LIMB_MASK;
431-
432-
long mask = c4 >> BITS_PER_LIMB; // Signed shift!
433-
434-
r[0] = ((c5 & mask) | (c0 & ~mask));
435-
r[1] = ((c6 & mask) | (c1 & ~mask));
436-
r[2] = ((c7 & mask) | (c2 & ~mask));
437-
r[3] = ((c8 & mask) | (c3 & ~mask));
438-
r[4] = ((c9 & mask) | (c4 & ~mask));
439-
440-
return 0;
402+
c6 += d2 + dd1;
403+
c7 += d3 + dd2;
404+
c8 += d4 + dd3;
405+
c9 = dd4;
406+
407+
r[0] = c5;
408+
r[1] = c6;
409+
r[2] = c7;
410+
r[3] = c8;
411+
r[4] = c9;
441412
}
442413

443414
@Override
@@ -516,8 +487,8 @@ public ImmutableElement getElement(byte[] v, int offset, int length,
516487
super.encode(v, offset, length, highByte, vLimbs);
517488

518489
// Convert to Montgomery domain
519-
int numAdds = mult(vLimbs, h, montLimbs);
520-
return new ImmutableElement(montLimbs, numAdds);
490+
mult(vLimbs, h, montLimbs);
491+
return new ImmutableElement(montLimbs, 0);
521492
}
522493

523494
/*
@@ -556,4 +527,27 @@ protected void reduceIn(long[] limbs, long v, int i) {
556527
limbs[i - 5] += (v << 4) & LIMB_MASK;
557528
limbs[i - 4] += v >> 48;
558529
}
530+
531+
// Used when limbs a could overflow by one modulus.
532+
@ForceInline
533+
protected void reducePositive(long[] a) {
534+
long aa0 = a[0];
535+
long aa1 = a[1] + (aa0>>BITS_PER_LIMB);
536+
long aa2 = a[2] + (aa1>>BITS_PER_LIMB);
537+
long aa3 = a[3] + (aa2>>BITS_PER_LIMB);
538+
long aa4 = a[4] + (aa3>>BITS_PER_LIMB);
539+
540+
long c0 = a[0] - modulus[0];
541+
long c1 = a[1] - modulus[1] + (c0 >> BITS_PER_LIMB);
542+
long c2 = a[2] - modulus[2] + (c1 >> BITS_PER_LIMB);
543+
long c3 = a[3] - modulus[3] + (c2 >> BITS_PER_LIMB);
544+
long c4 = a[4] - modulus[4] + (c3 >> BITS_PER_LIMB);
545+
long mask = c4 >> BITS_PER_LIMB; // Signed shift!
546+
547+
a[0] = ((aa0 & mask) | (c0 & ~mask)) & LIMB_MASK;
548+
a[1] = ((aa1 & mask) | (c1 & ~mask)) & LIMB_MASK;
549+
a[2] = ((aa2 & mask) | (c2 & ~mask)) & LIMB_MASK;
550+
a[3] = ((aa3 & mask) | (c3 & ~mask)) & LIMB_MASK;
551+
a[4] = ((aa4 & mask) | (c4 & ~mask));
552+
}
559553
}

0 commit comments

Comments
 (0)