diff --git a/whisper_live/server.py b/whisper_live/server.py index 3c38d2a..e3346d2 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -787,9 +787,15 @@ def __init__(self, websocket, task="transcribe", device=None, language=None, cli self.no_speech_thresh = 0.45 device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda": + major, _ = torch.cuda.get_device_capability(device) + self.compute_type = "float16" if major >= 7 else "float32" + else: + self.compute_type = "int8" if self.model_size_or_path is None: return + logging.info(f"Using Device={device} with precision {self.compute_type}") if single_model: if ServeClientFasterWhisper.SINGLE_MODEL is None: @@ -822,7 +828,7 @@ def create_model(self, device): self.transcriber = WhisperModel( self.model_size_or_path, device=device, - compute_type="int8" if device == "cpu" else "float16", + compute_type=self.compute_type, local_files_only=False, )