Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training problem with rnn_transducer #279

Open
yiqiaoc11 opened this issue Feb 8, 2023 · 8 comments
Open

training problem with rnn_transducer #279

yiqiaoc11 opened this issue Feb 8, 2023 · 8 comments

Comments

@yiqiaoc11
Copy link

The transducer in TensorFlowASR\examples\rnn_transducer doesn't work for the current version with either current or pretrained config.yml. This is a fundamental function. Can the author or someone give it a try to validate it?

@nglehuy
Copy link
Collaborator

nglehuy commented Feb 8, 2023

Hi @yiqiaoc11
I'm training on TPUs to validate this.
Are you using the warp-transducer loss or the rnnt loss in tensorflow?
So far as I'm testing with rnnt loss in tensorflow for the past months, it has some issues with convergence. But I dont have resources to test with GPUs.

@yiqiaoc11
Copy link
Author

yiqiaoc11 commented Feb 8, 2023

@usimarit Thanks for comments. All my test were conducted on GPU as bs = 2.
I tried once with TensorFlowASR\examples\rnn_transducer\config.yml with warp-transducer loss and observed same #231(#231). Then I switched to rnnt-loss while using pretrained models' .yml containing warmup_steps for transformer scheduler. The loss is shown below,
Epoch 1/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 339.2118 - val_loss: 160.0162
Epoch 2/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 254.4009 - val_loss: 147.7653
Epoch 3/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 241.0356 - val_loss: 144.2561
Epoch 4/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 231.6980 - val_loss: 140.2191
Epoch 5/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 223.2308 - val_loss: 137.6041
Epoch 6/20
14269/14269 [==============================] - 5923s 415ms/step - loss: 216.7098 - val_loss: 136.0396

4/6-layer encoder worked with different warmup steps in case of rnnt-loss, but not 8. Just trying to recover the performance of the pretrained. Conformer reportedly works which differs only with rnn_transducer.

Feel to advise and I can try it on GPU here.

@nglehuy
Copy link
Collaborator

nglehuy commented Feb 9, 2023

@yiqiaoc11 Could you help me train 2 models for 30 epochs using rnnt-loss:

  1. 4-layers encoder
  2. 8-layers encoder

Then plot the loss of 2 models for better comparison?
Other configs are the same.

@yiqiaoc11
Copy link
Author

yiqiaoc11 commented Feb 9, 2023

@usimarit, using the streaming config.yml (https://drive.google.com/file/d/1xYFYi3z94ZqaQZ-cTyiNekBwhITh1Ru4l) with warmup_steps=40000 , right?

From the timeline, you seemed to apply warp-transducer loss to get the pretrained .h5 weights.

@nglehuy
Copy link
Collaborator

nglehuy commented Feb 9, 2023

@yiqiaoc11 Yes, with the pretrained config

I trained the rnn transducer on TPUs so warp-transducer loss cannot be applied, only rnnt-loss can be used here. But you can experiment with warp-transducer loss too, plotting the loss of 2 models for better comparison.

@yiqiaoc11
Copy link
Author

@usimarit, Now I'm having 2 x 3090, 2 x 30 epochs will take fairly long time with rnnt-loss. Now 8-layer doesn't converge and 4-layer converge with > 40000 warmup-steps. Conformer using the same rnnt-loss works. Could rnn_transducer differ while you pretrained it giving same loss, same optimizer, same number of weights?

@nglehuy
Copy link
Collaborator

nglehuy commented Feb 9, 2023

@yiqiaoc11 The rnn_transducer structure stays the same in version v1.0.x
Is the number of weights in your case the same as in the pretrained example?

@yiqiaoc11
Copy link
Author

yiqiaoc11 commented Feb 10, 2023

Yes, the number of weights and distributions of layers are same, but other config information from the pretrained isn't tractable. Not sure what leads to the underfitting observed.

Primary loss curves for 4/8layer are posted for differentiation. Green curves are for 8-layer while blue 4-layer. Losses are very similar while models were tuned under the same .yml in GDrive. They don't converge.
Untitled

[2023-02-09 09:03:10] PRINT Layer (type) Output Shape Param #
[2023-02-09 09:03:10] PRINT ====================================================================================================
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_reshape (Resha multiple 0
[2023-02-09 09:03:10] PRINT pe)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_0 (RnnTr multiple 5511488
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_1 (RnnTr multiple 7149888
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_2 (RnnTr multiple 5839168
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT streaming_transducer_encoder_block_3 (RnnTr multiple 5839168
[2023-02-09 09:03:10] PRINT ansducerBlock)
[2023-02-09 09:03:10] PRINT
[2023-02-09 09:03:10] PRINT ====================================================================================================
[2023-02-09 09:03:10] PRINT Total params: 24,339,712
[2023-02-09 09:03:10] PRINT Trainable params: 24,339,712
[2023-02-09 09:03:10] PRINT Non-trainable params: 0

[2023-02-09 09:03:15] PRINT Layer (type) Output Shape Param #
[2023-02-09 09:03:15] PRINT ====================================================================================================
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_reshape (Resha multiple 0
[2023-02-09 09:03:15] PRINT pe)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_0 (RnnTr multiple 5511488
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_1 (RnnTr multiple 7149888
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_2 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_3 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_4 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_5 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_6 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT streaming_transducer_encoder_block_7 (RnnTr multiple 5839168
[2023-02-09 09:03:15] PRINT ansducerBlock)
[2023-02-09 09:03:15] PRINT
[2023-02-09 09:03:15] PRINT ====================================================================================================
[2023-02-09 09:03:15] PRINT Total params: 47,696,384
[2023-02-09 09:03:15] PRINT Trainable params: 47,696,384
[2023-02-09 09:03:15] PRINT Non-trainable params: 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants