user890234
user890234

Reputation: 995

ECDSA ES256 JWKS JWT signature verification in Java (Elliptic Curve Signatures)

I have to do signature verification of a token in Java which uses algorithm as ES256

{"typ":"JWT","alg":"ES256","kid":"4"}

The public JWKS has below format:

{
    "kty" : "EC",
    "kid" : "4",
    "use" : "sig",
    "x" : "hkjfghkjfdghkjdfsglkdjhg",
    "y" : "skjgf krhgkre",
    "crv" : "P-256",
    "x5c" : [ "uchfgurhnvgrejbhkltjrhbkrknlytknjlkfldfmndfkfvmlkasdfkljflksdanfgklnsdkjfnsadkjnkjdfnglksdfhkljdlkhfklhjdfgklghjkldfjklfjgklnvdfngjksdnfngkjvnsdfjkvsdfjkgndjkhnkjdsnhkltejhk" ],
    "x5t" : "jcdhsvkjgnrekngk"
  }

What is the way to verify these?

I had a look in the RS256 JWT token verification where the public JWKS is very different as below:

{
    "kty" : "RSA",
    "kid" : "16",
    "use" : "sig",
    "n" : "sdghjfhgjhfjdghjkdfhghfdghfdkjhgkjfhgkjfhgkjhfjkffghjkshgjkfhgjkhfjkghjkfhgjkhfgjkhfjkghjkfhgjkafhjghfjkhgkjfhgkjhfjklghlsjkfhgjksfagmfnvmbrgmberkjltgnerkjhgjkerngkjerngjkhsjkghsjklghjkhgjkhjkghjkfahgjkhgjkhfjklghjkfhg",
    "e" : "AQAB",
    "x5c" : [ "MIIClkdfgjlkdfjklgjdfkljgkfdjgkljfkgjfklgjkldjgdfjgkldftuioreutiourtoiuriotuieorutiorutioeurtiueriotuioerwutioerutioukgjkldjgkldfjgkljklgjdfklgjorutoireutioerutiueriotuklgjkldjgklsdfjgkldfjgklsdjgkl" ],
    "x5t" : "jshgjfhgjkhkhghgkdfhgklhdklh"
  }

Is there any library in Java which offers this functionality? I think in RS256 modulus and exponent is used to verify the signature, but what is used for ES256, I am not sure.

Please find the code that i am using for RS256.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigInteger;

import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.security.Key;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAPublicKeySpec;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;

import org.cache2k.Cache;
import org.cache2k.Cache2kBuilder;

import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.Jwt;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.SigningKeyResolverAdapter;
import java.net.HttpURLConnection;

public class JWTValidation {
    
    //Externalize the hosts as per the environment
    
    private static final String USER_AGENT = "Mozilla/5.0";
    // Create a cache object
    final static Cache<String,String> _cache = new Cache2kBuilder<String, String>() {}
      .expireAfterWrite(30, TimeUnit.MINUTES)    // expire/refresh after 30 minutes
    .build();     
    static String _jwkVersionCache = _cache.peek("jwk_version");
    static String _modulusCache = _cache.peek("modulus");
    static String _exponentCache = _cache.peek("exponent");
    
    public static void main(String[] args) throws JsonParseException, JsonMappingException, IOException {   
        System.out.println("Sample code to validate JWT");  
        // Running the code in loop to test multiple scenarios
        while(true) {   
            
            // Used only for console app to get the JWS as user input
            Scanner reader = new Scanner(System.in);
            // Get the JWT 
            System.out.println("Enter jwt or enter exit to terminate"); 
            String signedJwtToken = reader.next();

            if(signedJwtToken.equalsIgnoreCase("Exit")) 
            {               
                break;
            }
            
            try {
                // Validate the signed JWT (JWS)
                ValidateJWS(signedJwtToken);
            }
            catch (Exception e) {
                System.out.println("JWS validation failed");
            }
            finally {

            }
        }               
    }   
    
    // Code to validate signed JWT (JWS)
    private static void ValidateJWS(String signedJwtToken)
    {
        StringBuilder sb = null;
        String jwtWithoutSignature; 
        String jwtVersion;
        String jwksUri;
        String jwksUrl;
        String kid;
        TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {};
        ObjectMapper mapper = new ObjectMapper();   
        Map<String, Object> jwks = null;
            
        @SuppressWarnings("rawtypes")
        Jwt<Header, Claims> jwtClaims = null;
        try {           
            
            // Extract the base64 encoded JWT from the signed JWT token (JWS) 
            sb = new StringBuilder();
            sb.append(signedJwtToken);
            jwtWithoutSignature = sb.substring(0, sb.toString().lastIndexOf(".") + 1);  
            
            // Parse claims without validating the signature
            jwtClaims = Jwts.parser().parseClaimsJwt(jwtWithoutSignature);  
            
            // Extract the jwk uri 'jku' & the version 'ver' from the JWT   
            jwtVersion = (String) jwtClaims.getBody().get("ver");
            jwksUri = (String) jwtClaims.getBody().get("jku");
            // Extract the kid from JWT
            kid = (String) jwtClaims.getHeader().get("kid");
            
            jwksUrl = jwksHost;
            System.out.println("jwtVersion: " + jwtVersion);
            System.out.println("jwksUri: " + jwksUri);
            System.out.println("kid: " + kid);
            
            // Cache the jwk version (ver), modulus (n) and exponent (e) for lifetime of the application.
            // The JWT version will be same as jwk version. The jwt version will change only when the 
            // JWT signing certificate is renewed.
            // Invoke the JWK url only if the jwt version is different from the JWK version. 
            
            // check if the JWK version is cached or not
            if (_cache.get("jwk_version") != null) {
                // check if jwt version is same as jwk version 
                if (!jwtVersion.equals(_jwkVersionCache)) {
                    // Get the jwk key & add the modulus, exponent & the jwk version to the cache
                    GetJWK(jwksUrl, kid);   
                }
            }
            else
            {
                // Get the jwk key & add the modulus, exponent & the jwk version to the cache
                GetJWK(jwksUrl, kid);       
            }
            
            // Calling the setSigningKeyResolver as the JWT is parsed before validating the signature 
            SigningKeyResolver resolver = new SigningKeyResolverAdapter() {
                @SuppressWarnings("rawtypes")
                public Key resolveSigningKey(JwsHeader jwsHeader, Claims claims) {
                    try {            
                        // Build the RSA public key from modulus & exponent in JWK 
                        BigInteger modulus = new BigInteger(1, Base64.getUrlDecoder().decode(_modulusCache));
                        BigInteger exponent = new BigInteger(1, Base64.getUrlDecoder().decode(_exponentCache));
                        PublicKey rsaPublicKey = KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(modulus, exponent));
                        return rsaPublicKey;
                    } catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
                        System.out.println("Failed to resolve key: " + e);
                        return null;
                    }
                }
            };
            
            try {
                // Parse claims and validate the signature
                Jws<Claims> jwsClaims = Jwts.parser().setSigningKeyResolver(resolver).parseClaimsJws(signedJwtToken);
                System.out.println("Signature on this JWT is good and the JWT token has not expired");
                // OK, we can trust this JWT
                
                // Parse the claims
                System.out.println("JWS claims: " + jwsClaims.getBody());
                
                // Code below to validate the claims
                
            }
            catch (Exception ex) {
                System.out.println("Unable to validate JWS");
            }

        }
        // catch (SignatureException e)
        catch (Exception e) {
            // don't trust the JWT!
            System.out.println("JWT is malformed or expired");
        }                   
    }
        
    // Get the corresponding JWK using key Id from the JWK set 
    @SuppressWarnings("unchecked")
    static private Map<String, String> GetKeyById(Map<String, Object> jwks, String kid) {
        List<Map<String, String>> keys = (List<Map<String, String>>)jwks.get("keys");
        Map<String, String> ret = null;
        for (int i = 0; i < keys.size(); i++) {
            if (keys.get(i).get("kid").equals(kid)) {
                System.out.println("i-->"+ keys.get(i).get("kid"));
                System.out.println("i set-->"+ keys.get(i));
                return keys.get(i);
            }
        }    
        return ret;
    }
   
    // Get the JWK Set from the JWK endpoint 
    private static void GetJWK(String jwkUrl, String kid) throws IOException {
        URL url = new URL(jwkUrl);
        HttpURLConnection connection = (HttpURLConnection) url.openConnection();

        try {
            //URL url = new URL(jwkUrl);
            System.out.println("url: "+url);
            //connection = (HttpURLConnection) url.openConnection();
            System.out.println("connection: "+connection);


            connection.setRequestMethod("GET");
            connection.setRequestProperty("User-Agent", USER_AGENT);
            int responseCode = connection.getResponseCode();
            System.out.println("GET Response Code :: " + responseCode);

            BufferedReader rd = new BufferedReader(new InputStreamReader(connection.getInputStream()));

            System.out.println("rd: "+rd);
            //StringBuilder response = new StringBuilder();
            String inputLine;
            StringBuffer response = new StringBuffer();
            while ((inputLine = rd.readLine()) != null) {
                response.append(inputLine);
            }
            //in.close();

            // print result
            System.out.println(response.toString());

            System.out.println("response:--> "+response);
            String line;
            //while ((line = rd.readLine()) != null) {
            //  response.append(line);
            //  response.append('\r');

            //}
            rd.close();
            
            // Jackson mapper for parsing the json
            TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {};
            ObjectMapper mapper = new ObjectMapper();
            Map<String, Object> jwks = mapper.readValue(response.toString(), typeRef);
            
            // Get the jwk by using the key Id from the jwt
            Map<String, String> jwk = GetKeyById(jwks, kid);
            
            // Get the modulus 'n' & the exponent 'n' from the JWK & add it to cache 
            if (jwk != null) {
                
                _cache.put("modulus", jwk.get("x5c"));
                    
                _modulusCache = _cache.get("modulus");
                
                _cache.put("exponent", jwk.get("e"));
                
                _exponentCache = _cache.get("exponent");
                
                _cache.put("jwk_version", jwk.get("ver"));
                
                _jwkVersionCache = _cache.get("jwk_version");
                
            }            
        } catch (Exception e) {
            // Unable to fetch JWKS. Terminate this program
            System.out.println("Error getting jwks: " + e);             
        } finally {
            if (connection != null) {
                connection.disconnect();
            }
        }
    }
}

RFC says: https://www.rfc-editor.org/rfc/rfc7518#section-3.1

"alg" Param value = ES256 Digital Signature or MAC value = ECDSA using P-256 and SHA-256


import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.math.BigInteger;

import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.security.*;
import java.security.spec.*;
import java.text.ParseException;
import java.util.Base64;
import java.util.Base64.Decoder;
import java.util.Base64.Encoder;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;

import com.nimbusds.jose.JOSEException;
import org.cache2k.Cache;
import org.cache2k.Cache2kBuilder;

import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.Jws;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.Jwt;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.SigningKeyResolverAdapter;
import java.net.HttpURLConnection;


public class Es256stack {

    final static String jwksHost = "";
    
    private static final String USER_AGENT = "Mozilla/5.0";
    // Create a cache object
    final static Cache<String,String> _cache = new Cache2kBuilder<String, String>() {}
            .expireAfterWrite(30, TimeUnit.MINUTES)    // expire/refresh after 30 minutes
            .build();
    static String _jwkVersionCache = _cache.peek("jwk_version");
    static String _modulusCache = _cache.peek("modulus");
    static String _exponentCache = _cache.peek("exponent");



    public static void main(String[] args) throws NoSuchAlgorithmException, InvalidParameterSpecException, InvalidKeySpecException {
        ValidateJWS(signedJwtToken);
    }

        //String signedJwtToken="";

        // Code to validate signed JWT (JWS)
        private static void ValidateJWS(String signedJwtToken)
        {
            StringBuilder sb = null;
            String jwtWithoutSignature;
            String jwtVersion;
            String jwksUri;
            String jwksUrl;
            String kid;
            TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {};
            ObjectMapper mapper = new ObjectMapper();
            Map<String, Object> jwks = null;

            @SuppressWarnings("rawtypes")
            Jwt<Header, Claims> jwtClaims = null;
            try {

                // Extract the base64 encoded JWT from the signed JWT token (JWS)
                sb = new StringBuilder();
                sb.append(signedJwtToken);
                jwtWithoutSignature = sb.substring(0, sb.toString().lastIndexOf(".") + 1);

                // Parse claims without validating the signature
                jwtClaims = Jwts.parser().parseClaimsJwt(jwtWithoutSignature);

                // Extract the jwk uri 'jku' & the version 'ver' from the JWT
                jwtVersion = (String) jwtClaims.getBody().get("ver");
                jwksUri = (String) jwtClaims.getBody().get("jku");
                // Extract the kid from JWT
                kid = (String) jwtClaims.getHeader().get("kid");

                jwksUrl = "";
                System.out.println("jwtVersion: " + jwtVersion);
                System.out.println("jwksUri: " + jwksUri);
                System.out.println("kid: " + kid);

                // Cache the jwk version (ver), modulus (n) and exponent (e) for lifetime of the application.
                // The JWT version will be same as jwk version. The jwt version will change only when the
                // JWT signing certificate is renewed.
                // Invoke the JWK url only if the jwt version is different from the JWK version.

                // check if the JWK version is cached or not
                if (_cache.get("jwk_version") != null) {
                    // check if jwt version is same as jwk version
                    if (!jwtVersion.equals(_jwkVersionCache)) {
                        // Get the jwk key & add the modulus, exponent & the jwk version to the cache
                        GetJWK(jwksUrl, kid);
                    }
                }
                else
                {
                    // Get the jwk key & add the modulus, exponent & the jwk version to the cache
                    GetJWK(jwksUrl, kid);
                }

                // Calling the setSigningKeyResolver as the JWT is parsed before validating the signature
                SigningKeyResolver resolver = new SigningKeyResolverAdapter() {
                    @SuppressWarnings("rawtypes")
                    public Key resolveSigningKey(JwsHeader jwsHeader, Claims claims) {
                        try {
                            AlgorithmParameters a = AlgorithmParameters.getInstance("EC");
                            a.init(new ECGenParameterSpec("secp256r1"));
                            ECParameterSpec p = a.getParameterSpec(ECParameterSpec.class);
                            // Build the RSA public key from modulus & exponent in JWK
                            BigInteger x = new BigInteger(1, Base64.getDecoder().decode(x)); // either direct or cached
                            BigInteger y = new BigInteger(1, Base64.getDecoder().decode(y)); // ditto
                            PublicKey ecPublicKey = KeyFactory.getInstance("EC").generatePublic(new ECPublicKeySpec(new ECPoint(x,y), p));
                            return ecPublicKey;
                        } catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
                            System.out.println("Failed to resolve key: " + e);
                            return null;
                        } catch (InvalidParameterSpecException e) {
                            throw new RuntimeException(e);
                        }
                    }
                };

                try {
                    // Parse claims and validate the signature
                    Jws<Claims> jwsClaims = Jwts.parser().setSigningKeyResolver(resolver).parseClaimsJws(signedJwtToken);
                    System.out.println("Signature on this JWT is good and the JWT token has not expired");
                    // OK, we can trust this JWT

                    // Parse the claims
                    System.out.println("JWS claims: " + jwsClaims.getBody());

                    // Code below to validate the claims

                }
                catch (Exception ex) {
                    System.out.println("Unable to validate JWS");
                }

            }
            // catch (SignatureException e)
            catch (Exception e) {
                // don't trust the JWT!
                System.out.println("JWT is malformed or expired");
            }
        }

        // Get the JWK Set from the JWK endpoint
        private static void GetJWK(String jwkUrl, String kid) throws IOException {
            URL url = new URL(jwkUrl);
            HttpURLConnection connection = (HttpURLConnection) url.openConnection();

            try {
                //URL url = new URL(jwkUrl);
                System.out.println("url: "+url);
                //connection = (HttpURLConnection) url.openConnection();
                System.out.println("connection: "+connection);


                connection.setRequestMethod("GET");
                connection.setRequestProperty("User-Agent", USER_AGENT);
                int responseCode = connection.getResponseCode();
                System.out.println("GET Response Code :: " + responseCode);

                BufferedReader rd = new BufferedReader(new InputStreamReader(connection.getInputStream()));

                System.out.println("rd: "+rd);
                //StringBuilder response = new StringBuilder();
                String inputLine;
                StringBuffer response = new StringBuffer();
                while ((inputLine = rd.readLine()) != null) {
                    response.append(inputLine);
                }
                //in.close();

                // print result
                System.out.println(response.toString());

                System.out.println("response:--> "+response);
                String line;
                //while ((line = rd.readLine()) != null) {
                //  response.append(line);
                //  response.append('\r');

                //}
                rd.close();

                // Jackson mapper for parsing the json
                TypeReference<Map<String, Object>> typeRef = new TypeReference<Map<String, Object>>() {};
                ObjectMapper mapper = new ObjectMapper();
                Map<String, Object> jwks = mapper.readValue(response.toString(), typeRef);
                System.out.println("jwks:--> "+jwks);
                // Get the jwk by using the key Id from the jwt
                Map<String, String> jwk = GetKeyById(jwks, kid);
                System.out.println("jwk:--> "+jwk);
                // Get the modulus 'n' & the exponent 'n' from the JWK & add it to cache
                if (jwk != null) {
                   
                }
            } catch (Exception e) {
                // Unable to fetch JWKS. Terminate this program
                System.out.println("Error getting jwks: " + e);
            } finally {
                if (connection != null) {
                    connection.disconnect();
                }
            }
        }


    static private Map<String, String> GetKeyById(Map<String, Object> jwks, String kid) {
        List<Map<String, String>> keys = (List<Map<String, String>>)jwks.get("keys");
        Map<String, String> ret = null;
        for (int i = 0; i < keys.size(); i++) {
            if (keys.get(i).get("kid").equals(kid)) {
                System.out.println("i-->"+ keys.get(i).get("kid"));
                System.out.println("i set-->"+ keys.get(i));
                return keys.get(i);
            }
        }
        return ret;
    }







}

JWT signature does not match locally computed signature. JWT validity cannot be asserted and should not be trusted.

enter image description here

cert part worked as below:

String x5c="";
                        System.out.println(" x5c ="+x5c);
                        String stripped = x5c.replaceAll("-----BEGIN (.*)-----", "");
                        stripped = stripped.replaceAll("-----END (.*)----", "");
                        stripped = stripped.replaceAll("\r\n", "");
                        stripped = stripped.replaceAll("\n", "");
                        stripped.trim();
                        System.out.println(" stripped ="+stripped);
                        byte[] keyBytes = com.sun.org.apache.xerces.internal.impl.dv.util.Base64.decode(stripped);
                        CertificateFactory fact = CertificateFactory.getInstance("X.509");
                        X509Certificate cer = (X509Certificate) fact.generateCertificate(new ByteArrayInputStream(keyBytes));
                        System.out.println(cer);
                        return cer.getPublicKey();

Upvotes: 1

Views: 3603

Answers (1)

dave_thompson_085
dave_thompson_085

Reputation: 39029

First, your code has a bug or is miscopied. In GetJWK in the last block (before catch) you have a comment Get the modulus 'n' & the exponent 'n' which is wrong (the public exponent is 'e') but the code shown actually gets 'x5c' not 'n' and uses it as the modulus, which is very wrong, and shouldn't even work because 'x5c' is an array not a scalar.

Yes, the library you are using (jjwt) can verify (and generate) ECDSA signatures in JWS/JWT. For any code, the signature is generated using the algorithm-dependent elements in the private key, and verified using the algorithm-dependent elements in the public key: n and e for RSA, x and y for ECDSA on a curve which is both implied by alg and restated in crv -- see part of rfc7518 section 3.4 and rfc7518 section 6.2.1. Note x and y must be valid base64url (yours aren't) and must be exactly the length required by the size of the curve-group defined by alg and crv (yours aren't).

You can construct a Java-crypto ECPublicKey (or pedantically a provider's implementation object implementing that) similar to what you do in resolveSigningKey now for RSA, except that EC requires 'parameters' for the curve in addition to x and y:

// this part is the same for all keys and could be done at init or memoized
AlgorithmParameters a = AlgorithmParameters.getInstance("EC");
a.init(new ECGenParameterSpec("P-256"));
ECParameterSpec p = a.getParameterSpec(ECParameterSpec.class);

// this part must be redone for each different key
// to prevent misuse verify crv_field (either direct or cached) equals("P-256")
// and probably alg_field (ditto) equals("ES256")
BigInteger x = new BigInteger(1, base64urldecode(x_field)); // either direct or cached
BigInteger y = new BigInteger(1, base64urldecode(y_field)); // ditto 
PublicKey ecPublicKey = KeyFactory.getInstance("EC").generatePublic(new ECPublicKeySpec(new ECPoint(x,y), p));

// add exception handling to taste

However, if your JWKs (always) have x5c as in your examples, but with a valid value (yours aren't valid base64url-of-DER and are much too small) you can use much simpler code; for all signature algorithms just do:

String x5c = // get element 0 of field 'x5c' from JWK (cached if you like)
X509Certificate cert = (X509Certificate) CertificateFactory.getInstance("X.509")
    .generateCertificate( new ByteArrayInputStream( base64decode(x5c) ) ); // NOT base64url
// (with exception handling of course)
// then use cert.getPublicKey() as the PublicKey for jjwt

Upvotes: 2

Related Questions