From 94256cf384e4258b7b156c3018dbd70a747a70a0 Mon Sep 17 00:00:00 2001 From: Kevin Date: Mon, 23 Oct 2023 01:36:47 +0200 Subject: [PATCH 1/2] Use lambda url as cloudfront origin if functionName is defined --- src/constructs/aws/ServerSideWebsite.ts | 25 +++++++++++++++++++------ src/providers/AwsProvider.ts | 1 + src/types/serverless.ts | 1 + 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/constructs/aws/ServerSideWebsite.ts b/src/constructs/aws/ServerSideWebsite.ts index 0371e871..bab1dcc4 100644 --- a/src/constructs/aws/ServerSideWebsite.ts +++ b/src/constructs/aws/ServerSideWebsite.ts @@ -37,6 +37,7 @@ const SCHEMA = { properties: { type: { const: "server-side-website" }, apiGateway: { enum: ["http", "rest"] }, + functionName: { type: "string" }, assets: { type: "object", additionalProperties: { type: "string" }, @@ -115,11 +116,23 @@ export class ServerSideWebsite extends AwsConstruct { })(); const backendCachePolicy = CachePolicy.CACHING_DISABLED; - const apiId = - configuration.apiGateway === "rest" - ? this.provider.naming.getRestApiLogicalId() - : this.provider.naming.getHttpApiLogicalId(); - const apiGatewayDomain = Fn.join(".", [Fn.ref(apiId), `execute-api.${this.provider.region}.amazonaws.com`]); + // Resolve domain to use as cloudfront origin + const originDomain = (() => { + if (configuration.functionName !== void 0) { + // Use Lambda Url + const lambdaUrlId = this.provider.naming.getLambdaFunctionUrlLogicalId(configuration.functionName); + + return Fn.select(2, Fn.split("/", Fn.getAtt(lambdaUrlId, "FunctionUrl").toString())); + } else { + // Use API Gateway + const apiId = + configuration.apiGateway === "rest" + ? this.provider.naming.getRestApiLogicalId() + : this.provider.naming.getHttpApiLogicalId(); + + return Fn.join(".", [Fn.ref(apiId), `execute-api.${this.provider.region}.amazonaws.com`]); + } + })(); // Cast the domains to an array this.domains = configuration.domain !== undefined ? flatten([configuration.domain]) : undefined; @@ -132,7 +145,7 @@ export class ServerSideWebsite extends AwsConstruct { comment: `${provider.stackName} ${id} website CDN`, defaultBehavior: { // Origins are where CloudFront fetches content - origin: new HttpOrigin(apiGatewayDomain, { + origin: new HttpOrigin(originDomain, { // API Gateway only supports HTTPS protocolPolicy: OriginProtocolPolicy.HTTPS_ONLY, }), diff --git a/src/providers/AwsProvider.ts b/src/providers/AwsProvider.ts index 0046d44d..239a7d5a 100644 --- a/src/providers/AwsProvider.ts +++ b/src/providers/AwsProvider.ts @@ -62,6 +62,7 @@ export class AwsProvider implements ProviderInterface { public naming: { getStackName: () => string; getLambdaLogicalId: (functionName: string) => string; + getLambdaFunctionUrlLogicalId: (functionName: string) => string; getRestApiLogicalId: () => string; getHttpApiLogicalId: () => string; }; diff --git a/src/types/serverless.ts b/src/types/serverless.ts index d4ff66f1..19763c00 100644 --- a/src/types/serverless.ts +++ b/src/types/serverless.ts @@ -24,6 +24,7 @@ export type Provider = { naming: { getStackName: () => string; getLambdaLogicalId: (functionName: string) => string; + getLambdaFunctionUrlLogicalId: (functionName: string) => string; getRestApiLogicalId: () => string; getHttpApiLogicalId: () => string; getCompiledTemplateFileName: () => string; From d4ff0bb5488d1439b2d6e4086df02639873aa378 Mon Sep 17 00:00:00 2001 From: Kevin Date: Sat, 28 Oct 2023 20:54:42 +0200 Subject: [PATCH 2/2] Run config checks and make implementation backwards compatible Reformat code --- .eslintignore | 1 + src/constructs/aws/ServerSideWebsite.ts | 75 ++++++++++++++------ src/interfaces/WebLambdaFunctionInterface.ts | 4 ++ src/providers/AwsProvider.ts | 14 ++++ src/utils/GetWebLambdaFunctions.ts | 21 ++++++ 5 files changed, 95 insertions(+), 20 deletions(-) create mode 100644 src/interfaces/WebLambdaFunctionInterface.ts create mode 100644 src/utils/GetWebLambdaFunctions.ts diff --git a/.eslintignore b/.eslintignore index 6d5509a3..b3d5ee4e 100644 --- a/.eslintignore +++ b/.eslintignore @@ -2,4 +2,5 @@ /lib /dist /utils +/src/utils/GetWebLambdaFunctions.ts /test/fixtures diff --git a/src/constructs/aws/ServerSideWebsite.ts b/src/constructs/aws/ServerSideWebsite.ts index bab1dcc4..819f9b23 100644 --- a/src/constructs/aws/ServerSideWebsite.ts +++ b/src/constructs/aws/ServerSideWebsite.ts @@ -37,7 +37,7 @@ const SCHEMA = { properties: { type: { const: "server-side-website" }, apiGateway: { enum: ["http", "rest"] }, - functionName: { type: "string" }, + originName: { type: "string" }, assets: { type: "object", additionalProperties: { type: "string" }, @@ -116,24 +116,6 @@ export class ServerSideWebsite extends AwsConstruct { })(); const backendCachePolicy = CachePolicy.CACHING_DISABLED; - // Resolve domain to use as cloudfront origin - const originDomain = (() => { - if (configuration.functionName !== void 0) { - // Use Lambda Url - const lambdaUrlId = this.provider.naming.getLambdaFunctionUrlLogicalId(configuration.functionName); - - return Fn.select(2, Fn.split("/", Fn.getAtt(lambdaUrlId, "FunctionUrl").toString())); - } else { - // Use API Gateway - const apiId = - configuration.apiGateway === "rest" - ? this.provider.naming.getRestApiLogicalId() - : this.provider.naming.getHttpApiLogicalId(); - - return Fn.join(".", [Fn.ref(apiId), `execute-api.${this.provider.region}.amazonaws.com`]); - } - })(); - // Cast the domains to an array this.domains = configuration.domain !== undefined ? flatten([configuration.domain]) : undefined; const certificate = @@ -145,7 +127,7 @@ export class ServerSideWebsite extends AwsConstruct { comment: `${provider.stackName} ${id} website CDN`, defaultBehavior: { // Origins are where CloudFront fetches content - origin: new HttpOrigin(originDomain, { + origin: new HttpOrigin(this.getCloudFrontOrigin(), { // API Gateway only supports HTTPS protocolPolicy: OriginProtocolPolicy.HTTPS_ONLY, }), @@ -448,4 +430,57 @@ export class ServerSideWebsite extends AwsConstruct { private getErrorPageFileName(): string { return this.configuration.errorPage !== undefined ? path.basename(this.configuration.errorPage) : ""; } + + private getCloudFrontOrigin(): string { + const functions = this.provider.getWebLambdaFunctions(); + const functionsUsingLambdaUrl = functions.reduce((count, func) => count + (func.usesLambdaUrl ? 1 : 0), 0); + + // Fail if no web functions defined + if (functions.length === 0) { + throw new ServerlessError( + "Error trying to detect CloudFront origin. Please check that at least one Lambda function uses 'url', 'events.httpApi', 'events.http' or 'events.alb'.", + "LIFT_INVALID_STACK_CONFIGURATION" + ); + } + + // Try to use ApiGateway if one or more functions are defined and none uses Lambda URL + if (functions.length >= 1 && functionsUsingLambdaUrl === 0) { + return this.getApiGatewayUrl(); + } + + // Try to use Lambda URL if only one web function is defined + if (functions.length === 1 && functionsUsingLambdaUrl === 1) { + return this.getLambdaUrl(functions[0].name); + } + + // Try to use configured origin + if (this.configuration.originName !== undefined) { + const selectedWebFunction = functions.find((f) => f.name === this.configuration.originName); + if (selectedWebFunction) { + return selectedWebFunction.usesLambdaUrl + ? this.getLambdaUrl(selectedWebFunction.name) + : this.getApiGatewayUrl(); + } + } + + throw new ServerlessError( + `Error trying to detect CloudFront origin. Invalid or missing 'constructs.${this.id}.originName' key.`, + "LIFT_INVALID_CONSTRUCT_CONFIGURATION" + ); + } + + private getApiGatewayUrl() { + const apiId = + this.configuration.apiGateway === "rest" + ? this.provider.naming.getRestApiLogicalId() + : this.provider.naming.getHttpApiLogicalId(); + + return Fn.join(".", [Fn.ref(apiId), `execute-api.${this.provider.region}.amazonaws.com`]); + } + + private getLambdaUrl(name: string) { + const lambdaUrlId = this.provider.naming.getLambdaFunctionUrlLogicalId(name); + + return Fn.select(2, Fn.split("/", Fn.getAtt(lambdaUrlId, "FunctionUrl").toString())); + } } diff --git a/src/interfaces/WebLambdaFunctionInterface.ts b/src/interfaces/WebLambdaFunctionInterface.ts new file mode 100644 index 00000000..ac25fb9c --- /dev/null +++ b/src/interfaces/WebLambdaFunctionInterface.ts @@ -0,0 +1,4 @@ +export interface WebLambdaFunctionInterface { + name: string; + usesLambdaUrl: boolean; +} diff --git a/src/providers/AwsProvider.ts b/src/providers/AwsProvider.ts index 239a7d5a..f6403c8f 100644 --- a/src/providers/AwsProvider.ts +++ b/src/providers/AwsProvider.ts @@ -18,6 +18,8 @@ import { getStackOutput } from "../CloudFormation"; import type { CloudformationTemplate, Provider as LegacyAwsProvider, Serverless } from "../types/serverless"; import { awsRequest } from "../classes/aws"; import ServerlessError from "../utils/error"; +import type { WebLambdaFunctionInterface } from "../interfaces/WebLambdaFunctionInterface"; +import { GetWebLambdaFunctions } from "../utils/GetWebLambdaFunctions"; const AWS_DEFINITION = { type: "object", @@ -161,6 +163,18 @@ export class AwsProvider implements ProviderInterface { resources: this.app.synth().getStackByName(this.stack.stackName).template as CloudformationTemplate, }); } + + /** + * This function can be used by other constructs to get all web Lambda functions + * Web Lambda functions must contain at least one of the following keys: + * - url + * - events.http + * - events.httpApi + * - events.alb + */ + getWebLambdaFunctions(): WebLambdaFunctionInterface[] { + return GetWebLambdaFunctions(this.serverless.service.functions); + } } /** diff --git a/src/utils/GetWebLambdaFunctions.ts b/src/utils/GetWebLambdaFunctions.ts new file mode 100644 index 00000000..6c624943 --- /dev/null +++ b/src/utils/GetWebLambdaFunctions.ts @@ -0,0 +1,21 @@ +import { WebLambdaFunctionInterface } from "../interfaces/WebLambdaFunctionInterface"; + +export function GetWebLambdaFunctions(functions: any): WebLambdaFunctionInterface[] { + if (functions === undefined) { + return []; + } + + return Object.keys(functions) + .filter((key) => { + const fn = functions[key]; + + return fn.url !== undefined || (fn.events !== undefined && fn.events.some((e: any) => e.httpApi || e.http || e.alb)); + }) + .map((key) => { + return { + name: key, + usesLambdaUrl: functions[key].url !== undefined, + }; + }) + ; +}