Skip to content

Commit

Permalink
Add unified support to load pretrained checkpoint for MMM. (#17)
Browse files Browse the repository at this point in the history
Co-authored-by: coco58323 <[email protected]>
  • Loading branch information
victorywys and coco58323 authored Apr 16, 2024
1 parent aafbdcf commit 9267d0a
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 36 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ A deep learning framework for physiological data processing and understanding.

## NEWS

- [NEW🔥] Pretrained checkpoint for paper `Learning Topology-Agnostic EEG Representations with Geometry-Aware Modeling` is now available [here](https://seqml.github.io/MMM/).
- `Learning Topology-Agnostic EEG Representations with Geometry-Aware Modeling` is now available [here](docs/MMM.md).
- `ContiFormer: Continuous-Time Transformer for Irregular Time Series Modeling` is now available [here](docs/contiformer.md).

Expand Down
14 changes: 13 additions & 1 deletion docs/MMM.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,26 @@ We'll release the pretrained checkpoint soon. But without it, you can still trai

3. Run `scripts/SEED_DE.py` to obtain the compatible format of SEED DE feature.

4. *Optional: Download the pretrained checkpoint [here](https://seqml.github.io/MMM/) and set the path to the checkpoint in `docs/configs/mmm_emotion.yml`*

## Finetune with MMM
```bash
# create the output directory
mkdir -p outputs/MMM_SEED/7/

# run the finetuning task
python -m physiopro.entry.mmm_emotion docs/configs/mmm_emotion.yml
# or
# python -m physiopro.entry.mmm_emotion docs/configs/mmm_emotion_from_ckpt.yml
# if you would like to load the pretrained encoder.

# tensorboard
tensorboard --logdir outputs/
```

Then it will run finetuning process on the 7th subject. The results will be saved to `outputs/MMM_SEED/7/` directory. You can run the finetuning process similarly on other subjects by changing `data.subject_index` in the configuration.
Then it will run finetuning process on the 7th subject. The results will be saved to `outputs/MMM_SEED/7/` directory. You can run the finetuning process similarly on other subjects by changing `data.subject_index` in the configuration.

## Regarding the DE feature
We are now aware of the possible issue of using DE feature for SEED (See this [issue](https://github.com/microsoft/PhysioPro/issues/16)). In order to keep it consistent with our paper, we provide the checkpoint pretrained on DE features but please use it wisely.

We're working to figure out the influence to our results, as well as training MMM on the raw EEG signals. We'll keep it updated.
47 changes: 47 additions & 0 deletions docs/configs/mmm_emotion_from_ckpt.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
data:
type: SEED
window_size: 1
# the subject index of the dataset
subject_index: 3
# the upper directory of the dataset
prefix: /home/yansenwang/data/

network:
type: MMM_Encoder
depth: 6
num_heads: 8
encoder_dim: 64
channel_num: 79
in_chans: 5
pe_type: 2d

decoder_network: # used only during pre-training. Can be omitted if only finetuning.
type: MMM_Encoder
depth: 6
encoder_dim: 64
channel_num: 79
in_chans: 16

model:
type: MMM_Finetune
task: multiclassification
# set up pre-trained model path, leave blank for training from scratch
# E.g.
model_path: /path/to/tuh_pretrained_encoder_base.pt
optimizer: Adam
lr: 0.00005
weight_decay: 0.005
loss_fn: cross_entropy
metrics: [accuracy]
observe: accuracy
lower_is_better: False
max_epochs: 100
early_stop: 70
batch_size: 32
out_size: 3
mask_ratio: 0.

runtime:
seed: 51
use_cuda: true
output_dir: outputs/MMM_SEED/3/
35 changes: 0 additions & 35 deletions physiopro/entry/mmm_emotion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import torch

from utilsd import get_output_dir, get_checkpoint_dir, setup_experiment
from utilsd.config import PythonConfig, RegistryConfig, RuntimeConfig, configclass
from ..dataset import DATASETS
Expand All @@ -14,37 +12,6 @@ class Config(PythonConfig):
model: RegistryConfig[MODELS]
runtime: RuntimeConfig = RuntimeConfig()

def load_model(path):
params = torch.load(path)
# remove module from name
keys = list(params.keys())
for name in keys:
print(name)
val=params[name]
if name.startswith('module.'):
name = name[7:]
params[name] = val
del params['module.'+name]

# remove network from name
# keys = list(params.keys())
# for name in keys:
# val=params[name]
# if name.startswith('network.'):
# name = name[8:]
# params[name] = val
# del params['network.'+name]
# else:
# del params[name]

# remove pos_embed and attn_mask
if 'pos_embed' in params:
del params['pos_embed']
if 'attn_mask' in params:
del params['attn_mask']
return params


def run(config):
setup_experiment(config.runtime)
trainset_finetune = config.data.build(dataset_name="train")
Expand All @@ -54,8 +21,6 @@ def run(config):
attn_mask = trainset_finetune.attn_mask,
pe_coordination = pe_coordination,
)
if config.model.model_path is not None:
network.load_state_dict(load_model(config.model.model_path), strict=False)

model = config.model.build(
network=network,
Expand Down
11 changes: 11 additions & 0 deletions physiopro/model/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ def forward(self, inputs, mask_type='random'):

return preds

def load(self, model_path: str, strict=False):
"""Load the model parameter from model path
Args:
model_path (str): The location where the model parameters are saved.
strict (bool, optional): [description]. Defaults to False.
**This is not the case in the BaseModel class because MMM_finetune only loads the network(encoder) parameters.**
"""
state_dict = torch.load(model_path, map_location="cpu")
self.load_state_dict(state_dict, strict=strict)

def fit(
self,
trainset: Dataset,
Expand Down

0 comments on commit 9267d0a

Please sign in to comment.