Last year I faced the quest of building WhippedUp, a platform aimed at helping people eat better and more sustainably. To make the platform run, I heavily leverage machine learning techniques, some of which process the system’s data and some help the users consume it.
Since I’m still a tiny player in the space, I can’t afford the production infrastructure machine learning requires to run fast. However, as a rock climber, I learned that clever moves often make up for strength. Today I will share some of these moves that allowed me to operate production workloads reasonably well while reducing the infrastructure footprint and associated costs.
I solved the problem iteratively through a standard performance-tuning methodology by doing the following things:
- Adopting Multi-Task ML
- Knowledge Distillation
- Ditching Python over NodeJS as my deployment framework.
With these steps, I achieved up to 32x Improvement in memory consumption and 15x improvement in inference and cold start times.
The Classification Problem
While developing the WhippedUp platform, much research went into classifying unstructured information to streamline decision-making. Concretely, I have to classify ingredients in three ways for my flavor pairings feature.
- Ingredient / Not Ingredient: I want to avoid as much as possible people testing funny combinations and getting recipes for a poop cake or any of those creative things hackers test.
- Source: I want to know for every ingredient if it comes likely from animal or plant sources and if it’s an animal product or a dead animal.
- Broader Categories: Each ingredient will belong to one (or more) of a given set of categories; this helps me identify how to suggest ingredients and make the search smarter.
After much experimentation, I landed on a domain-adapted uncased BERT model that I would fine-tune for each specific task. To domain-adapt it, I built a proprietary dataset with cherry-picked information from the internet, explicitly designed to enhance the clustering of ingredients.
Initially, I wanted to leverage the Hugging Face pipelines that are so convenient to use on my production code, so I fine-tuned my domain-adapted BERT on each of the previously mentioned tasks. This process left me with three models to run in my serverless infrastructure.
I built a Python APP and a REST API with the previous models. While this APP ran reasonably quickly on my laptop, it became clear that I had to optimize it when I deployed it into the serverless infrastructure. It not only took over 16GB of Memory, but also it took over 30 seconds to go through a cold start.
In the Machine Learning Literature, there is this concept of Multi-Task machine learning where you can keep the N first layers’ weights unchanged and fine-tune one head for each classification task. I wanted to implement this without losing the flexibility that pipelines gave me. Also, I didn’t want to run the classification three times because that would take additional time and CPU cycles.
To solve the problem, I transformed the three classifiers from sequence classifiers to one that’s a token classifier with three unique control tokens (>+|), mapping to each of my classification tasks. I know what you’re thinking. I selected these tokens to look like an emoji as an inside joke.
The intuition is the following: If the model is smart enough to understand contextual information, it will assign different labels to these same tokens depending on the words that precede them.
Since BERT is quite strong at understanding contextual information, this solution solved my classification problem. Moreover, it increased the model’s inference performance as a byproduct by helping it find relationships between the categories.
Now my python backend’s cold starts took only ten seconds and one-third of the memory.
The Performance Problem
Coming from a Java backend, where I’m used to deploying web apps in, let’s say, J2EE containers, I’m accustomed to paying cold-start premiums because these happen only once. But in serverless, cold starts are the norm, so slow cold starts are not an option; this meant that the progress I made by cutting the startup time to a third was good but not enough.
As someone who’s dealt with performance for years and years, I’m familiar with the approach of identifying bottlenecks. To this end, I created a simple python framework that captures run times and logs them. I’m sure I reinvented the wheel, but it did things in the way I liked it, and it’s open source! See https://github.com/juancavallotti/py-profile.
After plugging my simple tool all over my app and deploying it, I identified two areas of concern:
- Application Startup
- Inference Time
I first noticed on my metrics that each call to classify took about 500ms on the serverless infrastructure. While this latency was ok from the user experience perspective, it added up when considering the cold starts.
To address inference time, I had to look into making those expensive computations cheaper. So I took the classical approach:
- Knowledge Distillation: Switched from a regular BERT model to its distilled counterpart, with fewer parameters and barely noticeable inference performance compromise.
- Quantization: I quantized the weights, which in Layman’s terms means decreasing the precision of floating points, so computations take less space and CPU cycles.
These improvements reduced the inference time to under 100ms on the serverless infrastructure!
One of the things I observed while deploying on serverless infrastructure is that IO is extremely slow, which means that reading a file that’s 512 MB or even 1GB takes a very long. In this case, there are two sets of files that my application was reading: the machine learning model and python and its libraries.
When I implemented the improvements for inference time, I knew I would reduce the mode’s file size. These improvements reduced the model’s weight file to under 100 MB, and the cold starts took under four seconds.
In a corporate environment, this change would have been hard to back up because nobody likes drifting out of the beaten path, but since the project was mine, I did it, and the results were spectacular! In addition, NodeJS has less memory footprint, so I was again on the free-tier land, and the performance was unbeatable.
Machine Learning in The Land of NodeJS
The last two challenges came from moving my codebase to Node. Specifically, ONNX Runtime for JS didn’t come bundled with the two components I needed: A BERT WordPiece tokenizer and access to a SoftMax operation to post-process my results.
For the tokenizer piece, I found an open-source implementation by NLPJS. Unfortunately, it wasn’t tokenizing according to the “word piece” specification for subword tokenization. But I submitted a bug report and a reference implementation that the team incorporated into the library!
The last piece was writing a vectorized Softmax implementation in JS, it took me some time to get there, but it was worth it. The main challenge I found was that ONNX JS’s abstractions didn’t seem to be for direct user consumption. Solving this piece added some unwanted fragility to my codebase, but I believe it’s not a considerable risk.
I will now try to summarize my results on each stage and close up this post with my conclusions.
|Stage / Optimization||Memory Setting||Startup Time||Inference Time|
Distillation / Quantization
|Total Gains||32x||15x||> 15x|
Additionally, I will highlight that I could move from premium tiers to free tiers of serverless containers.
In summary, this exercise showed me it’s possible to run MachineLearning competitively in production using inexpensive infrastructure. Still, I had to ask the hard questions and be willing to re-think my architecture and challenge the status quo on golden rules for specific technologies.
Finally, I will share that I tried other alternatives, such as the Hugging Face accelerated inference API, which partly solved the memory problem but worsened the startup time. Please contact me if you have questions or are struggling with similar challenges.
Leave a Reply