hongbo-miao/hongbomiao.com

View on GitHub
machine-learning/triton/amazon-sagamaker-triton-resnet-50/deploy/src/deploy.py

Summary

Maintainability
A
1 hr
Test Coverage
import logging
import time

import boto3
import sagemaker
from botocore.client import BaseClient


def check_endpoint_status(
    sagemaker_client: BaseClient, sagemaker_endpoint_name: str
) -> None:
    while (
        status := sagemaker_client.describe_endpoint(
            EndpointName=sagemaker_endpoint_name
        )["EndpointStatus"]
    ) == "Creating":
        print(f"Status: {status}")
        time.sleep(30)
    print(f"Status: {status}")


def deploy() -> None:
    model_name = "resnet-50"

    # `sagemaker_execution_role = sagemaker.get_execution_role()` only works in the Jupyter notebook hosted by Amazon SageMaker
    aws_account_id = boto3.client("sts").get_caller_identity()["Account"]
    sagemaker_execution_role = f"arn:aws:iam::{aws_account_id}:role/AmazonSageMakerExecutionRole-hm-sagemaker-notebook"

    sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
    sagemaker_model_name = f"{model_name}-model"
    sagemaker_endpoint_config_name = f"{model_name}-endpoint-config"
    sagemaker_endpoint_name = f"{model_name}-endpoint"
    model_s3_url = f"s3://{sagemaker_session.default_bucket()}/{model_name}/"

    # Account mapping for SageMaker multi-model endpoints (MME) Triton image
    aws_account_id_dict = {
        "us-east-1": "785573368785",
        "us-east-2": "007439368137",
        "us-west-1": "710691900526",
        "us-west-2": "301217895009",
        "eu-west-1": "802834080501",
        "eu-west-2": "205493899709",
        "eu-west-3": "254080097072",
        "eu-north-1": "601324751636",
        "eu-south-1": "966458181534",
        "eu-central-1": "746233611703",
        "ap-east-1": "110948597952",
        "ap-south-1": "763008648453",
        "ap-northeast-1": "941853720454",
        "ap-northeast-2": "151534178276",
        "ap-southeast-1": "324986816169",
        "ap-southeast-2": "355873309152",
        "cn-northwest-1": "474822919863",
        "cn-north-1": "472730292857",
        "sa-east-1": "756306329178",
        "ca-central-1": "464438896020",
        "me-south-1": "836785723513",
        "af-south-1": "774647643957",
    }
    region = boto3.Session().region_name
    if region not in aws_account_id_dict.keys():
        raise ValueError("Unsupported region")
    base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
    triton_server_image_uri = f"{aws_account_id_dict[region]}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.07-py3"

    # Create a model
    sagemaker_client = boto3.client("sagemaker")
    res = sagemaker_client.create_model(
        ModelName=sagemaker_model_name,
        ExecutionRoleArn=sagemaker_execution_role,
        PrimaryContainer={
            "Image": triton_server_image_uri,
            "ModelDataUrl": model_s3_url,
            "Mode": "MultiModel",
        },
    )
    logging.info(f'Model Arn: {res["ModelArn"]}')

    # Create an endpoint config
    res = sagemaker_client.create_endpoint_config(
        EndpointConfigName=sagemaker_endpoint_config_name,
        ProductionVariants=[
            {
                "InstanceType": "ml.g4dn.4xlarge",
                "InitialVariantWeight": 1,
                "InitialInstanceCount": 1,
                "ModelName": sagemaker_model_name,
                "VariantName": "AllTraffic",
            }
        ],
    )
    logging.info(f'Endpoint Config Arn: {res["EndpointConfigArn"]}')

    # Create an endpoint
    res = sagemaker_client.create_endpoint(
        EndpointName=sagemaker_endpoint_name,
        EndpointConfigName=sagemaker_endpoint_config_name,
    )
    logging.info(f'Endpoint Arn: {res["EndpointArn"]}')

    check_endpoint_status(sagemaker_client, sagemaker_endpoint_name)

    # Perform auto-scaling of the endpoint based on GPU memory utilization
    # This is the format in which application autoscaling references the endpoint
    resource_id = "endpoint/" + sagemaker_endpoint_name + "/variant/" + "AllTraffic"
    auto_scaling_client = boto3.client("application-autoscaling")
    auto_scaling_client.register_scalable_target(
        ServiceNamespace="sagemaker",
        ResourceId=resource_id,
        ScalableDimension="sagemaker:variant:DesiredInstanceCount",
        MinCapacity=1,
        MaxCapacity=5,
    )
    # GPUMemoryUtilization metric
    auto_scaling_client.put_scaling_policy(
        PolicyName="GPUUtil-ScalingPolicy",
        ServiceNamespace="sagemaker",
        ResourceId=resource_id,
        ScalableDimension="sagemaker:variant:DesiredInstanceCount",  # SageMaker supports only instance count
        PolicyType="TargetTrackingScaling",  # StepScaling, TargetTrackingScaling
        TargetTrackingScalingPolicyConfiguration={
            # Scale out when GPU utilization hits GPUUtilization target value.
            "TargetValue": 60.0,
            "CustomizedMetricSpecification": {
                "MetricName": "GPUUtilization",
                "Namespace": "/aws/sagemaker/Endpoints",
                "Dimensions": [
                    {"Name": "EndpointName", "Value": sagemaker_endpoint_name},
                    {"Name": "VariantName", "Value": "AllTraffic"},
                ],
                "Statistic": "Average",  # Average, Minimum, Maximum, SampleCount, Sum
                "Unit": "Percent",
            },
            "ScaleInCooldown": 600,
            "ScaleOutCooldown": 200,
        },
    )


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    deploy()