Skip to content

Commit

Permalink
feat(Sagemaker): allow deploying llama 3.1 from Sagemaker Jumpstart (#…
Browse files Browse the repository at this point in the history
…692)

* feat(Sagemaker): add Jumpstart support for new regions and add gatedBucket into consideration.
---------

Signed-off-by: arief hidayat <[email protected]>
Co-authored-by: Alain Krok <[email protected]>
  • Loading branch information
arief-hidayat and krokoko authored Sep 19, 2024
1 parent 9e91372 commit 6ab07ac
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 3 deletions.
6 changes: 6 additions & 0 deletions apidocs/interfaces/IJumpStartModelSpec.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@

***

### gatedBucket

> **gatedBucket**: `boolean`
***

### instanceAliases?

> `optional` **instanceAliases**: [`IInstanceAliase`](IInstanceAliase.md)[]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ interface JumpStartModelSpec {
hosting_artifact_key?: string;
hosting_script_key?: string;
hosting_prepacked_artifact_key?: string;
gated_bucket: boolean;
hosting_eula_key?: string;
inference_environment_variables: {
name: string;
Expand Down Expand Up @@ -113,6 +114,7 @@ export async function download_data() {
hosting_script_key,
hosting_artifact_key,
hosting_prepacked_artifact_key,
gated_bucket,
inference_environment_variables,
hosting_instance_type_variants,
hosting_eula_key,
Expand Down Expand Up @@ -149,6 +151,7 @@ export async function download_data() {
hosting_artifact_key,
hosting_script_key,
hosting_prepacked_artifact_key,
gated_bucket,
inference_environment_variables,
hosting_instance_type_variants,
hosting_eula_key,
Expand Down Expand Up @@ -220,6 +223,7 @@ function generateCode() {
instanceTypes: specSource.supported_inference_instance_types,
modelPackageArns: specSource.hosting_model_package_arns,
prepackedArtifactKey: specSource.hosting_prepacked_artifact_key,
gatedBucket: specSource.gated_bucket,
artifactKey: specSource.hosting_artifact_key,
environment,
instanceAliases: instanceAliasesArr,
Expand Down Expand Up @@ -270,6 +274,7 @@ export interface IJumpStartModelSpec {
instanceTypes: string[];
modelPackageArns?: { [region: string]: string };
prepackedArtifactKey?: string;
gatedBucket: boolean;
artifactKey?: string;
environment: { [key: string]: string | number | boolean };
instanceAliases?: IInstanceAliase[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export interface IJumpStartModelSpec {
instanceTypes: string[];
modelPackageArns?: { [region: string]: string };
prepackedArtifactKey?: string;
gatedBucket: boolean;
artifactKey?: string;
environment: { [key: string]: string | number | boolean };
instanceAliases?: IInstanceAliase[];
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ export class JumpStartSageMakerEndpoint extends SageMakerEndpointBase {
...this.environment,
};

if (environment.SAGEMAKER_SUBMIT_DIRECTORY) {
delete environment.SAGEMAKER_SUBMIT_DIRECTORY;
}

return environment;
}

Expand All @@ -195,7 +199,8 @@ export class JumpStartSageMakerEndpoint extends SageMakerEndpointBase {
vpcConfig: sagemaker.CfnModel.VpcConfigProperty | undefined,
) {
const key = this.spec.prepackedArtifactKey ?? this.spec.artifactKey;
const bucket = JumpStartConstants.JUMPSTART_LAUNCHED_REGIONS[this.region]?.contentBucket;
const bucket = this.spec.gatedBucket ? JumpStartConstants.JUMPSTART_LAUNCHED_REGIONS[this.region]?.gatedContentBucket :
JumpStartConstants.JUMPSTART_LAUNCHED_REGIONS[this.region]?.contentBucket;
if (!bucket) {
throw new Error(`JumpStart is not available in the region ${this.region}.`);
}
Expand Down Expand Up @@ -224,7 +229,7 @@ export class JumpStartSageMakerEndpoint extends SageMakerEndpointBase {
executionRoleArn: this.role.roleArn,
enableNetworkIsolation: true,
primaryContainer: isArtifactCompressed ? {
// True: Artifact is a tarball
// True: Artifact is a tarball
image,
modelDataUrl: modelArtifactUrl,
environment,
Expand Down Expand Up @@ -252,6 +257,18 @@ export class JumpStartSageMakerEndpoint extends SageMakerEndpointBase {
key: 'modelVersion',
value: this.spec.version,
},
{
key: 'sagemaker-studio:jumpstart-model-id',
value: this.spec.modelId,
},
{
key: 'sagemaker-studio:jumpstart-model-version',
value: this.spec.version,
},
{
key: 'sagemaker-studio:jumpstart-hub-name',
value: 'SageMakerPublicHub',
},
],
vpcConfig: vpcConfig,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ export abstract class JumpStartConstants {
contentBucket: 'jumpstart-cache-prod-eu-north-1',
gatedContentBucket: 'jumpstart-private-cache-prod-eu-north-1',
},
'me-central-1': {
contentBucket: 'jumpstart-cache-prod-me-central-1',
gatedContentBucket: 'jumpstart-private-cache-prod-me-central-1',
},
'me-south-1': {
contentBucket: 'jumpstart-cache-prod-me-south-1',
gatedContentBucket: 'jumpstart-private-cache-prod-me-south-1',
Expand Down Expand Up @@ -72,6 +76,10 @@ export abstract class JumpStartConstants {
contentBucket: 'jumpstart-cache-prod-ap-northeast-2',
gatedContentBucket: 'jumpstart-private-cache-prod-ap-northeast-2',
},
'ap-northeast-3': {
contentBucket: 'jumpstart-cache-prod-ap-northeast-3',
gatedContentBucket: 'jumpstart-private-cache-prod-ap-northeast-3',
},
'eu-west-2': {
contentBucket: 'jumpstart-cache-prod-eu-west-2',
gatedContentBucket: 'jumpstart-private-cache-prod-eu-west-2',
Expand All @@ -96,13 +104,29 @@ export abstract class JumpStartConstants {
contentBucket: 'jumpstart-cache-prod-ap-southeast-2',
gatedContentBucket: 'jumpstart-private-cache-prod-ap-southeast-2',
},
'ap-southeast-3': {
contentBucket: 'jumpstart-cache-prod-ap-southeast-3',
gatedContentBucket: 'jumpstart-private-cache-prod-ap-southeast-3',
},
'ca-central-1': {
contentBucket: 'jumpstart-cache-prod-ca-central-1',
gatedContentBucket: 'jumpstart-private-cache-prod-ca-central-1',
},
'cn-north-1': {
contentBucket: 'jumpstart-cache-prod-cn-north-1',
},
'il-central-1': {
contentBucket: 'jumpstart-cache-prod-il-central-1',
gatedContentBucket: 'jumpstart-private-cache-prod-il-central-1',
},
'us-gov-east-1': {
contentBucket: 'jumpstart-cache-prod-us-gov-east-1',
gatedContentBucket: 'jumpstart-private-cache-prod-us-gov-east-1',
},
'us-gov-west-1': {
contentBucket: 'jumpstart-cache-prod-us-gov-west-1',
gatedContentBucket: 'jumpstart-private-cache-prod-us-gov-west-1',
},
};

public static JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = 'models_manifest.json';
Expand Down

0 comments on commit 6ab07ac

Please sign in to comment.