Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests #136

Merged
merged 9 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 76 additions & 39 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CI
name: Test & Build CI/CD

on:
push:
Expand All @@ -7,46 +7,83 @@ on:
tags:
- v*
pull_request:
branches:
- main
branches: [ main ]
types: [opened, synchronize, reopened]

jobs:
build-and-push-package:
test:
runs-on: ubuntu-latest
timeout-minutes: 60
strategy:
matrix:
python-version: [3.8, 3.9, '3.10', '3.11']
steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Cache Python dependencies
uses: actions/cache@v2
with:
path: |
~/.cache/pip
!~/.cache/pip/log
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('requirements/server.txt', 'requirements/client.txt') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-

- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg portaudio19-dev

- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements/server.txt --extra-index-url https://download.pytorch.org/whl/cpu
pip install -r requirements/client.txt

- name: Run tests
run: |
echo "Running tests with Python ${{ matrix.python-version }}"
python -m unittest discover -s tests

build-and-push:
needs: test
runs-on: ubuntu-latest
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
steps:
- name: Check Out Repository
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8

- name: Set up FFmpeg
uses: FedericoCarboni/setup-ffmpeg@v2

- name: Install Additional requirements
run: |
sudo apt-get -y install portaudio19-dev wget
shell: bash

- name: Install Client Requirements
run: pip install -r requirements/client.txt

- name: Install Server Requirements
run: pip install -r requirements/server.txt

- name: Install Wheel for build
run: pip install wheel twine

- name: Build wheel
run: |
python setup.py sdist bdist_wheel

- name: Push package on Test PyPI
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
- uses: actions/checkout@v2

- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8

- name: Cache Python dependencies
uses: actions/cache@v2
with:
path: |
~/.cache/pip
!~/.cache/pip/log
key: ubuntu-latest-pip-3.8-${{ hashFiles('requirements/server.txt', 'requirements/client.txt') }}
restore-keys: |
ubuntu-latest-pip-3.8-

- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg portaudio19-dev

- name: Install Python dependencies
run: |
pip install -r requirements/server.txt
pip install -r requirements/client.txt

- name: Build package
run: python setup.py sdist bdist_wheel

- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
Binary file added assets/jfk.flac
Binary file not shown.
2 changes: 2 additions & 0 deletions requirements/server.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ kaldialign
soundfile
ffmpeg-python
scipy
jiwer
evaluate
Empty file added tests/__init__.py
Empty file.
109 changes: 109 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import json
import os
import scipy
import websocket
import unittest
from unittest.mock import patch, MagicMock
from whisper_live.client import TranscriptionClient, resample


class BaseTestCase(unittest.TestCase):
@patch('whisper_live.client.websocket.WebSocketApp')
@patch('whisper_live.client.pyaudio.PyAudio')
def setUp(self, mock_pyaudio, mock_websocket):
self.mock_pyaudio_instance = MagicMock()
mock_pyaudio.return_value = self.mock_pyaudio_instance
self.mock_stream = MagicMock()
self.mock_pyaudio_instance.open.return_value = self.mock_stream

self.mock_ws_app = mock_websocket.return_value
self.mock_ws_app.send = MagicMock()

self.client = TranscriptionClient(host='localhost', port=9090, lang="en").client

self.mock_pyaudio = mock_pyaudio
self.mock_websocket = mock_websocket

def tearDown(self):
self.client.close_websocket()
self.mock_pyaudio.stop()
self.mock_websocket.stop()
del self.client


class TestClientWebSocketCommunication(BaseTestCase):
def test_websocket_communication(self):
expected_url = 'ws://localhost:9090'
self.mock_websocket.assert_called()
self.assertEqual(self.mock_websocket.call_args[0][0], expected_url)


class TestClientCallbacks(BaseTestCase):
def test_on_open(self):
expected_message = json.dumps({
"uid": self.client.uid,
"language": self.client.language,
"task": self.client.task,
"model": self.client.model,
})
self.client.on_open(self.mock_ws_app)
self.mock_ws_app.send.assert_called_with(expected_message)

def test_on_message(self):
message = json.dumps(
{
"uid": self.client.uid,
"message": "SERVER_READY",
"backend": "faster_whisper"
}
)
self.client.on_message(self.mock_ws_app, message)

message = json.dumps({
"uid": self.client.uid,
"segments": [
{"start": 0, "end": 1, "text": "Test transcript"},
{"start": 1, "end": 2, "text": "Test transcript 2"},
{"start": 2, "end": 3, "text": "Test transcript 3"}
]
})
self.client.on_message(self.mock_ws_app, message)

# Assert that the transcript was updated correctly
self.assertEqual(len(self.client.transcript), 2)
self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2")

def test_on_close(self):
close_status_code = 1000
close_msg = "Normal closure"
self.client.on_close(self.mock_ws_app, close_status_code, close_msg)

self.assertFalse(self.client.recording)
self.assertFalse(self.client.server_error)
self.assertFalse(self.client.waiting)

def test_on_error(self):
error_message = "Test Error"
self.client.on_error(self.mock_ws_app, error_message)

self.assertTrue(self.client.server_error)
self.assertEqual(self.client.error_message, error_message)


class TestAudioResampling(unittest.TestCase):
def test_resample_audio(self):
original_audio = "assets/jfk.flac"
expected_sr = 16000
resampled_audio = resample(original_audio, expected_sr)

sr, _ = scipy.io.wavfile.read(resampled_audio)
self.assertEqual(sr, expected_sr)

os.remove(resampled_audio)


class TestSendingAudioPacket(BaseTestCase):
def test_send_packet(self):
mock_audio_packet = b'\x00\x01\x02\x03'
self.client.send_packet_to_server(mock_audio_packet)
self.client.client_socket.send.assert_called_with(mock_audio_packet, websocket.ABNF.OPCODE_BINARY)
105 changes: 105 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import subprocess
import time
import json
import unittest
from unittest import mock

import numpy as np
import evaluate
from whisper_live.server import TranscriptionServer
from whisper_live.client import TranscriptionClient
from whisper.normalizers import EnglishTextNormalizer


class TestTranscriptionServerInitialization(unittest.TestCase):
def test_initialization(self):
server = TranscriptionServer()
self.assertEqual(server.max_clients, 4)
self.assertEqual(server.max_connection_time, 600)
self.assertDictEqual(server.clients, {})
self.assertDictEqual(server.websockets, {})
self.assertDictEqual(server.clients_start_time, {})


class TestGetWaitTime(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()
self.server.clients_start_time = {
'client1': time.time() - 120,
'client2': time.time() - 300
}
self.server.max_connection_time = 600

def test_get_wait_time(self):
expected_wait_time = (600 - (time.time() - self.server.clients_start_time['client2'])) / 60
print(self.server.get_wait_time(), expected_wait_time)
self.assertAlmostEqual(self.server.get_wait_time(), expected_wait_time, places=2)


class TestServerConnection(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()

@mock.patch('websockets.WebSocketCommonProtocol')
def test_connection(self, mock_websocket):
mock_websocket.recv.return_value = json.dumps({
'uid': 'test_client',
'language': 'en',
'task': 'transcribe',
'model': 'tiny.en'
})
self.server.recv_audio(mock_websocket, "faster_whisper")


@mock.patch('websockets.WebSocketCommonProtocol')
def test_recv_audio_exception_handling(self, mock_websocket):
mock_websocket.recv.side_effect = [json.dumps({
'uid': 'test_client',
'language': 'en',
'task': 'transcribe',
'model': 'tiny.en'
}), np.array([1, 2, 3]).tobytes()]

with self.assertLogs(level="ERROR"):
self.server.recv_audio(mock_websocket, "faster_whisper")

self.assertNotIn(mock_websocket, self.server.clients)


class TestServerInferenceAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.server_process = subprocess.Popen(["python", "run_server.py"]) # Adjust the command as needed
time.sleep(2)

@classmethod
def tearDownClass(cls):
cls.server_process.terminate()
cls.server_process.wait()

@mock.patch('pyaudio.PyAudio')
def setUp(self, mock_pyaudio):
self.mock_pyaudio = mock_pyaudio.return_value
self.mock_stream = mock.MagicMock()
self.mock_pyaudio.open.return_value = self.mock_stream
self.metric = evaluate.load("wer")
self.normalizer = EnglishTextNormalizer()
self.client = TranscriptionClient(
"localhost", "9090", model="base.en", lang="en",
)

def test_inference(self):
gt = "And so my fellow Americans, ask not, what your country can do for you. Ask what you can do for your country!"
self.client("assets/jfk.flac")
with open("output.srt", "r") as f:
lines = f.readlines()
prediction = " ".join([l.strip() for l in lines[2::4]])
prediction_normalized = self.normalizer(prediction)
gt_normalized = self.normalizer(gt)

# calculate WER
wer = self.metric.compute(
predictions=[prediction_normalized],
references=[gt_normalized]
)
self.assertLess(wer, 0.05)
28 changes: 28 additions & 0 deletions tests/test_vad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
import numpy as np
import torch
import scipy.io as sio
from whisper_live.tensorrt_utils import load_audio
from whisper_live.vad import VoiceActivityDetection


class TestVoiceActivityDetection(unittest.TestCase):
def setUp(self):
self.vad = VoiceActivityDetection()
self.sample_rate = 16000

def generate_silence(self, duration_seconds):
return np.zeros(int(self.sample_rate * duration_seconds), dtype=np.float32)

def load_speech_segment(self, filepath):
return load_audio(filepath)

def test_vad_silence_detection(self):
silence = self.generate_silence(3)
speech_prob = self.vad(torch.from_numpy(silence.copy()), self.sample_rate).item()
self.assertLess(speech_prob, 0.5, "VAD incorrectly identified silence as speech.")

def test_vad_speech_detection(self):
audio_tensor = torch.from_numpy(load_audio("assets/jfk.flac"))
speech_prob = self.vad(audio_tensor, self.sample_rate).item()
self.assertGreater(speech_prob, 0.5, "VAD failed to identify speech segment.")
7 changes: 6 additions & 1 deletion whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,15 @@ def on_message(self, ws, message):
print(element)

def on_error(self, ws, error):
print(error)
print(f"[ERROR] WebSocket Error: {error}")
self.server_error = True
self.error_message = error

def on_close(self, ws, close_status_code, close_msg):
print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
self.recording = False
self.server_error = False
self.waiting = False

def on_open(self, ws):
"""
Expand Down
Loading
Loading