CheXpert MAE-DenseNet-FPN
A deep learning framework for multi-label chest X-ray classification using a hybrid architecture combining Masked Autoencoders (MAE), DenseNet with CBAM attention, and Feature Pyramid Networks (FPN) with bidirectional cross-attention fusion.
🏗️ Architecture Overview
This project implements a novel multi-modal fusion architecture for medical image classification:
- MAE Encoder: Vision Transformer-based masked autoencoder for self-supervised feature extraction
- DenseNet-169: Dense convolutional network with Channel and Spatial Attention (CBAM)
- Feature Pyramid Network: Multi-scale feature extraction at 4 different resolutions
- Bidirectional Cross-Attention: Fusion mechanism allowing MAE and DenseNet features to attend to each other
- Learned Logit Ensemble: Intelligent combination of 7 prediction heads with learnable temperature scaling
Key Components
Input Image (384×384)
│
├─────────────────────────────┐
│ │
▼ ▼
MAE Encoder DenseNet-169
(ViT-based) (with CBAM)
│ │
│ ┌───────────────────┤
│ │ │
│ FPN Pyramid Dense Features
│ (P1-P4) (Multi-scale)
│ │ │
└─────────┴───────────────────┘
│
Bidirectional Cross-Attention
│
┌─────────┴──────────┐
│ │
MAE Head Dense Head + 4 FPN Heads
│ │
└────────┬───────────┘
│
Learned Ensemble (7 heads)
│
▼
14-class Predictions
✨ Features
- Hybrid Architecture: Combines transformer-based and convolutional approaches
- Multi-scale Learning: FPN extracts features at 4 different resolutions
- Advanced Fusion: Bidirectional cross-attention between MAE and DenseNet features
- Optimized Training:
- Mixed precision training (FP16)
- Gradient accumulation
- Weighted sampling for class imbalance
- Cosine annealing with linear warmup
- Gradient checkpointing for memory efficiency
- Smart Data Loading:
- ZIP file reader with LRU caching
- On-the-fly augmentation using Albumentations
- Multi-worker data loading with persistent workers
- Comprehensive Evaluation:
- Per-class AUC metrics
- Optimal threshold computation per class
- Macro and Micro AUC tracking
📋 Requirements
- Python 3.8+
- CUDA-capable GPU (recommended: 16GB+ VRAM)
- CheXpert dataset
🚀 Installation
- Clone the repository
git clone https://github.com/adelelsayed/chexpert-mae-densenet-fpn.git
cd chexpert-mae-densenet-fpn
- Create a virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
- Install dependencies
pip install -r requirements.txt
📊 Dataset Setup
Download CheXpert Dataset
- Visit: https://stanfordmlgroup.github.io/competitions/chexpert/
- Download CheXpert-v1.0-small
Prepare the dataset
# Extract the dataset
unzip CheXpert-v1.0-small.zip
# Optionally, create a ZIP archive for faster loading
cd CheXpert-v1.0-small
zip -r chexpert.zip train/ valid/
- Update configuration
- Edit
configs/configs.py - Update
rootvariable to point to your dataset location - Update all paths accordingly
- Edit
🔧 Configuration
Edit configs/configs.py to customize:
# Example: Update paths
root = "/path/to/your/data"
mae_config = {
"lr": 1e-4,
"num_epochs": 200,
"batch_size": 96,
# ... other parameters
}
config = {
"lr": 1e-4,
"num_epochs": 200,
"batch_size": 36,
# ... other parameters
}
🎯 Training
Phase 1: Pre-train MAE
python trainer/trainer.py
# When prompted, type: mae
The MAE pre-training learns robust feature representations through masked image reconstruction.
Phase 2: Train Classifier
python trainer/trainer.py
# When prompted, type: classifier
This loads the pre-trained MAE encoder and trains the full classification pipeline.
Training Configuration
MAE Training:
- Batch size: 96
- Mask ratio: 0.75 (masks 75% of patches)
- Reconstruction loss on masked patches
Classifier Training:
- Batch size: 36 with gradient accumulation (8 steps)
- Effective batch size: 288
- Asymmetric loss with class weights
- Per-class threshold optimization
🧪 Testing
from trainer.utils import Trainer
from configs.configs import config
# Initialize trainer
trainer = Trainer(config)
# Run evaluation on test set
macro_auc, micro_auc, per_class = trainer.test(
model_path="path/to/checkpoint.pth"
)
print(f"Macro AUC: {macro_auc:.4f}")
print(f"Micro AUC: {micro_auc:.4f}")
📁 Project Structure
chexpert-mae-densenet-fpn/
├── configs/
│ ├── __init__.py
│ └── configs.py # Configuration parameters
├── data/
│ ├── __init__.py
│ ├── dataset.py # CheXpert dataset with ZIP caching
│ └── splitter.py # Data splitting utilities
├── loss/
│ ├── __init__.py
│ └── assymetric.py # Asymmetric loss for imbalanced data
├── models/
│ ├── __init__.py
│ ├── mae.py # Masked Autoencoder implementation
│ ├── densenet.py # DenseNet-169 with CBAM
│ └── classifier.py # Full classification architecture
├── trainer/
│ ├── __init__.py
│ ├── trainer.py # Main training script
│ ├── utils.py # Training utilities and loops
│ └── test.py # Testing utilities
├── notebooks/
│ ├── chexpert_mae.ipynb # MAE experiments
│ └── chexpert_mae_mask_classifier.ipynb # Full pipeline experiments
├── requirements.txt
└── README.md
📈 Model Architecture Details
MAE Encoder
- Patch size: 16×16
- Embedding dim: 768
- Depth: 12 transformer blocks
- Heads: 8 attention heads
- MLP ratio: 4×
DenseNet-169
- Growth rate (k): 64
- Layers: [6, 12, 24, 16]
- CBAM: Channel + Spatial attention at each stage
- Dropout: Progressive (0.05 → 0.1 → 0.1 → 0.1)
Cross-Attention Fusion
- 12 bidirectional cross-attention layers
- Projection dim: 512
- Attention heads: 8
FPN
- Feature levels: P1 (192×192), P2 (96×96), P3 (48×48), P4 (24×24)
- Channel unification: 256 channels per level
🎓 CheXpert Labels
The model predicts 14 pathologies:
- No Finding
- Enlarged Cardiomediastinum
- Cardiomegaly
- Lung Opacity
- Lung Lesion
- Edema
- Consolidation
- Pneumonia
- Atelectasis
- Pneumothorax
- Pleural Effusion
- Pleural Other
- Fracture
- Support Devices
🔬 Data Augmentation
Training augmentations (conservative for medical images):
- Horizontal flip (p=0.5)
- Random affine (translation, scale, rotation ±10°)
- Random brightness/contrast
- CLAHE histogram equalization
- Gaussian blur and noise
💾 Checkpoints
The training automatically saves:
- Best MAE checkpoint: Based on validation reconstruction loss
- Best classifier checkpoint: Based on validation AUC (macro/micro)
- Training history: JSON file with all metrics
- Per-epoch metrics plots: Loss and AUC curves
📊 Monitoring
Training logs are saved to:
training_log.txt: Training progress with live metricsval_log.txt: Validation resultstest_log.txt: Test evaluation resultshistory.json: All metrics across epochsmetrics.png: Visualization plots
⚡ Performance Tips
Memory Optimization:
- Use gradient checkpointing (already enabled)
- Reduce batch size if OOM occurs
- Increase gradient accumulation steps
Speed Optimization:
- Use persistent workers (already enabled)
- Enable cuDNN benchmark (already enabled)
- Use ZIP caching for faster data loading
Training Stability:
- Gradient clipping at norm 1.0
- Mixed precision with dynamic loss scaling
- Warmup learning rate schedule
🐛 Troubleshooting
Q: Out of memory errors?
- Reduce batch size in configs.py
- Increase gradient accumulation steps
- Enable gradient checkpointing
Q: Slow training?
- Check if ZIP caching is enabled
- Verify persistent workers are active
- Monitor GPU utilization
Q: Poor convergence?
- Ensure MAE is properly pre-trained first
- Check learning rate and warmup settings
- Verify class weights are computed correctly
📚 Citation
If you use this code in your research, please cite:
@misc{chexpert-mae-densenet-fpn,
author = {adel elsayed},
title = {CheXpert Classification with MAE-DenseNet-FPN},
year = {2025},
publisher = {GitHub},
url = {https://github.com/adelelsayed/chexpert-mae-densenet-fpn}
}
🙏 Acknowledgments
- CheXpert Dataset: Stanford ML Group
- Masked Autoencoders: Meta AI Research (He et al., 2021)
- DenseNet: Huang et al., 2017
- CBAM: Woo et al., 2018
- Feature Pyramid Networks: Lin et al., 2017
📄 License
License
This project is licensed under the MIT License.
📧 Contact
https://www.linkedin.com/in/adel-elsayed-a5260246/
Note: This is a research project. For clinical use, please ensure proper validation and regulatory approval.