How to A/B test Tensorflow models using Sagemaker Endpoints

How can we run multiple versions of a Tensorflow model in production at the same time? There are many possible solutions. The simplest option is to deploy those versions as two separate models in Tensorflow Serving and switch between them in the application code. However, that quickly becomes difficult to maintain when we want to do a canary release or A/B test more than two models.

Thankfully, Sagemaker Endpoints simplify A/B testing of machine learning models. We can achieve the desired result using a few lines of code.

In this article, I’ll show you how to define multiple ML models, configure them as Sagemaker Endpoint variants, deploy the endpoint, and capture the model results from all deployed versions.

Dependencies and Imports

To run the code, we need two dependencies:

1
2
boto3==1.14.12
sagemaker==2.5.3

If you prefer to run the deployment script as a step in the AWS Code Pipeline, take a look at this article.

I assume that the code runs in the environment in which AWS API key and secret have been provided using environment variables.

In the deployment script, we have to import Sagemaker and create the session.

1
2
3
4
5
6
7
8
9
import sagemaker
from sagemaker.session import production_variant

from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.model_monitor import DataCaptureConfig

sagemaker_session = sagemaker.Session()

role = 'ARN of the role that has access to Sagemaker and the deployment bucket in S3'

Creating Model Versions

Before we continue, we have to archive the saved Tensorflow models as tar.gz files and store them in an S3 bucket. In the next step, we create two models in AWS Sagemaker using those tar.gz files.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
model_version_A = TensorFlowModel(
    name='model-name-version-a',
    role=role,
    entry_point='inference.py',
    source_dir='src',
    model_data='s3://bucket/path/model-version-a.tar.gz',
    framework_version="2.3",
    sagemaker_session=sagemaker_session
)

sagemaker_session.create_model(
    'model-version-a',
    role,
    model_version_A.prepare_container_def(
        instance_type='ml.t2.medium'
    )
)

model_version_A = TensorFlowModel(
    name='model-name-version-b',
    role=role,
    entry_point='inference.py',
    source_dir='src',
    model_data='s3://bucket/path/model-version-b.tar.gz',
    framework_version="2.3",
    sagemaker_session=sagemaker_session
)

sagemaker_session.create_model(
    'model-version-b',
    role,
    model_version_B.prepare_container_def(
        instance_type='ml.t2.medium'
    )
)

The entry_point is the file containing input_handler and output_handler functions used to convert the HTTP request to input compatible with Tensorflow Serving and convert the response back to the format expected by the client application. The source_dir is the directory where we stored the inference.py script. We can also include the requirements.txt file in the source_dir directory to install additional dependencies.

Creating Variants

Now, we must define endpoint variants by specifying the model names and the percentage of traffic redirected to every variant (the initial_weight parameter):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
variantA = production_variant(
    model_name='model-version-a',
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name="VariantA",
    initial_weight=50,
)

variantB = production_variant(
    model_name='model-version-b',
    instance_type="ml.t2.medium",
    initial_instance_count=1,
    variant_name="VariantB",
    initial_weight=50,
)

Configuring Data Capture

We want to log the requests and responses to verify which version of the model performs better. Sagemaker Endpoints stores logs of every variant separately in JSON files, and we can log every request by configuring Data Capture with sampling set to 100%.

1
2
3
4
5
data_capture_config = DataCaptureConfig(
    enable_capture=True,
    sampling_percentage=100,
    destination_s3_uri='s3://bucket/logs'
)

Deploying the Endpoint

Finally, we can deploy the endpoint with two variants and data capture:

1
2
3
4
5
sagemaker_session.endpoint_from_production_variants(
    name='AB-endpoint-with-monitoring',
    production_variants=[variantA, variantB],
    data_capture_config_dict=data_capture_config._to_request_dict()
)

What to do when a Sagemaker Endpoint does not run because of the “Invalid protobuf file” error?

Recently, the Sagemaker backend has been updated, and the following function causes errors in some projects:

1
2
3
4
def find_model_versions(model_path):
    return [version.lstrip("0") for version in os.listdir(model_path) if version.isnumeric()]

# source: https://github.com/aws/deep-learning-containers/blob/fe4864d0ce873c269da58ad8f3d29a4733cddc80/tensorflow/inference/docker/build_artifacts/sagemaker/tfs_utils.py#L137

The Sagemaker backend lists the model versions and removes the leading zeros from the version. The problem is that some people who have only one model version use ‘0’ as the version id. That zero gets trimmed to an empty string, and the Sagemaker Endpoint crashes.

It may be the case in your project if you see the following messages in the Sagemaker Endpoint logs:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
INFO:__main__:tensorflow serving model config: 
model_config_list: {
  config: {
    name: 'saved_model'
    base_path: '/opt/ml/model/tensorflow/saved_model'
    model_platform: 'tensorflow'
    model_version_policy: {
      specific: {
        versions: 
      }
    }
  }
}

INFO:__main__:tensorflow version info:
TensorFlow ModelServer: 2.3.0-rc0+dev.sha.no_git
TensorFlow Library: 2.3.0
INFO:__main__:tensorflow serving command: tensorflow_model_server --port=20000 --rest_api_port=20001 --model_config_file=/sagemaker/model-config.cfg --max_num_load_retries=0    
INFO:__main__:started tensorflow serving (pid: 16)
INFO:tfs_utils:Trying to connect with model server: http://localhost:20001/v1/models/saved_model
WARNING:urllib3.connectionpool:Retrying (Retry(total=8, connect=None, read=None, redirect=None, status=None)) after connection broken by 'NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7fc8dfc4d6d0>: Failed to establish a new connection: [Errno 111] Connection refused')': /v1/models/saved_model
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:324] Error parsing text-format tensorflow.serving.ModelServerConfig: 9:7: Expected integer, got: }
Failed to start server. Error: Invalid argument: Invalid protobuf file: '/sagemaker/model-config.cfg'

I don’t know whether an easy fix exists. The simplest solution is to repackage the model tar.gz file and change the version id to ‘1’.


Remember to share on social media!
If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.

If you want to contact me, send me a message on LinkedIn or Twitter.

Would you like to have a call and talk? Please schedule a meeting using this link.


Bartosz Mikulski
Bartosz Mikulski * MLOps Engineer / data engineer * conference speaker * co-founder of Software Craft Poznan & Poznan Scala User Group

Subscribe to the newsletter and get access to my free email course on building trustworthy data pipelines.