diff --git a/impl/src/main/java/io/jsonwebtoken/impl/lang/Conditions.java b/impl/src/main/java/io/jsonwebtoken/impl/lang/Conditions.java index 759406311..42ead1544 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/lang/Conditions.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/lang/Conditions.java @@ -15,8 +15,6 @@ */ package io.jsonwebtoken.impl.lang; -import io.jsonwebtoken.lang.Assert; - /** * @since JJWT_RELEASE_VERSION */ @@ -25,37 +23,38 @@ public final class Conditions { private Conditions() { } - public static final Condition TRUE = of(true); - - public static Condition of(boolean val) { - return new BooleanCondition(val); - } - - public static Condition not(Condition c) { - return new NotCondition(c); - } - - public static Condition exists(CheckedSupplier s) { - return new ExistsCondition(s); - } - - public static Condition notExists(CheckedSupplier s) { - return not(exists(s)); - } - - private static final class NotCondition implements Condition { - - private final Condition c; - - private NotCondition(Condition c) { - this.c = Assert.notNull(c, "Condition cannot be null."); - } - - @Override - public boolean test() { - return !c.test(); - } - } + public static final Condition TRUE = new BooleanCondition(true); + public static final Condition FALSE = new BooleanCondition(false); + +// public static Condition of(boolean val) { +// return new BooleanCondition(val); +// } + +// public static Condition not(Condition c) { +// return new NotCondition(c); +// } + +// public static Condition exists(CheckedSupplier s) { +// return new ExistsCondition(s); +// } +// +// public static Condition notExists(CheckedSupplier s) { +// return not(exists(s)); +// } + +// private static final class NotCondition implements Condition { +// +// private final Condition c; +// +// private NotCondition(Condition c) { +// this.c = Assert.notNull(c, "Condition cannot be null."); +// } +// +// @Override +// public boolean test() { +// return !c.test(); +// } +// } private static final class BooleanCondition implements Condition { private final boolean value; @@ -70,21 +69,21 @@ public boolean test() { } } - private static final class ExistsCondition implements Condition { - private final CheckedSupplier supplier; - - ExistsCondition(CheckedSupplier supplier) { - this.supplier = Assert.notNull(supplier, "CheckedSupplier cannot be null."); - } - - @Override - public boolean test() { - Object value = null; - try { - value = supplier.get(); - } catch (Throwable ignored) { - } - return value != null; - } - } +// private static final class ExistsCondition implements Condition { +// private final CheckedSupplier supplier; +// +// ExistsCondition(CheckedSupplier supplier) { +// this.supplier = Assert.notNull(supplier, "CheckedSupplier cannot be null."); +// } +// +// @Override +// public boolean test() { +// Object value = null; +// try { +// value = supplier.get(); +// } catch (Throwable ignored) { +// } +// return value != null; +// } +// } } diff --git a/impl/src/main/java/io/jsonwebtoken/impl/lang/ConstantFunction.java b/impl/src/main/java/io/jsonwebtoken/impl/lang/ConstantFunction.java index 2c56e6dd6..973c60a82 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/lang/ConstantFunction.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/lang/ConstantFunction.java @@ -24,19 +24,12 @@ */ public final class ConstantFunction implements Function { - private static final Function NULL = new ConstantFunction<>(null); - private final R value; public ConstantFunction(R value) { this.value = value; } - @SuppressWarnings("unchecked") - public static Function forNull() { - return (Function) NULL; - } - @Override public R apply(T t) { return this.value; diff --git a/impl/src/main/java/io/jsonwebtoken/impl/lang/Functions.java b/impl/src/main/java/io/jsonwebtoken/impl/lang/Functions.java index d4b88cf10..85523a4b8 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/lang/Functions.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/lang/Functions.java @@ -22,10 +22,6 @@ public final class Functions { private Functions() { } - public static Function forNull() { - return ConstantFunction.forNull(); - } - public static Function identity() { return new Function() { @Override diff --git a/impl/src/main/java/io/jsonwebtoken/impl/lang/OptionalCtorInvoker.java b/impl/src/main/java/io/jsonwebtoken/impl/lang/OptionalCtorInvoker.java deleted file mode 100644 index 2225980f8..000000000 --- a/impl/src/main/java/io/jsonwebtoken/impl/lang/OptionalCtorInvoker.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright © 2023 jsonwebtoken.io - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.jsonwebtoken.impl.lang; - -import io.jsonwebtoken.lang.Arrays; -import io.jsonwebtoken.lang.Assert; -import io.jsonwebtoken.lang.Classes; - -import java.lang.reflect.Constructor; -import java.util.ArrayList; -import java.util.List; - -public class OptionalCtorInvoker extends ReflectionFunction { - - private final Constructor CTOR; - - public OptionalCtorInvoker(String fqcn, Object... ctorArgTypesOrFqcns) { - Assert.hasText(fqcn, "fqcn cannot be null."); - Constructor ctor = null; - try { - Class clazz = Classes.forName(fqcn); - Class[] ctorArgTypes = null; - if (Arrays.length(ctorArgTypesOrFqcns) > 0) { - ctorArgTypes = new Class[ctorArgTypesOrFqcns.length]; - List> l = new ArrayList<>(ctorArgTypesOrFqcns.length); - for (Object ctorArgTypeOrFqcn : ctorArgTypesOrFqcns) { - Class ctorArgClass; - if (ctorArgTypeOrFqcn instanceof Class) { - ctorArgClass = (Class) ctorArgTypeOrFqcn; - } else { - String typeFqcn = Assert.isInstanceOf(String.class, ctorArgTypeOrFqcn, "ctorArgTypesOrFcqns array must contain Class or String instances."); - ctorArgClass = Classes.forName(typeFqcn); - } - l.add(ctorArgClass); - } - ctorArgTypes = l.toArray(ctorArgTypes); - } - ctor = Classes.getConstructor(clazz, ctorArgTypes); - } catch (Exception ignored) { - } - this.CTOR = ctor; - } - - @Override - protected boolean supports(Object input) { - return CTOR != null; - } - - @Override - protected T invoke(Object input) { - Object[] args = null; - if (input instanceof Object[]) { - args = (Object[]) input; - } else if (input != null) { - args = new Object[]{input}; - } - return Classes.instantiate(CTOR, args); - } -} diff --git a/impl/src/main/java/io/jsonwebtoken/impl/security/EdwardsCurve.java b/impl/src/main/java/io/jsonwebtoken/impl/security/EdwardsCurve.java index ab79f37b1..dfdeb8673 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/security/EdwardsCurve.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/security/EdwardsCurve.java @@ -16,7 +16,6 @@ package io.jsonwebtoken.impl.security; import io.jsonwebtoken.impl.lang.Bytes; -import io.jsonwebtoken.impl.lang.CheckedFunction; import io.jsonwebtoken.impl.lang.Function; import io.jsonwebtoken.lang.Assert; import io.jsonwebtoken.lang.Collections; @@ -28,11 +27,9 @@ import io.jsonwebtoken.security.UnsupportedKeyException; import java.security.Key; -import java.security.KeyFactory; import java.security.PrivateKey; import java.security.Provider; import java.security.PublicKey; -import java.security.spec.InvalidKeySpecException; import java.security.spec.KeySpec; import java.security.spec.PKCS8EncodedKeySpec; import java.security.spec.X509EncodedKeySpec; @@ -45,21 +42,11 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier { private static final String OID_PREFIX = "1.3.101."; - // DER-encoded edwards keys have this exact sequence identifying the type of key that follows. The trailing + // ASN.1-encoded edwards keys have this exact sequence identifying the type of key that follows. The trailing // byte is the exact edwards curve subsection OID terminal node id. - private static final byte[] DER_OID_PREFIX = new byte[]{0x06, 0x03, 0x2B, 0x65}; - -// private static final String NAMED_PARAM_SPEC_FQCN = "java.security.spec.NamedParameterSpec"; // JDK >= 11 -// private static final String XEC_PRIV_KEY_SPEC_FQCN = "java.security.spec.XECPrivateKeySpec"; // JDK >= 11 -// private static final String EDEC_PRIV_KEY_SPEC_FQCN = "java.security.spec.EdECPrivateKeySpec"; // JDK >= 15 -// - private static final Function CURVE_NAME_FINDER = new NamedParameterSpecValueFinder(); -// private static final OptionalCtorInvoker NAMED_PARAM_SPEC_CTOR = -// new OptionalCtorInvoker<>(NAMED_PARAM_SPEC_FQCN, String.class); -// static final OptionalCtorInvoker XEC_PRIV_KEY_SPEC_CTOR = -// new OptionalCtorInvoker<>(XEC_PRIV_KEY_SPEC_FQCN, AlgorithmParameterSpec.class, byte[].class); -// static final OptionalCtorInvoker EDEC_PRIV_KEY_SPEC_CTOR = -// new OptionalCtorInvoker<>(EDEC_PRIV_KEY_SPEC_FQCN, NAMED_PARAM_SPEC_FQCN, byte[].class); + private static final byte[] ASN1_OID_PREFIX = new byte[]{0x06, 0x03, 0x2B, 0x65}; + + private static final Function CURVE_NAME_FINDER = new NamedParameterSpecValueFinder(); public static final EdwardsCurve X25519 = new EdwardsCurve("X25519", 110); // Requires JDK >= 11 or BC public static final EdwardsCurve X448 = new EdwardsCurve("X448", 111); // Requires JDK >= 11 or BC @@ -76,14 +63,14 @@ public class EdwardsCurve extends DefaultCurve implements KeyLengthSupplier { REGISTRY = new LinkedHashMap<>(8); BY_OID_TERMINAL_NODE = new LinkedHashMap<>(4); for (EdwardsCurve curve : VALUES) { - int subcategoryId = curve.DER_OID[curve.DER_OID.length - 1]; + int subcategoryId = curve.ASN1_OID[curve.ASN1_OID.length - 1]; BY_OID_TERMINAL_NODE.put(subcategoryId, curve); REGISTRY.put(curve.getId(), curve); REGISTRY.put(curve.OID, curve); // add OID as an alias for alg/id lookups } } - private static byte[] privateKeyPkcs8Prefix(int byteLength, byte[] DER_OID, boolean ber) { + private static byte[] privateKeyPkcs8Prefix(int byteLength, byte[] ASN1_OID, boolean ber) { byte[] keyPrefix = ber ? new byte[]{0x04, (byte) (byteLength + 2), 0x04, (byte) byteLength} : // correct @@ -92,10 +79,10 @@ private static byte[] privateKeyPkcs8Prefix(int byteLength, byte[] DER_OID, bool return Bytes.concat( new byte[]{ 0x30, - (byte) (5 + DER_OID.length + keyPrefix.length + byteLength), + (byte) (5 + ASN1_OID.length + keyPrefix.length + byteLength), 0x02, 0x01, 0x00, // encoding version 1 (integer, 1 byte, value 0) - 0x30, 0x05}, // DER SEQUENCE of 5 bytes to follow (i.e. the OID) - DER_OID, + 0x30, 0x05}, // ASN.1 SEQUENCE of 5 bytes to follow (i.e. the OID) + ASN1_OID, keyPrefix ); } @@ -103,14 +90,14 @@ private static byte[] privateKeyPkcs8Prefix(int byteLength, byte[] DER_OID, bool private final String OID; /** - * The byte sequence within an DER-encoded key that indicates an Edwards curve encoded key follows. DER (hex) + * The byte sequence within an ASN.1-encoded key that indicates an Edwards curve encoded key follows. ASN.1 (hex) * notation: *
      * 06 03       ;   OBJECT IDENTIFIER (3 bytes long)
      * |  2B 65 $I ;     "1.3.101.$I" for Edwards alg OID, where $I = 6E, 6F, 70, or 71 (decimal 110, 111, 112, or 113)
      * 
*/ - final byte[] DER_OID; + final byte[] ASN1_OID; private final int keyBitLength; @@ -119,39 +106,39 @@ private static byte[] privateKeyPkcs8Prefix(int byteLength, byte[] DER_OID, bool private final int encodedKeyByteLength; /** - * X.509 (BER) encoding of a public key associated with this curve as a prefix (that is, without the + * X.509 (ASN.1) encoding of a public key associated with this curve as a prefix (that is, without the * actual encoded key material at the end). Appending the public key material directly to the end of this value - * results in a complete X.509 (DER) encoded public key. BER (hex) notation: + * results in a complete X.509 (ASN.1) encoded public key. ASN.1 (hex) notation: *
-     * 30 $M               ; DER SEQUENCE ($M bytes long), where $M = encodedKeyByteLength + 10
-     *    30 05            ;   DER SEQUENCE (5 bytes long)
+     * 30 $M               ; ASN.1 SEQUENCE ($M bytes long), where $M = encodedKeyByteLength + 10
+     *    30 05            ;   ASN.1 SEQUENCE (5 bytes long)
      *       06 03         ;     OBJECT IDENTIFIER (3 bytes long)
      *          2B 65 $I   ;       "1.3.101.$I" for Edwards alg OID, where $I = 6E, 6F, 70, or 71 (110, 111, 112, or 113 decimal)
-     *    03 $S            ;   DER BIT STRING ($S bytes long), where $S = encodedKeyByteLength + 1
-     *       00            ;     DER bit string marker indicating zero unused bits at the end of the bit string
+     *    03 $S            ;   ASN.1 BIT STRING ($S bytes long), where $S = encodedKeyByteLength + 1
+     *       00            ;     ASN.1 bit string marker indicating zero unused bits at the end of the bit string
      *       XX XX XX ...  ;     encoded key material (not included in this PREFIX byte array variable)
      * 
*/ - private final byte[] PUBLIC_KEY_BER_PREFIX; + private final byte[] PUBLIC_KEY_ASN1_PREFIX; /** - * PKCS8 (BER) Version 1 encoding of a private key associated with this curve, as a prefix (that is, + * PKCS8 (ASN.1) Version 1 encoding of a private key associated with this curve, as a prefix (that is, * without actual encoded key material at the end). Appending the private key material directly to the - * end of this value results in a complete PKCS8 (BER) V1 encoded private key. BER (hex) notation: + * end of this value results in a complete PKCS8 (ASN.1) V1 encoded private key. ASN.1 (hex) notation: *
-     * 30 $M                  ; DER SEQUENCE ($M bytes long), where $M = encodedKeyByteLength + 14
-     *    02 01               ;   DER INTEGER (1 byte long)
+     * 30 $M                  ; ASN.1 SEQUENCE ($M bytes long), where $M = encodedKeyByteLength + 14
+     *    02 01               ;   ASN.1 INTEGER (1 byte long)
      *       00               ;     zero (private key encoding version V1)
-     *    30 05               ;   DER SEQUENCE (5 bytes long)
+     *    30 05               ;   ASN.1 SEQUENCE (5 bytes long)
      *       06 03            ;     OBJECT IDENTIFIER (3 bytes long). This is the edwards algorithm ID.
      *          2B 65 $I      ;       "1.3.101.$I" for Edwards alg OID, where $I = 6E, 6F, 70, or 71 (110, 111, 112, or 113 decimal)
-     *    04 $B               ;   DER SEQUENCE ($B bytes long, where $B = encodedKeyByteLength + 2
-     *       04 $K            ;     DER SEQUENCE ($K bytes long), where $K = encodedKeyByteLength
+     *    04 $B               ;   ASN.1 SEQUENCE ($B bytes long, where $B = encodedKeyByteLength + 2
+     *       04 $K            ;     ASN.1 SEQUENCE ($K bytes long), where $K = encodedKeyByteLength
      *          XX XX XX ...  ;       encoded key material (not included in this PREFIX byte array variable)
      * 
*/ - private final byte[] PRIVATE_KEY_BER_PREFIX; - private final byte[] PRIVATE_KEY_DER_PREFIX; // https://bugs.openjdk.org/browse/JDK-8213363 + private final byte[] PRIVATE_KEY_ASN1_PREFIX; + private final byte[] PRIVATE_KEY_JDK11_PREFIX; // https://bugs.openjdk.org/browse/JDK-8213363 /** * {@code true} IFF the curve is used for digital signatures, {@code false} if used for key agreement @@ -185,22 +172,22 @@ private static byte[] privateKeyPkcs8Prefix(int byteLength, byte[] DER_OID, bool this.OID = OID_PREFIX + oidTerminalNode; this.signatureCurve = (oidTerminalNode == 112 || oidTerminalNode == 113); byte[] suffix = new byte[]{(byte) oidTerminalNode}; - this.DER_OID = Bytes.concat(DER_OID_PREFIX, suffix); + this.ASN1_OID = Bytes.concat(ASN1_OID_PREFIX, suffix); this.encodedKeyByteLength = (this.keyBitLength + 7) / 8; - this.PUBLIC_KEY_BER_PREFIX = Bytes.concat( + this.PUBLIC_KEY_ASN1_PREFIX = Bytes.concat( new byte[]{ 0x30, (byte) (this.encodedKeyByteLength + 10), - 0x30, 0x05}, // DER SEQUENCE of 5 bytes to follow (i.e. the OID) - this.DER_OID, + 0x30, 0x05}, // ASN.1 SEQUENCE of 5 bytes to follow (i.e. the OID) + this.ASN1_OID, new byte[]{ 0x03, (byte) (this.encodedKeyByteLength + 1), 0x00} ); - this.PRIVATE_KEY_BER_PREFIX = privateKeyPkcs8Prefix(this.encodedKeyByteLength, this.DER_OID, true); - this.PRIVATE_KEY_DER_PREFIX = privateKeyPkcs8Prefix(this.encodedKeyByteLength, this.DER_OID, false); + this.PRIVATE_KEY_ASN1_PREFIX = privateKeyPkcs8Prefix(this.encodedKeyByteLength, this.ASN1_OID, true); + this.PRIVATE_KEY_JDK11_PREFIX = privateKeyPkcs8Prefix(this.encodedKeyByteLength, this.ASN1_OID, false); // The Sun CE KeyPairGenerator implementation that we'll use to derive PublicKeys with is problematic here: // @@ -226,14 +213,6 @@ private static byte[] privateKeyPkcs8Prefix(int byteLength, byte[] DER_OID, bool this.KEY_PAIR_GENERATOR_BIT_LENGTH = this.keyBitLength >= 448 ? 448 : 255; } -// // visible for testing -// protected static Function paramKeySpecFactory(AlgorithmParameterSpec spec, boolean signatureCurve) { -// if (spec == null) { -// return Functions.forNull(); -// } -// return new ParameterizedKeySpecFactory(spec, signatureCurve ? EDEC_PRIV_KEY_SPEC_CTOR : XEC_PRIV_KEY_SPEC_CTOR); -// } - @Override public int getKeyBitLength() { return this.keyBitLength; @@ -246,39 +225,39 @@ public byte[] getKeyMaterial(Key key) { if (t instanceof KeyException) { //propagate throw (KeyException) t; } - String msg = "Invalid " + getId() + " DER encoding: " + t.getMessage(); + String msg = "Invalid " + getId() + " ASN.1 encoding: " + t.getMessage(); throw new InvalidKeyException(msg, t); } } /** - * Parses the DER-encoding of the specified key + * Parses the ASN.1-encoding of the specified key * * @param key the Edwards curve key * @return the key value, encoded according to RFC 8032 - * @throws RuntimeException if the key's encoded bytes do not reflect a validly DER-encoded edwards key + * @throws RuntimeException if the key's encoded bytes do not reflect a validly ASN.1-encoded edwards key */ protected byte[] doGetKeyMaterial(Key key) { byte[] encoded = KeysBridge.getEncoded(key); - int i = Bytes.indexOf(encoded, DER_OID); + int i = Bytes.indexOf(encoded, ASN1_OID); Assert.gt(i, -1, "Missing or incorrect algorithm OID."); - i = i + DER_OID.length; + i = i + ASN1_OID.length; int keyLen = 0; if (encoded[i] == 0x05) { // NULL terminator, next should be zero byte indicator int unusedBytes = encoded[++i]; Assert.eq(unusedBytes, 0, "OID NULL terminator should indicate zero unused bytes."); i++; } - if (encoded[i] == 0x03) { // DER bit stream, Public Key + if (encoded[i] == 0x03) { // ASN.1 bit stream, Public Key i++; keyLen = encoded[i++]; int unusedBytes = encoded[i++]; Assert.eq(unusedBytes, 0, "BIT STREAM should not indicate unused bytes."); keyLen--; - } else if (encoded[i] == 0x04) { // DER octet sequence, Private Key. Key length follows as next byte. + } else if (encoded[i] == 0x04) { // ASN.1 octet sequence, Private Key. Key length follows as next byte. i++; keyLen = encoded[i++]; - if (encoded[i] == 0x04) { // DER octet sequence, key length follows as next byte. + if (encoded[i] == 0x04) { // ASN.1 octet sequence, key length follows as next byte. i++; // skip sequence marker keyLen = encoded[i++]; // next byte is length } @@ -302,41 +281,23 @@ private void assertLength(byte[] raw, boolean isPublic) { public PublicKey toPublicKey(byte[] x, Provider provider) { assertLength(x, true); - final byte[] encoded = Bytes.concat(this.PUBLIC_KEY_BER_PREFIX, x); + final byte[] encoded = Bytes.concat(this.PUBLIC_KEY_ASN1_PREFIX, x); final X509EncodedKeySpec spec = new X509EncodedKeySpec(encoded); JcaTemplate template = new JcaTemplate(getJcaName(), provider); - return template.withKeyFactory(new CheckedFunction() { - @Override - public PublicKey apply(KeyFactory keyFactory) throws Exception { - return keyFactory.generatePublic(spec); - } - }); + return template.generatePublic(spec); } - KeySpec privateKeySpec(byte[] d, boolean ber) { - byte[] prefix = ber ? this.PRIVATE_KEY_BER_PREFIX : this.PRIVATE_KEY_DER_PREFIX; + KeySpec privateKeySpec(byte[] d, boolean standard) { + byte[] prefix = standard ? this.PRIVATE_KEY_ASN1_PREFIX : this.PRIVATE_KEY_JDK11_PREFIX; byte[] encoded = Bytes.concat(prefix, d); return new PKCS8EncodedKeySpec(encoded); } public PrivateKey toPrivateKey(final byte[] d, Provider provider) { assertLength(d, false); + KeySpec spec = privateKeySpec(d, true); JcaTemplate template = new JcaTemplate(getJcaName(), provider); - return template.withKeyFactory(new CheckedFunction() { - @Override - public PrivateKey apply(KeyFactory keyFactory) throws Exception { - KeySpec spec = privateKeySpec(d, true); // BER-encoding is RFC-correct - try { - return keyFactory.generatePrivate(spec); - } catch (InvalidKeySpecException e) { - if (!isSignatureCurve()) { // https://bugs.openjdk.org/browse/JDK-8213363 (X25519 and X448) - spec = privateKeySpec(d, false); // adjust for Sun Provider bug - return keyFactory.generatePrivate(spec); - } - throw e; // propagate - } - } - }); + return template.generatePrivate(spec); } /** @@ -389,7 +350,7 @@ public static EdwardsCurve findByKey(Key key) { curve = findById(alg); } if (curve == null) { // Fall back to key encoding if possible: - // Try to find the Key DER algorithm OID: + // Try to find the Key ASN.1 algorithm OID: byte[] encoded = KeysBridge.findEncoded(key); if (!Bytes.isEmpty(encoded)) { int oidTerminalNode = findOidTerminalNode(encoded); @@ -403,9 +364,9 @@ public static EdwardsCurve findByKey(Key key) { } private static int findOidTerminalNode(byte[] encoded) { - int index = Bytes.indexOf(encoded, DER_OID_PREFIX); + int index = Bytes.indexOf(encoded, ASN1_OID_PREFIX); if (index > -1) { - index = index + DER_OID_PREFIX.length; + index = index + ASN1_OID_PREFIX.length; if (index < encoded.length) { return encoded[index]; } @@ -430,39 +391,4 @@ static K assertEdwards(K key) { forKey(key); // will throw UnsupportedKeyException if the key is not an Edwards key return key; } - -// private static final class Pkcs8KeySpecFactory implements Function { -// private final byte[] PREFIX; -// -// private Pkcs8KeySpecFactory(byte[] pkcs8EncodedKeyPrefix) { -// this.PREFIX = Assert.notEmpty(pkcs8EncodedKeyPrefix, "pkcs8EncodedKeyPrefix cannot be null or empty."); -// } -// -// @Override -// public KeySpec apply(byte[] d) { -// Assert.notEmpty(d, "Key bytes cannot be null or empty."); -// byte[] encoded = Bytes.concat(PREFIX, d); -// return new PKCS8EncodedKeySpec(encoded); -// } -// } - -// // visible for testing -// protected static final class ParameterizedKeySpecFactory implements Function { -// -// private final AlgorithmParameterSpec params; -// -// private final Function keySpecFactory; -// -// ParameterizedKeySpecFactory(AlgorithmParameterSpec params, Function keySpecFactory) { -// this.params = Assert.notNull(params, "AlgorithmParameterSpec cannot be null."); -// this.keySpecFactory = Assert.notNull(keySpecFactory, "KeySpec factory function cannot be null."); -// } -// -// @Override -// public KeySpec apply(byte[] d) { -// Assert.notEmpty(d, "Key bytes cannot be null or empty."); -// Object[] args = new Object[]{params, d}; -// return this.keySpecFactory.apply(args); -// } -// } } diff --git a/impl/src/main/java/io/jsonwebtoken/impl/security/JcaTemplate.java b/impl/src/main/java/io/jsonwebtoken/impl/security/JcaTemplate.java index b19aeb81d..9d6f58397 100644 --- a/impl/src/main/java/io/jsonwebtoken/impl/security/JcaTemplate.java +++ b/impl/src/main/java/io/jsonwebtoken/impl/security/JcaTemplate.java @@ -41,6 +41,7 @@ import java.io.InputStream; import java.security.AlgorithmParameters; import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; import java.security.KeyFactory; import java.security.KeyPair; import java.security.KeyPairGenerator; @@ -87,7 +88,8 @@ public Class apply(InstanceFactory factory) { } }); - private static Provider findBouncyCastle() { + // visible for testing + protected Provider findBouncyCastle() { return Providers.findBouncyCastle(Conditions.TRUE); } @@ -135,7 +137,6 @@ public R get() throws Exception { }); } - @SuppressWarnings("SameParameterValue") protected R fallback(final Class clazz, final CheckedFunction callback) throws SecurityException { return execute(clazz, new CheckedSupplier() { @Override @@ -248,6 +249,33 @@ public PublicKey apply(KeyFactory keyFactory) throws Exception { }); } + protected boolean isJdk11() { + return System.getProperty("java.version").startsWith("11"); + } + + private boolean isJdk8213363Bug(InvalidKeySpecException e) { + return isJdk11() && + ("XDH".equals(this.jcaName) || "X25519".equals(this.jcaName) || "X448".equals(this.jcaName)) && + e.getCause() instanceof InvalidKeyException && + !Objects.isEmpty(e.getStackTrace()) && + "sun.security.ec.XDHKeyFactory".equals(e.getStackTrace()[0].getClassName()) && + "engineGeneratePrivate".equals(e.getStackTrace()[0].getMethodName()); + } + + // visible for testing + private int getJdk8213363BugExpectedSize(InvalidKeyException e) { + String msg = e.getMessage(); + String prefix = "key length must be "; + if (Strings.hasText(msg) && msg.startsWith(prefix)) { + String expectedSizeString = msg.substring(prefix.length()); + try { + return Integer.parseInt(expectedSizeString); + } catch (NumberFormatException ignored) { // return -1 below + } + } + return -1; + } + private KeySpec respecIfNecessary(InvalidKeySpecException e, KeySpec spec) { if (!(spec instanceof PKCS8EncodedKeySpec)) { return null; @@ -260,38 +288,24 @@ private KeySpec respecIfNecessary(InvalidKeySpecException e, KeySpec spec) { // SunCE provider incorrectly expects an ASN.1 OCTET STRING (without the DER tag/length prefix) // when it should actually be a BER-encoded OCTET STRING (with the tag/length prefix). // So we get the raw key bytes and use our key factory method: - int xdhExpectedSize = getJdk8213363BugExpectedSize(e); - if ((xdhExpectedSize == 32 || xdhExpectedSize == 56) && Bytes.length(encoded) > xdhExpectedSize) { - byte[] adjusted = new byte[xdhExpectedSize]; - System.arraycopy(encoded, encoded.length - xdhExpectedSize, adjusted, 0, xdhExpectedSize); - if (xdhExpectedSize == 32) { // X25519 - return EdwardsCurve.X25519.privateKeySpec(adjusted, false); - } else { // X448 - return EdwardsCurve.X448.privateKeySpec(adjusted, false); + if (isJdk8213363Bug(e)) { + InvalidKeyException cause = // asserted in isJdk8213363Bug method + Assert.isInstanceOf(InvalidKeyException.class, e.getCause(), "Unexpected argument."); + int size = getJdk8213363BugExpectedSize(cause); + if ((size == 32 || size == 56) && Bytes.length(encoded) >= size) { + byte[] adjusted = new byte[size]; + System.arraycopy(encoded, encoded.length - size, adjusted, 0, size); + EdwardsCurve curve = size == 32 ? EdwardsCurve.X25519 : EdwardsCurve.X448; + return curve.privateKeySpec(adjusted, false); } } + return null; } - private int getJdk8213363BugExpectedSize(InvalidKeySpecException e) { - if (System.getProperty("java.version").startsWith("11") && - ("XDH".equals(this.jcaName) || "X25519".equals(this.jcaName) || "X448".equals(this.jcaName)) && - e.getCause() instanceof java.security.InvalidKeyException && - !Objects.isEmpty(e.getStackTrace()) && - "sun.security.ec.XDHKeyFactory".equals(e.getStackTrace()[0].getClassName()) && - "engineGeneratePrivate".equals(e.getStackTrace()[0].getMethodName())) { - java.security.InvalidKeyException ike = (java.security.InvalidKeyException) e.getCause(); - String msg = ike.getMessage(); - String prefix = "key length must be "; - if (Strings.hasText(msg) && msg.startsWith(prefix)) { - String expectedSizeString = msg.substring(prefix.length()); - try { - return Integer.parseInt(expectedSizeString); - } catch (NumberFormatException ignored) { // return -1 below - } - } - } - return -1; + // visible for testing + protected PrivateKey generatePrivate(KeyFactory factory, KeySpec spec) throws InvalidKeySpecException { + return factory.generatePrivate(spec); } public PrivateKey generatePrivate(final KeySpec spec) { @@ -299,11 +313,11 @@ public PrivateKey generatePrivate(final KeySpec spec) { @Override public PrivateKey apply(KeyFactory keyFactory) throws Exception { try { - return keyFactory.generatePrivate(spec); + return generatePrivate(keyFactory, spec); } catch (InvalidKeySpecException e) { KeySpec respec = respecIfNecessary(e, spec); - if (spec != null) { - return keyFactory.generatePrivate(respec); + if (respec != null) { + return generatePrivate(keyFactory, respec); } throw e; // could not respec, propagate } @@ -349,6 +363,11 @@ public String getId() { return clazz.getSimpleName(); } + // visible for testing + protected Provider findBouncyCastle() { + return Providers.findBouncyCastle(Conditions.TRUE); + } + @SuppressWarnings("GrazieInspection") @Override public final T get(String jcaName, final Provider specifiedProvider) throws Exception { @@ -358,7 +377,7 @@ public final T get(String jcaName, final Provider specifiedProvider) throws Exce if (provider == null && attempted != null && attempted) { // We tried with the default provider previously, and needed to fallback, so just // preemptively load the fallback to avoid the fallback/retry again: - provider = Providers.findBouncyCastle(Conditions.TRUE); + provider = findBouncyCastle(); } try { return doGet(jcaName, provider); @@ -366,7 +385,7 @@ public final T get(String jcaName, final Provider specifiedProvider) throws Exce if (specifiedProvider == null && attempted == null) { // default provider doesn't support the alg name, // and we haven't tried BC yet, so try that now: - Provider fallback = Providers.findBouncyCastle(Conditions.TRUE); + Provider fallback = findBouncyCastle(); if (fallback != null) { // BC found, try again: try { T value = doGet(jcaName, fallback); @@ -391,16 +410,16 @@ public final T get(String jcaName, final Provider specifiedProvider) throws Exce // visible for testing: protected Exception wrap(Exception e, String jcaName, Provider specifiedProvider, Provider fallbackProvider) { - String msg = "Unable to obtain " + getId() + " instance from "; + String msg = "Unable to obtain '" + jcaName + "' " + getId() + " instance from "; if (specifiedProvider != null) { - msg += "specified Provider '" + specifiedProvider + "' "; + msg += "specified '" + specifiedProvider + "' Provider"; } else { - msg += "default JCA Provider "; + msg += "default JCA Provider"; } if (fallbackProvider != null) { - msg += "or fallback Provider '" + fallbackProvider + "' "; + msg += " or fallback '" + fallbackProvider + "' Provider"; } - msg += "for JCA algorithm '" + jcaName + "': " + e.getMessage(); + msg += ": " + e.getMessage(); return wrap(msg, e); } diff --git a/impl/src/test/groovy/io/jsonwebtoken/impl/lang/OptionalCtorInvokerTest.groovy b/impl/src/test/groovy/io/jsonwebtoken/impl/lang/OptionalCtorInvokerTest.groovy deleted file mode 100644 index 57022791f..000000000 --- a/impl/src/test/groovy/io/jsonwebtoken/impl/lang/OptionalCtorInvokerTest.groovy +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright © 2023 jsonwebtoken.io - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.jsonwebtoken.impl.lang - -import org.junit.Test - -import javax.crypto.spec.PBEKeySpec - -import static org.junit.Assert.* - -class OptionalCtorInvokerTest { - - @Test - void testCtorWithClassArg() { - String foo = 'test' - def fn = new OptionalCtorInvoker<>("java.lang.String", String.class) // copy constructor - def result = fn.apply(foo) - assertEquals foo, result - } - - @Test - void testCtorWithFqcnArg() { - String foo = 'test' - def fn = new OptionalCtorInvoker<>("java.lang.String", "java.lang.String") // copy constructor - def result = fn.apply(foo) - assertEquals foo, result - } - - @Test - void testCtorWithMultipleMixedArgTypes() { - char[] chars = "foo".toCharArray() - byte[] salt = [0x00, 0x01, 0x02, 0x03] as byte[] - int iterations = 256 - def fn = new OptionalCtorInvoker<>("javax.crypto.spec.PBEKeySpec", char[].class, byte[].class, int.class) //password, salt, iteration count - def args = [chars, salt, iterations] as Object[] - def result = fn.apply(args) as PBEKeySpec - assertArrayEquals chars, result.getPassword() - assertArrayEquals salt, result.getSalt() - assertEquals iterations, result.getIterationCount() - } - - @Test - void testZeroArgConstructor() { - OptionalCtorInvoker fn = new OptionalCtorInvoker("java.util.LinkedHashMap") - Object args = null - def result = fn.apply(args) - assertTrue result instanceof LinkedHashMap - } - - @Test - void testMissingConstructor() { - def fn = new OptionalCtorInvoker('com.foo.Bar') - assertNull fn.apply(null) - } -} diff --git a/impl/src/test/groovy/io/jsonwebtoken/impl/security/EdwardsCurveTest.groovy b/impl/src/test/groovy/io/jsonwebtoken/impl/security/EdwardsCurveTest.groovy index ba6f1b2c0..8ecd3fdc7 100644 --- a/impl/src/test/groovy/io/jsonwebtoken/impl/security/EdwardsCurveTest.groovy +++ b/impl/src/test/groovy/io/jsonwebtoken/impl/security/EdwardsCurveTest.groovy @@ -20,6 +20,8 @@ import io.jsonwebtoken.security.InvalidKeyException import io.jsonwebtoken.security.UnsupportedKeyException import org.junit.Test +import java.security.spec.PKCS8EncodedKeySpec + import static org.junit.Assert.* class EdwardsCurveTest { @@ -101,7 +103,7 @@ class EdwardsCurveTest { @Test void testFindByKeyUsingMalformedEncoding() { curves.each { - byte[] encoded = EdwardsCurve.DER_OID_PREFIX // just the prefix isn't enough + byte[] encoded = EdwardsCurve.ASN1_OID_PREFIX // just the prefix isn't enough def key = new TestKey(algorithm: 'foo', encoded: encoded) assertNull EdwardsCurve.findByKey(key) } @@ -145,6 +147,17 @@ class EdwardsCurveTest { } } + @Test + void testPrivateKeySpecJdk11() { + curves.each { + byte[] d = new byte[it.encodedKeyByteLength]; Randoms.secureRandom().nextBytes(d) + def keySpec = it.privateKeySpec(d, false) // standard = false for JDK 11 bug + assertTrue keySpec instanceof PKCS8EncodedKeySpec + def expectedEncoded = Bytes.concat(it.PRIVATE_KEY_JDK11_PREFIX, d) + assertArrayEquals expectedEncoded, ((PKCS8EncodedKeySpec)keySpec).getEncoded() + } + } + @Test void testToPublicKeyInvalidLength() { curves.each { @@ -174,7 +187,7 @@ class EdwardsCurveTest { byte[] encoded = Bytes.concat( [0x30, it.encodedKeyByteLength + 10 + DER_NULL.length, 0x30, 0x05] as byte[], - it.DER_OID, + it.ASN1_OID, DER_NULL, // this should be skipped when getting key material [0x03, it.encodedKeyByteLength + 1, 0x00] as byte[], x @@ -212,7 +225,7 @@ class EdwardsCurveTest { it.getKeyMaterial(key) fail() } catch (InvalidKeyException ike) { - String msg = "Invalid ${it.getId()} DER encoding: Missing or incorrect algorithm OID." as String + String msg = "Invalid ${it.getId()} ASN.1 encoding: Missing or incorrect algorithm OID." as String assertEquals msg, ike.getMessage() } } @@ -226,13 +239,13 @@ class EdwardsCurveTest { encoded[0] = 0x20 // anything other than 0x03, 0x04, 0x05 curves.each { // prefix it with the OID to make it look valid: - encoded = Bytes.concat(it.DER_OID, encoded) + encoded = Bytes.concat(it.ASN1_OID, encoded) def key = new TestKey(encoded: encoded) try { it.getKeyMaterial(key) fail() } catch (InvalidKeyException ike) { - String msg = "Invalid ${it.getId()} DER encoding: Invalid key length." as String + String msg = "Invalid ${it.getId()} ASN.1 encoding: Invalid key length." as String assertEquals msg, ike.getMessage() } } @@ -246,13 +259,13 @@ class EdwardsCurveTest { size = it.encodedKeyByteLength byte[] keyBytes = new byte[size] Randoms.secureRandom().nextBytes(keyBytes) - byte[] encoded = Bytes.concat(it.PUBLIC_KEY_BER_PREFIX, keyBytes) + byte[] encoded = Bytes.concat(it.PUBLIC_KEY_ASN1_PREFIX, keyBytes) encoded[11] = 0x01 // should always be zero def key = new TestKey(encoded: encoded) it.getKeyMaterial(key) fail() } catch (InvalidKeyException ike) { - String msg = "Invalid ${it.getId()} DER encoding: BIT STREAM should not indicate unused bytes." as String + String msg = "Invalid ${it.getId()} ASN.1 encoding: BIT STREAM should not indicate unused bytes." as String assertEquals msg, ike.getMessage() } } @@ -266,13 +279,13 @@ class EdwardsCurveTest { size = it.encodedKeyByteLength byte[] keyBytes = new byte[size] Randoms.secureRandom().nextBytes(keyBytes) - byte[] encoded = Bytes.concat(it.PRIVATE_KEY_BER_PREFIX, keyBytes) - encoded[14] = 0x0F // should always be 0x04 (DER SEQUENCE tag) + byte[] encoded = Bytes.concat(it.PRIVATE_KEY_ASN1_PREFIX, keyBytes) + encoded[14] = 0x0F // should always be 0x04 (ASN.1 SEQUENCE tag) def key = new TestKey(encoded: encoded) it.getKeyMaterial(key) fail() } catch (InvalidKeyException ike) { - String msg = "Invalid ${it.getId()} DER encoding: Invalid key length." as String + String msg = "Invalid ${it.getId()} ASN.1 encoding: Invalid key length." as String assertEquals msg, ike.getMessage() } } @@ -286,13 +299,13 @@ class EdwardsCurveTest { size = it.encodedKeyByteLength - 1 // one less than required byte[] keyBytes = new byte[size] Randoms.secureRandom().nextBytes(keyBytes) - byte[] encoded = Bytes.concat(it.PUBLIC_KEY_BER_PREFIX, keyBytes) - encoded[10] = (byte) (size + 1) // DER size value (zero byte + key bytes) + byte[] encoded = Bytes.concat(it.PUBLIC_KEY_ASN1_PREFIX, keyBytes) + encoded[10] = (byte) (size + 1) // ASN.1 size value (zero byte + key bytes) def key = new TestKey(encoded: encoded) it.getKeyMaterial(key) fail() } catch (InvalidKeyException ike) { - String msg = "Invalid ${it.getId()} DER encoding: Invalid key length." as String + String msg = "Invalid ${it.getId()} ASN.1 encoding: Invalid key length." as String assertEquals msg, ike.getMessage() } } @@ -306,64 +319,18 @@ class EdwardsCurveTest { size = it.encodedKeyByteLength + 1 // one less than required byte[] keyBytes = new byte[size] Randoms.secureRandom().nextBytes(keyBytes) - byte[] encoded = Bytes.concat(it.PUBLIC_KEY_BER_PREFIX, keyBytes) - encoded[10] = (byte) (size + 1) // DER size value (zero byte + key bytes) + byte[] encoded = Bytes.concat(it.PUBLIC_KEY_ASN1_PREFIX, keyBytes) + encoded[10] = (byte) (size + 1) // ASN.1 size value (zero byte + key bytes) def key = new TestKey(encoded: encoded) it.getKeyMaterial(key) fail() } catch (InvalidKeyException ike) { - String msg = "Invalid ${it.getId()} DER encoding: Invalid key length." as String + String msg = "Invalid ${it.getId()} ASN.1 encoding: Invalid key length." as String assertEquals msg, ike.getMessage() } } } -// @Test -// void testParamKeySpecFactoryWithNullSpec() { -// def fn = EdwardsCurve.paramKeySpecFactory(null, true) -// assertSame Functions.forNull(), fn -// } -// -// @Test -// void testXecParamKeySpecFactory() { -// AlgorithmParameterSpec spec = new ECGenParameterSpec('foo') // any impl will do for this test -// def fn = EdwardsCurve.paramKeySpecFactory(spec, false) as EdwardsCurve.ParameterizedKeySpecFactory -// assertSame spec, fn.params -// assertSame EdwardsCurve.XEC_PRIV_KEY_SPEC_CTOR, fn.keySpecFactory -// } -// -// @Test -// void testEdEcParamKeySpecFactory() { -// AlgorithmParameterSpec spec = new ECGenParameterSpec('foo') // any impl will do for this test -// def fn = EdwardsCurve.paramKeySpecFactory(spec, true) as EdwardsCurve.ParameterizedKeySpecFactory -// assertSame spec, fn.params -// assertSame EdwardsCurve.EDEC_PRIV_KEY_SPEC_CTOR, fn.keySpecFactory -// } - -// @Test -// void testParamKeySpecFactoryInvocation() { -// AlgorithmParameterSpec spec = new ECGenParameterSpec('foo') // any impl will do for this test -// KeySpec keySpec = new PasswordSpec("foo".toCharArray()) // any KeySpec impl will do -// -// byte[] d = new byte[32] -// Randoms.secureRandom().nextBytes(d) -// -// def keySpecFn = new Function() { -// @Override -// KeySpec apply(Object o) { -// assertTrue o instanceof Object[] -// Object[] args = (Object[]) o -// assertSame spec, args[0] -// assertSame d, args[1] -// return keySpec // simulate a creation -// } -// } -// -// def fn = new EdwardsCurve.ParameterizedKeySpecFactory(spec, keySpecFn) -// def result = fn.apply(d) -// assertSame keySpec, result -// } - @Test void testDerivePublicKeyFromPrivateKey() { for (def curve : EdwardsCurve.VALUES) { diff --git a/impl/src/test/groovy/io/jsonwebtoken/impl/security/JcaTemplateTest.groovy b/impl/src/test/groovy/io/jsonwebtoken/impl/security/JcaTemplateTest.groovy index 37bd3416c..b3d1cc69e 100644 --- a/impl/src/test/groovy/io/jsonwebtoken/impl/security/JcaTemplateTest.groovy +++ b/impl/src/test/groovy/io/jsonwebtoken/impl/security/JcaTemplateTest.groovy @@ -15,6 +15,7 @@ */ package io.jsonwebtoken.impl.security +import io.jsonwebtoken.impl.lang.Bytes import io.jsonwebtoken.impl.lang.CheckedFunction import io.jsonwebtoken.security.SecurityException import io.jsonwebtoken.security.SignatureException @@ -23,9 +24,12 @@ import org.junit.Test import javax.crypto.Cipher import javax.crypto.Mac -import java.security.Provider -import java.security.Security -import java.security.Signature +import java.security.* +import java.security.cert.CertificateException +import java.security.spec.InvalidKeySpecException +import java.security.spec.KeySpec +import java.security.spec.PKCS8EncodedKeySpec +import java.security.spec.X509EncodedKeySpec import static org.junit.Assert.* @@ -37,7 +41,7 @@ class JcaTemplateTest { @Test void testGetInstanceExceptionMessage() { def factories = JcaTemplate.FACTORIES - for(def factory : factories) { + for (def factory : factories) { def clazz = factory.getInstanceClass() try { factory.get('foo', null) @@ -45,8 +49,8 @@ class JcaTemplateTest { if (clazz == Signature || clazz == Mac) { assertTrue expected instanceof SignatureException } - String prefix = "Unable to obtain ${clazz.getSimpleName()} instance " + - "from default JCA Provider for JCA algorithm 'foo': " + String prefix = "Unable to obtain 'foo' ${clazz.getSimpleName()} instance " + + "from default JCA Provider: " assertTrue expected.getMessage().startsWith(prefix) } } @@ -56,7 +60,7 @@ class JcaTemplateTest { void testGetInstanceWithExplicitProviderExceptionMessage() { def factories = JcaTemplate.FACTORIES def provider = BC_PROVIDER - for(def factory : factories) { + for (def factory : factories) { def clazz = factory.getInstanceClass() try { factory.get('foo', provider) @@ -64,8 +68,8 @@ class JcaTemplateTest { if (clazz == Signature || clazz == Mac) { assertTrue expected instanceof SignatureException } - String prefix = "Unable to obtain ${clazz.getSimpleName()} instance " + - "from specified Provider '${provider.toString()}' for JCA algorithm 'foo': " + String prefix = "Unable to obtain 'foo' ${clazz.getSimpleName()} instance " + + "from specified '${provider.toString()}' Provider: " assertTrue expected.getMessage().startsWith(prefix) } } @@ -102,69 +106,212 @@ class JcaTemplateTest { }) } -// @Test -// void testGetInstanceFailureWithExplicitProvider() { -// //noinspection GroovyUnusedAssignment -// Provider provider = Security.getProvider('SunJCE') -// def supplier = new JcaTemplate.JcaInstanceSupplier(Cipher.class, "AES", provider) { -// @Override -// protected Cipher doGetInstance() { -// throw new IllegalStateException("foo") -// } -// } -// -// try { -// supplier.getInstance() -// } catch (SecurityException ce) { //should be wrapped as SecurityException -// String msg = ce.getMessage() -// //we check for starts-with/ends-with logic here instead of equals because the JCE provider String value -// //contains the JCE version number, and that can differ across JDK versions. Since we use different JDK -// //versions in the test machine matrix, we don't want test failures from JDKs that run on higher versions -// assertTrue msg.startsWith('Unable to obtain Cipher instance from specified Provider {SunJCE') -// assertTrue msg.endsWith('} for JCA algorithm \'AES\': foo') -// } -// } -// -// @Test -// void testGetInstanceDoesNotWrapCryptoExceptions() { -// def ex = new SecurityException("foo") -// def supplier = new JcaTemplate.JcaInstanceSupplier(Cipher.class, 'AES', null) { -// @Override -// protected Cipher doGetInstance() { -// throw ex -// } -// } -// -// try { -// supplier.getInstance() -// } catch (SecurityException ce) { -// assertSame ex, ce -// } -// } -// -// static void wrapInSignatureException(Class instanceType, String jcaName) { -// def ex = new IllegalArgumentException("foo") -// def supplier = new JcaTemplate.JcaInstanceSupplier(instanceType, jcaName, null) { -// @Override -// protected Object doGetInstance() { -// throw ex -// } -// } -// -// try { -// supplier.getInstance() -// } catch (SignatureException se) { -// assertSame ex, se.getCause() -// String msg = "Unable to obtain ${instanceType.simpleName} instance from default JCA Provider for JCA algorithm '${jcaName}': foo" -// assertEquals msg, se.getMessage() -// } -// } - -// @Test -// void testNonCryptoExceptionForSignatureOrMacInstanceIsWrappedInSignatureException() { -// wrapInSignatureException(Signature.class, 'RSA') -// wrapInSignatureException(Mac.class, 'HmacSHA256') -// } + @Test + void testInstanceFactoryFallbackFailureRetainsOriginalException() { + String alg = 'foo' + NoSuchAlgorithmException ex = new NoSuchAlgorithmException('foo') + def factory = new JcaTemplate.JcaInstanceFactory(Cipher.class) { + @Override + protected Cipher doGet(String jcaName, Provider provider) throws Exception { + throw ex + } + + @Override + protected Provider findBouncyCastle() { + return null + } + } + + try { + factory.get(alg, null) + fail() + } catch (SecurityException se) { + assertSame ex, se.getCause() + String msg = "Unable to obtain '$alg' Cipher instance from default JCA Provider: $alg" + assertEquals msg, se.getMessage() + } + } + + @Test + void testWrapWithDefaultJcaProviderAndFallbackProvider() { + JcaTemplate.FACTORIES.each { + Provider fallback = Providers.findBouncyCastle() + String jcaName = 'foo' + NoSuchAlgorithmException nsa = new NoSuchAlgorithmException("doesn't exist") + Exception out = ((JcaTemplate.JcaInstanceFactory) it).wrap(nsa, jcaName, null, fallback) + assertTrue out instanceof SecurityException + String msg = "Unable to obtain '${jcaName}' ${it.getId()} instance from default JCA Provider or fallback " + + "'${fallback.toString()}' Provider: doesn't exist" + assertEquals msg, out.getMessage() + } + } + + @Test + void testFallbackWithBouncyCastle() { + def template = new JcaTemplate('foo', null) + try { + template.generateX509Certificate(Bytes.random(32)) + } catch (SecurityException expected) { + String prefix = "Unable to obtain 'foo' CertificateFactory instance from default JCA Provider: " + assertTrue expected.getMessage().startsWith(prefix) + assertTrue expected.getCause() instanceof CertificateException + } + } + + @Test + void testFallbackWithoutBouncyCastle() { + def template = new JcaTemplate('foo', null) { + @Override + protected Provider findBouncyCastle() { + return null + } + } + try { + template.generateX509Certificate(Bytes.random(32)) + } catch (SecurityException expected) { + String prefix = "Unable to obtain 'foo' CertificateFactory instance from default JCA Provider: " + assertTrue expected.getMessage().startsWith(prefix) + assertTrue expected.getCause() instanceof CertificateException + } + } + + static InvalidKeySpecException jdk8213363BugEx(String msg) { + // mock up JDK 11 bug behavior: + String className = 'sun.security.ec.XDHKeyFactory' + String methodName = 'engineGeneratePrivate' + def ste = new StackTraceElement(className, methodName, null, 0) + def stes = new StackTraceElement[]{ste} + def cause = new InvalidKeyException(msg) + def ex = new InvalidKeySpecException(cause) { + @Override + StackTraceElement[] getStackTrace() { + return stes + } + } + return ex + } + + @Test + void testJdk8213363Bug() { + for(def bundle in [TestKeys.X25519, TestKeys.X448]) { + def privateKey = bundle.pair.private + byte[] d = bundle.alg.getKeyMaterial(privateKey) + byte[] pkcs8d = Bytes.concat(new byte[]{0x04, (byte) (d.length)}, d) + int callCount = 0 + def ex = jdk8213363BugEx("key length must be ${d.length}") + def template = new Jdk8213363JcaTemplate(bundle.alg.id) { + @Override + protected PrivateKey generatePrivate(KeyFactory factory, KeySpec spec) throws InvalidKeySpecException { + if (callCount == 0) { // simulate first attempt throwing an exception + callCount++ + throw ex + } + // otherwise 2nd call due to fallback logic, simulate a successful call: + return privateKey + } + } + assertSame privateKey, template.generatePrivate(new PKCS8EncodedKeySpec(pkcs8d)) + } + } + + @Test + void testGeneratePrivateRespecWithoutPkcs8() { + byte[] invalid = Bytes.random(456) + def template = new JcaTemplate('X448', null) + try { + template.generatePrivate(new X509EncodedKeySpec(invalid)) + fail() + } catch (SecurityException expected) { + assertEquals 'KeyFactory callback execution failed: key spec not recognized', expected.getMessage() + } + } + + @Test + void testGeneratePrivateRespecTooSmall() { + byte[] invalid = Bytes.random(16) + def ex = jdk8213363BugEx("key length must be ${invalid.length}") + def template = new Jdk8213363JcaTemplate('X25519') { + @Override + protected PrivateKey generatePrivate(KeyFactory factory, KeySpec spec) throws InvalidKeySpecException { + throw ex + } + } + try { + template.generatePrivate(new PKCS8EncodedKeySpec(invalid)) + fail() + } catch (SecurityException expected) { + String msg = "KeyFactory callback execution failed: java.security.InvalidKeyException: " + + "key length must be ${invalid.length}" + assertEquals msg, expected.getMessage() + } + } + + @Test + void testGeneratePrivateRespecTooLarge() { + byte[] invalid = Bytes.random(50) + def ex = jdk8213363BugEx("key length must be ${invalid.length}") + def template = new Jdk8213363JcaTemplate('X448') { + @Override + protected PrivateKey generatePrivate(KeyFactory factory, KeySpec spec) throws InvalidKeySpecException { + throw ex + } + } + try { + template.generatePrivate(new PKCS8EncodedKeySpec(invalid)) + fail() + } catch (SecurityException expected) { + String msg = "KeyFactory callback execution failed: java.security.InvalidKeyException: " + + "key length must be ${invalid.length}" + assertEquals msg, expected.getMessage() + } + } + + @Test + void testGetJdk8213363BugExpectedSizeNoExMsg() { + InvalidKeyException ex = new InvalidKeyException() + def template = new JcaTemplate('X448', null) + assertEquals(-1, template.getJdk8213363BugExpectedSize(ex)) + } + + @Test + void testGetJdk8213363BugExpectedSizeExMsgDoesntMatch() { + InvalidKeyException ex = new InvalidKeyException('not what is expected') + def template = new JcaTemplate('X448', null) + assertEquals(-1, template.getJdk8213363BugExpectedSize(ex)) + } + + @Test + void testGetJdk8213363BugExpectedSizeExMsgDoesntContainNumber() { + InvalidKeyException ex = new InvalidKeyException('key length must be foo') + def template = new JcaTemplate('X448', null) + assertEquals(-1, template.getJdk8213363BugExpectedSize(ex)) + } + + @Test + void testRespecIfNecessaryWithoutPkcs8KeySpec() { + def spec = new X509EncodedKeySpec(Bytes.random(32)) + def template = new JcaTemplate('X448', null) + assertNull template.respecIfNecessary(null, spec) + } + + @Test + void testRespecIfNecessaryNotJdk8213363Bug() { + def ex = new InvalidKeySpecException('foo') + def template = new JcaTemplate('X448', null) + assertNull template.respecIfNecessary(ex, new PKCS8EncodedKeySpec(Bytes.random(32))) + } + + @Test + void testIsJdk11() { + // determine which JDK the test is being run on in CI: + boolean testMachineIsJdk11 = System.getProperty('java.version').startsWith('11') + def template = new JcaTemplate('X448', null) + if (testMachineIsJdk11) { + assertTrue template.isJdk11() + } else { + assertFalse template.isJdk11() + } + } @Test void testCallbackThrowsException() { @@ -183,4 +330,13 @@ class JcaTemplateTest { } } + private static class Jdk8213363JcaTemplate extends JcaTemplate { + Jdk8213363JcaTemplate(String jcaName) { + super(jcaName, null) + } + @Override + protected boolean isJdk11() { + return true + } + } } diff --git a/impl/src/test/groovy/io/jsonwebtoken/impl/security/JwksTest.groovy b/impl/src/test/groovy/io/jsonwebtoken/impl/security/JwksTest.groovy index 5e103346c..5d7d0fe33 100644 --- a/impl/src/test/groovy/io/jsonwebtoken/impl/security/JwksTest.groovy +++ b/impl/src/test/groovy/io/jsonwebtoken/impl/security/JwksTest.groovy @@ -269,10 +269,6 @@ class JwksTest { assertEquals pub, pubJwk.toKey() def builder = Jwks.builder().key(priv).publicKeyUse('sig') - if (alg instanceof EdSignatureAlgorithm) { - // We haven't implemented EdDSA public-key derivation yet, so public key is required - builder.publicKey(pub) - } PrivateJwk privJwk = builder.build() assertEquals priv, privJwk.toKey() PublicJwk privPubJwk = privJwk.toPublicJwk() diff --git a/impl/src/test/groovy/io/jsonwebtoken/impl/security/ProvidersTest.groovy b/impl/src/test/groovy/io/jsonwebtoken/impl/security/ProvidersTest.groovy index eea0530c9..e95490fd2 100644 --- a/impl/src/test/groovy/io/jsonwebtoken/impl/security/ProvidersTest.groovy +++ b/impl/src/test/groovy/io/jsonwebtoken/impl/security/ProvidersTest.groovy @@ -61,6 +61,11 @@ class ProvidersTest { new Providers() } + @Test + void testFindBouncyCastleFalse() { + assertNull Providers.findBouncyCastle(Conditions.FALSE); + } + @Test void testBouncyCastleAlreadyExists() { diff --git a/impl/src/test/groovy/io/jsonwebtoken/impl/security/TestCertificates.groovy b/impl/src/test/groovy/io/jsonwebtoken/impl/security/TestCertificates.groovy index 9b7e0dca6..07c943df7 100644 --- a/impl/src/test/groovy/io/jsonwebtoken/impl/security/TestCertificates.groovy +++ b/impl/src/test/groovy/io/jsonwebtoken/impl/security/TestCertificates.groovy @@ -17,7 +17,6 @@ package io.jsonwebtoken.impl.security import io.jsonwebtoken.Identifiable import io.jsonwebtoken.impl.lang.CheckedFunction -import io.jsonwebtoken.impl.lang.Conditions import io.jsonwebtoken.lang.Classes import io.jsonwebtoken.lang.Strings import org.bouncycastle.asn1.pkcs.PrivateKeyInfo @@ -128,7 +127,10 @@ class TestCertificates { // (for example, an Ed25519 key on JDK 8 which doesn't natively support such keys). This means the // X.509 certificate should also be loaded by BC; otherwise the Sun X.509 CertificateFactory returns // a certificate with certificate.getPublicKey() being a sun X509Key instead of the type-specific key we want: - Provider provider = Providers.findBouncyCastle(Conditions.of(pub.getClass().getName().startsWith("org.bouncycastle"))) + Provider provider = null + if (pub.getClass().getName().startsWith("org.bouncycastle")) { + provider = Providers.findBouncyCastle() + } X509Certificate cert = readCert(alg, provider) as X509Certificate PublicKey certPub = cert.getPublicKey() assert pub.equals(certPub)