You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Training using "facebook/wav2vec2-base" as backbone consistently fails with the following error:
1020it [01:35, 10.71it/s]
Starting epoch 0 ...
Traceback (most recent call last):
File "/scratch/jiranzotmp/trabajo/ICASSP2023_argumentation/software/SHAS/src/supervised_hybrid/train.py", line 365, in <module>
train(args)
File "/scratch/jiranzotmp/trabajo/ICASSP2023_argumentation/software/SHAS/src/supervised_hybrid/train.py", line 147, in train
logits = sfc_model(wav2vec_hidden, out_mask)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/scratch/jiranzotmp/trabajo/ICASSP2023_argumentation/software/SHAS/src/supervised_hybrid/models.py", line 41, in forward
x = self.transformer(x, src_key_padding_mask=attention_mask)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 198, in forward
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 336, in forward
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/normalization.py", line 189, in forward
return F.layer_norm(
File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/functional.py", line 2347, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Given normalized_shape=[1024], expected input with shape [*, 1024], but got input of size[28, 999, 768]
Training with the default "facebook/wav2vec2-xls-r-300m" using the same setup gives me no issues.
Could this have something to do with the fact that wav2vec2-base uses "do_stable_layer_norm": false, whereas facebook/wav2vec2-xls-r-300m uses "do_stable_layer_norm": true? My first guess would be that the assumptions made here might not hold if "do_stable_layer_norm": false.
Hi @jairsan, yes the different dimensionalities of base and large configurations is the problem here. I will keep the issue open so that I add it as an argument (rather than constant) in the future. Thanks for posting it!
Training using "facebook/wav2vec2-base" as backbone consistently fails with the following error:
Training with the default "facebook/wav2vec2-xls-r-300m" using the same setup gives me no issues.
Could this have something to do with the fact that wav2vec2-base uses "do_stable_layer_norm": false, whereas facebook/wav2vec2-xls-r-300m uses "do_stable_layer_norm": true?My first guess would be that the assumptions made here might not hold if "do_stable_layer_norm": false.SHAS/src/supervised_hybrid/models.py
Line 80 in 418b5e6
I will let you know if I find any additional information about this.EDIT:
Actually it was something much simpler, the wav2vec2 base model has different hidden dimension (768 instead of 1024). Changing constants.py seems to fix everything:
https://github.com/mt-upc/SHAS/blob/main/src/supervised_hybrid/constants.py#L4
Feel free to close the issue if you think this is obvious.
The text was updated successfully, but these errors were encountered: