-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
241 lines (191 loc) · 8.59 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import subprocess
import os
import gradio as gr
import requests
import json
OLLAMA_SERVER_URL = "http://localhost:11434" # Replace this with your actual Ollama server URL if different
WHISPER_MODEL_DIR = "./whisper.cpp/models" # Directory where whisper models are stored
def get_available_models() -> list[str]:
"""
Retrieves a list of all available models from the Ollama server and extracts the model names.
Returns:
A list of model names available on the Ollama server.
"""
response = requests.get(f"{OLLAMA_SERVER_URL}/api/tags")
if response.status_code == 200:
models = response.json()["models"]
llm_model_names = [model["model"] for model in models] # Extract model names
return llm_model_names
else:
raise Exception(
f"Failed to retrieve models from Ollama server: {response.text}"
)
def get_available_whisper_models() -> list[str]:
"""
Retrieves a list of available Whisper models based on downloaded .bin files in the whisper.cpp/models directory.
Filters out test models and only includes official Whisper models (e.g., base, small, medium, large).
Returns:
A list of available Whisper model names (e.g., 'base', 'small', 'medium', 'large-V3').
"""
# List of acceptable official Whisper models
valid_models = ["base", "small", "medium", "large", "large-V3"]
# Get the list of model files in the models directory
model_files = [f for f in os.listdir(WHISPER_MODEL_DIR) if f.endswith(".bin")]
# Filter out test models and models that aren't in the valid list
whisper_models = [
os.path.splitext(f)[0].replace("ggml-", "")
for f in model_files
if any(valid_model in f for valid_model in valid_models) and "test" not in f
]
# Remove any potential duplicates
whisper_models = list(set(whisper_models))
return whisper_models
def summarize_with_model(llm_model_name: str, context: str, text: str) -> str:
"""
Uses a specified model on the Ollama server to generate a summary.
Handles streaming responses by processing each line of the response.
Args:
llm_model_name (str): The name of the model to use for summarization.
context (str): Optional context for the summary, provided by the user.
text (str): The transcript text to summarize.
Returns:
str: The generated summary text from the model.
"""
prompt = f"""You are given a transcript from a meeting, along with some optional context.
Context: {context if context else 'No additional context provided.'}
The transcript is as follows:
{text}
Please summarize the transcript."""
headers = {"Content-Type": "application/json"}
data = {"model": llm_model_name, "prompt": prompt}
response = requests.post(
f"{OLLAMA_SERVER_URL}/api/generate", json=data, headers=headers, stream=True
)
if response.status_code == 200:
full_response = ""
try:
# Process the streaming response line by line
for line in response.iter_lines():
if line:
# Decode each line and parse it as a JSON object
decoded_line = line.decode("utf-8")
json_line = json.loads(decoded_line)
# Extract the "response" part from each JSON object
full_response += json_line.get("response", "")
# If "done" is True, break the loop
if json_line.get("done", False):
break
return full_response
except json.JSONDecodeError:
print("Error: Response contains invalid JSON data.")
return f"Failed to parse the response from the server. Raw response: {response.text}"
else:
raise Exception(
f"Failed to summarize with model {llm_model_name}: {response.text}"
)
def preprocess_audio_file(audio_file_path: str) -> str:
"""
Converts the input audio file to a WAV format with 16kHz sample rate and mono channel.
Args:
audio_file_path (str): Path to the input audio file.
Returns:
str: The path to the preprocessed WAV file.
"""
output_wav_file = f"{os.path.splitext(audio_file_path)[0]}_converted.wav"
# Ensure ffmpeg converts to 16kHz sample rate and mono channel
cmd = f'ffmpeg -y -i "{audio_file_path}" -ar 16000 -ac 1 "{output_wav_file}"'
subprocess.run(cmd, shell=True, check=True)
return output_wav_file
def translate_and_summarize(
audio_file_path: str, context: str, whisper_model_name: str, llm_model_name: str
) -> tuple[str, str]:
"""
Translates the audio file into text using the whisper.cpp model and generates a summary using Ollama.
Also provides the transcript file for download.
Args:
audio_file_path (str): Path to the input audio file.
context (str): Optional context to include in the summary.
whisper_model_name (str): Whisper model to use for audio-to-text conversion.
llm_model_name (str): Model to use for summarizing the transcript.
Returns:
tuple[str, str]: A tuple containing the summary and the path to the transcript file for download.
"""
output_file = "output.txt"
print("Processing audio file:", audio_file_path)
# Convert the input file to WAV format if necessary
audio_file_wav = preprocess_audio_file(audio_file_path)
print("Audio preprocessed:", audio_file_wav)
# Call the whisper.cpp binary
whisper_command = f'./whisper.cpp/main -m ./whisper.cpp/models/ggml-{whisper_model_name}.bin -f "{audio_file_wav}" > {output_file}'
subprocess.run(whisper_command, shell=True, check=True)
print("Whisper.cpp executed successfully")
# Read the output from the transcript
with open(output_file, "r") as f:
transcript = f.read()
# Save the transcript to a downloadable file
transcript_file = "transcript.txt"
with open(transcript_file, "w") as transcript_f:
transcript_f.write(transcript)
# Generate summary from the transcript using Ollama's model
summary = summarize_with_model(llm_model_name, context, transcript)
# Clean up temporary files
os.remove(audio_file_wav)
os.remove(output_file)
# Return the downloadable link for the transcript and the summary text
return summary, transcript_file
# Gradio interface
def gradio_app(
audio, context: str, whisper_model_name: str, llm_model_name: str
) -> tuple[str, str]:
"""
Gradio application to handle file upload, model selection, and summary generation.
Args:
audio: The uploaded audio file.
context (str): Optional context provided by the user.
whisper_model_name (str): The selected Whisper model name.
llm_model_name (str): The selected language model for summarization.
Returns:
tuple[str, str]: A tuple containing the summary text and a downloadable transcript file.
"""
return translate_and_summarize(audio, context, whisper_model_name, llm_model_name)
# Main function to launch the Gradio interface
if __name__ == "__main__":
# Retrieve available models for Gradio dropdown input
ollama_models = get_available_models() # Retrieve models from Ollama server
whisper_models = (
get_available_whisper_models()
) # Dynamically detect downloaded Whisper models
# Ensure the first model is selected by default
iface = gr.Interface(
fn=gradio_app,
inputs=[
gr.Audio(type="filepath", label="Upload an audio file"),
gr.Textbox(
label="Context (optional)",
placeholder="Provide any additional context for the summary",
),
gr.Dropdown(
choices=whisper_models,
label="Select a Whisper model for audio-to-text conversion",
value=whisper_models[0],
),
gr.Dropdown(
choices=ollama_models,
label="Select a model for summarization",
value=ollama_models[0] if ollama_models else None,
),
],
outputs=[
gr.Textbox(
label="Summary",
show_copy_button=True,
), # Display the summary generated by the Ollama model
gr.File(
label="Download Transcript"
), # Provide the transcript as a downloadable file
],
analytics_enabled=False,
title="Meeting Summarizer",
description="Upload an audio file of a meeting and get a summary of the key concepts discussed.",
)
iface.launch(debug=True)