---
AWSTemplateFormatVersion: '2010-09-09'
Transform: 'AWS::Serverless-2016-10-31'
Description: ALB with Cognito authentication and Lambda backend

Metadata:
  AWS::CloudFormation::Interface:
    ParameterGroups:
      -
        Label:
          default: "ALB"
        Parameters:
          - VPCID
          - Subnets
          - ALBScheme
          - SSLCert
          - R53Zone
      -
        Label: 
          default: "Cognito"
        Parameters: 
          - CognitoUserpoolArn
          - CognitoUserpoolDomain
      -
        Label:
          default: "Lambda"
        Parameters:
          - LambdaLayerArn

Parameters:

  VPCID:
    Type: AWS::EC2::VPC::Id
    Description: VPCID

  Subnets:
    Type: List<AWS::EC2::Subnet::Id>
    Description: Subnets (min 2)

  ALBScheme:
    Type: String
    Description: ALB Scheme
    Default: internet-facing
    AllowedValues: 
      - internal
      - internet-facing

  R53Zone:
    Type: String
    Description: R53 Domain for ALB DNS entry
    Default: domain.without.dot.in.the.end.com

  SSLCert:
    Type: String
    Description: ALB SSL Certificate ARN
    Default: arn:aws:acm:eu-west-1:123412341234:certificate/12345678-abcd-efgh-1234-123456789012

  CognitoUserpoolArn:
    Type: String
    Description: Cognito Userpool ARN
    Default: arn:aws:cognito-idp:eu-west-1:123412341234:userpool/eu-west-1_123412341

  CognitoUserpoolDomain:
    Type: String
    Description: Cognito Userpool Domain
    Default: mycognito.auth.eu-west-1.amazoncognito.com
  
  LambdaLayerArn:
    Type: String
    Description: Lambda layer for Python libraries
    Default: arn:aws:lambda:eu-west-1:017001121740:layer:pyjwt:1

Outputs:

  URL:
    Description: ALB URL
    Value: !Sub 'https://${AWS::StackName}.${R53Zone}/'

Resources:

  R53Record:
    Type: AWS::Route53::RecordSet
    Properties:
      Type: A
      AliasTarget:
        HostedZoneId: !GetAtt ALB.CanonicalHostedZoneID
        DNSName: !GetAtt ALB.DNSName
      Comment: !Sub '${AWS::StackName} ALB alias'
      HostedZoneName: !Sub '${R53Zone}.'
      Name: !Sub '${AWS::StackName}.${R53Zone}.'

  ALB:
    Type: AWS::ElasticLoadBalancingV2::LoadBalancer
    Properties:
      Type: application
      Scheme: !Ref ALBScheme
      Subnets: !Ref Subnets
      SecurityGroups: 
        - !Ref ALBSecGroup

  ALBSecGroup:
    Type: AWS::EC2::SecurityGroup
    Properties: 
      GroupDescription: !Sub "${AWS::StackName} ALB, allow port 443 from 0.0.0.0/0"
      SecurityGroupIngress: 
        - CidrIp: 0.0.0.0/0
          FromPort: 443
          ToPort: 443
          IpProtocol: tcp
      VpcId: !Ref VPCID

  CognitoAppClient:
    Type: AWS::Cognito::UserPoolClient
    Properties: 
      ClientName: !Sub '${AWS::StackName} ALB'
      ExplicitAuthFlows: 
        - USER_PASSWORD_AUTH
      GenerateSecret: True
      RefreshTokenValidity: 1
      UserPoolId: !Select [1, !Split [ 'userpool/', !Ref CognitoUserpoolArn ]]
      ReadAttributes: 
        - email
        - email_verified
      AllowedOAuthFlows: 
        - code
      AllowedOAuthScopes: 
        - openid
      SupportedIdentityProviders: 
        - COGNITO
      AllowedOAuthFlowsUserPoolClient: True
      CallbackURLs: 
        - !Sub https://${AWS::StackName}.${R53Zone}/oauth2/idpresponse

  Listener:
    Type: AWS::ElasticLoadBalancingV2::Listener
    DependsOn: CognitoAppClient
    Properties: 
      DefaultActions: 
        - Type: authenticate-cognito
          Order: 1
          AuthenticateCognitoConfig: 
            OnUnauthenticatedRequest: authenticate
            Scope: openid
            SessionCookieName: AWSELBAuthSessionCookie
            SessionTimeout: 60
            UserPoolArn: !Ref CognitoUserpoolArn
            UserPoolDomain: !Ref CognitoUserpoolDomain
            UserPoolClientId: !Ref CognitoAppClient
        - Type: forward
          Order: 2
          TargetGroupArn: !Ref LambdaTargetGroup
      LoadBalancerArn: !Ref ALB
      Port: 443
      Protocol: HTTPS
      Certificates: 
        - CertificateArn: !Ref SSLCert

  LambdaTargetGroup:
    Type: AWS::ElasticLoadBalancingV2::TargetGroup
    DependsOn: LambdaExecPermission
    Properties: 
      HealthCheckEnabled: False
      Name: !Ref AWS::StackName
      TargetType: lambda
      Targets: 
        - Id: !GetAtt JWTverifier.Arn

  LambdaExecPermission:
    Type: AWS::Lambda::Permission
    Properties: 
      Action: lambda:InvokeFunction
      FunctionName: !GetAtt JWTverifier.Arn
      Principal: elasticloadbalancing.amazonaws.com

  JWTverifier:
    Type: 'AWS::Serverless::Function'
    Properties:
      Runtime: python3.7
      Timeout: 30
      Handler: index.lambda_handler

      Description: Decode and verify JWT
      MemorySize: 128
      Layers:
        - !Ref LambdaLayerArn
      InlineCode: |
        import requests
        import base64
        import json
        import jwt

        # NOTE: This requires following Lambda layer ...
        # pip3 install PyJWT -t python
        # pip3 install requests -t python
        # pip3 install cryptography -t python --upgrade
        # aws --region eu-west-1 lambda publish-layer-version --layer-name pyjwt --description "PyJWT for Python 3.7" --zip-file fileb://jwt.zip --compatible-runtimes python3.7
 
        def lambda_handler(event, context):

            # Step 1: Get the key id from JWT headers (the kid field)
            encoded_jwt = event['headers']['x-amzn-oidc-data']
            jwt_fields = encoded_jwt.split('.')
            
            jwt_headers = jwt.get_unverified_header(encoded_jwt)
            jwt_sig = jwt_fields[2]
            
            # NOTE: Payload is in base64 clear-text, but you should use jwt.decode() to verify the signature!  
            # jwt_payload = json.loads(base64.b64decode(jwt_fields[1]))
            
            # Step 2: Get the public key from regional endpoint
            aws_region = context.invoked_function_arn.split(':')[3]
            url = 'https://public-keys.auth.elb.' + aws_region + '.amazonaws.com/' + jwt_headers['kid']
            pub_key = requests.get(url).text

            # Step 3: Get the payload
            jwt_payload = jwt.decode(encoded_jwt, pub_key, algorithms=[jwt_headers['alg']])
            
            response = {
                "statusCode": 200,
                "statusDescription": "200 OK",
                "isBase64Encoded": False,
                "headers": {
                    "Content-Type": "text/html; charset=utf-8"
                }
            }
            
            response['body'] = ( "<html><head><title>Hello JWT</title></head><body><pre>"
                + 'Encoded JWT:\n' + encoded_jwt + '\n\n'
                + 'JWT Headers:\n' + json.dumps(jwt_headers, indent=2, sort_keys=True) + '\n\n'
                + 'JWT Payload:\n' + json.dumps(jwt_payload, indent=2, sort_keys=True) + '\n\n'
                + 'JWT Signature:\n' + jwt_sig + '\n'
                + "</pre></body></html>" )

            return response
