Reputation: 995
I have to do signature verification of a token in Java which uses algorithm as ES256
{"typ":"JWT","alg":"ES256","kid":"4"}
The public JWKS has below format:
{
"kty" : "EC",
"kid" : "4",
"use" : "sig",
"x" : "hkjfghkjfdghkjdfsglkdjhg",
"y" : "skjgf krhgkre",
"crv" : "P-256",
"x5c" : [ "uchfgurhnvgrejbhkltjrhbkrknlytknjlkfldfmndfkfvmlkasdfkljflksdanfgklnsdkjfnsadkjnkjdfnglksdfhkljdlkhfklhjdfgklghjkldfjklfjgklnvdfngjksdnfngkjvnsdfjkvsdfjkgndjkhnkjdsnhkltejhk" ],
"x5t" : "jcdhsvkjgnrekngk"
}
What is the way to verify these?
I had a look in the RS256 JWT token verification where the public JWKS is very different as below:
{
"kty" : "RSA",
"kid" : "16",
"use" : "sig",
"n" : "sdghjfhgjhfjdghjkdfhghfdghfdkjhgkjfhgkjfhgkjhfjkffghjkshgjkfhgjkhfjkghjkfhgjkhfgjkhfjkghjkfhgjkafhjghfjkhgkjfhgkjhfjklghlsjkfhgjksfagmfnvmbrgmberkjltgnerkjhgjkerngkjerngjkhsjkghsjklghjkhgjkhjkghjkfahgjkhgjkhfjklghjkfhg",
"e" : "AQAB",
"x5c" : [ "MIIClkdfgjlkdfjklgjdfkljgkfdjgkljfkgjfklgjkldjgdfjgkldftuioreutiourtoiuriotuieorutiorutioeurtiueriotuioerwutioerutioukgjkldjgkldfjgkljklgjdfklgjorutoireutioerutiueriotuklgjkldjgklsdfjgkldfjgklsdjgkl" ],
"x5t" : "jshgjfhgjkhkhghgkdfhgklhdklh"
}
Is there any library in Java which offers this functionality? I think in RS256 modulus and exponent is used to verify the signature, but what is used for ES256, I am not sure.
Please find the code that i am using for RS256.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigInteger;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.security.Key;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAPublicKeySpec;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;
import org.cache2k.Cache;
import org.cache2k.Cache2kBuilder;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.Jwt;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.SigningKeyResolverAdapter;
import java.net.HttpURLConnection;
public class JWTValidation {
//Externalize the hosts as per the environment
private static final String USER_AGENT = "Mozilla/5.0";
// Create a cache object
final static Cache<String,String> _cache = new Cache2kBuilder<String, String>() {}
.expireAfterWrite(30, TimeUnit.MINUTES) // expire/refresh after 30 minutes
.build();
static String _jwkVersionCache = _cache.peek("jwk_version");
static String _modulusCache = _cache.peek("modulus");
static String _exponentCache = _cache.peek("exponent");
public static void main(String[] args) throws JsonParseException, JsonMappingException, IOException {
System.out.println("Sample code to validate JWT");
// Running the code in loop to test multiple scenarios
while(true) {
// Used only for console app to get the JWS as user input
Scanner reader = new Scanner(System.in);
// Get the JWT
System.out.println("Enter jwt or enter exit to terminate");
String signedJwtToken = reader.next();
if(signedJwtToken.equalsIgnoreCase("Exit"))
{
break;
}
try {
// Validate the signed JWT (JWS)
ValidateJWS(signedJwtToken);
}
catch (Exception e) {
System.out.println("JWS validation failed");
}
finally {
}
}
}
// Code to validate signed JWT (JWS)
private static void ValidateJWS(String signedJwtToken)
{
StringBuilder sb = null;
String jwtWithoutSignature;
String jwtVersion;
String jwksUri;
String jwksUrl;
String kid;
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {};
ObjectMapper mapper = new ObjectMapper();
Map<String, Object> jwks = null;
@SuppressWarnings("rawtypes")
Jwt<Header, Claims> jwtClaims = null;
try {
// Extract the base64 encoded JWT from the signed JWT token (JWS)
sb = new StringBuilder();
sb.append(signedJwtToken);
jwtWithoutSignature = sb.substring(0, sb.toString().lastIndexOf(".") + 1);
// Parse claims without validating the signature
jwtClaims = Jwts.parser().parseClaimsJwt(jwtWithoutSignature);
// Extract the jwk uri 'jku' & the version 'ver' from the JWT
jwtVersion = (String) jwtClaims.getBody().get("ver");
jwksUri = (String) jwtClaims.getBody().get("jku");
// Extract the kid from JWT
kid = (String) jwtClaims.getHeader().get("kid");
jwksUrl = jwksHost;
System.out.println("jwtVersion: " + jwtVersion);
System.out.println("jwksUri: " + jwksUri);
System.out.println("kid: " + kid);
// Cache the jwk version (ver), modulus (n) and exponent (e) for lifetime of the application.
// The JWT version will be same as jwk version. The jwt version will change only when the
// JWT signing certificate is renewed.
// Invoke the JWK url only if the jwt version is different from the JWK version.
// check if the JWK version is cached or not
if (_cache.get("jwk_version") != null) {
// check if jwt version is same as jwk version
if (!jwtVersion.equals(_jwkVersionCache)) {
// Get the jwk key & add the modulus, exponent & the jwk version to the cache
GetJWK(jwksUrl, kid);
}
}
else
{
// Get the jwk key & add the modulus, exponent & the jwk version to the cache
GetJWK(jwksUrl, kid);
}
// Calling the setSigningKeyResolver as the JWT is parsed before validating the signature
SigningKeyResolver resolver = new SigningKeyResolverAdapter() {
@SuppressWarnings("rawtypes")
public Key resolveSigningKey(JwsHeader jwsHeader, Claims claims) {
try {
// Build the RSA public key from modulus & exponent in JWK
BigInteger modulus = new BigInteger(1, Base64.getUrlDecoder().decode(_modulusCache));
BigInteger exponent = new BigInteger(1, Base64.getUrlDecoder().decode(_exponentCache));
PublicKey rsaPublicKey = KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(modulus, exponent));
return rsaPublicKey;
} catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
System.out.println("Failed to resolve key: " + e);
return null;
}
}
};
try {
// Parse claims and validate the signature
Jws<Claims> jwsClaims = Jwts.parser().setSigningKeyResolver(resolver).parseClaimsJws(signedJwtToken);
System.out.println("Signature on this JWT is good and the JWT token has not expired");
// OK, we can trust this JWT
// Parse the claims
System.out.println("JWS claims: " + jwsClaims.getBody());
// Code below to validate the claims
}
catch (Exception ex) {
System.out.println("Unable to validate JWS");
}
}
// catch (SignatureException e)
catch (Exception e) {
// don't trust the JWT!
System.out.println("JWT is malformed or expired");
}
}
// Get the corresponding JWK using key Id from the JWK set
@SuppressWarnings("unchecked")
static private Map<String, String> GetKeyById(Map<String, Object> jwks, String kid) {
List<Map<String, String>> keys = (List<Map<String, String>>)jwks.get("keys");
Map<String, String> ret = null;
for (int i = 0; i < keys.size(); i++) {
if (keys.get(i).get("kid").equals(kid)) {
System.out.println("i-->"+ keys.get(i).get("kid"));
System.out.println("i set-->"+ keys.get(i));
return keys.get(i);
}
}
return ret;
}
// Get the JWK Set from the JWK endpoint
private static void GetJWK(String jwkUrl, String kid) throws IOException {
URL url = new URL(jwkUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
try {
//URL url = new URL(jwkUrl);
System.out.println("url: "+url);
//connection = (HttpURLConnection) url.openConnection();
System.out.println("connection: "+connection);
connection.setRequestMethod("GET");
connection.setRequestProperty("User-Agent", USER_AGENT);
int responseCode = connection.getResponseCode();
System.out.println("GET Response Code :: " + responseCode);
BufferedReader rd = new BufferedReader(new InputStreamReader(connection.getInputStream()));
System.out.println("rd: "+rd);
//StringBuilder response = new StringBuilder();
String inputLine;
StringBuffer response = new StringBuffer();
while ((inputLine = rd.readLine()) != null) {
response.append(inputLine);
}
//in.close();
// print result
System.out.println(response.toString());
System.out.println("response:--> "+response);
String line;
//while ((line = rd.readLine()) != null) {
// response.append(line);
// response.append('\r');
//}
rd.close();
// Jackson mapper for parsing the json
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {};
ObjectMapper mapper = new ObjectMapper();
Map<String, Object> jwks = mapper.readValue(response.toString(), typeRef);
// Get the jwk by using the key Id from the jwt
Map<String, String> jwk = GetKeyById(jwks, kid);
// Get the modulus 'n' & the exponent 'n' from the JWK & add it to cache
if (jwk != null) {
_cache.put("modulus", jwk.get("x5c"));
_modulusCache = _cache.get("modulus");
_cache.put("exponent", jwk.get("e"));
_exponentCache = _cache.get("exponent");
_cache.put("jwk_version", jwk.get("ver"));
_jwkVersionCache = _cache.get("jwk_version");
}
} catch (Exception e) {
// Unable to fetch JWKS. Terminate this program
System.out.println("Error getting jwks: " + e);
} finally {
if (connection != null) {
connection.disconnect();
}
}
}
}
RFC says: https://www.rfc-editor.org/rfc/rfc7518#section-3.1
"alg" Param value = ES256 Digital Signature or MAC value = ECDSA using P-256 and SHA-256
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigInteger;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.security.*;
import java.security.spec.*;
import java.text.ParseException;
import java.util.Base64;
import java.util.Base64.Decoder;
import java.util.Base64.Encoder;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;
import com.nimbusds.jose.JOSEException;
import org.cache2k.Cache;
import org.cache2k.Cache2kBuilder;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.Jwt;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.SigningKeyResolverAdapter;
import java.net.HttpURLConnection;
public class Es256stack {
final static String jwksHost = "";
private static final String USER_AGENT = "Mozilla/5.0";
// Create a cache object
final static Cache<String,String> _cache = new Cache2kBuilder<String, String>() {}
.expireAfterWrite(30, TimeUnit.MINUTES) // expire/refresh after 30 minutes
.build();
static String _jwkVersionCache = _cache.peek("jwk_version");
static String _modulusCache = _cache.peek("modulus");
static String _exponentCache = _cache.peek("exponent");
public static void main(String[] args) throws NoSuchAlgorithmException, InvalidParameterSpecException, InvalidKeySpecException {
ValidateJWS(signedJwtToken);
}
//String signedJwtToken="";
// Code to validate signed JWT (JWS)
private static void ValidateJWS(String signedJwtToken)
{
StringBuilder sb = null;
String jwtWithoutSignature;
String jwtVersion;
String jwksUri;
String jwksUrl;
String kid;
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {};
ObjectMapper mapper = new ObjectMapper();
Map<String, Object> jwks = null;
@SuppressWarnings("rawtypes")
Jwt<Header, Claims> jwtClaims = null;
try {
// Extract the base64 encoded JWT from the signed JWT token (JWS)
sb = new StringBuilder();
sb.append(signedJwtToken);
jwtWithoutSignature = sb.substring(0, sb.toString().lastIndexOf(".") + 1);
// Parse claims without validating the signature
jwtClaims = Jwts.parser().parseClaimsJwt(jwtWithoutSignature);
// Extract the jwk uri 'jku' & the version 'ver' from the JWT
jwtVersion = (String) jwtClaims.getBody().get("ver");
jwksUri = (String) jwtClaims.getBody().get("jku");
// Extract the kid from JWT
kid = (String) jwtClaims.getHeader().get("kid");
jwksUrl = "";
System.out.println("jwtVersion: " + jwtVersion);
System.out.println("jwksUri: " + jwksUri);
System.out.println("kid: " + kid);
// Cache the jwk version (ver), modulus (n) and exponent (e) for lifetime of the application.
// The JWT version will be same as jwk version. The jwt version will change only when the
// JWT signing certificate is renewed.
// Invoke the JWK url only if the jwt version is different from the JWK version.
// check if the JWK version is cached or not
if (_cache.get("jwk_version") != null) {
// check if jwt version is same as jwk version
if (!jwtVersion.equals(_jwkVersionCache)) {
// Get the jwk key & add the modulus, exponent & the jwk version to the cache
GetJWK(jwksUrl, kid);
}
}
else
{
// Get the jwk key & add the modulus, exponent & the jwk version to the cache
GetJWK(jwksUrl, kid);
}
// Calling the setSigningKeyResolver as the JWT is parsed before validating the signature
SigningKeyResolver resolver = new SigningKeyResolverAdapter() {
@SuppressWarnings("rawtypes")
public Key resolveSigningKey(JwsHeader jwsHeader, Claims claims) {
try {
AlgorithmParameters a = AlgorithmParameters.getInstance("EC");
a.init(new ECGenParameterSpec("secp256r1"));
ECParameterSpec p = a.getParameterSpec(ECParameterSpec.class);
// Build the RSA public key from modulus & exponent in JWK
BigInteger x = new BigInteger(1, Base64.getDecoder().decode(x)); // either direct or cached
BigInteger y = new BigInteger(1, Base64.getDecoder().decode(y)); // ditto
PublicKey ecPublicKey = KeyFactory.getInstance("EC").generatePublic(new ECPublicKeySpec(new ECPoint(x,y), p));
return ecPublicKey;
} catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
System.out.println("Failed to resolve key: " + e);
return null;
} catch (InvalidParameterSpecException e) {
throw new RuntimeException(e);
}
}
};
try {
// Parse claims and validate the signature
Jws<Claims> jwsClaims = Jwts.parser().setSigningKeyResolver(resolver).parseClaimsJws(signedJwtToken);
System.out.println("Signature on this JWT is good and the JWT token has not expired");
// OK, we can trust this JWT
// Parse the claims
System.out.println("JWS claims: " + jwsClaims.getBody());
// Code below to validate the claims
}
catch (Exception ex) {
System.out.println("Unable to validate JWS");
}
}
// catch (SignatureException e)
catch (Exception e) {
// don't trust the JWT!
System.out.println("JWT is malformed or expired");
}
}
// Get the JWK Set from the JWK endpoint
private static void GetJWK(String jwkUrl, String kid) throws IOException {
URL url = new URL(jwkUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
try {
//URL url = new URL(jwkUrl);
System.out.println("url: "+url);
//connection = (HttpURLConnection) url.openConnection();
System.out.println("connection: "+connection);
connection.setRequestMethod("GET");
connection.setRequestProperty("User-Agent", USER_AGENT);
int responseCode = connection.getResponseCode();
System.out.println("GET Response Code :: " + responseCode);
BufferedReader rd = new BufferedReader(new InputStreamReader(connection.getInputStream()));
System.out.println("rd: "+rd);
//StringBuilder response = new StringBuilder();
String inputLine;
StringBuffer response = new StringBuffer();
while ((inputLine = rd.readLine()) != null) {
response.append(inputLine);
}
//in.close();
// print result
System.out.println(response.toString());
System.out.println("response:--> "+response);
String line;
//while ((line = rd.readLine()) != null) {
// response.append(line);
// response.append('\r');
//}
rd.close();
// Jackson mapper for parsing the json
TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {};
ObjectMapper mapper = new ObjectMapper();
Map<String, Object> jwks = mapper.readValue(response.toString(), typeRef);
System.out.println("jwks:--> "+jwks);
// Get the jwk by using the key Id from the jwt
Map<String, String> jwk = GetKeyById(jwks, kid);
System.out.println("jwk:--> "+jwk);
// Get the modulus 'n' & the exponent 'n' from the JWK & add it to cache
if (jwk != null) {
}
} catch (Exception e) {
// Unable to fetch JWKS. Terminate this program
System.out.println("Error getting jwks: " + e);
} finally {
if (connection != null) {
connection.disconnect();
}
}
}
static private Map<String, String> GetKeyById(Map<String, Object> jwks, String kid) {
List<Map<String, String>> keys = (List<Map<String, String>>)jwks.get("keys");
Map<String, String> ret = null;
for (int i = 0; i < keys.size(); i++) {
if (keys.get(i).get("kid").equals(kid)) {
System.out.println("i-->"+ keys.get(i).get("kid"));
System.out.println("i set-->"+ keys.get(i));
return keys.get(i);
}
}
return ret;
}
}
JWT signature does not match locally computed signature. JWT validity cannot be asserted and should not be trusted.
cert part worked as below:
String x5c="";
System.out.println(" x5c ="+x5c);
String stripped = x5c.replaceAll("-----BEGIN (.*)-----", "");
stripped = stripped.replaceAll("-----END (.*)----", "");
stripped = stripped.replaceAll("\r\n", "");
stripped = stripped.replaceAll("\n", "");
stripped.trim();
System.out.println(" stripped ="+stripped);
byte[] keyBytes = com.sun.org.apache.xerces.internal.impl.dv.util.Base64.decode(stripped);
CertificateFactory fact = CertificateFactory.getInstance("X.509");
X509Certificate cer = (X509Certificate) fact.generateCertificate(new ByteArrayInputStream(keyBytes));
System.out.println(cer);
return cer.getPublicKey();
Upvotes: 1
Views: 3603
Reputation: 39029
First, your code has a bug or is miscopied. In GetJWK
in the last block (before catch
) you have a comment Get the modulus 'n' & the exponent 'n'
which is wrong (the public exponent is 'e') but the code shown actually gets 'x5c' not 'n' and uses it as the modulus, which is very wrong, and shouldn't even work because 'x5c' is an array not a scalar.
Yes, the library you are using (jjwt) can verify (and generate) ECDSA signatures in JWS/JWT. For any code, the signature is generated using the algorithm-dependent elements in the private key, and verified using the algorithm-dependent elements in the public key: n
and e
for RSA, x
and y
for ECDSA on a curve which is both implied by alg
and restated in crv
-- see part of rfc7518 section 3.4 and rfc7518 section 6.2.1. Note x and y must be valid base64url (yours aren't) and must be exactly the length required by the size of the curve-group defined by alg
and crv
(yours aren't).
You can construct a Java-crypto ECPublicKey
(or pedantically a provider's implementation object implementing that) similar to what you do in resolveSigningKey
now for RSA, except that EC requires 'parameters' for the curve in addition to x and y:
// this part is the same for all keys and could be done at init or memoized
AlgorithmParameters a = AlgorithmParameters.getInstance("EC");
a.init(new ECGenParameterSpec("P-256"));
ECParameterSpec p = a.getParameterSpec(ECParameterSpec.class);
// this part must be redone for each different key
// to prevent misuse verify crv_field (either direct or cached) equals("P-256")
// and probably alg_field (ditto) equals("ES256")
BigInteger x = new BigInteger(1, base64urldecode(x_field)); // either direct or cached
BigInteger y = new BigInteger(1, base64urldecode(y_field)); // ditto
PublicKey ecPublicKey = KeyFactory.getInstance("EC").generatePublic(new ECPublicKeySpec(new ECPoint(x,y), p));
// add exception handling to taste
However, if your JWKs (always) have x5c as in your examples, but with a valid value (yours aren't valid base64url-of-DER and are much too small) you can use much simpler code; for all signature algorithms just do:
String x5c = // get element 0 of field 'x5c' from JWK (cached if you like)
X509Certificate cert = (X509Certificate) CertificateFactory.getInstance("X.509")
.generateCertificate( new ByteArrayInputStream( base64decode(x5c) ) ); // NOT base64url
// (with exception handling of course)
// then use cert.getPublicKey() as the PublicKey for jjwt
Upvotes: 2