Let’s imagine the following use case: I’ve built a model that tries to predict how much a person will spend in an online webshop, using the age and the gender of a person. The training data set mostly consists of females, with the average age of all customers being 35. The model learns that mostly older people and women buy something more often. After monitoring my model’s predictive performance, I notice that it’s starting to decrease… But why?
After doing some research I discovered that the webshop now has younger customers: the average age has dropped from 35 years to 25 years. Furthermore, the live data shows that currently not the women, but the men spend more believe it or not!
So, the original relationship between the features that the model has learned from the original training data set, isn’t valid anymore on the real-time incoming data. Which causes the model to lose its predictive power. This loss in predictive power is called model drift, and in this article, I will try to explain to you what model drift is, what types of model drift there are, and some techniques that can help to prevent model drift.
When a model is put into production, it moves from a static environment where the relationship between the input & output doesn’t change to a live environment where the relationship between the input & output does change. The changing relationship causes the model to start losing predictive power. This loss in predictive power is called model drift.
Model drift in machine learning
There are many types of model drift, but they can all be categorised into two broad categories: concept drift and data drift. In short data drift occurs when the properties of the independent variables change. Where concept drift occurs when the properties of the dependent variables change, these two concepts will be explained in more detail later in the article.
In general, there are two reasons why drift can occur in machine learning models.
- The first one is when an external event mixes up the data: customers’ preferences suddenly change due to the pandemic, or a competitor launches a new product or service. If this has happened, an update or retraining is necessary for your model to keep producing accurate predictions.
- The other reason is when there is something wrong with the data integrity. For example, the model swaps two values like the height of a person and their age, or a website suddenly accepts blank fields due to a bug.
What is data drift?
So, I mentioned previously that data drift occurs when the properties of the independent variables change, but what does this mean? It basically means that the distributions of the input data have changed. An example can be shown in the figure below, which represents the change in the distribution of (example) Feature 1
Image by UbiOps
Let’s look back at the use case of the webshop I described at the beginning. As I said before, the training set consisted mostly of women, and all the users had an average age of 35 years. The model learned from the training set that it was mostly women and older people who spend more.
After some time, the webshop is more often visited by men than by women. Since the model had limited examples to learn from in the training set, it is harder for the model to try and predict what men will spend on average. In our case, the behaviour of the newly attracted men is different from that of the subpopulation that was in the original training set which causes the performance of the model to drop. If the behaviour of the newly attracted population would be the same as the behaviour of the smaller subpopulation, the model would still perform well.
Image by UbiOps
What is concept drift?
As explained before, concept drift occurs when the properties of the dependent variables change, i.e., the relationship between the input and output starts to change, the distribution of the input might even stay the same. The “concept” in concept drift is defined as the hidden and unknown relationship (the decision boundary) between the input and output variables. This is marked with a grey line in the leftmost graph below.
Mathematically concept drift can be explained as the change in distribution p(y|X), where X are the available features and y is the real label. Concept drift is usually the result of an external event, like COVID. The figure below shows a graph of sales predictions for a company that sells loungewear (source: EvidentlyAI). Notice that the model produced pretty accurate predictions before the lockdown was announced. Obviously, people spent a lot more time at home when the lockdown was announced. Which resulted in people ordering two sets (or more) of loungewear, where they would previously have only bought one set. i.e.
How fast does concept drift occur?
The rate at which concept drift can occur can differ, this mostly depends on the use case. Machine learning models that are used for quality checks in manufacturing processes are more stable than models that try to predict consumer behaviour. In general, there are four different rates at which drift can occur:
- Sudden drift: Here the drift occurs in a short amount of time. This is typically caused by an unforeseen event. Think about the change in consumer behaviour because of COVID. Shops that used a model to predict the availability of products in a given store saw the accuracy of a model drop from 90% to 61% after just a couple of months.
- Gradual drift: Here the change happens over time and for most use cases is quite natural to happen. An example of this would be a model that predicts the prices of houses. Let’s say the model was put into production in 2015 and had an accuracy in the high 80s at that time. After a couple of years, the accuracy, and thus the validity of the predictions, starts to drop due to the increase of house prices over time. This is logical since we all know that in general, the prices of houses increase. But if it is not considered, it could have a detrimental impact on the accuracy of a model.
- Recurrent: This is the final type I want to talk about in this article. Recurrent drift means that the changes re-occur after the first observed occurrence. i.e., it happens periodically. A simple example of this would be the sales of winter coats in the colder months, but also shopping spree days like Black Friday or Christmas are examples of recurrent drift.
How to detect model drift
The last thing I want to inform you about in this article is some techniques you can use to monitor model drift. The first one is obvious: by monitoring the accuracy of the model, by comparing the predicted values with the true values. If the predicted values start to deviate farther and farther from the true values, drift has occurred.
One little side note about the technique I just mentioned is that the “accuracy” of a model can be expressed in different Key Performance Indicators (KPI’s), it depends on the use case which one is best. There are four KPI’s that can be used to monitor a models performance:
- The first one is called accuracy (little bit confusing I know), and is defined as the number of predictions a model correctly predicts, divided by the total number of predictions made:
- The second KPI is called precision, which shows how many of the selected items are relevant i.e., what proportion of positive identifications is correct? This metric is useful when the cost of a false positive is high: Let’s say SpaceX is launching another rocket, and they use a machine learning model to predict a day with good weather conditions to launch the rocket. If the predicted day turns out to be a false positive, the outcome could be catastrophic. The precision of a model can be calculated by dividing the true positives, with the true positives added with the false positives:
- The third KPI is called recall, sometimes referred to as precision. This KPI measures the ability of a model to find all actual positives within a data set. Recall is used in use cases where there is a high cost related with a false negative, like COVID tests for example. It is better to keep an uninfected person indoors than to let an infected person go on with their day-to-day business. The recall can be calculated by dividing the true positives with the true positives added with the false negatives:
- The final KPI I want to mention is the F-1 score. The F-1 score is the harmonic mean of the precision and recall of a model. This metric is used in cases where a high True Positive score is required, like information retrieval tasks. The F-1 score of a model can be calculated as follows:
There are two general ways of detecting model drift: a statistical approach and model-based approach. The statistical approach basically tries to calculate the difference between two populations. This can be done using multiple techniques like the Population Stability Index, the (Kullback-Leiber) KL divergence, the Jenson-Shannon (JS) divergence etc. explaining all these methods goes out of the scope for this article, but if you want to read more about it you can visit this site.
The model-based approach uses a model to determine the similarity between data point(s) to a reference baseline. This technique usually gives a more accurate result, but it is not a very intuitive approach and can therefore be difficult to explain how it works to others.
Best practices for dealing with model Drift
Detecting model drift is great, but what can be done after to let your model regain its predictive power? Just like with detecting model drift, there are two general methods. The first one is to prevent model drift from ever occurring. This can be done with online learning. With online learning, your model is constantly updated by a real-time data stream. While this technique is very effective, it is also more complex to implement than the second way described below.
The second way model drift can be addressed is by periodically retraining your model. This can happen either at a fixed interval or when the model’s performance drops under a certain threshold. You can also choose to take a representative subsample that has the same distribution as the live data and use that to re-label the data points.
The third way is by using ensemble learning with model weighting, which uses multiple models that are ensembled. These models produce a combined prediction, which is typically computed with a weighted average of all the individual predictions for each model.
The last one I want to briefly mention because it’s used a lot in the industry is feature dropping. Where, just as with ensemble learning, multiple models are built that use one feature at a time and keep the target variable unchanged. The AUC-ROC response is monitored for every model, and if the value of one of these graphs drops below a threshold, drift might have occurred.
Set up your monitoring plan. Identify model drift early on
There are also companies like Whylabs, Neptune, or Arize where you can upload your model and monitor the drift in real-time. In the past UbiOps have released three articles that explain how you can connect your UbiOps environment with a model monitoring tool: