Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using Wav2Vec2-base as backbone does not work #5

Open
jairsan opened this issue Feb 14, 2023 · 1 comment
Open

Using Wav2Vec2-base as backbone does not work #5

jairsan opened this issue Feb 14, 2023 · 1 comment

Comments

@jairsan
Copy link

jairsan commented Feb 14, 2023

Training using "facebook/wav2vec2-base" as backbone consistently fails with the following error:

1020it [01:35, 10.71it/s]
Starting epoch 0 ...
Traceback (most recent call last):
  File "/scratch/jiranzotmp/trabajo/ICASSP2023_argumentation/software/SHAS/src/supervised_hybrid/train.py", line 365, in <module>
    train(args)
  File "/scratch/jiranzotmp/trabajo/ICASSP2023_argumentation/software/SHAS/src/supervised_hybrid/train.py", line 147, in train
    logits = sfc_model(wav2vec_hidden, out_mask)
  File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/scratch/jiranzotmp/trabajo/ICASSP2023_argumentation/software/SHAS/src/supervised_hybrid/models.py", line 41, in forward
    x = self.transformer(x, src_key_padding_mask=attention_mask)
  File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 198, in forward
    output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
  File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 336, in forward
    x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
  File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/modules/normalization.py", line 189, in forward
    return F.layer_norm(
  File "/home/jiranzo/anaconda3/envs/shas/lib/python3.9/site-packages/torch/nn/functional.py", line 2347, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Given normalized_shape=[1024], expected input with shape [*, 1024], but got input of size[28, 999, 768]

Training with the default "facebook/wav2vec2-xls-r-300m" using the same setup gives me no issues.

Could this have something to do with the fact that wav2vec2-base uses "do_stable_layer_norm": false, whereas facebook/wav2vec2-xls-r-300m uses "do_stable_layer_norm": true?
My first guess would be that the assumptions made here might not hold if "do_stable_layer_norm": false.

wav2vec_model.encoder.layer_norm = torch.nn.Identity()

I will let you know if I find any additional information about this.

EDIT:

Actually it was something much simpler, the wav2vec2 base model has different hidden dimension (768 instead of 1024). Changing constants.py seems to fix everything:
https://github.com/mt-upc/SHAS/blob/main/src/supervised_hybrid/constants.py#L4

Feel free to close the issue if you think this is obvious.

@johntsi
Copy link
Collaborator

johntsi commented Feb 14, 2023

Hi @jairsan, yes the different dimensionalities of base and large configurations is the problem here. I will keep the issue open so that I add it as an argument (rather than constant) in the future. Thanks for posting it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

2 participants