KyelJmD
KyelJmD

Reputation: 4732

How to verify the Signature of a JWT generated by AWS Cognito in Python 3.6?

Here's my script

import urllib.request
import json
import time
from jose import jwk, jwt
from jose.utils import base64url_decode
import base64


region = '....'
userpool_id = '.....'
app_client_id = '...' 
keys_url = 'https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json'.format(region, userpool_id)

response = urllib.request.urlopen(keys_url)
keys = json.loads(response.read())['keys']
token = request.headers['Authorization']
print(token)
# get the kid from the headers prior to verification
headers = jwt.get_unverified_headers(request.headers['Authorization'])
kid = headers['kid']

print(kid)
# search for the kid in the downloaded public keys
key_index = -1
for i in range(len(keys)):
    if kid == keys[i]['kid']:
        key_index = i
        break
if key_index == -1:
    print('Public key not found in jwks.json')
    return False
# construct the public key
public_key = jwk.construct(keys[key_index])
# get the last two sections of the token,
# message and signature (encoded in base64)
message, encoded_signature = str(token).rsplit('.', 1)
# decode the 
print('>>encoded signature')
print(encoded_signature)
decoded_signature = base64.b64decode(encoded_signature)
if not public_key.verify(message, decoded_signature):
    print('Signature verification failed')
    return False
print('Signature successfully verified')

I am always ending up Signature verification failed even though jwt token is generated by a valid legitimate cognito user pool. I've looked at the documentation and it does not really specify the whole verification process.

Upvotes: 5

Views: 8877

Answers (2)

mousazeidbaker
mousazeidbaker

Reputation: 31

Following class verifies Cognito tokens. You are required to install jose and pydantic.

The implementation is derived from this repo, it contains more details, addiotional functionalitites, tests etc.

import json
import logging
import os
import time
import urllib.request
from typing import Dict, List

from jose import jwk, jwt
from jose.utils import base64url_decode
from pydantic import BaseModel


class JWK(BaseModel):
    """A JSON Web Key (JWK) model that represents a cryptographic key.

    The JWK specification:
    https://datatracker.ietf.org/doc/html/rfc7517
    """

    alg: str
    e: str
    kid: str
    kty: str
    n: str
    use: str


class CognitoAuthenticator:
    def __init__(self, pool_region: str, pool_id: str, client_id: str) -> None:
        self.pool_region = pool_region
        self.pool_id = pool_id
        self.client_id = client_id
        self.issuer = f"https://cognito-idp.{self.pool_region}.amazonaws.com/{self.pool_id}"
        self.jwks = self.__get_jwks()

    def __get_jwks(self) -> List[JWK]:
        """Returns a list of JSON Web Keys (JWKs) from the issuer. A JWK is a
        public key used to verify a JSON Web Token (JWT).

        Returns:
            List of keys
        Raises:
            Exception when JWKS endpoint does not contain any keys
        """

        file = urllib.request.urlopen(f"{self.issuer}/.well-known/jwks.json")
        res = json.loads(file.read().decode("utf-8"))
        if not res.get("keys"):
            raise Exception("The JWKS endpoint does not contain any keys")
        jwks = [JWK(**key) for key in res["keys"]]
        return jwks

    def verify_token(
        self,
        token: str,
    ) -> bool:
        """Verify a JSON Web Token (JWT).

        For more details refer to:
        https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html

        Args:
            token: The token to verify
        Returns:
            True if valid, False otherwise
        """

        try:
            self._is_jwt(token)
            self._get_verified_header(token)
            self._get_verified_claims(token)
        except CognitoError:
            return False
        return True

    def _is_jwt(self, token: str) -> bool:
        """Validate a JSON Web Token (JWT).
        A JSON Web Token (JWT) includes three sections: Header, Payload and
        Signature. They are base64url encoded and are separated by dot (.)
        characters. If JWT token does not conform to this structure, it is
        considered invalid.

        Args:
            token: The token to validate
        Returns:
            True if valid
        Raises:
            CognitoError when invalid token
        """

        try:
            jwt.get_unverified_header(token)
            jwt.get_unverified_claims(token)
        except jwt.JWTError:
            logging.info("Invalid JWT")
            raise InvalidJWTError
        return True

    def _get_verified_header(self, token: str) -> Dict:
        """Verifies the signature of a a JSON Web Token (JWT) and returns its
        decoded header.

        Args:
            token: The token to decode header from
        Returns:
            A dict representation of the token header
        Raises:
            CognitoError when unable to verify signature
        """

        # extract key ID (kid) from token
        headers = jwt.get_unverified_header(token)
        kid = headers["kid"]

        # find JSON Web Key (JWK) that matches kid from token
        key = None
        for k in self.jwks:
            if k.kid == kid:
                # construct a key object from found key data
                key = jwk.construct(k.dict())
                break
        if not key:
            logging.info(f"Unable to find a signing key that matches '{kid}'")
            raise InvalidKidError

        # get message and signature (base64 encoded)
        message, encoded_signature = str(token).rsplit(".", 1)
        signature = base64url_decode(encoded_signature.encode("utf-8"))

        if not key.verify(message.encode("utf8"), signature):
            logging.info("Signature verification failed")
            raise SignatureError

        # signature successfully verified
        return headers

    def _get_verified_claims(self, token: str) -> Dict:
        """Verifies the claims of a JSON Web Token (JWT) and returns its claims.

        Args:
            token: The token to decode claims from
        Returns:
            A dict representation of the token claims
        Raises:
            CognitoError when unable to verify claims
        """

        claims = jwt.get_unverified_claims(token)

        # verify expiration time
        if claims["exp"] < time.time():
            logging.info("Expired token")
            raise TokenExpiredError

        # verify issuer
        if claims["iss"] != self.issuer:
            logging.info("Invalid issuer claim")
            raise InvalidIssuerError

        # verify audience
        # note: claims["client_id"] for access token, claims["aud"] otherwise
        if claims["client_id"] != self.client_id:
            logging.info("Invalid audience claim")
            raise InvalidAudienceError

        # verify token use
        if claims["token_use"] != "access":
            logging.info("Invalid token use claim")
            raise InvalidTokenUseError

        # claims successfully verified
        return claims


class CognitoError(Exception):
    pass


class InvalidJWTError(CognitoError):
    pass


class InvalidKidError(CognitoError):
    pass


class SignatureError(CognitoError):
    pass


class TokenExpiredError(CognitoError):
    pass


class InvalidIssuerError(CognitoError):
    pass


class InvalidAudienceError(CognitoError):
    pass


class InvalidTokenUseError(CognitoError):
    pass


if __name__ == "__main__":
    auth = CognitoAuthenticator(
        pool_region=os.environ["AWS_COGNITO_REGION"],
        pool_id=os.environ["AWS_USER_POOL_ID"],
        client_id=os.environ["AWS_USER_POOL_CLIENT_ID"],
    )

    # note: if you are not using access token, see line 161
    access_token = "my_access_token"
    print(f"Token verified: {auth.verify_token(access_token)}")

Upvotes: 2

jcvalverde
jcvalverde

Reputation: 398

I see you're using jose, and I'm using pyjwt, but this solution might help you. Most of the bulk code from the bottom comes from the "api-gateway-authorizer-python" blueprint. Note that this is very frail code that will just break if anything is fails, I ended up not using lambda authentication but rather selecting AWS_IAM authentication for my API Gateway with Identity Pools so I never finished it.

This example requires that you install pyjwt and cryptography with pip on your work directory and upload everything as a .zip file.

I'd recommend that you watch this video if you want to consider the AWS_IAM authentication option: https://www.youtube.com/watch?v=VZqG7HjT2AQ

They also have a solution with a more elaborate lambda authorizer implementation in github at: https://github.com/awslabs/aws-serverless-auth-reference-app (they show the link at the beggining of the video) but I don't know about their pip dependencies.

from __future__ import print_function
from jwt.algorithms import RSAAlgorithm

import re
import jwt
import json
import sys
import urllib

region = 'your-region'
userpoolId = 'your-user-pool-id'
appClientId = 'your-app-client-id' 
keysUrl = 'https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json'.format(region, userpoolId)

def lambda_handler(event, context):
    bearerToken = event['authorizationToken']
    methodArn = event['methodArn']
    print("Client token: " + bearerToken)
    print("Method ARN: " + methodArn)

    response = urllib.urlopen(keysUrl)
    keys = json.loads(response.read())['keys']

    jwtToken = bearerToken.split(' ')[-1]
    header = jwt.get_unverified_header(jwtToken)
    kid = header['kid']

    jwkValue = findJwkValue(keys, kid)
    publicKey = RSAAlgorithm.from_jwk(json.dumps(jwkValue))

    decoded = decodeJwtToken(jwtToken, publicKey)
    print('Decoded token: ' + json.dumps(decoded))
    principalId = decoded['cognito:username']

    methodArn = event['methodArn'].split(':')
    apiGatewayArnTmp = methodArn[5].split('/')
    awsAccountId = methodArn[4]

    policy = AuthPolicy(principalId, awsAccountId)
    policy.restApiId = apiGatewayArnTmp[0]
    policy.region = methodArn[3]
    policy.stage = apiGatewayArnTmp[1]
    #policy.denyAllMethods()
    policy.allowAllMethods()

    # Finally, build the policy
    authResponse = policy.build()

    # new! -- add additional key-value pairs associated with the authenticated principal
    # these are made available by APIGW like so: $context.authorizer.<key>
    # additional context is cached
    context = {
        'key': 'value',  # $context.authorizer.key -> value
        'number': 1,
        'bool': True
    }
    # context['arr'] = ['foo'] <- this is invalid, APIGW will not accept it
    # context['obj'] = {'foo':'bar'} <- also invalid

    authResponse['context'] = context

    return authResponse

def findJwkValue(keys, kid):
    for key in keys:
        if key['kid'] == kid:
            return key

def decodeJwtToken(token, publicKey):
    try:
        decoded=jwt.decode(token, publicKey, algorithms=['RS256'], audience=appClientId)
        return decoded
    except Exception as e:
        print(e)
        raise

class HttpVerb:
    GET = 'GET'
    POST = 'POST'
    PUT = 'PUT'
    PATCH = 'PATCH'
    HEAD = 'HEAD'
    DELETE = 'DELETE'
    OPTIONS = 'OPTIONS'
    ALL = '*'


class AuthPolicy(object):
    # The AWS account id the policy will be generated for. This is used to create the method ARNs.
    awsAccountId = ''
    # The principal used for the policy, this should be a unique identifier for the end user.
    principalId = ''
    # The policy version used for the evaluation. This should always be '2012-10-17'
    version = '2012-10-17'
    # The regular expression used to validate resource paths for the policy
    pathRegex = '^[/.a-zA-Z0-9-\*]+$'

    '''Internal lists of allowed and denied methods.

    These are lists of objects and each object has 2 properties: A resource
    ARN and a nullable conditions statement. The build method processes these
    lists and generates the approriate statements for the final policy.
    '''
    allowMethods = []
    denyMethods = []

    # The API Gateway API id. By default this is set to '*'
    restApiId = '*'
    # The region where the API is deployed. By default this is set to '*'
    region = '*'
    # The name of the stage used in the policy. By default this is set to '*'
    stage = '*'

    def __init__(self, principal, awsAccountId):
        self.awsAccountId = awsAccountId
        self.principalId = principal
        self.allowMethods = []
        self.denyMethods = []

    def _addMethod(self, effect, verb, resource, conditions):
        '''Adds a method to the internal lists of allowed or denied methods. Each object in
        the internal list contains a resource ARN and a condition statement. The condition
        statement can be null.'''
        if verb != '*' and not hasattr(HttpVerb, verb):
            raise NameError('Invalid HTTP verb ' + verb + '. Allowed verbs in HttpVerb class')
        resourcePattern = re.compile(self.pathRegex)
        if not resourcePattern.match(resource):
            raise NameError('Invalid resource path: ' + resource + '. Path should match ' + self.pathRegex)

        if resource[:1] == '/':
            resource = resource[1:]

        resourceArn = 'arn:aws:execute-api:{}:{}:{}/{}/{}/{}'.format(self.region, self.awsAccountId, self.restApiId, self.stage, verb, resource)

        if effect.lower() == 'allow':
            self.allowMethods.append({
                'resourceArn': resourceArn,
                'conditions': conditions
            })
        elif effect.lower() == 'deny':
            self.denyMethods.append({
                'resourceArn': resourceArn,
                'conditions': conditions
            })

    def _getEmptyStatement(self, effect):
        '''Returns an empty statement object prepopulated with the correct action and the
        desired effect.'''
        statement = {
            'Action': 'execute-api:Invoke',
            'Effect': effect[:1].upper() + effect[1:].lower(),
            'Resource': []
        }

        return statement

    def _getStatementForEffect(self, effect, methods):
        '''This function loops over an array of objects containing a resourceArn and
        conditions statement and generates the array of statements for the policy.'''
        statements = []

        if len(methods) > 0:
            statement = self._getEmptyStatement(effect)

            for curMethod in methods:
                if curMethod['conditions'] is None or len(curMethod['conditions']) == 0:
                    statement['Resource'].append(curMethod['resourceArn'])
                else:
                    conditionalStatement = self._getEmptyStatement(effect)
                    conditionalStatement['Resource'].append(curMethod['resourceArn'])
                    conditionalStatement['Condition'] = curMethod['conditions']
                    statements.append(conditionalStatement)

            if statement['Resource']:
                statements.append(statement)

        return statements

    def allowAllMethods(self):
        '''Adds a '*' allow to the policy to authorize access to all methods of an API'''
        self._addMethod('Allow', HttpVerb.ALL, '*', [])

    def denyAllMethods(self):
        '''Adds a '*' allow to the policy to deny access to all methods of an API'''
        self._addMethod('Deny', HttpVerb.ALL, '*', [])

    def allowMethod(self, verb, resource):
        '''Adds an API Gateway method (Http verb + Resource path) to the list of allowed
        methods for the policy'''
        self._addMethod('Allow', verb, resource, [])

    def denyMethod(self, verb, resource):
        '''Adds an API Gateway method (Http verb + Resource path) to the list of denied
        methods for the policy'''
        self._addMethod('Deny', verb, resource, [])

    def allowMethodWithConditions(self, verb, resource, conditions):
        '''Adds an API Gateway method (Http verb + Resource path) to the list of allowed
        methods and includes a condition for the policy statement. More on AWS policy
        conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition'''
        self._addMethod('Allow', verb, resource, conditions)

    def denyMethodWithConditions(self, verb, resource, conditions):
        '''Adds an API Gateway method (Http verb + Resource path) to the list of denied
        methods and includes a condition for the policy statement. More on AWS policy
        conditions here: http://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements.html#Condition'''
        self._addMethod('Deny', verb, resource, conditions)

    def build(self):
        '''Generates the policy document based on the internal lists of allowed and denied
        conditions. This will generate a policy with two main statements for the effect:
        one statement for Allow and one statement for Deny.
        Methods that includes conditions will have their own statement in the policy.'''
        if ((self.allowMethods is None or len(self.allowMethods) == 0) and
                (self.denyMethods is None or len(self.denyMethods) == 0)):
            raise NameError('No statements defined for the policy')

        policy = {
            'principalId': self.principalId,
            'policyDocument': {
                'Version': self.version,
                'Statement': []
            }
        }

        policy['policyDocument']['Statement'].extend(self._getStatementForEffect('Allow', self.allowMethods))
        policy['policyDocument']['Statement'].extend(self._getStatementForEffect('Deny', self.denyMethods))

        return policy

Upvotes: 13

Related Questions