A new technique revolutionizes how machine-learning models address bias by targeting specific harmful data points, ensuring better outcomes for minority groups without compromising overall model performance.
Research: Data Debiasing with Datamodels (D3M): Improving Subgroup Robustness via Data Selection. Image Credit: Shutterstock AI
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.
Machine-learning models can fail when they try to make predictions for underrepresented individuals in the datasets they were trained on.
For instance, a model that predicts the best treatment option for someone with a chronic disease may be trained using a dataset mainly consisting of male patients. When deployed in a hospital, that model might make incorrect predictions for female patients.
To improve outcomes, engineers can try balancing the training dataset by removing data points until all subgroups are represented equally. While dataset balancing is promising, removing a large amount of data often hurts the model's overall performance.
MIT researchers have introduced a new method called 'Data Debiasing with Datamodels' (D3M), which uses a technique called TRAK to pinpoint and remove specific training examples responsible for model failures on minority subgroups. By removing far fewer data points than other approaches, this technique maintains the model's overall accuracy while improving its performance regarding underrepresented groups.
In addition, the technique can identify hidden sources of bias in a training dataset that lacks labels. This innovation, AUTO-D3M, enables the detection of biases without requiring labeled subgroup information, making it especially useful for real-world scenarios where annotations are unavailable. For many applications, unlabeled data are far more prevalent than labeled data.
This method could also be combined with other approaches to improve the fairness of machine-learning models deployed in high-stakes situations. For example, it might someday help ensure underrepresented patients aren't misdiagnosed due to a biased AI model.
"Many other algorithms that try to address this issue assume each datapoint matters as much as every other datapoint. In this paper, we are showing that assumption is not true. There are specific points in our dataset that are contributing to this bias, and we can find those data points, remove them, and get better performance," says Kimia Hamidieh, an electrical engineering and computer science (EECS) graduate student at MIT and co-lead author of a paper on this technique.
She wrote the paper with co-lead authors Saachi Jain, PhD '24 and fellow EECS graduate student Kristian Georgiev; Andrew Ilyas MEng '18, PhD '23, a Stein Fellow at Stanford University; and senior authors Marzyeh Ghassemi, an associate professor in EECS and a member of the Institute of Medical Engineering Sciences and the Laboratory for Information and Decision Systems, and Aleksander Madry, the Cadence Design Systems Professor at MIT. The research will be presented at the Conference on Neural Information Processing Systems.
Enhancing fairness through targeted data removal
Often, machine-learning models are trained using huge datasets gathered from many sources across the internet. These datasets are far too large to be carefully curated by hand, so they may contain bad examples that hurt model performance.
Scientists also know that some data points impact a model's performance on certain downstream tasks more than others.
The MIT researchers combined these two ideas into an approach that identifies and removes these problematic data points. They seek to solve a problem known as worst-group error, which occurs when a model underperforms on minority subgroups in a training dataset.
The researchers' new technique is based on prior work, in which they introduced a method called TRAK that identifies the most critical training examples for a specific model output.
For this new technique, they take incorrect predictions the model made about minority subgroups and use TRAK to identify which training examples contributed the most to that incorrect prediction.
"By aggregating this information across bad test predictions in the right way, we are able to find the specific parts of the training that are driving worst-group accuracy down overall," Ilyas explains.
Then, they removed those specific samples and retrained the model on the remaining data.
Since having more data usually yields better overall performance, removing just the samples that drive worst-group failures maintains the model's overall accuracy while boosting its performance on minority subgroups.
Achieving superior results with fewer deletions
Their method outperformed multiple techniques across three machine-learning datasets. For instance, on the CelebA-Age dataset, the D3M method improved worst-group accuracy significantly while requiring the removal of 20,000 fewer examples compared to conventional balancing methods. Their technique also achieved higher accuracy than methods that require changing a model's inner workings.
Because the MIT method involves changing a dataset instead, it would be easier for a practitioner to use and can be applied to many types of models.
It can also be utilized when bias is unknown because subgroups in a training dataset are not labeled. By leveraging AUTO-D3M, practitioners can identify unlabeled biases directly, such as those arising from spurious correlations or underrepresented subgroups.
"This is a tool anyone can use when they are training a machine-learning model. They can look at those datapoints and see whether they are aligned with the capability they are trying to teach the model," says Hamidieh.
Using the technique to detect unknown subgroup bias would require intuition about which groups to look for, so the researchers hope to validate it and explore it more fully through future human studies.
They also want to improve the performance and reliability of their technique and ensure that it is accessible and easy to use for practitioners who may deploy it in real-world environments.
"When you have tools that let you critically look at the data and figure out which datapoints are going to lead to bias or other undesirable behavior, it gives you a first step toward building models that are going to be more fair and more reliable," Ilyas says.
This work is funded, in part, by the National Science Foundation and the U.S. Defense Advanced Research Projects Agency.
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.
Source:
- Massachusetts Institute of Technology
Journal reference:
- Preliminary scientific report.
Jain, S., Hamidieh, K., Georgiev, K., Ilyas, A., Ghassemi, M., & Madry, A. (2024). Data Debiasing with Datamodels (D3M): Improving Subgroup Robustness via Data Selection. ArXiv. https://arxiv.org/abs/2406.16846