diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml index 9b7a9701c4f2..4c9ce8f08cf0 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -5,9 +5,9 @@ # These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the part1 (callhome1) specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v2/run.sh # Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. parameters: - onset: 0.53 # Onset threshold for detecting the beginning and end of a speech - offset: 0.49 # Offset threshold for detecting the end of a speech - pad_onset: 0.23 # Adding durations before each speech segment - pad_offset: 0.01 # Adding durations after each speech segment - min_duration_on: 0.42 # Threshold for small non-speech deletion - min_duration_off: 0.34 # Threshold for short speech segment deletion \ No newline at end of file + onset: 0.53 # Onset threshold for detecting the beginning of a speech segment + offset: 0.49 # Offset threshold for detecting the end of a speech segment + pad_onset: 0.23 # Adds the specified duration at the beginning of each speech segment + pad_offset: 0.01 # Adds the specified duration at the end of each speech segment + min_duration_on: 0.42 # Removes short silences if the duration is less than the specified minimum duration + min_duration_off: 0.34 # Removes short speech segments if the duration is less than the specified minimum duration \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml index ebf994c10f2e..e168a97083bc 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml @@ -5,9 +5,9 @@ # These parameters were optimized on the development split of DIHARD3 dataset (See https://arxiv.org/pdf/2012.01477). # Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. parameters: - onset: 0.64 # Onset threshold for detecting the beginning and end of a speech - offset: 0.74 # Offset threshold for detecting the end of a speech - pad_onset: 0.06 # Adding durations before each speech segment - pad_offset: 0.0 # Adding durations after each speech segment - min_duration_on: 0.1 # Threshold for small non-speech deletion - min_duration_off: 0.15 # Threshold for short speech segment deletion \ No newline at end of file + onset: 0.64 # Onset threshold for detecting the beginning of a speech segment + offset: 0.74 # Offset threshold for detecting the end of a speech segment + pad_onset: 0.06 # Adds the specified duration at the beginning of each speech segment + pad_offset: 0.0 # Adds the specified duration at the end of each speech segment + min_duration_on: 0.1 # Removes short silences if the duration is less than the specified minimum duration + min_duration_off: 0.15 # Removes short speech segments if the duration is less than the specified minimum duration \ No newline at end of file diff --git a/tutorials/speaker_tasks/End_to_end_Diarization_Inference.ipynb b/tutorials/speaker_tasks/End_to_end_Diarization_Inference.ipynb new file mode 100644 index 000000000000..367e10254f4b --- /dev/null +++ b/tutorials/speaker_tasks/End_to_end_Diarization_Inference.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies.\n", + "\"\"\"\n", + "# If you're using Google Colab and not running locally, run this cell.\n", + "\n", + "## Install dependencies\n", + "!pip install wget\n", + "!apt-get install sox libsndfile1 ffmpeg\n", + "!pip install text-unidecode\n", + "!pip install ipython\n", + "\n", + "# ## Install NeMo\n", + "BRANCH = 'main'\n", + "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr]\n", + "\n", + "## Install TorchAudio\n", + "!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# End-to-End Speaker Diarization in NeMo" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"PIL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Traditional cascaded (also referred to as modular or pipelined) speaker diarization systems consist of multiple modules such as a speaker activity detection (SAD) module and a speaker embedding extractor module. Cascaded systems are often time-consuming to develop since each module should be separately trained and optimized." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"End-to-end" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On the other hand, end-to-end speaker diarization systems pursue a much more simplified version of a system where a single neural network model accepts raw audio signals and outputs speaker activity for each audio frame. Therefore, end-to-end diarization models have an advantage in ease of optimization and depoloyments." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sortformer Diarization Inference\n", + "\n", + "As explained in the [Sortformer Diarization Training](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb) tutorial, Sortformer is trained with Sort-Loss to generate speaker segments in arrival-time order. If a diarization model can generate speaker segments in a pre-defined rule or order, we do not need to match the permutations when we train diarization model with multi-speaker automatic speech recognition (ASR) models or we do not need to match permutations from each window when a diarization model is running in streaming mode where audio chunk sequences are processed, creating a problem of permutation matchin between inference windows. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"PIL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### A toy example for speaker diarization with a Sortformer model " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Download a toy example audio file (`an4_diarize_test.wav`) and its ground-truth label file (`an4_diarize_test.rttm`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import wget\n", + "ROOT = os.getcwd()\n", + "data_dir = os.path.join(ROOT,'data')\n", + "os.makedirs(data_dir, exist_ok=True)\n", + "an4_audio = os.path.join(data_dir,'an4_diarize_test.wav')\n", + "an4_rttm = os.path.join(data_dir,'an4_diarize_test.rttm')\n", + "if not os.path.exists(an4_audio):\n", + " an4_audio_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav\"\n", + " an4_audio = wget.download(an4_audio_url, data_dir)\n", + "if not os.path.exists(an4_rttm):\n", + " an4_rttm_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.rttm\"\n", + " an4_rttm = wget.download(an4_rttm_url, data_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's plot and listen to the audio. Pay attention to the two speakers in the audio clip." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import IPython\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import librosa\n", + "\n", + "sr = 16000\n", + "signal, sr = librosa.load(an4_audio,sr=sr) \n", + "\n", + "fig,ax = plt.subplots(1,1)\n", + "fig.set_figwidth(20)\n", + "fig.set_figheight(2)\n", + "plt.plot(np.arange(len(signal)),signal,'gray')\n", + "fig.suptitle('Reference merged an4 audio', fontsize=16)\n", + "plt.xlabel('time (secs)', fontsize=18)\n", + "ax.margins(x=0)\n", + "plt.ylabel('signal strength', fontsize=16)\n", + "a,_ = plt.xticks();plt.xticks(a,a/sr)\n", + "\n", + "IPython.display.Audio(an4_audio)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have downloaded the example audio file contains multiple speakers, we can feed the audio clip into Sortformer diarizer and get the speaker diarization results." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download Sortformer diarizer model\n", + "\n", + "To download Sortformer diarizer from [Sortformer HuggingFace model card](https://huggingface.co/nvidia/diar_sortformer_4spk-v1), you need to get a [HuggingFace Acces Token](https://huggingface.co/docs/hub/en/security-tokens) and feed your access token in your python environment using [HuggingFace CLI](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli).\n", + "\n", + "If you are having trouble getting a HuggingFace token, you can download Sortformer model from [Sortformer HuggingFace model card](https://huggingface.co/nvidia/diar_sortformer_4spk-v1) and specify the path to the downloaded model.\n", + "\n", + "### Timestamp output and tensor output\n", + "\n", + "When excuting `diarize()` function, if you specify `include_tensor_outputs=True`, the diarization model will return the predicted speaker-labeled segments and tensors containing T by N (N is number of max speakers) sigmoid values for each audio frame. \n", + "\n", + "Without `include_tensor_outputs` variable or `include_tensor_outputs=False`, only speaker labeled segments will be returned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.asr.models import SortformerEncLabelModel\n", + "from huggingface_hub import get_token as get_hf_token\n", + "import torch\n", + "\n", + "if get_hf_token() is not None and get_hf_token().startswith(\"hf_\"):\n", + " # If you have logged into HuggingFace hub and have access token \n", + " diar_model = SortformerEncLabelModel.from_pretrained(\"nvidia/diar_sortformer_4spk-v1\")\n", + "else:\n", + " # You can downloaded \".nemo\" file from https://huggingface.co/nvidia/diar_sortformer_4spk-v1 and specify the path.\n", + " diar_model = SortformerEncLabelModel.restore_from(restore_path=\"/path/to/diar_sortformer_4spk-v1.nemo\", map_location=torch.device('cuda'), strict=False)\n", + "\n", + "pred_list, pred_tensor_list = diar_model.diarize(audio=an4_audio, include_tensor_outputs=True)\n", + "\n", + "print(pred_list[0])\n", + "print(pred_tensor_list[0].shape)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Now let's visualize the predicted speaker diarization results. The diarization model outputs a tensor where each row represents a speaker and each column represents a frame. The sigmoid values in the tensor represent the probability of the frame being spoken by that speaker.\n", + "\n", + "In the following code, we'll plot the predicted speaker diarization results for the sample audio file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "def plot_diarout(preds):\n", + " preds_mat = preds.cpu().numpy().transpose()\n", + " cmap_str, grid_color_p= 'viridis', 'gray'\n", + " LW, FS = 0.4, 36\n", + "\n", + " yticklabels = [\"spk0\", \"spk1\", \"spk2\", \"spk3\"]\n", + " yticks = np.arange(len(yticklabels))\n", + " fig, axs = plt.subplots(1, 1, figsize=(30, 3)) # 1 row, 2 columns for preds and targets\n", + "\n", + " axs.imshow(preds_mat, cmap=cmap_str, interpolation='nearest') #, aspect=aspect_float)\n", + " axs.set_title('Predictions', fontsize=FS)\n", + " axs.set_xticks(np.arange(-.5, preds_mat.shape[1], 1), minor=True)\n", + " axs.set_yticks(yticks)\n", + " axs.set_yticklabels(yticklabels)\n", + " axs.set_xlabel(f\"80 ms Frames\", fontsize=FS)\n", + " axs.grid(which='minor', color=grid_color_p, linestyle='-', linewidth=LW)\n", + "\n", + " plt.savefig('plot.png', dpi=300) # bbox_inches='tight')\n", + " plt.show()\n", + "\n", + "\n", + "plot_diarout(pred_tensor_list[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the first speaker always gets the first generic speaker label `spk0`. Sortformer model is trained to generate speaker labels in arrival time order, thus permutations of speakers are always predictable if we know the arrival time order of speakers." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Post-Processing of Speaker Timestamps\n", + "\n", + "In the previous steps, we have obtained predictions of speaker activity in a form of Tensors. While the floating point probability values can be interpreted as speaker probabilities, these values are not designed to consumed as is and still requires to be processed to be segment information which has start and end time for each time stamp.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf\n", + "import os\n", + "import wget\n", + "import pandas as pd\n", + "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object\n", + "from pyannote.metrics.diarization import DiarizationErrorRate\n", + "\n", + "ROOT = os.getcwd()\n", + "data_dir = os.path.join(ROOT,'data')\n", + "\n", + "# MODEL_CONFIG = os.path.join(data_dir,'sortformer_diar_4spk-v1_callhome-part1.yaml')\n", + "yaml_name=\"sortformer_diar_4spk-v1_dihard3-dev.yaml\"\n", + "MODEL_CONFIG = os.path.join(data_dir, yaml_name)\n", + "# if not os.path.exists(MODEL_CONFIG):\n", + "if True:\n", + " config_url = f\"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/post_processing/{yaml_name}\"\n", + " MODEL_CONFIG = wget.download(config_url, data_dir)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The post-processing yaml file `sortformer_diar_4spk-v1_dihard3-dev.yaml` is containing parameters that are optimized to have the lowest Diarization Error Rate (DER) on the [DiHARD 3 development set](https://catalog.ldc.upenn.edu/LDC2022S12)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.asr.parts.utils.vad_utils import load_postprocessing_from_yaml\n", + "import json\n", + "from omegaconf import OmegaConf \n", + "post_processing_params = load_postprocessing_from_yaml(MODEL_CONFIG)\n", + "print(json.dumps(OmegaConf.to_container(post_processing_params), indent=4))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Post-processing of Speaker Timestamps\n", + "\n", + "The parameters in post-processing yaml configurations perform the following tasks:\n", + "\n", + "| **Parameter** | **Description** |\n", + "|-------------------:|--------------------------------------------------------------|\n", + "| **onset** | Onset threshold for detecting the beginning of a speech segment. |\n", + "| **offset** | Offset threshold for detecting the end of a speech segment. |\n", + "| **pad_onset** | Adds the specified duration at the beginning of each speech segment. |\n", + "| **pad_offset** | Adds the specified duration at the end of each speech segment. |\n", + "| **min_duration_on**| Removes short silences if the duration is less than the specified minimum duration. |\n", + "| **min_duration_off**| Removes short speech segments if the duration is less than the specified minimum duration. |\n", + "\n", + "Now let's check the diarization output timestamps and compare how post-processing changes the diarization output." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def show_diar_df(pred_session_list):\n", + " data = [segment.split() for segment in pred_session_list]\n", + " df = pd.DataFrame(data, columns=['Start Time', 'End Time', 'Speaker'])\n", + " print(df)\n", + "\n", + "# Call `diarize` method without postprocessing params\n", + "pred_list_bn = diar_model.diarize(audio=an4_audio)\n", + "# \"data/input_manifest.json\")\n", + "print(f\" [Default Binarized Diarization Output]: \")\n", + "show_diar_df(pred_list_bn[0])\n", + "\n", + "# Call `diarize` method with postprocessing params\n", + "pred_list_pp = diar_model.diarize(audio=an4_audio, postprocessing_yaml=MODEL_CONFIG)\n", + "print(f\" [Post-processed Diarization Output]: \")\n", + "show_diar_df(pred_list_pp[0])\n", + "\n", + "print(f\" [Ground-truth Diarization Output]: \")\n", + "ref_labels = rttm_to_labels(an4_rttm)\n", + "show_diar_df(ref_labels)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see that the post-processed segments have more on-set padding while the differences are not significant. \n", + "\n", + "Now let's calculate DER (Diarization Error Rate) between the predicted labels and the ground-truth labels for both raw binarized and post-processed diarization outputs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the refernce labels from ground-truth RTTM file\n", + "ref_labels = rttm_to_labels(an4_rttm)\n", + "\n", + "reference = labels_to_pyannote_object(ref_labels, uniq_name=\"binarize\")\n", + "hypothesis1 = labels_to_pyannote_object(pred_list_bn[0], uniq_name=\"binarize\")\n", + "der_metric1 = DiarizationErrorRate(collar=0, skip_overlap=False)\n", + "der_metric1(reference, hypothesis1, detailed=True)\n", + "print(f\"Raw Binarization DER: {abs(der_metric1):.6f}\")\n", + "\n", + "reference = labels_to_pyannote_object(ref_labels, uniq_name=\"post_processing\")\n", + "hypothesis2 = labels_to_pyannote_object(pred_list_pp[0], uniq_name=\"post_processing\")\n", + "der_metric2 = DiarizationErrorRate(collar=0, skip_overlap=False)\n", + "der_metric2(reference, hypothesis2, detailed=True)\n", + "print(f\"Post-Processing DER: {abs(der_metric2):.6f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since the diarization output with post-processing is optimized to get the lowest DER for the given sigmoid tensor outputs, it generaly gives lower DER values when compared to the raw binarized diarization output. To achieve the lowest DER, it is recommended to use the post-processing parameters that are optimized for your dataset of interest and your diarization model. " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + }, + "vscode": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/speaker_tasks/End_to_end_Diarization_Training.ipynb b/tutorials/speaker_tasks/End_to_end_Diarization_Training.ipynb new file mode 100644 index 000000000000..ae27daa045e2 --- /dev/null +++ b/tutorials/speaker_tasks/End_to_end_Diarization_Training.ipynb @@ -0,0 +1,893 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies.\n", + "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n", + "\"\"\"\n", + "# If you're using Google Colab and not running locally, run this cell.\n", + "\n", + "NEMO_DIR_PATH = \"NeMo\"\n", + "BRANCH = 'main'\n", + "\n", + "! git clone https://github.com/NVIDIA/NeMo\n", + "%cd NeMo\n", + "! python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", + "%cd .." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# End-to-end Speaker Diarization with Sortformer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sortformer: Bridging the Gap between tokens (ASR) and Timestamps (Diarization)\n", + "\n", + "### Speaker Diarization as a part of ASR system\n", + "\n", + "Speaker diarization is all about figuring out who’s speaking when in an audio recording. In the world of automatic speech recognition (ASR), this becomes even more important for handling conversations with multiple speakers. Multispeaker ASR (also called speaker-attributed or multitalker ASR) uses this process to not just transcribe what’s being said, but also to label each part of the transcript with the right speaker.\n", + "\n", + "As ASR technology continues to advance, speaker diarization is increasingly becoming part of the ASR workflow itself. Some systems now handle speaker labeling and transcription at the same time during decoding. This means you don’t just get accurate text—you're also getting insights into who said what, making it more useful for conversational analysis.\n", + "\n", + "### Challenges in Integrating Speaker Diarization and ASR\n", + "\n", + "However, despite significant advancements, integrating speaker diarization and ASR into a unified, seamless system remains a considerable challenge. A key obstacle lies in the need for extensive high-quality, annotated audio data featuring multiple speakers. Acquiring such data is far more complex than collecting single-speaker audio or image datasets. This challenge is particularly pronounced for low-resource languages and domains like healthcare, where strict privacy regulations further constrain data availability.\n", + "\n", + "On top of that, many real-world use cases need these models to handle really long audio files—sometimes hours of conversation at a time. Training on such lengthy data is even more complicated because it’s hard to find or annotate. This creates a big gap between what’s needed and what’s available, making multispeaker ASR one of the toughest nuts to crack in the field of speech technology." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"PIL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Sortformer: Simplifying Multispeaker ASR with Arrival Time Sorting\n", + "\n", + "To tackle the complexities of multispeaker automatic speech recognition (ASR), we introduce _*Sortformer*_, a new approach that incorporates _*Sort-Loss*_ and techniques to align timestamps with text tokens. Traditional approaches like permutation-invariant loss (PIL) face challenges when applied in batchable and differentiable computational graphs, especially since token-based objectives struggle to incorporate speaker-specific attributes into PIL-based loss functions.\n", + "\n", + "To address this, we propose an arrival time sorting (ATS) approach. In this method, speaker tokens from ASR outputs and speaker timestamps from diarization outputs are sorted by their arrival times to resolve permutations. This approach allows the multispeaker ASR system to be trained or fine-tuned using token-based cross-entropy loss, eliminating the need for timestamp-based or frame-level objectives with PIL." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"PIL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The ATS-based multispeaker ASR system is powered by an end-to-end neural diarizer model, Sortformer, which generates speaker-label timestamps in arrival time order (ATO). To train the neural diarizer to produce sorted outputs, we introduce Sort Loss, a method that creates gradients enabling the Transformer model to learn the ATS mechanism.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"PIL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additionally, as shown in the above figure, our diarization system integrates directly with the ASR encoder. By embedding speaker supervision data as speaker kernels into the ASR encoder states, the system seamlessly combines speaker and transcription information. This unified approach improves performance and simplifies the overall architecture.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As a result, our end-to-end multispeaker ASR system is fully or partially trainable with token objectives, allowing both the ASR and speaker diarization modules to be trained or fine-tuned using these objectives. Additionally, during the multispeaker ASR training phase, no specialized loss calculation functions are needed when using Sortformer, as frameworks for standard single-speaker ASR models can be employed. These compatibilities greatly simplify and accelerate the training and fine-tuning process of multispeaker ASR systems. \n", + "\n", + "On top of all these benefits, _*Sortformer*_ can be used as a stand-alone end-to-end speaker diarization model. By training a Sortformer diarizer model especially on high-quality simulated data with accurate time-stamps, you can boost the performance of multi-speaker ASR systems, just by integrating the _*Sortformer*_ model as _*Speaker Supervision*_ model in a computation graph.\n", + "\n", + "In this tutorial, we will walk you through the process of training a Sortformer diarizer model with toy dataset. Before starting, we will introduce the concepts of Sort-Loss calculation and the Hybrid loss technique." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## _*Sort-Loss*_ for _*Sortformer*_ Diarizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"PIL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_*Sort-Loss*_ is designed to compare the predicted outputs with the true labels, typically sorted in arrival-time order or another relevant metric. The key distinction that _*Sortformer*_ introduces compared to previous end-to-end diarization systems such as [EEND-SA](https://arxiv.org/pdf/1909.06247), [EEND-EDA](https://arxiv.org/abs/2106.10654) lies in the organization of class presence $\\mathbf{\\hat{Y}}$.\n", + "\n", + "The figure below illustrates the difference between _*Sort-Loss*_ and permutation-invariant loss (PIL) or permutation-free loss.\n", + "\n", + " - PIL is calculated by finding the permutation of the target that minimizes the loss value between the prediction and the target.\n", + "\n", + " - _*Sort-Loss*_ simply compares the arrival-time-sorted version of speaker activity outputs for both the prediction and the target. Note that sometimes the same ground-truth labels lead to different target matrices for _*Sort-Loss*_ and PIL.\n", + "\n", + "For example, the figure below shows two identical source target matrices (the two matrices at the top), but the resulting target matrices for _*Sort-Loss*_ and PIL are different." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"PIL" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In mathmatical terms, _*Sort-Loss*_ can be expressed as follows:\n", + "\n", + "* **Arrival Time Sorting Function with $\\Psi$ function** \n", + "\n", + " Let $\\Psi$ be a function that determines the first segment's arrival time for a speaker bin:\n", + "$$\n", + " \\Psi\\big(\\mathbf{y}_{k}\\big) = \\min \\{ t' \\mid y_{k, t'} \\neq 0, t' \\in [1, T] \\} = t^{0}_{k},\n", + "$$\n", + " where $t^{0}_{k}$ is the frame index of the first active segment for the $k$-th speaker.\n", + "\n", + "Sortformer aims to generate predictions $\\mathbf{\\hat{y}}_{k}$ for each speaker $k$ such that:\n", + "$$\n", + "\\Psi(\\mathbf{\\hat{y}}_{1}) \\leq \\Psi(\\mathbf{\\hat{y}}_{2}) \\leq \\cdots \\leq \\Psi(\\mathbf{\\hat{y}}_{K}),\n", + "$$\n", + "ensuring the model produces class presence outputs ($\\mathbf{\\hat{Y}}$) sorted by arrival time.\n", + "\n", + "* **Sorting Function $\\eta$ and Sorted Targets $\\mathbf{Y}_{\\eta}$** \n", + "\n", + "\n", + "Let $\\eta$ be the sorting function applied to speaker indices $\\{1, \\dots, K\\}$. The sorted ground-truth matrix $\\mathbf{Y}_{\\eta}$ is defined as:\n", + "$$\n", + "\\eta\\big( \\mathbf{Y} \\big) = \\mathbf{Y}_{\\eta} = \\left(\\mathbf{y}_{\\eta(1)}, \\dots, \\mathbf{y}_{\\eta(K)} \\right).\n", + "$$\n", + "Using $\\Psi$, the following condition holds for the sorted ground-truth labels $\\mathbf{y}_{\\eta(k)}$:\n", + "$$\n", + "\\Psi(\\mathbf{y}_{\\eta(1)}) \\leq \\Psi(\\mathbf{y}_{\\eta(2)}) \\leq \\cdots \\leq \\Psi(\\mathbf{y}_{\\eta(K)}).\n", + "$$\n", + "\n", + "* **_*Sort-Loss*_ ($\\mathcal{L}_{\\text{Sort}}$) Definition** \n", + "\n", + "\n", + "_*Sort-Loss*_ is computed as:\n", + "$$\n", + "\\mathcal{L}_{\\text{Sort}}\\left(\\mathbf{Y}, \\mathbf{P}\\right) = \\mathcal{L}_{\\text{BCE}} \\left(\\mathbf{Y}_{\\eta}, \\mathbf{P}\\right) = \\frac{1}{K} \\sum_{k=1}^{K} \\mathcal{L}_{\\text{BCE}}(\\mathbf{y}_{\\eta(k)}, \\mathbf{q}_k),\n", + "$$\n", + "where:\n", + "\n", + "- $\\mathbf{y}_{\\eta(k)}$: True labels sorted by arrival time using the sorting function $\\eta$.\n", + "- $\\mathbf{q}_k$: Predicted outputs for the $k$-th speaker.\n", + "- $\\mathcal{L}_{\\text{BCE}}(\\mathbf{y}_{\\eta(k)}, \\mathbf{q}_k)$: Binary cross-entropy (BCE) loss for the $k$-th speaker.\n", + "- $K$: Total number of speakers.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we learn the concept of _*Sort-Loss*_ and Sortformer, we can now calculate _*Sort-Loss*_ based target matrix and PIL-based target matrix to compare the difference in target-value setting atrix and loss calculation.\n", + "\n", + "- raw target matrix $\\mathbf{Y}$: `raw_targets`\n", + "- prediction matrix $\\mathbf{P}$: `preds`\n", + "- ATS target matrix $\\mathbf{Y}_{\\eta}$: `ats_targets`\n", + "- PIL target matrix $\\mathbf{Y}_{\\text{PIL}}$: `pil_targets`\n", + "\n", + "First, assign the values in the above examples to the respective variables to create tensors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "# Define the binary grid as a Python list\n", + "raw_targets_list = [[\n", + " [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]\n", + "],]\n", + "\n", + "preds_list = [[\n", + " [1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0],\n", + " [0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0],\n", + " [0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1],\n", + " [0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1],\n", + "],]\n", + "\n", + "# Convert the list to a PyTorch tensor\n", + "raw_targets = torch.tensor(raw_targets_list).transpose(1,2)\n", + "preds = torch.tensor(preds_list).transpose(1,2)\n", + "\n", + "print(raw_targets.shape)\n", + "print(preds.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, import `get_ats_targets` and `get_pil_targets` functions from the `nemo.collections.asr.parts.utils.asr_multispeaker_utils` module to calculate the ATS and PIL targets. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import itertools\n", + "import torch\n", + "import nemo\n", + "from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets\n", + "\n", + "max_num_of_spks = 4 # Number of speakers\n", + "speaker_inds = list(range(max_num_of_spks))\n", + "speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations\n", + "\n", + "\n", + "ats_target = get_ats_targets(labels=raw_targets.clone(), preds=preds, speaker_permutations=speaker_permutations)\n", + "pil_target = get_pil_targets(labels=raw_targets.clone(), preds=preds, speaker_permutations=speaker_permutations)\n", + "\n", + "print(f\"Predicted tensor:\")\n", + "print(preds[0].T)\n", + "\n", + "print(f\"\\nATS target:\")\n", + "print(ats_target[0].T)\n", + "\n", + "print(f\"\\nPIL target:\")\n", + "print(pil_target[0].T)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see that ATS target and PIL target are different. Now, you will display the ATS and PIL target matrices to visually compare the difference and also calculate loss values using the BCE loss." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from nemo.collections.asr.losses.bce_loss import BCELoss \n", + "\n", + "bce_loss = BCELoss()\n", + "# reduction='mean', class_normalization=False)\n", + "\n", + "def plot_diarout(preds, title_text, cmap_str):\n", + "\n", + " preds_mat = preds.cpu().numpy().transpose()\n", + " grid_color_p = 'gray'\n", + " LW, FS = 0.5, 10\n", + "\n", + " yticklabels = [\"spk0\", \"spk1\", \"spk2\", \"spk3\"]\n", + " yticks = np.arange(len(yticklabels))\n", + " fig, axs = plt.subplots(1, 1, figsize=(20, 2)) # 1 row, 2 columns for preds and targets\n", + "\n", + " axs.imshow(preds_mat, cmap=cmap_str, interpolation='nearest')\n", + " axs.set_title(title_text, fontsize=FS)\n", + " axs.set_xticks(np.arange(-0.5, preds_mat.shape[1], 1), minor=True)\n", + " axs.set_yticks(np.arange(-0.5, preds_mat.shape[0], 1), minor=True)\n", + " axs.set_yticks(yticks)\n", + " axs.set_yticklabels(yticklabels)\n", + " axs.set_xlabel(f\"80 ms Frames\", fontsize=FS)\n", + " \n", + " # Enable grid\n", + " axs.grid(which='minor', color=grid_color_p, linestyle='-', linewidth=LW)\n", + " axs.tick_params(which=\"minor\", size=0) # Hide minor ticks\n", + " axs.tick_params(which=\"major\", size=5) # Show major ticks\n", + "\n", + " plt.savefig('plot.png', dpi=300) # bbox_inches='tight')\n", + " plt.show()\n", + "\n", + "target_length = torch.tensor([ats_target.shape[1]]) \n", + "print(f\"Target length: {target_length}\")\n", + "plot_diarout(preds[0], title_text='Predictions', cmap_str='viridis')\n", + "\n", + "loss_ats = bce_loss(probs=preds.float(), labels=ats_target.float(), target_lens=target_length)\n", + "print(f\"[ {loss_ats:.4f} ] is the loss from Arrival Time Sort Target: \")\n", + "plot_diarout(ats_target[0], title_text='ATS Target', cmap_str='summer')\n", + "\n", + "loss_pil = bce_loss(probs=preds.float(), labels=pil_target.float(), target_lens=target_length)\n", + "print(f\"[ {loss_pil:.4f} ] is the loss from Permutation Invariant Loss Target\")\n", + "plot_diarout(pil_target[0], title_text='PIL Target', cmap_str='inferno')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "While _*Sortformer*_ can be trained solely using _*Sort-Loss*_, there is a limitation: the arrival time estimation is not always accurate. This issue becomes more pronounced as the number of speakers increases during the training session.\n", + "\n", + "Note that _*Sortformer*_ models can be trained using _*Sort-Loss*_ only, PIL only, or a hybrid loss by adjusting the weight between these two loss components. The hybrid loss $\\mathcal{L}_{\\text{hybrid}}$ can be described as follows:\n", + "\n", + "\n", + "$$\n", + "\\mathcal{L}_{\\text{hybrid}} = \\alpha \\cdot \\mathcal{L}_{\\text{Sort}} + \\beta \\cdot \\mathcal{L}_{\\text{PIL}},\n", + "$$\n", + "\n", + "The weight between ATS and PIL can be adjusted with the variable `model.ats_weight`($\\alpha$) and `model.pil_weight`($\\beta$) in the YAML file of the _*Sortformer*_ diarizer model as follows:\n", + "\n", + "```yaml\n", + "model: \n", + " sample_rate: 16000\n", + " pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model\n", + " ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model\n", + " max_num_of_spks: 4 \n", + " ...\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train a _*Sortformer*_ Diarizer Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example Data Creation\n", + "\n", + "In this tutorial, we will create a simple toy training dataset using the [NeMo Multispeaker Simulator](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tools/Multispeaker_Simulator.ipynb), with Librispeech as the source dataset for demonstration purposes. If you already have datasets with proper speaker annotations (RTTM files), you can replace the simulated dataset with your own.\n", + "\n", + "If you don’t have access to any speaker diarization datasets, the [NeMo Multispeaker Simulator](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tools/Multispeaker_Simulator.ipynb) can be used to generate a sufficient amount of data samples to meet your requirements.\n", + "\n", + "For more details on the data simulator, refer to the documentation in the [NeMo Multispeaker Simulator](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tools/Multispeaker_Simulator.ipynb). This tutorial will not cover the configurations and detailed process of data simulation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install dependencies for data simulator\n", + "!apt-get install sox libsndfile1 ffmpeg\n", + "!pip install wget\n", + "!pip install unidecode\n", + "!pip install \"matplotlib>=3.3.2\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Data Simulation Step 1: Download Required Resources\n", + "\n", + "We need to download the LibriSpeech corpus and corresponding word alignments for generating synthetic multi-speaker audio sessions. In addition, we need to download necessary data cleaning scripts from NeMo git." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "NEMO_DIR_PATH = \"NeMo\"\n", + "BRANCH = 'main'\n", + "\n", + "# download scripts if not already there \n", + "if not os.path.exists('NeMo/scripts'):\n", + " print(\"Downloading necessary scripts\")\n", + " !mkdir -p NeMo/scripts/dataset_processing\n", + " !mkdir -p NeMo/scripts/speaker_tasks\n", + " !wget -P NeMo/scripts/dataset_processing/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/dataset_processing/get_librispeech_data.py\n", + " !wget -P NeMo/scripts/speaker_tasks/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/speaker_tasks/create_alignment_manifest.py\n", + " !wget -P NeMo/scripts/speaker_tasks/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have downloaded all the necessary scripts for data creation and preparation, we can start the data simulation step by downloading the LibriSpeech corpus." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir -p LibriSpeech\n", + "!python {NEMO_DIR_PATH}/scripts/dataset_processing/get_librispeech_data.py \\\n", + " --data_root LibriSpeech \\\n", + " --data_sets dev_clean" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can get the forced word alignments data for the LibriSpeech corpus from [this repository.](https://github.com/CorentinJ/librispeech-alignments). Full forced alignments data can be downloaded at [google drive link for alignments data](https://drive.google.com/file/d/1WYfgr31T-PPwMcxuAq09XZfHQO5Mw8fE/view?usp=sharing). We will download only a subset of forced alignment data containing dev-clean part." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!wget -nc https://dldata-public.s3.us-east-2.amazonaws.com/LibriSpeech_Alignments.tar.gz\n", + "!tar -xzf LibriSpeech_Alignments.tar.gz\n", + "!rm -f LibriSpeech_Alignments.tar.gz" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Data Simulation Step 2: Produce Manifest File with Forced Alignments\n", + "\n", + "We will merge the LibriSpeech manifest files and LibriSpeech forced alignments into one manifest file for ease of use when generating synthetic data. Create alignment files by running the following script.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!NEMO_BASE_PATH=/home/taejinp/projects/sortformer_docs/NeMo;export PYTHONPATH=$NEMO_BASE_PATH:$PYTHONPATH;python NeMo/scripts/speaker_tasks/create_alignment_manifest.py \\\n", + " --input_manifest_filepath LibriSpeech/dev_clean.json \\\n", + " --base_alignment_path LibriSpeech_Alignments \\\n", + " --output_manifest_filepath ./dev-clean-align.json \\\n", + " --ctm_output_directory ./ctm_out \\\n", + " --libri_dataset_split dev-clean" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Data Simulation Step 3: Set data simulation parameters" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have downloaded all the sources we need for data creation, we need to download data simulator configurations in `.yaml` format. Download the YAML file and download `data_simulator.py` script from NeMo repository." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf\n", + "import os\n", + "ROOT = os.getcwd()\n", + "conf_dir = os.path.join(ROOT,'conf')\n", + "!mkdir -p $conf_dir\n", + "CONFIG_PATH = os.path.join(conf_dir, 'data_simulator.yaml')\n", + "if not os.path.exists(CONFIG_PATH):\n", + " !wget -P $conf_dir https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/tools/speech_data_simulator/conf/data_simulator.yaml\n", + "\n", + "config = OmegaConf.load(CONFIG_PATH)\n", + "print(OmegaConf.to_yaml(config))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Data Simulation Step 4: Generate Simulated Audio Session\n", + "\n", + "We will generate a set of example sessions with the following specifications:\n", + "\n", + "- 10 example sessions for train \n", + "- 10 example sessions for validation\n", + "- 2-speakers in each session\n", + "- 90 seconds of recordings\n", + "\n", + "We need to setup different seed for train and validation sets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.asr.data.data_simulation import MultiSpeakerSimulator\n", + "\n", + "# Generate train set \n", + "ROOT = os.getcwd()\n", + "data_dir = os.path.join(ROOT,'simulated_train')\n", + "config.data_simulator.random_seed=10\n", + "config.data_simulator.manifest_filepath=\"./dev-clean-align.json\"\n", + "config.data_simulator.outputs.output_dir=data_dir\n", + "config.data_simulator.session_config.num_sessions=10\n", + "config.data_simulator.session_config.num_speakers=2\n", + "config.data_simulator.session_config.session_length=90\n", + "config.data_simulator.background_noise.add_bg=False \n", + "\n", + "lg = MultiSpeakerSimulator(cfg=config)\n", + "lg.generate_sessions()\n", + "\n", + "# Generate validation set \n", + "data_dir = os.path.join(ROOT,'simulated_valid')\n", + "config.data_simulator.random_seed=20\n", + "config.data_simulator.outputs.output_dir=data_dir\n", + "\n", + "lg = MultiSpeakerSimulator(cfg=config)\n", + "lg.generate_sessions()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that parameter setting is done, generate the samples by launching `generate_sessions()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lg = MultiSpeakerSimulator(cfg=config)\n", + "lg.generate_sessions()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data preparation step 5: Listen to and Visualize Session\n", + "\n", + "Listen to the audio and visualize the corresponding speaker timestamps (recorded in a RTTM file for each session)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import os\n", + "import IPython\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import librosa\n", + "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object\n", + "\n", + "ROOT = os.getcwd()\n", + "data_dir = os.path.join(ROOT,'simulated_train')\n", + "audio = os.path.join(data_dir,'multispeaker_session_0.wav')\n", + "rttm = os.path.join(data_dir,'multispeaker_session_0.rttm')\n", + "\n", + "sr = 16000\n", + "signal, sr = librosa.load(audio,sr=sr) \n", + "\n", + "fig,ax = plt.subplots(1,1)\n", + "fig.set_figwidth(20)\n", + "fig.set_figheight(2)\n", + "plt.plot(np.arange(len(signal)),signal,'gray')\n", + "fig.suptitle('Synthetic Audio Session', fontsize=16)\n", + "plt.xlabel('Time (s)', fontsize=18)\n", + "plt.yticks([], [])\n", + "ax.margins(x=0)\n", + "a,_ = plt.xticks()\n", + "plt.xticks(a[:-1],a[:-1]/sr);\n", + "IPython.display.Audio(audio)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can visually check the ground-truth file of the first sample by running the following commands. We can see that it has plenty of overlap between two speakers. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# display speaker labels for reference\n", + "labels = rttm_to_labels(rttm)\n", + "reference = labels_to_pyannote_object(labels)\n", + "reference" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can check that corresponding RTTM files are generated as ground-truth labels for training and evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!cat simulated_train/multispeaker_session_0.rttm | head -10" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data preparation step 6: Check out the created files\n", + "\n", + "The following files are generated from data simulator:\n", + "\n", + "* _wav files_ (one per audio session) - the output audio sessions\n", + "* _rttm files_ (one per audio session) - the speaker timestamps for the corresponding audio session (used for diarization training)\n", + "* _list files_ (one per file type per batch of sessions) - a list of generated files of the given type (e.g., wav, rttm), used primarily for manifest creation\n", + "\n", + "Check if the files we need are generated by running the following commands." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\n Training audio files:\")\n", + "!ls simulated_train/*.wav\n", + "print(\"\\n Training audio files:\")\n", + "!ls simulated_train/*.rttm\n", + "print(\"\\n Training RTTM list content:\")\n", + "!cat simulated_train/synthetic_wav.list\n", + "print(\"\\n Training RTTM list content:\")\n", + "!cat simulated_train/synthetic_rttm.list\n", + "\n", + "print(\"\\n Validation audio files:\")\n", + "!ls simulated_valid/*.wav\n", + "print(\"\\n Validation audio files:\")\n", + "!ls simulated_valid/*.rttm\n", + "print(\"\\n Validation RTTM list content:\")\n", + "!cat simulated_valid/synthetic_wav.list\n", + "print(\"\\n Validation RTTM list content:\")\n", + "!cat simulated_valid/synthetic_rttm.list" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare Training Data for _*Sortformer*_ diarizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have datasets for both train and validation (dev), we can start preparing and cleaning the data samples for training. Make sure you have the following list of files:\n", + "\n", + "**Training set** \n", + "\n", + "- Train audio files `.wav`\n", + "- A train audio list file `.list`\n", + "- Train RTTM files `.rttm`\n", + "- A train RTTM list content `.list`\n", + "\n", + "**Validation set** \n", + "\n", + "- Validation audio files `.wav`\n", + "- A validation audio list file `.list`\n", + "- Validation RTTM files `.rttm`\n", + "- A validation RTTM list file `.list`\n", + "\n", + "\n", + "Based on these files, we need to create manifest files containing data samples we have. If you don't have a `.list` file, you need to create a `.list` file for the `.wav` files and `.rttm` files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create a NeMo manifest (.json) file for training dataset\n", + "!pwd;NEMO_BASE_PATH=/home/taejinp/projects/sortformer_docs/NeMo;export PYTHONPATH=$NEMO_BASE_PATH;python NeMo/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py \\\n", + " --add_duration \\\n", + " --paths2audio_files='simulated_train/synthetic_wav.list' \\\n", + " --paths2rttm_files='simulated_train/synthetic_rttm.list' \\\n", + " --manifest_filepath='simulated_train/sortformer_train.json'\n", + "\n", + "# create a NeMo manifest (.json) file for validation (dev) dataset\n", + "!NEMO_BASE_PATH=/home/taejinp/projects/sortformer_docs/NeMo;export PYTHONPATH=$NEMO_BASE_PATH;python NeMo/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py \\\n", + " --add_duration \\\n", + " --paths2audio_files='simulated_valid/synthetic_wav.list' \\\n", + " --paths2rttm_files='simulated_valid/synthetic_rttm.list' \\\n", + " --manifest_filepath='simulated_valid/sortformer_valid.json'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you print the content of the created manifest file, you can see that `.rttm` files in the list and `.wav` files are grouped together in the generated manifest files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"\\nTraining Dataset:\")\n", + "!cat simulated_train/sortformer_train.json | tail -5\n", + "print(\"\\nValidation Dataset:\")\n", + "!cat simulated_valid/sortformer_valid.json | tail -5 " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train a _*Sortformer*_ Diarizer Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have prepared all the necessary dataset, we can train an _*Sortformer*_ diarizer model on the prepared dataset. Download YAML file for training form NeMo repository and load the configuration into `config` variable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import nemo\n", + "import os\n", + "import lightning.pytorch as pl\n", + "from omegaconf import OmegaConf\n", + "from nemo.collections.asr.models import SortformerEncLabelModel\n", + "from nemo.utils.exp_manager import exp_manager\n", + "\n", + "NEMO_ROOT = os.getcwd()\n", + "!mkdir -p conf \n", + "!wget -P conf https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml\n", + "MODEL_CONFIG = os.path.join(NEMO_ROOT,'conf/sortformer_diarizer_hybrid_loss_4spk-v1.yaml')\n", + "config = OmegaConf.load(MODEL_CONFIG)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setup the `manifest_filepath` for `train_ds` and `validation_ds` by feeding the `json` file paths based on the created training dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "curr_dir = os.getcwd() + \"/\"\n", + "config.model.train_ds.manifest_filepath = f'{curr_dir}simulated_train/sortformer_train.json'\n", + "config.model.test_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n", + "config.model.validation_ds.manifest_filepath = f'{curr_dir}simulated_valid/sortformer_valid.json'\n", + "config.trainer.strategy = \"ddp_notebook\"\n", + "config.batch_size = 3\n", + "\n", + "config.trainer.devices=1\n", + "config.accelerator=\"gpu\"\n", + "print(os.getcwd())\n", + "\n", + "print(\"config.model.train_ds.manifest_filepath \", config.model.train_ds.manifest_filepath )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setup a trainer object with the given configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=50,\n", + " enable_checkpointing=False, logger=False,\n", + " log_every_n_steps=5, check_val_every_n_epoch=10)\n", + "\n", + "exp_manager(trainer, config.get(\"exp_manager\", None))\n", + "sortformer_model = SortformerEncLabelModel(cfg=config.model, trainer=trainer)\n", + "sortformer_model.maybe_init_from_pretrained_checkpoint(config)\n", + "trainer.fit(sortformer_model)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/speaker_tasks/images/ats.png b/tutorials/speaker_tasks/images/ats.png new file mode 100644 index 000000000000..34af0e0bdfe3 Binary files /dev/null and b/tutorials/speaker_tasks/images/ats.png differ diff --git a/tutorials/speaker_tasks/images/cascaded_diar_diagram.png b/tutorials/speaker_tasks/images/cascaded_diar_diagram.png new file mode 100644 index 000000000000..2259704ed49c Binary files /dev/null and b/tutorials/speaker_tasks/images/cascaded_diar_diagram.png differ diff --git a/tutorials/speaker_tasks/images/e2e_diar_diagram.png b/tutorials/speaker_tasks/images/e2e_diar_diagram.png new file mode 100644 index 000000000000..4b8cf75c3574 Binary files /dev/null and b/tutorials/speaker_tasks/images/e2e_diar_diagram.png differ diff --git a/tutorials/speaker_tasks/images/intro_comparison.png b/tutorials/speaker_tasks/images/intro_comparison.png new file mode 100644 index 000000000000..a6cd98152aec Binary files /dev/null and b/tutorials/speaker_tasks/images/intro_comparison.png differ diff --git a/tutorials/speaker_tasks/images/loss_types.png b/tutorials/speaker_tasks/images/loss_types.png new file mode 100644 index 000000000000..fc365500fba1 Binary files /dev/null and b/tutorials/speaker_tasks/images/loss_types.png differ diff --git a/tutorials/speaker_tasks/images/main_dataflow.png b/tutorials/speaker_tasks/images/main_dataflow.png new file mode 100644 index 000000000000..5124a7442134 Binary files /dev/null and b/tutorials/speaker_tasks/images/main_dataflow.png differ diff --git a/tutorials/speaker_tasks/images/sortformer.png b/tutorials/speaker_tasks/images/sortformer.png new file mode 100644 index 000000000000..b98e821ec6fc Binary files /dev/null and b/tutorials/speaker_tasks/images/sortformer.png differ