How to A/B test Tensorflow models using Sagemaker Endpoints
How can we run multiple versions of a Tensorflow model in production at the same time? There are many possible solutions. The simplest option is to deploy those versions as two separate models in Tensorflow Serving and switch between them in the application code. However, that quickly becomes difficult to maintain when we want to do a canary release or A/B test more than two models.
Thankfully, Sagemaker Endpoints simplify A/B testing of machine learning models. We can achieve the desired result using a few lines of code.
In this article, I’ll show you how to define multiple ML models, configure them as Sagemaker Endpoint variants, deploy the endpoint, and capture the model results from all deployed versions.
Dependencies and Imports
To run the code, we need two dependencies:
1 2 boto3==1.14.12 sagemaker==2.5.3
If you prefer to run the deployment script as a step in the AWS Code Pipeline, take a look at this article.
I assume that the code runs in the environment in which AWS API key and secret have been provided using environment variables.
In the deployment script, we have to import Sagemaker and create the session.
1 2 3 4 5 6 7 8 9 import sagemaker from sagemaker.session import production_variant from sagemaker.tensorflow.model import TensorFlowModel from sagemaker.model_monitor import DataCaptureConfig sagemaker_session = sagemaker.Session() role = 'ARN of the role that has access to Sagemaker and the deployment bucket in S3'
Creating Model Versions
Before we continue, we have to archive the saved Tensorflow models as tar.gz files and store them in an S3 bucket. In the next step, we create two models in AWS Sagemaker using those tar.gz files.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 model_version_A = TensorFlowModel( name='model-name-version-a', role=role, entry_point='inference.py', source_dir='src', model_data='s3://bucket/path/model-version-a.tar.gz', framework_version="2.3", sagemaker_session=sagemaker_session ) sagemaker_session.create_model( 'model-version-a', role, model_version_A.prepare_container_def( instance_type='ml.t2.medium' ) ) model_version_A = TensorFlowModel( name='model-name-version-b', role=role, entry_point='inference.py', source_dir='src', model_data='s3://bucket/path/model-version-b.tar.gz', framework_version="2.3", sagemaker_session=sagemaker_session ) sagemaker_session.create_model( 'model-version-b', role, model_version_B.prepare_container_def( instance_type='ml.t2.medium' ) )
entry_point is the file containing
output_handler functions used to convert the HTTP request to input compatible with Tensorflow Serving and convert the response back to the format expected by the client application. The
source_dir is the directory where we stored the
inference.py script. We can also include the
requirements.txt file in the
source_dir directory to install additional dependencies.
Now, we must define endpoint variants by specifying the model names and the percentage of traffic redirected to every variant (the
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 variantA = production_variant( model_name='model-version-a', instance_type="ml.t2.medium", initial_instance_count=1, variant_name="VariantA", initial_weight=50, ) variantB = production_variant( model_name='model-version-b', instance_type="ml.t2.medium", initial_instance_count=1, variant_name="VariantB", initial_weight=50, )
Configuring Data Capture
We want to log the requests and responses to verify which version of the model performs better. Sagemaker Endpoints stores logs of every variant separately in JSON files, and we can log every request by configuring Data Capture with sampling set to 100%.
1 2 3 4 5 data_capture_config = DataCaptureConfig( enable_capture=True, sampling_percentage=100, destination_s3_uri='s3://bucket/logs' )
Deploying the Endpoint
Finally, we can deploy the endpoint with two variants and data capture:
1 2 3 4 5 sagemaker_session.endpoint_from_production_variants( name='AB-endpoint-with-monitoring', production_variants=[variantA, variantB], data_capture_config_dict=data_capture_config._to_request_dict() )
You may also like