diff --git a/README.md b/README.md index 59f328b..57e17c8 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ A deep learning framework for physiological data processing and understanding. ## NEWS +- [NEW🔥] Pretrained checkpoint for paper `Learning Topology-Agnostic EEG Representations with Geometry-Aware Modeling` is now available [here](https://seqml.github.io/MMM/). - `Learning Topology-Agnostic EEG Representations with Geometry-Aware Modeling` is now available [here](docs/MMM.md). - `ContiFormer: Continuous-Time Transformer for Irregular Time Series Modeling` is now available [here](docs/contiformer.md). diff --git a/docs/MMM.md b/docs/MMM.md index 5d63645..4a323ca 100644 --- a/docs/MMM.md +++ b/docs/MMM.md @@ -12,14 +12,26 @@ We'll release the pretrained checkpoint soon. But without it, you can still trai 3. Run `scripts/SEED_DE.py` to obtain the compatible format of SEED DE feature. +4. *Optional: Download the pretrained checkpoint [here](https://seqml.github.io/MMM/) and set the path to the checkpoint in `docs/configs/mmm_emotion.yml`* + ## Finetune with MMM ```bash # create the output directory mkdir -p outputs/MMM_SEED/7/ + # run the finetuning task python -m physiopro.entry.mmm_emotion docs/configs/mmm_emotion.yml +# or +# python -m physiopro.entry.mmm_emotion docs/configs/mmm_emotion_from_ckpt.yml +# if you would like to load the pretrained encoder. + # tensorboard tensorboard --logdir outputs/ ``` -Then it will run finetuning process on the 7th subject. The results will be saved to `outputs/MMM_SEED/7/` directory. You can run the finetuning process similarly on other subjects by changing `data.subject_index` in the configuration. \ No newline at end of file +Then it will run finetuning process on the 7th subject. The results will be saved to `outputs/MMM_SEED/7/` directory. You can run the finetuning process similarly on other subjects by changing `data.subject_index` in the configuration. + +## Regarding the DE feature +We are now aware of the possible issue of using DE feature for SEED (See this [issue](https://github.com/microsoft/PhysioPro/issues/16)). In order to keep it consistent with our paper, we provide the checkpoint pretrained on DE features but please use it wisely. + +We're working to figure out the influence to our results, as well as training MMM on the raw EEG signals. We'll keep it updated. \ No newline at end of file diff --git a/docs/configs/mmm_emotion_from_ckpt.yml b/docs/configs/mmm_emotion_from_ckpt.yml new file mode 100644 index 0000000..2eb2366 --- /dev/null +++ b/docs/configs/mmm_emotion_from_ckpt.yml @@ -0,0 +1,47 @@ +data: + type: SEED + window_size: 1 + # the subject index of the dataset + subject_index: 3 + # the upper directory of the dataset + prefix: /home/yansenwang/data/ + +network: + type: MMM_Encoder + depth: 6 + num_heads: 8 + encoder_dim: 64 + channel_num: 79 + in_chans: 5 + pe_type: 2d + +decoder_network: # used only during pre-training. Can be omitted if only finetuning. + type: MMM_Encoder + depth: 6 + encoder_dim: 64 + channel_num: 79 + in_chans: 16 + +model: + type: MMM_Finetune + task: multiclassification + # set up pre-trained model path, leave blank for training from scratch + # E.g. + model_path: /path/to/tuh_pretrained_encoder_base.pt + optimizer: Adam + lr: 0.00005 + weight_decay: 0.005 + loss_fn: cross_entropy + metrics: [accuracy] + observe: accuracy + lower_is_better: False + max_epochs: 100 + early_stop: 70 + batch_size: 32 + out_size: 3 + mask_ratio: 0. + +runtime: + seed: 51 + use_cuda: true + output_dir: outputs/MMM_SEED/3/ \ No newline at end of file diff --git a/physiopro/entry/mmm_emotion.py b/physiopro/entry/mmm_emotion.py index 9375b1d..c63a96c 100644 --- a/physiopro/entry/mmm_emotion.py +++ b/physiopro/entry/mmm_emotion.py @@ -1,5 +1,3 @@ -import torch - from utilsd import get_output_dir, get_checkpoint_dir, setup_experiment from utilsd.config import PythonConfig, RegistryConfig, RuntimeConfig, configclass from ..dataset import DATASETS @@ -14,37 +12,6 @@ class Config(PythonConfig): model: RegistryConfig[MODELS] runtime: RuntimeConfig = RuntimeConfig() -def load_model(path): - params = torch.load(path) - # remove module from name - keys = list(params.keys()) - for name in keys: - print(name) - val=params[name] - if name.startswith('module.'): - name = name[7:] - params[name] = val - del params['module.'+name] - - # remove network from name - # keys = list(params.keys()) - # for name in keys: - # val=params[name] - # if name.startswith('network.'): - # name = name[8:] - # params[name] = val - # del params['network.'+name] - # else: - # del params[name] - - # remove pos_embed and attn_mask - if 'pos_embed' in params: - del params['pos_embed'] - if 'attn_mask' in params: - del params['attn_mask'] - return params - - def run(config): setup_experiment(config.runtime) trainset_finetune = config.data.build(dataset_name="train") @@ -54,8 +21,6 @@ def run(config): attn_mask = trainset_finetune.attn_mask, pe_coordination = pe_coordination, ) - if config.model.model_path is not None: - network.load_state_dict(load_model(config.model.model_path), strict=False) model = config.model.build( network=network, diff --git a/physiopro/model/mmm.py b/physiopro/model/mmm.py index e4b5956..9f7b67a 100644 --- a/physiopro/model/mmm.py +++ b/physiopro/model/mmm.py @@ -185,6 +185,17 @@ def forward(self, inputs, mask_type='random'): return preds + def load(self, model_path: str, strict=False): + """Load the model parameter from model path + + Args: + model_path (str): The location where the model parameters are saved. + strict (bool, optional): [description]. Defaults to False. + **This is not the case in the BaseModel class because MMM_finetune only loads the network(encoder) parameters.** + """ + state_dict = torch.load(model_path, map_location="cpu") + self.load_state_dict(state_dict, strict=strict) + def fit( self, trainset: Dataset,