Skip to content

Commit

Permalink
Merge pull request #251 from TensorSpeech/refactor
Browse files Browse the repository at this point in the history
Refactor: add support tf2.8 + update imports + update examples + add helpers
  • Loading branch information
nglehuy authored Mar 12, 2022
2 parents c426e8f + 68ead53 commit caee2e7
Show file tree
Hide file tree
Showing 40 changed files with 1,225 additions and 1,491 deletions.
6 changes: 3 additions & 3 deletions examples/conformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Go to [config.yml](./config.yml)

## Usage

Training, see `python examples/conformer/train_*.py --help`
Training, see `python examples/conformer/train.py --help`

Testing, see `python examples/conformer/test_*.py --help`
Testing, see `python examples/conformer/test.py --help`

TFLite Conversion, see `python examples/conformer/tflite_*.py --help`
TFLite Conversion, see `python examples/conformer/inference/gen_tflite_model.py --help`

## Conformer Subwords - Results on LibriSpeech

Expand Down
20 changes: 11 additions & 9 deletions examples/conformer/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ decoder_config:
beam_width: 0
norm_score: True
corpus_files:
- /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
- H:/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv

model_config:
name: conformer
Expand Down Expand Up @@ -75,8 +75,8 @@ learning_config:
num_masks: 1
mask_factor: 27
data_paths:
- /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
tfrecords_dir: /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
- H:/MLDL/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
tfrecords_dir: null
shuffle: True
cache: True
buffer_size: 100
Expand All @@ -85,8 +85,9 @@ learning_config:

eval_dataset_config:
use_tf: True
data_paths: null
tfrecords_dir: /mnt/Data/MLDL/Datasets/ASR/Raw/LibriSpeech/tfrecords_1030
data_paths:
- H:/MLDL/Datasets/ASR/Raw/LibriSpeech/dev-clean/transcripts.tsv
tfrecords_dir: null
shuffle: False
cache: True
buffer_size: 100
Expand All @@ -95,7 +96,8 @@ learning_config:

test_dataset_config:
use_tf: True
data_paths: null
data_paths:
- H:/MLDL/Datasets/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv
tfrecords_dir: null
shuffle: False
cache: True
Expand All @@ -113,13 +115,13 @@ learning_config:
batch_size: 2
num_epochs: 50
checkpoint:
filepath: /mnt/Miscellanea/Models/local/conformer/checkpoints/{epoch:02d}.h5
filepath: D:/Models/local/conformer/checkpoints/{epoch:02d}.h5
save_best_only: False
save_weights_only: True
save_freq: epoch
states_dir: /mnt/Miscellanea/Models/local/conformer/states
states_dir: D:/Models/local/conformer/states
tensorboard:
log_dir: /mnt/Miscellanea/Models/local/conformer/tensorboard
log_dir: D:/Models/local/conformer/tensorboard
histogram_freq: 1
write_graph: True
write_images: True
Expand Down
115 changes: 51 additions & 64 deletions examples/conformer/inference/gen_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import fire

from tensorflow_asr.utils import env_util

Expand All @@ -22,71 +22,58 @@

DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")

tf.keras.backend.clear_session()

parser = argparse.ArgumentParser(prog="Conformer Testing")

parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")

parser.add_argument("--h5", type=str, default=None, help="Path to saved h5 weights")

parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")

parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")

parser.add_argument("--output_dir", type=str, default=None, help="Output directory for saved model")

args = parser.parse_args()

assert args.h5
assert args.output_dir

from tensorflow_asr.configs.config import Config
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SentencePieceFeaturizer, SubwordFeaturizer
from tensorflow_asr.helpers import featurizer_helpers
from tensorflow_asr.models.transducer.conformer import Conformer

config = Config(args.config)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)

if args.sentence_piece:
logger.info("Use SentencePiece ...")
text_featurizer = SentencePieceFeaturizer(config.decoder_config)
elif args.subwords:
logger.info("Use subwords ...")
text_featurizer = SubwordFeaturizer(config.decoder_config)
else:
logger.info("Use characters ...")
text_featurizer = CharFeaturizer(config.decoder_config)

tf.random.set_seed(0)

# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer.make(speech_featurizer.shape)
conformer.load_weights(args.h5, by_name=True)
conformer.summary(line_length=100)
conformer.add_featurizers(speech_featurizer, text_featurizer)


class ConformerModule(tf.Module):
def __init__(self, model: Conformer, name=None):
super().__init__(name=name)
self.model = model
self.num_rnns = config.model_config["prediction_num_rnns"]
self.rnn_units = config.model_config["prediction_rnn_units"]
self.rnn_nstates = 2 if config.model_config["prediction_rnn_type"] == "lstm" else 1

@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
def pred(self, signal):
predicted = tf.constant(0, dtype=tf.int32)
states = tf.zeros([self.num_rnns, self.rnn_nstates, 1, self.rnn_units], dtype=tf.float32)
features = self.model.speech_featurizer.tf_extract(signal)
encoded = self.model.encoder_inference(features)
hypothesis = self.model._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=False)
transcript = self.model.text_featurizer.indices2upoints(hypothesis.prediction)
return transcript


module = ConformerModule(model=conformer)
tf.saved_model.save(module, export_dir=args.output_dir, signatures=module.pred.get_concrete_function())
def main(
config: str = DEFAULT_YAML,
h5: str = None,
sentence_piece: bool = False,
subwords: bool = False,
output_dir: str = None,
):
assert h5 and output_dir
config = Config(config)
tf.random.set_seed(0)
tf.keras.backend.clear_session()

speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
config=config,
subwords=subwords,
sentence_piece=sentence_piece,
)

# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer.make(speech_featurizer.shape)
conformer.load_weights(h5, by_name=True)
conformer.summary(line_length=100)
conformer.add_featurizers(speech_featurizer, text_featurizer)

class ConformerModule(tf.Module):
def __init__(self, model: Conformer, name=None):
super().__init__(name=name)
self.model = model
self.num_rnns = config.model_config["prediction_num_rnns"]
self.rnn_units = config.model_config["prediction_rnn_units"]
self.rnn_nstates = 2 if config.model_config["prediction_rnn_type"] == "lstm" else 1

@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
def pred(self, signal):
predicted = tf.constant(0, dtype=tf.int32)
states = tf.zeros([self.num_rnns, self.rnn_nstates, 1, self.rnn_units], dtype=tf.float32)
features = self.model.speech_featurizer.tf_extract(signal)
encoded = self.model.encoder_inference(features)
hypothesis = self.model._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=False)
transcript = self.model.text_featurizer.indices2upoints(hypothesis.prediction)
return transcript

module = ConformerModule(model=conformer)
tf.saved_model.save(module, export_dir=output_dir, signatures=module.pred.get_concrete_function())


if __name__ == "__main__":
fire.Fire(main)
67 changes: 27 additions & 40 deletions examples/conformer/inference/gen_tflite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,45 @@
# limitations under the License.

import os
import argparse
from tensorflow_asr.utils import env_util, file_util
import fire
from tensorflow_asr.utils import env_util

logger = env_util.setup_environment()
import tensorflow as tf

from tensorflow_asr.configs.config import Config
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, CharFeaturizer
from tensorflow_asr.helpers import exec_helpers, featurizer_helpers
from tensorflow_asr.models.transducer.conformer import Conformer

DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")

tf.keras.backend.clear_session()
tf.compat.v1.enable_control_flow_v2()

parser = argparse.ArgumentParser(prog="Conformer TFLite")
def main(
config: str = DEFAULT_YAML,
h5: str = None,
subwords: bool = False,
sentence_piece: bool = False,
output: str = None,
):
assert h5 and output
tf.keras.backend.clear_session()
tf.compat.v1.enable_control_flow_v2()

parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
config = Config(config)
speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(
config=config,
subwords=subwords,
sentence_piece=sentence_piece,
)

parser.add_argument("--h5", type=str, default=None, help="Path to saved model")
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer.make(speech_featurizer.shape)
conformer.load_weights(h5, by_name=True)
conformer.summary(line_length=100)
conformer.add_featurizers(speech_featurizer, text_featurizer)

parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
exec_helpers.convert_tflite(model=conformer, output=output)

parser.add_argument("output", type=str, default=None, help="TFLite file path to be exported")

args = parser.parse_args()

assert args.h5 and args.output

config = Config(args.config)
speech_featurizer = TFSpeechFeaturizer(config.speech_config)

if args.subwords:
text_featurizer = SubwordFeaturizer(config.decoder_config)
else:
text_featurizer = CharFeaturizer(config.decoder_config)

# build model
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
conformer.make(speech_featurizer.shape)
conformer.load_weights(args.h5, by_name=True)
conformer.summary(line_length=100)
conformer.add_featurizers(speech_featurizer, text_featurizer)

concrete_func = conformer.make_tflite_function().get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

args.output = file_util.preprocess_paths(args.output)
with open(args.output, "wb") as tflite_out:
tflite_out.write(tflite_model)
if __name__ == "__main__":
fire.Fire(main)
24 changes: 13 additions & 11 deletions examples/conformer/inference/run_saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import fire

from tensorflow_asr.utils import env_util

Expand All @@ -22,21 +22,23 @@

DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")

tf.keras.backend.clear_session()

parser = argparse.ArgumentParser()
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio

parser.add_argument("--saved_model", type=str, default=None, help="The file path of saved model")

parser.add_argument("filename", type=str, default=None, help="Audio file path")
def main(
saved_model: str = None,
filename: str = None,
):
tf.keras.backend.clear_session()

args = parser.parse_args()
module = tf.saved_model.load(export_dir=saved_model)

from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
signal = read_raw_audio(filename)
transcript = module.pred(signal)

module = tf.saved_model.load(export_dir=args.saved_model)
print("Transcript: ", "".join([chr(u) for u in transcript]))

signal = read_raw_audio(args.filename)
transcript = module.pred(signal)

print("Transcript: ", "".join([chr(u) for u in transcript]))
if __name__ == "__main__":
fire.Fire(main)
49 changes: 23 additions & 26 deletions examples/conformer/inference/run_tflite_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import fire
import tensorflow as tf

from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio

parser = argparse.ArgumentParser()

parser.add_argument("filename", metavar="FILENAME", help="Audio file to be played back")
def main(
filename: str,
tflite: str = None,
blank: int = 0,
num_rnns: int = 1,
nstates: int = 2,
statesize: int = 320,
):
tflitemodel = tf.lite.Interpreter(model_path=tflite)

parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite")
signal = read_raw_audio(filename)

parser.add_argument("--blank", type=int, default=0, help="Blank index")
input_details = tflitemodel.get_input_details()
output_details = tflitemodel.get_output_details()
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
tflitemodel.allocate_tensors()
tflitemodel.set_tensor(input_details[0]["index"], signal)
tflitemodel.set_tensor(input_details[1]["index"], tf.constant(blank, dtype=tf.int32))
tflitemodel.set_tensor(input_details[2]["index"], tf.zeros([num_rnns, nstates, 1, statesize], dtype=tf.float32))
tflitemodel.invoke()
hyp = tflitemodel.get_tensor(output_details[0]["index"])

parser.add_argument("--num_rnns", type=int, default=1, help="Number of RNN layers in prediction network")
print("".join([chr(u) for u in hyp]))

parser.add_argument("--nstates", type=int, default=2, help="Number of RNN states in prediction network")

parser.add_argument("--statesize", type=int, default=320, help="Size of RNN state in prediction network")

args = parser.parse_args()

tflitemodel = tf.lite.Interpreter(model_path=args.tflite)

signal = read_raw_audio(args.filename)

input_details = tflitemodel.get_input_details()
output_details = tflitemodel.get_output_details()
tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape)
tflitemodel.allocate_tensors()
tflitemodel.set_tensor(input_details[0]["index"], signal)
tflitemodel.set_tensor(input_details[1]["index"], tf.constant(args.blank, dtype=tf.int32))
tflitemodel.set_tensor(input_details[2]["index"], tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32))
tflitemodel.invoke()
hyp = tflitemodel.get_tensor(output_details[0]["index"])

print("".join([chr(u) for u in hyp]))
if __name__ == "__main__":
fire.Fire(main)
Loading

0 comments on commit caee2e7

Please sign in to comment.