Running BERT Models on GCP Serverless Infrastructure

Last year I faced the quest of building WhippedUp, a platform aimed at helping people eat better and more sustainably. To make the platform run, I heavily leverage machine learning techniques, some of which process the system’s data and some help the users consume it.

Since I’m still a tiny player in the space, I can’t afford the production infrastructure machine learning requires to run fast. However, as a rock climber, I learned that clever moves often make up for strength. Today I will share some of these moves that allowed me to operate production workloads reasonably well while reducing the infrastructure footprint and associated costs.

TLDR

I solved the problem iteratively through a standard performance-tuning methodology by doing the following things:

  • Adopting Multi-Task ML
  • Knowledge Distillation
  • Quantization
  • Ditching Python over NodeJS as my deployment framework.

With these steps, I achieved up to 32x Improvement in memory consumption and 15x improvement in inference and cold start times.

The Classification Problem

While developing the WhippedUp platform, much research went into classifying unstructured information to streamline decision-making. Concretely, I have to classify ingredients in three ways for my flavor pairings feature.

  • Ingredient / Not Ingredient: I want to avoid as much as possible people testing funny combinations and getting recipes for a poop cake or any of those creative things hackers test.
  • Source: I want to know for every ingredient if it comes likely from animal or plant sources and if it’s an animal product or a dead animal.
  • Broader Categories: Each ingredient will belong to one (or more) of a given set of categories; this helps me identify how to suggest ingredients and make the search smarter.

Model Selection

After much experimentation, I landed on a domain-adapted uncased BERT model that I would fine-tune for each specific task. To domain-adapt it, I built a proprietary dataset with cherry-picked information from the internet, explicitly designed to enhance the clustering of ingredients.

Initially, I wanted to leverage the Hugging Face pipelines that are so convenient to use on my production code, so I fine-tuned my domain-adapted BERT on each of the previously mentioned tasks. This process left me with three models to run in my serverless infrastructure.

Multi-Task ML

I built a Python APP and a REST API with the previous models. While this APP ran reasonably quickly on my laptop, it became clear that I had to optimize it when I deployed it into the serverless infrastructure. It not only took over 16GB of Memory, but also it took over 30 seconds to go through a cold start.

In the Machine Learning Literature, there is this concept of Multi-Task machine learning where you can keep the N first layers’ weights unchanged and fine-tune one head for each classification task. I wanted to implement this without losing the flexibility that pipelines gave me. Also, I didn’t want to run the classification three times because that would take additional time and CPU cycles.

To solve the problem, I transformed the three classifiers from sequence classifiers to one that’s a token classifier with three unique control tokens (>+|), mapping to each of my classification tasks. I know what you’re thinking. I selected these tokens to look like an emoji as an inside joke.

The intuition is the following: If the model is smart enough to understand contextual information, it will assign different labels to these same tokens depending on the words that precede them.

First, the initial solution with one model for each classification task. Second, the alternative with one model running all the classifications at once.

Since BERT is quite strong at understanding contextual information, this solution solved my classification problem. Moreover, it increased the model’s inference performance as a byproduct by helping it find relationships between the categories.

Now my python backend’s cold starts took only ten seconds and one-third of the memory.

The Performance Problem

Coming from a Java backend, where I’m used to deploying web apps in, let’s say, J2EE containers, I’m accustomed to paying cold-start premiums because these happen only once. But in serverless, cold starts are the norm, so slow cold starts are not an option; this meant that the progress I made by cutting the startup time to a third was good but not enough.

As someone who’s dealt with performance for years and years, I’m familiar with the approach of identifying bottlenecks. To this end, I created a simple python framework that captures run times and logs them. I’m sure I reinvented the wheel, but it did things in the way I liked it, and it’s open source! See https://github.com/juancavallotti/py-profile.

After plugging my simple tool all over my app and deploying it, I identified two areas of concern:

  • Application Startup
  • Inference Time

Inference Time

I first noticed on my metrics that each call to classify took about 500ms on the serverless infrastructure. While this latency was ok from the user experience perspective, it added up when considering the cold starts.

To address inference time, I had to look into making those expensive computations cheaper. So I took the classical approach:

  • Knowledge Distillation: Switched from a regular BERT model to its distilled counterpart, with fewer parameters and barely noticeable inference performance compromise.
  • Quantization: I quantized the weights, which in Layman’s terms means decreasing the precision of floating points, so computations take less space and CPU cycles.

These improvements reduced the inference time to under 100ms on the serverless infrastructure!

Application Startup

One of the things I observed while deploying on serverless infrastructure is that IO is extremely slow, which means that reading a file that’s 512 MB or even 1GB takes a very long. In this case, there are two sets of files that my application was reading: the machine learning model and python and its libraries.

When I implemented the improvements for inference time, I knew I would reduce the mode’s file size. These improvements reduced the model’s weight file to under 100 MB, and the cold starts took under four seconds.

At this point, I could have been happy and called it done, but I noticed something else. When I quantized the model, I had to convert it to ONNX format. ONNX is a multi-platform framework designed to run everywhere, including browsers and Javascript. Another thing I noticed is that NodeJS containers started particularly fast on GCP’s infra. So I put 2 and 2 together and migrated my backend from python to Node JS. I know this sounds crazy because python is the de-facto standard for running machine learning, but oh boy, this change brought me under 2 seconds of startup time!

In a corporate environment, this change would have been hard to back up because nobody likes drifting out of the beaten path, but since the project was mine, I did it, and the results were spectacular! In addition, NodeJS has less memory footprint, so I was again on the free-tier land, and the performance was unbeatable.

Machine Learning in The Land of NodeJS

The last two challenges came from moving my codebase to Node. Specifically, ONNX Runtime for JS didn’t come bundled with the two components I needed: A BERT WordPiece tokenizer and access to a SoftMax operation to post-process my results.

For the tokenizer piece, I found an open-source implementation by NLPJS. Unfortunately, it wasn’t tokenizing according to the “word piece” specification for subword tokenization. But I submitted a bug report and a reference implementation that the team incorporated into the library!

The last piece was writing a vectorized Softmax implementation in JS, it took me some time to get there, but it was worth it. The main challenge I found was that ONNX JS’s abstractions didn’t seem to be for direct user consumption. Solving this piece added some unwanted fragility to my codebase, but I believe it’s not a considerable risk.

The Results

I will now try to summarize my results on each stage and close up this post with my conclusions.

Stage / OptimizationMemory SettingStartup TimeInference Time

Initial Stage

16GB

~30s

~1500ms

MultiTask ML

8GB

~10s

~500ms

Distillation / Quantization

1GB

~4s

< 100ms

NodeJS

512MB

~2s

< 100ms

Total Gains32x15x> 15x
These are the partial results I got on each stage of optimization I performed.

Additionally, I will highlight that I could move from premium tiers to free tiers of serverless containers.

In summary, this exercise showed me it’s possible to run MachineLearning competitively in production using inexpensive infrastructure. Still, I had to ask the hard questions and be willing to re-think my architecture and challenge the status quo on golden rules for specific technologies.

Finally, I will share that I tried other alternatives, such as the Hugging Face accelerated inference API, which partly solved the memory problem but worsened the startup time. Please contact me if you have questions or are struggling with similar challenges.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: