Skip to content

Commit

Permalink
Merge pull request #284 from makaveli10/expose_client_manager_args
Browse files Browse the repository at this point in the history
Expose client manager args.
  • Loading branch information
zoq authored Oct 31, 2024
2 parents be71657 + 8b87a05 commit 00f0ff1
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 21 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ If you don't want this, set `--no_single_model`.
- `use_vad`: Whether to use `Voice Activity Detection` on the server.
- `save_output_recording`: Set to True to save the microphone input as a `.wav` file during live transcription. This option is helpful for recording sessions for later playback or analysis. Defaults to `False`.
- `output_recording_filename`: Specifies the `.wav` file path where the microphone input will be saved if `save_output_recording` is set to `True`.
- `max_clients`: Specifies the maximum number of clients the server should allow. Defaults to 4.
- `max_connection_time`: Maximum connection time for each client in seconds. Defaults to 600.

```python
from whisper_live.client import TranscriptionClient
client = TranscriptionClient(
Expand All @@ -87,7 +90,9 @@ client = TranscriptionClient(
model="small",
use_vad=False,
save_output_recording=True, # Only used for microphone input, False by Default
output_recording_filename="./output_recording.wav" # Only used for microphone input
output_recording_filename="./output_recording.wav", # Only used for microphone input
max_clients=4,
max_connection_time=600
)
```
It connects to the server running on localhost at port 9090. Using a multilingual model, language for the transcription will be automatically detected. You can also use the language option to specify the target language for the transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language.
Expand Down
2 changes: 1 addition & 1 deletion requirements/server.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ jiwer
evaluate
numpy<2
tiktoken==0.3.3
openai-whisper
openai-whisper==20231117
4 changes: 3 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def test_on_open(self):
"language": self.client.language,
"task": self.client.task,
"model": self.client.model,
"use_vad": True
"use_vad": True,
"max_clients": 4,
"max_connection_time": 600,
})
self.client.on_open(self.mock_ws_app)
self.mock_ws_app.send.assert_called_with(expected_message)
Expand Down
26 changes: 12 additions & 14 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
from unittest import mock

import numpy as np
import evaluate
import jiwer

from websockets.exceptions import ConnectionClosed
from whisper_live.server import TranscriptionServer
from whisper_live.server import TranscriptionServer, BackendType, ClientManager
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
from whisper.normalizers import EnglishTextNormalizer


class TestTranscriptionServerInitialization(unittest.TestCase):
def test_initialization(self):
server = TranscriptionServer()
server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.assertEqual(server.client_manager.max_clients, 4)
self.assertEqual(server.client_manager.max_connection_time, 600)
self.assertDictEqual(server.client_manager.clients, {})
Expand All @@ -25,6 +26,7 @@ def test_initialization(self):
class TestGetWaitTime(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()
self.server.client_manager = ClientManager(max_clients=4, max_connection_time=600)
self.server.client_manager.start_times = {
'client1': time.time() - 120,
'client2': time.time() - 300
Expand All @@ -49,7 +51,7 @@ def test_connection(self, mock_websocket):
'task': 'transcribe',
'model': 'tiny.en'
})
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_recv_audio_exception_handling(self, mock_websocket):
Expand All @@ -61,7 +63,7 @@ def test_recv_audio_exception_handling(self, mock_websocket):
}), np.array([1, 2, 3]).tobytes()]

with self.assertLogs(level="ERROR"):
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))

self.assertNotIn(mock_websocket, self.server.client_manager.clients)

Expand All @@ -82,7 +84,6 @@ def tearDownClass(cls):
cls.server_process.wait()

def setUp(self):
self.metric = evaluate.load("wer")
self.normalizer = EnglishTextNormalizer()

def check_prediction(self, srt_path):
Expand All @@ -94,11 +95,8 @@ def check_prediction(self, srt_path):
gt_normalized = self.normalizer(gt)

# calculate WER
wer = self.metric.compute(
predictions=[prediction_normalized],
references=[gt_normalized]
)
self.assertLess(wer, 0.05)
wer_score = jiwer.wer(gt_normalized, prediction_normalized)
self.assertLess(wer_score, 0.05)

def test_inference(self):
client = TranscriptionClient(
Expand All @@ -124,26 +122,26 @@ def setUp(self):

@mock.patch('websockets.WebSocketCommonProtocol')
def test_connection_closed_exception(self, mock_websocket):
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed")
mock_websocket.recv.side_effect = ConnectionClosed(1001, "testing connection closed", rcvd_then_sent=mock.Mock())

with self.assertLogs(level="INFO") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Connection closed by client" in message for message in log.output))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_json_decode_exception(self, mock_websocket):
mock_websocket.recv.return_value = "invalid json"

with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
self.assertTrue(any("Failed to decode JSON from client" in message for message in log.output))

@mock.patch('websockets.WebSocketCommonProtocol')
def test_unexpected_exception_handling(self, mock_websocket):
mock_websocket.recv.side_effect = RuntimeError("Unexpected error")

with self.assertLogs(level="ERROR") as log:
self.server.recv_audio(mock_websocket, "faster_whisper")
self.server.recv_audio(mock_websocket, BackendType("faster_whisper"))
for message in log.output:
print(message)
print()
Expand Down
19 changes: 16 additions & 3 deletions whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(
model="small",
srt_file_path="output.srt",
use_vad=True,
log_transcription=True
log_transcription=True,
max_clients=4,
max_connection_time=600,
):
"""
Initializes a Client instance for audio recording and streaming to a server.
Expand Down Expand Up @@ -59,6 +61,8 @@ def __init__(
self.last_segment = None
self.last_received_segment = None
self.log_transcription = log_transcription
self.max_clients = max_clients
self.max_connection_time = max_connection_time

if translate:
self.task = "translate"
Expand Down Expand Up @@ -199,7 +203,9 @@ def on_open(self, ws):
"language": self.language,
"task": self.task,
"model": self.model,
"use_vad": self.use_vad
"use_vad": self.use_vad,
"max_clients": self.max_clients,
"max_connection_time": self.max_connection_time,
}
)
)
Expand Down Expand Up @@ -681,8 +687,15 @@ def __init__(
output_recording_filename="./output_recording.wav",
output_transcription_path="./output.srt",
log_transcription=True,
max_clients=4,
max_connection_time=600,
):
self.client = Client(host, port, lang, translate, model, srt_file_path=output_transcription_path, use_vad=use_vad, log_transcription=log_transcription)
self.client = Client(
host, port, lang, translate, model, srt_file_path=output_transcription_path,
use_vad=use_vad, log_transcription=log_transcription, max_clients=max_clients,
max_connection_time=max_connection_time
)

if save_output_recording and not output_recording_filename.endswith(".wav"):
raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
if not output_transcription_path.endswith(".srt"):
Expand Down
8 changes: 7 additions & 1 deletion whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class TranscriptionServer:
RATE = 16000

def __init__(self):
self.client_manager = ClientManager()
self.client_manager = None
self.no_voice_activity_chunks = 0
self.use_vad = True
self.single_model = False
Expand Down Expand Up @@ -224,6 +224,12 @@ def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
logging.info("New client connected")
options = websocket.recv()
options = json.loads(options)

if self.client_manager is None:
max_clients = options.get('max_clients', 4)
max_connection_time = options.get('max_connection_time', 600)
self.client_manager = ClientManager(max_clients, max_connection_time)

self.use_vad = options.get('use_vad')
if self.client_manager.is_server_full(websocket, options):
websocket.close()
Expand Down

0 comments on commit 00f0ff1

Please sign in to comment.