Skip to content

Commit

Permalink
minor updates to the scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 7, 2024
1 parent 266e840 commit 32a7d22
Show file tree
Hide file tree
Showing 19 changed files with 422 additions and 46 deletions.
26 changes: 18 additions & 8 deletions egs/librispeech/ASR/zipformer/attention_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def forward(
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
attn_mask = torch.logical_or(
padding_mask.unsqueeze(1), # (batch, 1, seq_len)
torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len)
torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len)
) # (batch, seq_len, seq_len)

if memory is not None:
Expand Down Expand Up @@ -367,7 +367,9 @@ def __init__(
self.num_heads = num_heads
self.head_dim = attention_dim // num_heads
assert self.head_dim * num_heads == attention_dim, (
self.head_dim, num_heads, attention_dim
self.head_dim,
num_heads,
attention_dim,
)
self.dropout = dropout
self.name = None # will be overwritten in training code; for diagnostics.
Expand Down Expand Up @@ -437,15 +439,19 @@ def forward(
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)

if attn_mask is not None:
assert (
attn_mask.shape == (batch, 1, src_len)
or attn_mask.shape == (batch, tgt_len, src_len)
assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == (
batch,
tgt_len,
src_len,
), attn_mask.shape
attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf"))
attn_weights = attn_weights.masked_fill(
attn_mask.unsqueeze(1), float("-inf")
)

attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
Expand All @@ -456,7 +462,11 @@ def forward(

# (batch * head, tgt_len, head_dim)
attn_output = torch.bmm(attn_weights, v)
assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape
assert attn_output.shape == (
batch * num_heads,
tgt_len,
head_dim,
), attn_output.shape

attn_output = attn_output.transpose(0, 1).contiguous()
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)
Expand Down
12 changes: 7 additions & 5 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def build_inputs_outputs(tensors, i):

add_meta_data(filename=encoder_filename, meta_data=meta_data)


def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
Expand Down Expand Up @@ -754,30 +755,31 @@ def main():
)
logging.info(f"Exported joiner to {joiner_filename}")

if(params.fp16) :
if params.fp16:
from onnxconverter_common import float16

logging.info("Generate fp16 models")

encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16)
onnx.save(encoder_fp16, encoder_filename_fp16)

decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16)
onnx.save(decoder_fp16, decoder_filename_fp16)

joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16)
onnx.save(joiner_fp16, joiner_filename_fp16)

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection

logging.info("Generate int8 quantization models")

encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
Expand Down
8 changes: 4 additions & 4 deletions egs/librispeech/ASR/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,23 +592,23 @@ def main():
)
logging.info(f"Exported joiner to {joiner_filename}")

if(params.fp16) :
if params.fp16:
logging.info("Generate fp16 models")

encoder = onnx.load(encoder_filename)
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
onnx.save(encoder_fp16,encoder_filename_fp16)
onnx.save(encoder_fp16, encoder_filename_fp16)

decoder = onnx.load(decoder_filename)
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
onnx.save(decoder_fp16,decoder_filename_fp16)
onnx.save(decoder_fp16, decoder_filename_fp16)

joiner = onnx.load(joiner_filename)
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
onnx.save(joiner_fp16,joiner_filename_fp16)
onnx.save(joiner_fp16, joiner_filename_fp16)

# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
Expand Down
2 changes: 1 addition & 1 deletion egs/libritts/ASR/local/compute_fbank_libritts.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def compute_fbank_libritts(
supervisions=m["supervisions"],
)
if sampling_rate != 24000:
logging.info(f"Resampling audio to {sampling_rate}")
logging.info(f"Resampling audio to {sampling_rate}Hz")
cut_set = cut_set.resample(sampling_rate)
if "train" in partition:
if perturb_speed:
Expand Down
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/download_lm.py
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/norm_text.py
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/prepare_lang.py
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/prepare_lang_bpe.py
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/prepare_lang_fst.py
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/prepare_lm_training_data.py
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/train_bpe_model.py
1 change: 1 addition & 0 deletions egs/libritts/ASR/local/validate_bpe_lexicon.py
87 changes: 86 additions & 1 deletion egs/libritts/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ set -eou pipefail

stage=0
stop_stage=100
sampling_rate=24000
sampling_rate=16000
nj=32
perturb_speed=true
vocab_sizes=(
# 5000
# 2000
# 1000
500
)

dl_dir=$PWD/download

Expand All @@ -27,6 +33,15 @@ log() {

log "dl_dir: $dl_dir"

if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download LM" # we directly use the librispeech lm here
mkdir -p $dl_dir/lm
if [ ! -e $dl_dir/lm/.done ]; then
./local/download_lm.py --out-dir=$dl_dir/lm
touch $dl_dir/lm/.done
fi
fi

if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"

Expand Down Expand Up @@ -107,3 +122,73 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
touch data/fbank/.msuan.done
fi
fi

if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Train BPE model for normalized text"

if [ ! -f data/texts ]; then
gunzip -c data/manifests/libritts_supervisions_train-clean-100.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/norm_text.py > data/texts

gunzip -c data/manifests/libritts_supervisions_train-clean-360.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/norm_text.py >> data/texts

gunzip -c data/manifests/libritts_supervisions_train-other-500.jsonl.gz \
| jq ".text" | sed 's/"//g' \
| ./local/norm_text.py >> data/texts
fi

for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir

cp data/texts $lang_dir/text

if [ ! -f $lang_dir/bpe.model ]; then
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/text
fi
done
fi

if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare phone based lang"
lang_dir=data/lang_phone
mkdir -p $lang_dir

if [ ! -f $dl_dir/lm/librispeech-lexicon.txt ]; then
log "No lexicon file in $dl_dir/lm, please run :"
log "prepare.sh --stage -1 --stop-stage -1"
exit -1
fi

if [ ! -f $lang_dir/lexicon.txt ]; then
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt
fi

if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
fi

if [ ! -f $lang_dir/L.fst ]; then
log "Converting L.pt to L.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L.pt \
$lang_dir/L.fst
fi

if [ ! -f $lang_dir/L_disambig.fst ]; then
log "Converting L_disambig.pt to L_disambig.fst"
./shared/convert-k2-to-openfst.py \
--olabels aux_labels \
$lang_dir/L_disambig.pt \
$lang_dir/L_disambig.fst
fi
fi
Loading

0 comments on commit 32a7d22

Please sign in to comment.