Transfer Learning: Making use of what is already there
As the number of AI models grows, it becomes clear that many new applications share similarities with existing ones. Transfer learning and domain adaptation harness this overlap by repurposing previously trained models. This approach accelerates training, reduces data requirements, and improves model performance, making it a valuable tool.
The traditional approach to deep learning involves collecting large amounts of labeled data and training a neural network on it. However, as the number of models grows, this approach faces increasing challenges:
- Gathering sufficient labeled data for training high-quality models can be challenging. For instance, specialized medical image data is often scarce or inaccessible due to privacy regulations. Furthermore, creating labeled datasets is expensive and time-consuming, as it typically requires manual annotation of vast amounts of data.
- Training deep learning models is computationally demanding, typically requiring powerful hardware like GPUs. Access to such resources can be costly and limited. This challenge is particularly pronounced for large language models (LLMs) like ChatGPT-4, which have trillions of parameters and require training costs exceeding hundreds of millions of dollars.
To address these challenges, we propose leveraging existing data or models as a knowledge source for our target applications. Through transfer learning, we aim to effectively adapt this source knowledge to the specific requirements of the new task. For instance, consider Figure 1. It illustrates two very different datasets and tasks. On the left, we have readily available colorful butterfly images, while on the right, we have black-and-white medical data subject to privacy restrictions and limited availability. Despite these differences, both tasks involve image recognition and share common high-level features like shape and texture.
The different types of transfer
There are various types of differences between source and target data (see [1] for more details). These disparities influence the appropriate transfer learning method and the data requirement for effective learning. We categorize these differences into three main types:
Prior shift: A label shift occurs when the distribution of class labels differs between the source and target datasets. For example, consider a task classifying patients as sick or healthy. In a source dataset collected via online surveys, only 10% of participants reported being sick. Conversely, in a target dataset collected in a hospital, 80% of the patients are sick. This imbalance in class proportions is a prior shift.
Covariat shift: A covariate shift occurs when the distribution of input features differs between the source and target domains, while the underlying task remains unchanged. For example, in facial emotion recognition, a model trained on studio-quality images might underperform when evaluated on smartphone selfies and home-made photos due to varying lighting and backgrounds. Figures 3 and 4 illustrate two scenarios of covariate shift: one where the shifted features significantly impact predictions and another where they are irrelevant.
Concept shift: Concept shift refers to changes in the expected prediction for a given input between the source and target data. This can occur due to factors like device variations or device aging. For example, a battery’s power output differs between its initial state and after 100 cycles, necessitating model adjustments based on the battery’s age.
In practical situations, the three types of shifts we described often coexist and interact in complex ways. While we’ve presented idealized scenarios of label shift, covariate shift, and concept shift, real-world data is likely to exhibit combinations of these challenges. For instance, a medical imaging dataset might simultaneously exhibit label shift due to changes in disease prevalence and covariate shift due to variations in imaging equipment.
Understanding these shifts is crucial for effective transfer learning. By identifying the dominant shift type, researchers can select appropriate adaptation techniques, also depending on the type of data availability.
The different type of transfer
Often, we possess extensive source data or well-established models. However, target applications frequently grapple with data scarcity and limited or absent annotations. These different situations present unique challenges that necessitate tailored transfer learning approaches.
- Supervised transfer: When target data is labeled, a straightforward approach is to fine-tune the source model on the target data. This is often computationally efficient, especially with limited target data, and is particularly effective for concept shift adaptation. However, how much data is enough depends on how large the shift between source and target data is. Moreover, overfitting during fine-tuning can occur when retraining on a small dataset size, leading to poor generalization and forgetting[1]. To address this, regularization techniques can be employed to prevent excessive deviation from the source model, as depicted in Figure 6.
- Semi-supervised transfer: When target data is predominantly unlabeled, with possibly a few labeled examples, a combined approach is often employed. This usually involves an unsupervised learning phase to extract shared features from both source and target data. Subsequently, a supervised learning component is introduced, often involving the prediction of pseudo-labels for the unlabeled target data.
- Unsupervised transfer: When target data has no labels, addressing the concept and prior shifts is often not possible as these shifts involve a change of previously unseen predictions. Without labeled data or additional information (for example physics-informed learning), the model cannot discern the correct predictions.
Covariate shift, characterized by changes in input distribution without affecting the prediction task, can be partially addressed using unsupervised methods. Covariate shifts involving critical feature changes (as illustrated in Figure 3) remain challenging without labeled target data, as predictions outside the source domain’s range become unpredictable without any additional knowledge. However, in the case of covariat shift involving irrelevant features (Figure 4), deep domain alignment is an effective transfer method. This technique aims to align the input distributions of the source and target domains to obtain a common feature representation from which we make predictions:
Example of applications
The proliferation of vast datasets and powerful models has led to a surge in domain shifts, where models trained on one dataset struggle to perform well on another. To address this challenge and optimize resource utilization, domain adaptation has emerged as a critical technique.
Common sources of domain shift include:
- Personalization: The unique nature of individual data, generated by devices like smartphones and other IoT devices, makes it challenging to create one-size-fits-all AI models. For instance, text prediction heavily relies on personal writing style, necessitating personalized models. While personalization offers significant benefits, a large number of participants and the additional personal data privacy requirements create additional challenges that are best addressed using federated learning [3]. Domain adaptation and transfer learning in the context of federated learning is then an important topic of research.
- Geographic variations: Factors like location significantly influence various tasks. Weather prediction, for example, requires localized models due to geographic differences.
- Medical Image Analysis: Transferring knowledge from large-scale natural image datasets to medical image tasks like disease classification, object detection, and segmentation. In [4] they use a large source image dataset (ImageNet, with more than 14 million images) as a source dataset to then transfer to their smaller medical dataset.
- Synthetic data: Sometimes the scarcity of real data can be mitigated by creating a large source dataset of synthetic data, and then transferring this knowledge to the real target dataset [5].
Conclusion
In today’s world, we’re surrounded by vast amounts of data. But not all data is created equal. Often, the information we need to solve a problem is different from the information we already have. This is where transfer learning and domain adaptation come in.
Imagine teaching a child to ride a bike. Once they learn how to balance, they can easily adapt to different bikes. This is like transfer learning—using what you already know to learn something new. But what if the child has to learn to ride a bike on ice? That’s where domain adaptation comes in—adjusting what you know to fit a new situation. These techniques are essential for creating smart technologies that can handle different situations and improve over time.
References
[1] W. M. Kouw, An introduction to domain adaptation and transfer learning, ArXiv abs/1812.11806 (2018). URL https://api.semanticscholar.org/CorpusID:57189554
[2] L. Wang, X. Zhang, H. Su, J. Zhu, A comprehensive survey of continual learning: Theory, method and application, IEEE Transactions on Pattern Analysis and Machine Intelligence 46 (8) (2024) 5362–5383. doi: 10.1109/TPAMI.2024.3367329.
[3] Federated learning: The future of ai without compromising privacy, https://www.societybyte.swiss/en/ 2024/04/26/federated-learning-the-future-of-ai-without-compromising-privacy/ (2024).
[4] H.-C. Shin, H. R. Roth, M. Gao, L. Lu, Z. Xu, I. Nogues, J. Yao, D. Mollura, R. M. Summers, Deep convolutional neural networks for computer-aided detection: Cnn architectures, dataset characteristics and transfer learning, IEEE Transactions on Medical Imaging 35 (5) (2016) 1285–1298. doi:10.1109/TMI.2016.2528162.
[5] S. Mishra, R. Panda, C. P. Phoo, C.-F. R. Chen, L. Karlinsky, K. Saenko, V. Saligrama, R. S. Feris, Task2sim: Towards effective pre-training and transfer from synthetic data, in: 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2022, pp. 9184–9194. doi:10.1109/CVPR52688.2022.00898.
Additional Reading Materials on the Topic:
[6] Fine tune large language model (llm) on a custom dataset with qlora, https://dassum.medium.com/fine-tune-large-language-model-llm-on-a-custom-dataset-with-qlora-fb60abdeba07 (2024).
[7] Awesome-domain-adaptation, https://github.com/zhaoxin94/awesome-domain-adaptation?tab=readme-ov-file#survey (2024).
[8] L. Zhang, X. Gao, Transfer adaptation learning: A decade survey, IEEE Transactions on Neural Networks and Learning Systems 35 (2019) 23–44. URL https://api.semanticscholar.org/CorpusID:75137541
[9] S. J. Pan, Q. Yang, A survey on transfer learning, IEEE Transactions on Knowledge and Data Engineering 22 (10) (2010) 1345–1359. doi:10.1109/TKDE.2009.191.
[10] F. Zhuang, Z. Qi, K. Duan, D. Xi, Y. Zhu, H. Zhu, H. Xiong, Q. He, A comprehensive survey on transfer learning, Proceedings of the IEEE 109 (2019) 43–76. URL https://api.semanticscholar.org/CorpusID:207847753
Footnote
[1] Forgetting is a common challenge in continual learning, where models must continually adapt to new, small datasets without compromising performance on previously learned tasks (catastrophic forgetting). This is a form of concept shift, and readers can refer to [2] for a deeper understanding of continual learning and the adaptation vs remembering trade-off.
Leave a Reply
Want to join the discussion?Feel free to contribute!