PyTorch codebase for I-JEPA (the Image-based Joint-Embedding Predictive Architecture) published @ CVPR-23. [arXiv] [JEPAs] [blogpost]
Pretraining with Diabetic Retinopathy Kaggle competition dataset. [Kaggle]
arch. | patch size | resolution | epochs | data | download | ||
---|---|---|---|---|---|---|---|
ViT-H | 14x14 | 224x224 | 300 | ImageNet-1K | full checkpoint | logs | configs |
.
├── configs # directory in which all experiment '.yaml' configs are stored
├── src # the package
│ ├── train.py # the I-JEPA training loop
│ ├── helper.py # helper functions for init of models & opt/loading checkpoint
│ ├── transforms.py # pre-train data transforms
│ ├── datasets # datasets, data loaders, ...
│ ├── models # model definitions
│ ├── masks # mask collators, masking utilities, ...
│ └── utils # shared utilities
├── logs # Logs for the result
│ └── vith14.224-bs.2048-ep.300 # Folder containing loss info
│ └── jepa_r0.csv # Csv file containing loss info
├── main_distributed.py # entrypoint for launch distributed I-JEPA pretraining on SLURM cluster
├── main.py # entrypoint for launch I-JEPA pretraining locally on your machine
├── notebook_prompt.ipynb # (JHA) Replacement for cmd. Run this first in order to pretrain I-JEPA with Fundus data.
└── csv_analysis.ipynb # (JHA) Plot the moving average of loss in each epoch.
.
└── Input #
├── aptos2019-blindness-detection #
├── diabetic-retinopathy-pre-training #
└── diabetic-retinopathy-resized #
├── trainLabels.csv # Csv recording the Diabetic Retinopathy level of each image
└── resized_train #
├── resized_train # Resized fundus image. See the Kaggle link above for the definition of "resizing"
└── train # Original fundus image
Config files: Note that all experiment parameters are specified in config files (as opposed to command-line-arguments). See the configs/ directory for example config files.
This implementation starts from the notebook_prompt.ipynb.
- Python 3.8.18
- Torch 2.2.1 + cu121
- torchvision
- Other dependencies: pyyaml, numpy, opencv, submitit
- GPU: Nvidia RTX A4000 x1/ CPU: Intel Core i9 (13th Gen)
- Cuda: 12.1, CuDNN: 8.9.7
Last updated: 03/26/2024 09:53 (UTC +09:00)