In an article recently submitted to the arxiv* server, researchers introduced MEDITRON, a pair of open-source Large Language Models (LLMs) tailored for medical reasoning, featuring 7 billion and 70 billion parameters. MEDITRON built on Llama-2 and leveraged curated medical data sources to address inequitable access to medical knowledge.
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.
Background
LLMs hold immense potential for democratizing access to medical knowledge, a critical aspect of evidence-based decision-making in clinical settings. However, existing models, such as Pathways Language Model (PaLM) and Generative Pre-trained Transformer (GPT)-4, are either closed-source or limited in scale (≤ 13 billion parameters), constraining their capabilities. While recent advancements in LLMs have demonstrated their potential in various domains, their application in medicine has faced challenges.
Prior efforts focused on generalist models trained on diverse internet data, leading to limitations in task-specific performance, especially in the medical domain. Task-specific models, although showing promise, were confined to smaller scales. Recognizing this limitation, the paper introduced MEDITRON, an open-source suite of LLMs with 7 billion and 70 billion parameters tailored for the medical domain. Unlike prior models, MEDITRON leveraged a comprehensive, curated medical corpus, drawing from selected PubMed articles, abstracts, and internationally-recognized medical guidelines.
The paper introduced an optimized workflow for scaling domain-specific pretraining in medical LLMs. This approach included knowledge-based data curation, distributed training, finetuning, and advanced inference methods like chain-of-thought reasoning and self-consistency. By releasing the curated training corpus, the distributed training library, and the MEDITRON models to the public, the paper encouraged real-world evaluation and aimed to catalyze similar developments in other domains. MEDITRON's substantial performance improvements showcased its potential to enhance medical knowledge access and reasoning capabilities, marking a significant step towards more capable and accessible medical LLMs.
Medical Training Data
MEDITRON's domain-adaptive pre-training corpus, GAP-REPLAY, integrated 48.1 billion tokens from diverse datasets:
- Clinical Guidelines: Comprising 46,000 clinical practice guidelines from global healthcare sources, the GUIDELINES corpus aided evidence-based decision-making. A subset of 35,733 articles from open-access sources was released, encompassing varied medical domains.
- PubMed Papers & Abstracts: Utilizing PubMed, 4.47 million full-text papers, 444,521 open-access PubMed papers, and 16,209,047 abstracts were collected from the S2ORC corpus. Pre-processing involved metadata removal, citation formatting, and hierarchical structure indication.
- Experience Replay: To mitigate catastrophic forgetting, experience replay incorporated general domain data (1% of the mixture) from the RedPajama dataset, enhancing the retention of pre-trained knowledge.
Engineering
Training LLMs at scale posed significant engineering challenges due to the model's large parameter size and the extensive token count involved in pre-training. To address this, the Megatron-LLM distributed training library was developed, extending Nvidia's Megatron-LM to support three open-source LLMs: Llama, Falcon, and Llama-2. Various forms of parallelism, including Data Parallelism (DP), Pipeline Parallelism (PP), and Tensor Parallelism (TP), were employed for distributed training.
To enhance training efficiency, the library incorporated features specific to Llama, Llama-2, and Falcon models, such as rotary position embedding, grouped-query attention, and FlashAttention. Hardware-wise, the MEDITRON models were trained on a cluster with 16 nodes, each equipped with 8 Nvidia A100 80 Giga Byte (GB) Graphics Processing Units (GPU), Advanced Micro Devices (AMD) EPYC 7543 processors, and 512 GB of Random Access Memory (RAM).
Model Parallelism was crucial, and for the 70 billion parameter model, a PP factor of 8 was used. With 128 GPUs in the cluster, a DP of 2 was achieved. The authors noted the importance of 3D model parallelism for efficient training, where TP, PP, and DP are all greater than one.
Modeling
There are three parts to the modeling phase:
- Pretraining: To adapt the Llama-2 language model to the medical domain, continued pre-training was performed on the GAP-REPLAY data mixture, including papers from PubMed and PubMed Central, abstracts from PubMed, and medical guidelines. The pretraining settings and model architecture largely followed Llama-2, utilizing the transformer architecture, Root Mean Square Normalization (RMSNorm), Swish Gated Linear Unit (SwiGLU) activation, rotary positional embeddings, and group-query attention. Training details, hyperparameters, and tokenization specifics were aligned with Llama-2's implementation.
- Supervised Finetuning: The MEDITRON models underwent supervised finetuning for evaluating downstream performance on medical reasoning benchmarks. Each model was individually finetuned on the training set of respective benchmarks, and evaluations were conducted on the corresponding test sets. Finetuning involved manual crafting of clear instructions for each training set, following OpenAI’s ChatML format. Hyperparameters for finetuning utilized the AdamW optimizer, a cosine learning rate schedule, and specific learning rates, weight decay, and batch size settings.
- Inference: Various inference methods were employed to obtain answers from the pretraining or finetuned models. Top Token Selection was utilized for tasks with single-label answers. Chain-of-Thought (CoT) prompting was applied for multi-step problems, enhancing the model's reasoning ability. Self-consistency CoT (SC-CoT) involved sampling multiple reasoning traces and answers, employing majority voting for the final prediction. Inference methods differed for fine-tuned models and those continuing from pretraining, including in-context learning and direct generation.
Medical Benchmark
The study utilized four medical benchmarks, namely MedQA, MedMCQA, PubMedQA, and MMLU-Medical, in line with prior research on medical LLM development and evaluation methods. MedQA, resembling the US Medical License Exam (USMLE), presented challenges in contextualizing diverse medical knowledge. MedMCQA comprised over 194 thousand questions from Indian medical entrance exams, while PubMedQA predicted yes, no, or maybe answers based on PubMed abstracts.
MMLU-Medical amalgamates subjects related to medical and clinical knowledge from the MMLU dataset. MedQA and MedMCQA datasets varied in the number of answer choices, with the latter used for chain-of-thought reasoning training. The study used MedQA's explanatory answers for chain-of-thought training. For MedMCQA, the validation set was employed for evaluations due to the unavailability of public test set answer keys. PubMedQA was fine-tuned using artificially labeled examples due to the limited size of expert-labeled samples. MMLU-Medical, lacking training data, is finetuned using MedMCQA data, and evaluations assessed generalization performance from MedMCQA to MMLU-Medical.
Main Results
Pretrained Model Evaluation:
- MEDITRON-7B outperformed baselines, particularly excelling in PubMedQA.
- MEDITRON-70B surpassed the base model Llama-2-70B on all benchmarks.
- At the 70 billion scale, MEDITRON-70B demonstrated robust reasoning ability even before task-specific finetuning.
- In-context learning enhanced MEDITRON-7B's performance, especially notable in PubMedQA.
Finetuned Model Evaluation:
- MEDITRON-7B outperformed Llama-2-7B and PMC-Llama-7B, achieving the best accuracy in various benchmarks.
- MEDITRON-70B improved over Clinical-Camel-70B and Med42-70B, showcasing its effectiveness.
- CoT and SC-CoT further enhanced MEDITRON-70B's performance.
- MEDITRON-70B outperformed commercial LLMs in multiple medical benchmarks despite having a smaller parameter count.
Key Findings:
- MEDITRON models, especially at the 70 billion scale, exhibited strong reasoning abilities in medical tasks.
- In-context learning during pre-training and SC-CoT during finetuning contributed to improved performance.
- MEDITRON-70B competed favorably with commercial LLMs with significantly larger parameter sizes in medical reasoning tasks.
Analysis
Impact of Continued Pretraining:
- Continued pre-training was monitored through language modeling losses and intermediate evaluations.
- Learning quality, indicated by decreasing losses, demonstrated effective learning without overfitting.
- Intermediate evaluations showed consistent performance improvement across iterations in in-context learning.
- Certain datasets showed performance drop in intermediate checkpoints, emphasizing the benefits of large-scale continual pretraining.
Data Mixture Ablation:
- Replay tokens from the general domain improved downstream performance, except in MedMCQA.
- PMC + Replay outperformed PMC, showing a 1.6% average performance increase.
- Upsampling medical papers weakened downstream performance compared to PMC + Replay.
- Adding code to the training corpus decreased overall performance on medical benchmarks.
- GAP + Replay, incorporating PubMed abstracts and medical guidelines, lead to the best average performance and was chosen for MEDITRON’s continued pretraining.
Conclusion
In conclusion, MEDITRON, a domain-adapted medical LLM, exhibited advanced medical reasoning and surpassed state-of-the-art benchmarks. Achieving notable performance gains through continued pretraining on curated medical resources, including clinical guidelines, MEDITRON outperformed open-source and commercial LLMs at comparable scales. Released openly, MEDITRON, along with tools and training resources, aimed to catalyze advancements in medical research, patient care, and innovation by fostering collaborative efforts in the health domain.
*Important notice: arXiv publishes preliminary scientific reports that are not peer-reviewed and, therefore, should not be regarded as definitive, used to guide development decisions, or treated as established information in the field of artificial intelligence research.