diff --git a/tests/test_client.py b/tests/test_client.py index 836b616..4610ea9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -68,15 +68,15 @@ def test_on_message(self): message = json.dumps({ "uid": self.client.uid, "segments": [ - {"start": 0, "end": 1, "text": "Test transcript"}, - {"start": 1, "end": 2, "text": "Test transcript 2"}, - {"start": 2, "end": 3, "text": "Test transcript 3"} + {"start": 0, "end": 1, "text": "Test transcript", "completed": True}, + {"start": 1, "end": 2, "text": "Test transcript 2", "completed": True}, + {"start": 2, "end": 3, "text": "Test transcript 3", "completed": True} ] }) self.client.on_message(self.mock_ws_app, message) # Assert that the transcript was updated correctly - self.assertEqual(len(self.client.transcript), 2) + self.assertEqual(len(self.client.transcript), 3) self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2") def test_on_close(self): diff --git a/whisper_live/client.py b/whisper_live/client.py index 4cfb63c..15b6306 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -112,9 +112,9 @@ def process_segments(self, segments): for i, seg in enumerate(segments): if not text or text[-1] != seg["text"]: text.append(seg["text"]) - if i == len(segments) - 1: + if i == len(segments) - 1 and not seg["completed"]: self.last_segment = seg - elif (self.server_backend == "faster_whisper" and + elif (self.server_backend == "faster_whisper" and seg["completed"] and (not self.transcript or float(seg['start']) >= float(self.transcript[-1]['end']))): self.transcript.append(seg) @@ -259,7 +259,7 @@ def write_srt_file(self, output_path="output.srt"): """ if self.server_backend == "faster_whisper": - if (self.last_segment): + if (self.last_segment) and self.transcript[-1]["text"] != self.last_segment["text"]: self.transcript.append(self.last_segment) utils.create_srt_file(self.transcript, output_path) diff --git a/whisper_live/server.py b/whisper_live/server.py index 575a2cf..b68df5a 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -1001,7 +1001,7 @@ def speech_to_text(self): logging.error(f"[ERROR]: Failed to transcribe audio chunk: {e}") time.sleep(0.01) - def format_segment(self, start, end, text): + def format_segment(self, start, end, text, completed=False): """ Formats a transcription segment with precise start and end times alongside the transcribed text. @@ -1018,7 +1018,8 @@ def format_segment(self, start, end, text): return { 'start': "{:.3f}".format(start), 'end': "{:.3f}".format(end), - 'text': text + 'text': text, + 'completed': completed } def update_segments(self, segments, duration): @@ -1058,7 +1059,7 @@ def update_segments(self, segments, duration): if s.no_speech_prob > self.no_speech_thresh: continue - self.transcript.append(self.format_segment(start, end, text_)) + self.transcript.append(self.format_segment(start, end, text_, completed=True)) offset = min(duration, s.end) # only process the segments if it satisfies the no_speech_thresh @@ -1067,7 +1068,8 @@ def update_segments(self, segments, duration): last_segment = self.format_segment( self.timestamp_offset + segments[-1].start, self.timestamp_offset + min(duration, segments[-1].end), - self.current_out + self.current_out, + completed=False ) # if same incomplete segment is seen multiple times then update the offset @@ -1083,7 +1085,8 @@ def update_segments(self, segments, duration): self.transcript.append(self.format_segment( self.timestamp_offset, self.timestamp_offset + duration, - self.current_out + self.current_out, + completed=True )) self.current_out = '' offset = duration