Drift Detection in Robust Machine Learning Systems



was co-authored by Sebastian Humberg and Morris Stallmann.


Introduction     

Machine learning (ML) models are designed to make accurate predictions based on patterns in historical data. But what if these patterns change overnight? For instance, in credit card fraud detection, today’s legitimate transaction patterns might look suspicious tomorrow as criminals evolve their tactics and honest customers change their habits. Or picture an e-commerce recommender system: what worked for summer shoppers may suddenly flop as winter holidays sweep in new trends. This subtle, yet relentless, shifting of data, known as drift, can quietly erode your model’s performance, turning yesterday’s accurate predictions into today’s costly mistakes.

In this article, we’ll lay the foundation for understanding drift: what it is, why it matters, and how it can sneak up on even the best machine learning systems. We’ll break down the two main types of drift: data drift and concept drift. Then, we move from theory to practice by outlining robust frameworks and statistical tools for detecting drift before it derails your models. Finally, you’ll get a glance into what to do against drift, so your machine learning systems remain resilient in a constantly evolving world.

What is drift? 

Drift refers to unexpected changes in the data distribution over time, which can negatively impact the performance of predictive models. ML models solve prediction tasks by applying patterns that the model learned from historical data. More formally, in supervised ML, the model learns a joint distribution of some set of feature vectors X and target values y from all data available at time t0:

\[P_{t_{0}}(X, y) = P_{t_{0}}(X) \times P_{t_{0}}(y|X)\]

After training and deployment, the model will be applied to new data Xt to predict yt under the assumption that the new data follows the same joint distribution. However, if that assumption is violated, then the model’s predictions may no longer be reliable, as the patterns in the training data may have become irrelevant. The violation of that assumption, namely the change of the joint distribution, is called drift. Formally, we say drift has occurred if:

\[P_{t_0} (X,y) \ne P_{t}(X,y).\]

for some t>t0.

The Main Types of Drift: Data Drift and Concept Drift

Generally, drift occurs when the joint probability P(X, y) changes over time. But if we look more closely, we notice there are different sources of drift with different implications for the ML system. In this section, we introduce the notions of data drift and concept drift.

Recall that the joint probability can be decomposed as follows: 

\[P(X,y) = P(X) \times P(y|X).\]

Depending on which part of the joint distribution changes, we either talk about data drift or concept drift.

Data Drift

If the distribution of the features changes, then we speak of data drift:

\[ P_{t_0}(X) \ne P_{t}(X), t_0 > t. \]

Note that data drift does not necessarily mean that the relationship between the target values y and the features X has changed. Hence, it is possible that the machine learning model still performs reliably even after the occurrence of data drift.

Generally, however, data drift often coincides with concept drift and can be a good early indicator of model performance degradation. Especially in scenarios where ground truth labels are not (immediately) available, detecting data drift can be an important component of a drift warning system. For example, think of the COVID-19 pandemic, where the input data distribution of patients, such as symptoms, changed for models trying to predict clinical outcomes. This change in clinical outcomes was a drift in concept and would only be observable after a while. To avoid incorrect treatment based on outdated model predictions, it is important to detect and signal data drift that can be observed immediately.

Moreover, drift can also occur in unsupervised ML systems where target values y are not of interest at all. In such unsupervised systems, only data drift is defined.

Data drift is a shift in the distribution (figure created by the authors and inspired by Evidently AI).

Concept Drift

Concept drift is the change in the relationship between target values and features over time:

\[P_{t_0}(y|X) \ne P_{t}(y|X), t_0 > t.\]

Usually, performance is negatively impacted if concept drift occurs.

In practice, the ground truth label y often only becomes available with a delay (or not at all). Hence, also observing Pt(y|X) may only be possible with a delay. Therefore, in many scenarios, detecting concept drift in a timely and reliable manner can be much more involved or even impossible. In such cases, we may need to rely on data drift as an indicator of concept drift.

How Drift Can Evolve Over Time

Drift evolution patterns over time (Figure from Towards Unsupervised Sudden Data Drift Detection in Federated Learning with Fuzzy Clustering).

Concept and data drift can take different forms, and these forms may have varying implications for drift detection and drift handling strategies.

Drift may occur suddenly with abrupt distribution changes. For example, purchasing behavior may change overnight with the introduction of a new product or promotion.

In other cases, drift may occur more gradually or incrementally over a longer period of time. For instance, if a digital platform introduces a new feature, this may affect user behavior on that platform. While in the beginning, only a few users adopted the new feature, more and more users may adopt it in the long run. Lastly, drift may be recurring and driven by seasonality. Imagine a clothing company. While in the summer the company’s top-selling products may be T-shirts and shorts, those are unlikely to sell equally well in winter, when customers may be more interested in coats and other warmer clothing items. 

How to Identify Drift

A mental framework for identifying drift (figure created by the authors).

Before drift can be handled, it must be detected. To discuss drift detection effectively, we introduce a mental framework borrowed from the excellent read “Learning under Concept Drift: A review” (see reference list). A drift detection framework can be described in three stages:

  1. Data Collection and Modelling: The data retrieval logic specifies the data and time periods to be compared. Moreover, the data is prepared for the next steps by applying a data model. This model could be a machine learning model, histograms, or even no model at all. We will see examples in subsequent sections.
  2. Test Statistic Calculation: The test statistic defines how we measure (dis)similarity between historical and new data. For example, by comparing model performance on historical and new data, or by measuring how different the data chunks’ histograms are.
  3. Hypothesis Testing: Finally, we apply a hypothesis test to decide whether we want the system to signal drift. We formulate a null hypothesis and a decision criterion (such as defining a p-value).

Data Collection and Modelling

In this stage, we define exactly which chunks of data will be compared in subsequent steps. First, the time windows of our reference and comparison (i.e., new) data need to be defined. The reference data could strictly be the historical training data (see figure below), or change over time as defined by a sliding window. Similarly, the comparison data can strictly be the newest batches of data, or it can extend the historical data over time, where both time windows can be sliding.

Once the data is available, it needs to be prepared for the test statistic calculation. Depending on the statistic, it might need to be fed through a machine learning model (e.g., when calculating performance metrics), transformed into histograms, or not be processed at all.

              Data collection techniques (figure from “Learning under Concept Drift: A Review”).

Drift Detection Methods

One can identify drift by applying certain detection methods. These methods monitor the performance of a model (concept drift detection) or directly analyse incoming data (data drift detection). By applying various statistical tests or monitoring metrics, drift detection methods help to keep your model reliable. Either through simple threshold-based approaches or advanced techniques, these methods guarantee the robustness and adaptivity of your machine learning system.

Observing Concept Drift Through Performance Metrics

Observable ML model performance degradation as a consequence of drift (figure created by the authors).

The most direct way to spot concept drift (or its consequences) is by tracking the model’s performance over time. Given two time windows [t0, t1] and [t2, t3], we calculate the performance p[t0, t1] and p[t2, t3]. Then, the test statistic can be defined as the difference (or dissimilarity) of performance: 

\[dis = |p_{[t_0, t_1]} – p_{[t_2, t_3]}|.\]

Performance can be any metric of interest, such as accuracy, precision, recall, F1-score (in classification tasks), or mean squared error, mean absolute percentage error, R-squared, etc. (in regression problems).

Calculating performance metrics often requires ground truth labels that may only become available with a delay, or may never become available.

To detect drift in a timely manner even in such cases, proxy performance metrics can sometimes be derived. For example, in a spam detection system, we might never know whether an email was actually spam or not, so we cannot calculate the accuracy of the model on live data. However, we might be able to observe a proxy metric: the percentage of emails that were moved to the spam folder. If the rate changes significantly over time, this might indicate concept drift.

If such proxy metrics are not available either, we can base the detection framework on data distribution-based metrics, which we introduce in the next section.

Data Distribution-Based Methods

Methods in this category quantify how dissimilar the data distributions of reference data X[t0,t1] and new data X[t2,t3] are without requiring ground truth labels. 

How can the dissimilarity between two distributions be quantified? In the next subsections, we will introduce some popular univariate and multivariate metrics.

Univariate Metrics

Let’s start with a very simple univariate approach: 

First, calculate the means of the i-th feature in the reference and new data. Then, define the differences of means as the dissimilarity measure

\[dis_i = |mean_{i}^{[t_0,t_1]} – mean_{i}^{[t_2,t_3]}|. \]

Finally, signal drift if disi is unexpectedly big. We signal drift whenever we observe an unexpected change in a feature’s mean over time. Other similar simple statistics include the minimum, maximum, quantiles, and the ratio of null values in a column. These are simple to calculate and are an excellent starting point for building drift detection systems.

However, these approaches can be overly simplistic. For example, calculating the mean misses changes in the tails of the distribution, as would other simple statistics. This is why we need slightly more involved data drift detection methods.

Kolmogorov-Smirnov (K-S) Test
       Kolmogorov-Smirnov (K-S) test statistic (figure from WIkipedia).

Another popular univariate method is the Kolmogorov-Smirnov (K-S) test. The KS test examines the entire distribution of a single feature and calculates the cumulative distribution function (CDF) of X(i)[t0,t1] and X(i)[t2,t3]. Then, the test statistic is calculated as the maximum difference between the two distributions:

\[ dis_i = \sup |CDF(X(i)_{[t_0,t_1]})-CDF(X(i)_{[t_2,t_3]})|, \]

and can detect differences in the mean and the tails of the distribution. 

The null hypothesis is that all samples are drawn from the same distribution. Hence, if the p-value is less than a predefined value of 𝞪 (e.g., 0.05), then we reject the null hypothesis and conclude drift. To determine the critical value for a given 𝞪, we need to consult a two-sample KS table. Or, if the sample sizes n (number of reference samples) and m (number of new samples) are large, the critical value cv𝞪 is calculated according to

\[cv_{\alpha}= c(\alpha)\sqrt{ \frac{n+m}{n*m} }, \]

where c(𝞪) can be found here on Wikipedia for common values.

The K-S test is widely used in drift detection and is relatively robust against extreme values. Nevertheless, be aware that even small numbers of extreme outliers can disproportionately affect the dissimilarity measure and lead to false positive alarms.

Population Stability Index
Bin distribution for Popularity Stability Index test statistic calculation (figure created by the authors).

An even less sensitive alternative (or complement) is the population stability index (PSI). Instead of using cumulative distribution functions, the PSI involves dividing the range of observations into bins b and calculating frequencies for each bin, effectively generating histograms of the reference and new data. We compare the histograms, and if they appear to have changed unexpectedly, the system signals drift. Formally, the dissimilarity is calculated according to:

\[dis = \sum_{b\in B} (ratio(b^{new}) – ratio(b^{ref}))\ln(\frac{ratio(b^{new})}{ratio(b^{ref})}) = \sum_{b\in B} PSI_{b}, \]

where ratio(bnew) is the ratio of data points falling into bin b in the new dataset, and ratio(bref) is the ratio of data points falling into bin b in the reference dataset, B is the set of all bins. The smaller the difference between ratio(bnew) and ratio(bref), the smaller the PSI. Hence, if a big PSI is observed, then a drift detection system would signal drift. In practice, often a threshold of 0.2 or 0.25 is applied as a rule of thumb. That is, if the PSI > 0.25, the system signals drift.

Chi-Squared Test

Lastly, we introduce a univariate drift detection method that can be applied to categorical features. All previous methods only work with numerical features.

So, let x be a categorical feature with n categories. Calculating the chi-squared test statistic is somewhat similar to calculating the PSI from the previous section. Rather than calculating the histogram of a continuous feature, we now consider the (relative) counts per category i. With these counts, we define the dissimilarity as the (normalized) sum of squared frequency differences in the reference and new data:

\[dis = \sum_{i=1}^{n} \frac{(count_{i}^{new}-count_{i}^{ref})^{2}}{count_{i}^{ref}}\].

Note that in practice you may need to resort to relative counts if the cardinalities of new and reference data are different.

To decide whether an observed dissimilarity is significant (with some pre-defined p value), a table of chi-squared values with one degree of freedom is consulted, e.g., Wikipedia.

Multivariate Tests

In many cases, each feature’s distribution individually may not be affected by drift according to the univariate tests in the previous section, but the overall distribution X may still be affected. For example, the correlation between x1 and x2 may change while the histograms of both (and, hence, the univariate PSI) appear to be stable. Clearly, such changes in feature interactions can severely impact machine learning model performance and must be detected. Therefore, we introduce a multivariate test that can complement the univariate tests of the previous sections.

Reconstruction-Error Based Test
A schematic overview of autoencoder architectures (figure from Wikipedia)

This approach is based on self-supervised autoencoders that can be trained without labels. Such models consist of an encoder and a decoder part, where the encoder maps the data to a, typically low-dimensional, latent space and the decoder learns to reconstruct the original data from the latent space representation. The learning objective is to minimize the reconstruction error, i.e., the difference between the original and reconstructed data.

How can such autoencoders be used for drift detection? First, we train the autoencoder on the reference dataset, and store the mean reconstruction error. Then, using the same model, we calculate the reconstruction error on new data and use the difference as the dissimilarity metric:

\[ dis = |error_{[t_0, t_1]} – error_{[t_2, t_3]}|. \]

Intuitively, if the new and reference data are similar, the original model should not have problems reconstructing the data. Hence, if the dissimilarity is greater than a predefined threshold, the system signals drift. 

This approach can spot more subtle multivariate drift. Note that principal component analysis can be interpreted as a special case of autoencoders. NannyML demonstrates how PCA reconstructions can identify changes in feature correlations that univariate methods miss.

Summary of Popular Drift Detection Methods

To conclude this section, we would like to summarize the drift detection methods in the following table:

Name Applied to Test statistic Drift if Notes
Statistical and threshold-based tests Univariate, numerical data Differences in simple statistics like mean, quantiles, counts, etc. The difference is greater than a predefined threshold May miss differences in tails of distributions, setting the threshold requires domain knowledge or gut feeling
Kolmogorov-Smirnov (K-S) Univariate, numerical data Maximum difference in the cumulative distribution function of reference and new data. p-value is small (e.g., p < 0.05) Can be sensitive to outliers
Population Stability Index (PSI) Univariate, numerical data Differences in the histogram of reference and new data. PSI is greater than the predefined threshold (e.g., PSI > 0.25) Choosing a threshold is often based on gut feeling
Chi-Squared Test Univariate, categorical data Differences in counts of observations per category in reference and new data. p-value is small (e.g., p < 0.05)
Reconstruction-Error Test Multivariate, numerical data Difference in mean reconstruction error in reference and new data The difference is greater than the predefined threshold Defining a threshold can be hard; the method may be relatively complex to implement and maintain.

What to Do Against Drift

Even though the focus of this article is the detection of drift, we would also like to give an idea of what can be done against drift.

As a general rule, it is important to automate drift detection and mitigation as much as possible and to define clear responsibilities ensure ML systems remain relevant.

First Line of Defense: Robust Modeling Techniques

The first line of defense is applied even before the model is deployed. Training data and model engineering decisions directly impact sensitivity to drift, and model developers should focus on robust modeling techniques or robust machine learning. For example, a machine learning model relying on many features may be more susceptible to the consequences of drift. Naturally, more features mean a larger “attack surface”, and some features may be more sensitive to drift than others (e.g., sensor measurements are subject to noise, whereas sociodemographic data may be more stable). Investing in robust feature selection is likely to pay off in the long run.

Furthermore, including noisy or malicious data in the training dataset may make models more robust against smaller distributional changes. The field of adversarial machine learning is concerned with teaching ML models how to deal with adversarial inputs.

Second Line of Defense: Define a Fallback Strategy

Even the most carefully engineered model will likely experience drift at some point. When this happens, make sure to have a backup plan ready. To prepare such a plan, first, the consequences of failure must be understood. Recommending the wrong pair of shoes in an email newsletter has very different implications from misclassifying objects in autonomous driving systems. In the first case, it may be acceptable to wait for human feedback before sending the email if drift is detected. In the latter case, a much more immediate reaction is required. For example, a rule-based system or any other system not affected by drift may take over. 

Striking Back: Model Updates

After addressing the immediate effects of drift, you can work to restore the model’s performance. The most obvious activity is retraining the model or updating model weights with the newest data. One of the challenges of retraining is defining a new training dataset. Should it include all available data? In the case of concept drift, this may harm convergence since the dataset may contain inconsistent training samples. If the dataset is too small, this may lead to catastrophic forgetting of previously learned patterns since the model may not be exposed to enough training samples.

To prevent catastrophic forgetting, methods from continual and active learning can be applied, e.g., by introducing memory systems.

It is important to weigh different options, be aware of the trade-offs, and make a decision based on the impact on the use case.

Conclusion

In this article, we describe why drift detection is important if you care about the long-term success and robustness of machine learning systems. If drift occurs and is not taken care of, then machine learning models’ performance will degrade, potentially harming revenue, eroding trust and reputation, or even having legal consequences.

We formally introduce concept and data drift as unexpected differences between training and inference data. Such unexpected changes can be detected by applying univariate tests like the Kolmogorov-Smirnov test, Population Stability Index tests, and the Chi-Square test, or multivariate tests like reconstruction-error-based tests. Lastly, we briefly touch upon a few strategies about how to deal with drift. 

In the future, we plan to follow up with a hands-on guide building on the concepts introduced in this article. Finally, one last note: While the article introduces several increasingly more complex methods and concepts, bear in mind that any drift detection is always better than no drift detection. Depending on the use case, a very simple detection system can prove itself to be very effective.

  • https://en.wikipedia.org/wiki/Catastrophic_interference
  • J. Lu, A. Liu, F. Dong, F. Gu, J. Gama and G. Zhang, “Learning under Concept Drift: A Review,” in IEEE Transactions on Knowledge and Data Engineering, vol. 31, no. 12, pp. 2346-2363, 1 Dec. 2019
  • M. Stallmann, A. Wilbik and G. Weiss, “Towards Unsupervised Sudden Data Drift Detection in Federated Learning with Fuzzy Clustering,” 2024 IEEE International Conference on Fuzzy Systems (FUZZ-IEEE), Yokohama, Japan, 2024, pp. 1-8, doi: 10.1109/FUZZ-IEEE60900.2024.10611883
  • https://www.evidentlyai.com/ml-in-production/concept-drift
  • https://www.evidentlyai.com/ml-in-production/data-drift
  • https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test
  • https://stats.stackexchange.com/questions/471732/intuitive-explanation-of-kolmogorov-smirnov-test
  • Yurdakul, Bilal, “Statistical Properties of Population Stability Index” (2018). Dissertations. 3208. https://scholarworks.wmich.edu/dissertations/3208
  • https://en.wikipedia.org/wiki/Chi-squared_test
  • https://www.nannyml.com/blog/hypothesis-testing-for-ml-performance#chi-2-test
  • https://nannyml.readthedocs.io/en/main/how_it_works/multivariate_drift.html#how-multiv-drift
  • https://en.wikipedia.org/wiki/Autoencoder



Source link

The post Drift Detection in Robust Machine Learning Systems first appeared on TechToday.

This post originally appeared on TechToday.

Leave a Reply

Your email address will not be published. Required fields are marked *