---
AWSTemplateFormatVersion: '2010-09-09'
Transform: 'AWS::Serverless-2016-10-31'
Description: ALB with Cognito authentication and Lambda backend

Metadata:
  AWS::CloudFormation::Interface:
    ParameterGroups:
      -
        Label:
          default: "ALB"
        Parameters:
          - VPCID
          - Subnets
          - ALBScheme
          - SSLCert
          - R53Zone
      -
        Label: 
          default: "Cognito"
        Parameters: 
          - CognitoUserpoolArn
          - CognitoUserpoolDomain
      -
        Label:
          default: "Lambda"
        Parameters:
          - LambdaLayerArn

Parameters:

  VPCID:
    Type: AWS::EC2::VPC::Id
    Description: VPCID

  Subnets:
    Type: List<AWS::EC2::Subnet::Id>
    Description: Subnets (min 2)

  ALBScheme:
    Type: String
    Description: ALB Scheme
    Default: internet-facing
    AllowedValues: 
      - internal
      - internet-facing

  R53Zone:
    Type: String
    Description: R53 Domain for ALB DNS entry
    Default: domain.without.dot.in.the.end.com

  SSLCert:
    Type: String
    Description: ALB SSL Certificate ARN
    Default: arn:aws:acm:eu-west-1:123412341234:certificate/12345678-abcd-efgh-1234-123456789012

  CognitoUserpoolArn:
    Type: String
    Description: Cognito Userpool ARN
    Default: arn:aws:cognito-idp:eu-west-1:123412341234:userpool/eu-west-1_123412341

  CognitoUserpoolDomain:
    Type: String
    Description: Cognito Userpool Domain
    Default: mycognito.auth.eu-west-1.amazoncognito.com
  
  LambdaLayerArn:
    Type: String
    Description: Lambda layer for Python libraries
    Default: arn:aws:lambda:eu-west-1:017001121740:layer:pyjwt:1

Outputs:

  URL:
    Description: ALB URL
    Value: !Sub 'https://${AWS::StackName}.${R53Zone}/'

Resources:

  R53Record:
    Type: AWS::Route53::RecordSet
    Properties:
      Type: A
      AliasTarget:
        HostedZoneId: !GetAtt ALB.CanonicalHostedZoneID
        DNSName: !GetAtt ALB.DNSName
      Comment: !Sub '${AWS::StackName} ALB alias'
      HostedZoneName: !Sub '${R53Zone}.'
      Name: !Sub '${AWS::StackName}.${R53Zone}.'

  ALB:
    Type: AWS::ElasticLoadBalancingV2::LoadBalancer
    Properties:
      Type: application
      Scheme: !Ref ALBScheme
      Subnets: !Ref Subnets
      SecurityGroups: 
        - !Ref ALBSecGroup

  ALBSecGroup:
    Type: AWS::EC2::SecurityGroup
    Properties: 
      GroupDescription: !Sub "${AWS::StackName} ALB, allow port 443 from 0.0.0.0/0"
      SecurityGroupIngress: 
        - CidrIp: 0.0.0.0/0
          FromPort: 443
          ToPort: 443
          IpProtocol: tcp
      VpcId: !Ref VPCID

  CognitoAppClient:
    Type: AWS::Cognito::UserPoolClient
    Properties: 
      ClientName: !Sub '${AWS::StackName} ALB'
      ExplicitAuthFlows: 
        - USER_PASSWORD_AUTH
      GenerateSecret: True
      RefreshTokenValidity: 1
      UserPoolId: !Select [1, !Split [ 'userpool/', !Ref CognitoUserpoolArn ]]
      ReadAttributes: 
        - email
        - email_verified

  CognitoAppClientSettings:
    Type: Custom::CognitoUserPoolClientSettings
    Properties:
      ServiceToken: !GetAtt CustomResource.Arn
      Service: CognitoIdentityServiceProvider
      Create:
        Action: updateUserPoolClient
        Parameters:
          UserPoolId: !Select [1, !Split [ 'userpool/', !Ref CognitoUserpoolArn ]]
          ClientId: !Ref CognitoAppClient
          AllowedOAuthFlows: [ code ]
          AllowedOAuthScopes: [ openid ]
          SupportedIdentityProviders: [ COGNITO ]
          AllowedOAuthFlowsUserPoolClient: true
          CallbackURLs:
            - !Sub https://${AWS::StackName}.${R53Zone}/oauth2/idpresponse

  Listener:
    Type: AWS::ElasticLoadBalancingV2::Listener
    DependsOn: CognitoAppClientSettings
    Properties: 
      DefaultActions: 
        - Type: authenticate-cognito
          Order: 1
          AuthenticateCognitoConfig: 
            OnUnauthenticatedRequest: authenticate
            Scope: openid
            SessionCookieName: AWSELBAuthSessionCookie
            SessionTimeout: 60
            UserPoolArn: !Ref CognitoUserpoolArn
            UserPoolDomain: !Ref CognitoUserpoolDomain
            UserPoolClientId: !Ref CognitoAppClient
        - Type: forward
          Order: 2
          TargetGroupArn: !Ref LambdaTargetGroup
      LoadBalancerArn: !Ref ALB
      Port: 443
      Protocol: HTTPS
      Certificates: 
        - CertificateArn: !Ref SSLCert

  LambdaTargetGroup:
    Type: AWS::ElasticLoadBalancingV2::TargetGroup
    DependsOn: LambdaExecPermission
    Properties: 
      HealthCheckEnabled: False
      Name: !Ref AWS::StackName
      TargetType: lambda
      Targets: 
        - Id: !GetAtt JWTverifier.Arn

  LambdaExecPermission:
    Type: AWS::Lambda::Permission
    Properties: 
      Action: lambda:InvokeFunction
      FunctionName: !GetAtt JWTverifier.Arn
      Principal: elasticloadbalancing.amazonaws.com

  JWTverifier:
    Type: 'AWS::Serverless::Function'
    Properties:
      Runtime: python3.7
      Timeout: 30
      Handler: index.lambda_handler

      Description: Decode and verify JWT
      MemorySize: 128
      Layers:
        - !Ref LambdaLayerArn
      InlineCode: |
        import requests
        import base64
        import json
        import jwt

        # NOTE: This requires following Lambda layer ...
        # pip3 install PyJWT -t python
        # pip3 install requests -t python
        # pip3 install cryptography -t python --upgrade
        # aws --region eu-west-1 lambda publish-layer-version --layer-name pyjwt --description "PyJWT for Python 3.7" --zip-file fileb://jwt.zip --compatible-runtimes python3.7
 
        def lambda_handler(event, context):

            # Step 1: Get the key id from JWT headers (the kid field)
            encoded_jwt = event['headers']['x-amzn-oidc-data']
            jwt_fields = encoded_jwt.split('.')
            
            jwt_headers = jwt.get_unverified_header(encoded_jwt)
            jwt_sig = jwt_fields[2]
            
            # NOTE: Payload is in base64 clear-text, but you should use jwt.decode() to verify the signature!  
            # jwt_payload = json.loads(base64.b64decode(jwt_fields[1]))
            
            # Step 2: Get the public key from regional endpoint
            aws_region = context.invoked_function_arn.split(':')[3]
            url = 'https://public-keys.auth.elb.' + aws_region + '.amazonaws.com/' + jwt_headers['kid']
            pub_key = requests.get(url).text

            # Step 3: Get the payload
            jwt_payload = jwt.decode(encoded_jwt, pub_key, algorithms=[jwt_headers['alg']])
            
            response = {
                "statusCode": 200,
                "statusDescription": "200 OK",
                "isBase64Encoded": False,
                "headers": {
                    "Content-Type": "text/html; charset=utf-8"
                }
            }
            
            response['body'] = ( "<html><head><title>Hello JWT</title></head><body><pre>"
                + 'Encoded JWT:\n' + encoded_jwt + '\n\n'
                + 'JWT Headers:\n' + json.dumps(jwt_headers, indent=2, sort_keys=True) + '\n\n'
                + 'JWT Payload:\n' + json.dumps(jwt_payload, indent=2, sort_keys=True) + '\n\n'
                + 'JWT Signature:\n' + jwt_sig + '\n'
                + "</pre></body></html>" )

            return response

  # This is generic custom resource implementation to configure Cognito
  # userpool client settings that are not supported by Cloudformation.
  # Code is from https://github.com/emdgroup/cfn-custom-resource
  
  CustomResource:
    Type: 'AWS::Serverless::Function'
    Properties:
      Runtime: nodejs8.10
      Timeout: 300
      Handler: index.handler
      Description: Cloudformation custom resource
      MemorySize: 128
      Policies:
        - Version: '2012-10-17'
          Statement:
            - Effect: Allow
              Action:
                - cognito-idp:UpdateUserPoolClient 
              Resource: !Ref CognitoUserpoolArn
      InlineCode: |
        const AWS = require('aws-sdk'),
          jmespath = require('jmespath'),
          querystring = require('querystring'),
          crypto = require('crypto'),
          https = require("https"),
          url = require("url");

        let pid = 'PhysicalResourceId', rp = 'ResourceProperties';
        exports.handler = (ev, ctx, cb) => {
          console.log(JSON.stringify(Object.assign({}, ev, {
            ResourceProperties: null,
            OldResourceProperties: null,
          })));
          let rand = random();
          ev[rp] = fixBooleans(ev[rp], ev.RequestType !== 'Create' ? ev[pid] : fixBooleans(ev[rp][pid], null, rand), rand);
          let args = ev[rp][ev.RequestType];
          if (!args) args = ev.RequestType === 'Delete' ? {} : ev[rp]['Create'];
          ['Attributes', pid, 'PhysicalResourceIdQuery', 'Parameters'].forEach(attr =>
            args[attr] = args[attr] || ev[rp][attr]
          );
          if (ev.RequestType === 'Delete') {
            request(args, ev, ctx, () => response.send(ev, ctx, 'SUCCESS', {}, ev[pid]));
          } else if (ev.RequestType === 'Create' || ev.RequestType === 'Update') {
            request(args, ev, ctx, function(data) {
              let props = ev[rp][ev.RequestType] || ev[rp]['Create'];
              if (props.PhysicalResourceIdQuery) ev[pid] = jmespath.search(data, props.PhysicalResourceIdQuery);
              if (props[pid]) ev[pid] = props[pid];
              if (props.Attributes) data = jmespath.search(data, props.Attributes);
              response.send(ev, ctx, 'SUCCESS', data, ev[pid]);
            });
          }
        };

        function random() {
          return crypto.randomBytes(6).toString('base64').replace(/[\\+=\\/]/g, '').toUpperCase();
        }

        function fixBooleans(obj, id, rand) {
          if (Array.isArray(obj)) return obj.map(item => fixBooleans(item, id, rand));
          else if (typeof obj === 'object') {
            for (let key in obj) obj[key] = fixBooleans(obj[key], id, rand);
            return obj;
          } else if (typeof obj === 'string') {
            obj = obj === 'true' ? true : obj === 'false' ? false : obj === 'null' ? null : obj.replace(/\\${Random}/, rand);
            if (typeof obj === 'string' && id) obj = obj.replace(/\\${Physical(Resource)?Id}/, id);
            return obj;
          } else return obj;
        }

        function b64ify(obj) {
          if (Buffer.isBuffer(obj))
            return obj.toString('base64');
          else if (Array.isArray(obj)) return obj.map(item => b64ify(item));
          else if (typeof obj === 'object') {
            for (let key in obj) obj[key] = b64ify(obj[key]);
            return obj;
          } else return obj;
        }


        function request(args, ev, ctx, cb) {
          if (ev.RequestType === 'Delete' && !args.Action) return cb();
          let creds = AWS.config.credentials;
          creds.getPromise().then(() => {
            if(ev[rp].RoleArn) creds = new AWS.TemporaryCredentials({
              RoleArn: ev[rp].RoleArn,
            });
            let client = new AWS[ev[rp].Service]({
              credentials: creds,
              region: ev[rp].Region || AWS.config.region,
            });
            client[args.Action](args.Parameters, (err, data) => {
              if (err && args.IgnoreErrors !== true) {
                response.send(ev, ctx, 'FAILED', err, ev[pid]);
              } else cb(data);
            });
          })
        }

        let response = {
          body: function(ev, ctx, responseStatus, resData, pId) {
            let body = {
              Status: responseStatus,
              Reason: resData instanceof Error ? resData.toString() : '',
              PhysicalResourceId: pId || ev.RequestId,
              StackId: ev.StackId,
              RequestId: ev.RequestId,
              LogicalResourceId: ev.LogicalResourceId,
              Data: responseStatus === 'FAILED' ? null : b64ify(resData),
            }
            if (JSON.stringify(body).length > 4096) {
              console.log('truncated responseData as it exceeded 4096 bytes');
              return Object.assign(body, {
                Data: null
              });
            } else { return body }
          },
          send: function(ev, ctx) {
            let responseBody = response.body.apply(this, arguments);
            console.log('Response', JSON.stringify(Object.assign({}, responseBody, {
              Data: null
            })));

            var parsed = url.parse(ev.ResponseURL);
            https.request({
              hostname: parsed.hostname,
              path: parsed.path,
              method: 'PUT',
            }, res => () => ctx.done()).on("error", function(error) {
              console.log(error);
              ctx.done();
            }).end(JSON.stringify(responseBody));
          },
        };
