CheXpert MAE-DenseNet-FPN

A deep learning framework for multi-label chest X-ray classification using a hybrid architecture combining Masked Autoencoders (MAE), DenseNet with CBAM attention, and Feature Pyramid Networks (FPN) with bidirectional cross-attention fusion.

🏗️ Architecture Overview

This project implements a novel multi-modal fusion architecture for medical image classification:

  • MAE Encoder: Vision Transformer-based masked autoencoder for self-supervised feature extraction
  • DenseNet-169: Dense convolutional network with Channel and Spatial Attention (CBAM)
  • Feature Pyramid Network: Multi-scale feature extraction at 4 different resolutions
  • Bidirectional Cross-Attention: Fusion mechanism allowing MAE and DenseNet features to attend to each other
  • Learned Logit Ensemble: Intelligent combination of 7 prediction heads with learnable temperature scaling

Key Components

Input Image (384×384)
    │
    ├─────────────────────────────┐
    │                             │
    ▼                             ▼
MAE Encoder                  DenseNet-169
(ViT-based)                  (with CBAM)
    │                             │
    │         ┌───────────────────┤
    │         │                   │
    │    FPN Pyramid          Dense Features
    │    (P1-P4)              (Multi-scale)
    │         │                   │
    └─────────┴───────────────────┘
              │
    Bidirectional Cross-Attention
              │
    ┌─────────┴──────────┐
    │                    │
MAE Head          Dense Head + 4 FPN Heads
    │                    │
    └────────┬───────────┘
             │
    Learned Ensemble (7 heads)
             │
             ▼
    14-class Predictions

✨ Features

  • Hybrid Architecture: Combines transformer-based and convolutional approaches
  • Multi-scale Learning: FPN extracts features at 4 different resolutions
  • Advanced Fusion: Bidirectional cross-attention between MAE and DenseNet features
  • Optimized Training:
    • Mixed precision training (FP16)
    • Gradient accumulation
    • Weighted sampling for class imbalance
    • Cosine annealing with linear warmup
    • Gradient checkpointing for memory efficiency
  • Smart Data Loading:
    • ZIP file reader with LRU caching
    • On-the-fly augmentation using Albumentations
    • Multi-worker data loading with persistent workers
  • Comprehensive Evaluation:
    • Per-class AUC metrics
    • Optimal threshold computation per class
    • Macro and Micro AUC tracking

📋 Requirements

  • Python 3.8+
  • CUDA-capable GPU (recommended: 16GB+ VRAM)
  • CheXpert dataset

🚀 Installation

  1. Clone the repository
git clone https://github.com/adelelsayed/chexpert-mae-densenet-fpn.git
cd chexpert-mae-densenet-fpn
  1. Create a virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate
  1. Install dependencies
pip install -r requirements.txt

📊 Dataset Setup

  1. Download CheXpert Dataset

  2. Prepare the dataset

# Extract the dataset
unzip CheXpert-v1.0-small.zip

# Optionally, create a ZIP archive for faster loading
cd CheXpert-v1.0-small
zip -r chexpert.zip train/ valid/
  1. Update configuration
    • Edit configs/configs.py
    • Update root variable to point to your dataset location
    • Update all paths accordingly

🔧 Configuration

Edit configs/configs.py to customize:

# Example: Update paths
root = "/path/to/your/data"

mae_config = {
    "lr": 1e-4,
    "num_epochs": 200,
    "batch_size": 96,
    # ... other parameters
}

config = {
    "lr": 1e-4,
    "num_epochs": 200,
    "batch_size": 36,
    # ... other parameters
}

🎯 Training

Phase 1: Pre-train MAE

python trainer/trainer.py
# When prompted, type: mae

The MAE pre-training learns robust feature representations through masked image reconstruction.

Phase 2: Train Classifier

python trainer/trainer.py
# When prompted, type: classifier

This loads the pre-trained MAE encoder and trains the full classification pipeline.

Training Configuration

  • MAE Training:

    • Batch size: 96
    • Mask ratio: 0.75 (masks 75% of patches)
    • Reconstruction loss on masked patches
  • Classifier Training:

    • Batch size: 36 with gradient accumulation (8 steps)
    • Effective batch size: 288
    • Asymmetric loss with class weights
    • Per-class threshold optimization

🧪 Testing

from trainer.utils import Trainer
from configs.configs import config

# Initialize trainer
trainer = Trainer(config)

# Run evaluation on test set
macro_auc, micro_auc, per_class = trainer.test(
    model_path="path/to/checkpoint.pth"
)

print(f"Macro AUC: {macro_auc:.4f}")
print(f"Micro AUC: {micro_auc:.4f}")

📁 Project Structure

chexpert-mae-densenet-fpn/
├── configs/
│   ├── __init__.py
│   └── configs.py          # Configuration parameters
├── data/
│   ├── __init__.py
│   ├── dataset.py          # CheXpert dataset with ZIP caching
│   └── splitter.py         # Data splitting utilities
├── loss/
│   ├── __init__.py
│   └── assymetric.py       # Asymmetric loss for imbalanced data
├── models/
│   ├── __init__.py
│   ├── mae.py              # Masked Autoencoder implementation
│   ├── densenet.py         # DenseNet-169 with CBAM
│   └── classifier.py       # Full classification architecture
├── trainer/
│   ├── __init__.py
│   ├── trainer.py          # Main training script
│   ├── utils.py            # Training utilities and loops
│   └── test.py             # Testing utilities
├── notebooks/
│   ├── chexpert_mae.ipynb              # MAE experiments
│   └── chexpert_mae_mask_classifier.ipynb  # Full pipeline experiments
├── requirements.txt
└── README.md

📈 Model Architecture Details

MAE Encoder

  • Patch size: 16×16
  • Embedding dim: 768
  • Depth: 12 transformer blocks
  • Heads: 8 attention heads
  • MLP ratio: 4×

DenseNet-169

  • Growth rate (k): 64
  • Layers: [6, 12, 24, 16]
  • CBAM: Channel + Spatial attention at each stage
  • Dropout: Progressive (0.05 → 0.1 → 0.1 → 0.1)

Cross-Attention Fusion

  • 12 bidirectional cross-attention layers
  • Projection dim: 512
  • Attention heads: 8

FPN

  • Feature levels: P1 (192×192), P2 (96×96), P3 (48×48), P4 (24×24)
  • Channel unification: 256 channels per level

🎓 CheXpert Labels

The model predicts 14 pathologies:

  1. No Finding
  2. Enlarged Cardiomediastinum
  3. Cardiomegaly
  4. Lung Opacity
  5. Lung Lesion
  6. Edema
  7. Consolidation
  8. Pneumonia
  9. Atelectasis
  10. Pneumothorax
  11. Pleural Effusion
  12. Pleural Other
  13. Fracture
  14. Support Devices

🔬 Data Augmentation

Training augmentations (conservative for medical images):

  • Horizontal flip (p=0.5)
  • Random affine (translation, scale, rotation ±10°)
  • Random brightness/contrast
  • CLAHE histogram equalization
  • Gaussian blur and noise

💾 Checkpoints

The training automatically saves:

  • Best MAE checkpoint: Based on validation reconstruction loss
  • Best classifier checkpoint: Based on validation AUC (macro/micro)
  • Training history: JSON file with all metrics
  • Per-epoch metrics plots: Loss and AUC curves

📊 Monitoring

Training logs are saved to:

  • training_log.txt: Training progress with live metrics
  • val_log.txt: Validation results
  • test_log.txt: Test evaluation results
  • history.json: All metrics across epochs
  • metrics.png: Visualization plots

⚡ Performance Tips

  1. Memory Optimization:

    • Use gradient checkpointing (already enabled)
    • Reduce batch size if OOM occurs
    • Increase gradient accumulation steps
  2. Speed Optimization:

    • Use persistent workers (already enabled)
    • Enable cuDNN benchmark (already enabled)
    • Use ZIP caching for faster data loading
  3. Training Stability:

    • Gradient clipping at norm 1.0
    • Mixed precision with dynamic loss scaling
    • Warmup learning rate schedule

🐛 Troubleshooting

Q: Out of memory errors?

  • Reduce batch size in configs.py
  • Increase gradient accumulation steps
  • Enable gradient checkpointing

Q: Slow training?

  • Check if ZIP caching is enabled
  • Verify persistent workers are active
  • Monitor GPU utilization

Q: Poor convergence?

  • Ensure MAE is properly pre-trained first
  • Check learning rate and warmup settings
  • Verify class weights are computed correctly

📚 Citation

If you use this code in your research, please cite:

@misc{chexpert-mae-densenet-fpn,
  author = {adel elsayed},
  title = {CheXpert Classification with MAE-DenseNet-FPN},
  year = {2025},
  publisher = {GitHub},
  url = {https://github.com/adelelsayed/chexpert-mae-densenet-fpn}
}

🙏 Acknowledgments

  • CheXpert Dataset: Stanford ML Group
  • Masked Autoencoders: Meta AI Research (He et al., 2021)
  • DenseNet: Huang et al., 2017
  • CBAM: Woo et al., 2018
  • Feature Pyramid Networks: Lin et al., 2017

📄 License

License

This project is licensed under the MIT License.

📧 Contact

https://www.linkedin.com/in/adel-elsayed-a5260246/

Note: This is a research project. For clinical use, please ensure proper validation and regulatory approval.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support