How to A/B test Tensorflow models using Sagemaker Endpoints

How can we run multiple versions of a Tensorflow model in production at the same time? There are many possible solutions. The simplest option is to deploy those versions as two separate models in Tensorflow Serving and switch between them in the application code. However, that quickly becomes difficult to maintain when we want to do a canary release or A/B test more than two models.

Thankfully, Sagemaker Endpoints simplify A/B testing of machine learning models. We can achieve the desired result using a few lines of code.

In this article, I’ll show you how to define multiple ML models, configure them as Sagemaker Endpoint variants, deploy the endpoint, and capture the model results from all deployed versions.

Dependencies and Imports

To run the code, we need two dependencies:

1
2
boto3==1.14.12
sagemaker==2.5.3

If you prefer to run the deployment script as a step in the AWS Code Pipeline, take a look at this article.

I assume that the code runs in the environment in which AWS API key and secret have been provided using environment variables.

In the deployment script, we have to import Sagemaker and create the session.

1
2
3
4
5
6
7
8
9
import sagemaker
from sagemaker.session import production_variant

from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.model_monitor import DataCaptureConfig

sagemaker_session = sagemaker.Session()

role = 'ARN of the role that has access to Sagemaker and the deployment bucket in S3'

Creating Model Versions

Before we continue, we have to archive the saved Tensorflow models as tar.gz files and store them in an S3 bucket. In the next step, we create two models in AWS Sagemaker using those tar.gz files.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
model_version_A = TensorFlowModel(
    name='model-name-version-a',
    role=role,
    entry_point='inference.py',
    source_dir='src',
    model_data='s3://bucket/path/model-version-a.tar.gz',
    framework_version="2.3",
    sagemaker_session=sagemaker_session
)

sagemaker_session.create_model(
    'model-version-a',
    role,
    model_version_A.prepare_container_def(
        instance_type='ml.t2.medium'
    )
)

model_version_A = TensorFlowModel(
    name='model-name-version-b',
    role=role,
    entry_point='inference.py',
    source_dir='src',
    model_data='s3://bucket/path/model-version-b.tar.gz',
    framework_version="2.3",
    sagemaker_session=sagemaker_session
)

sagemaker_session.create_model(
    'model-version-b',
    role,
    model_version_B.prepare_container_def(
        instance_type='ml.t2.medium'
    )
)

The entry_point is the file containing input_handler and output_handler functions used to convert the HTTP request to input compatible with Tensorflow Serving and convert the response back to the format expected by the client application. The source_dir is the directory where we stored the inference.py script. We can also include the requirements.txt file in the source_dir directory to install additional dependencies.

Creating Variants

Now, we must define endpoint variants by specifying the model names and the percentage of traffic redirected to every variant (the initial_weight parameter):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
variantA = production_variant(
    model_name='model-version-a',
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name="VariantA",
    initial_weight=50,
)

variantB = production_variant(
    model_name='model-version-b',
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name="VariantB",
    initial_weight=50,
)

Configuring Data Capture

We want to log the requests and responses to verify which version of the model performs better. Sagemaker Endpoints stores logs of every variant separately in JSON files, and we can log every request by configuring Data Capture with sampling set to 100%.

1
2
3
4
5
data_capture_config = DataCaptureConfig(
    enable_capture=True,
    sampling_percentage=100,
    destination_s3_uri='s3://bucket/logs'
)

Deploying the Endpoint

Finally, we can deploy the endpoint with two variants and data capture:

1
2
3
4
5
sagemaker_session.endpoint_from_production_variants(
    name='AB-endpoint-with-monitoring',
    production_variants=[variantA, variantB],
    data_capture_config_dict=data_capture_config._to_request_dict()
)

Remember to share on social media!
If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.

If you want to contact me, send me a message on LinkedIn or Twitter.

Would you like to have a call and talk? Please schedule a meeting using this link.


Bartosz Mikulski
Bartosz Mikulski * data/machine learning engineer * conference speaker * co-founder of Software Craft Poznan & Poznan Scala User Group

Subscribe to the newsletter and get access to my free email course on building trustworthy data pipelines.

Do you want to work with me at riskmethods?

REMOTE position (available in Poland or Germany)