Skip to content
This repository has been archived by the owner on May 12, 2024. It is now read-only.

Commit

Permalink
catalan ud
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Oct 24, 2023
1 parent 357dd8a commit 11cdc2f
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 49 deletions.
63 changes: 21 additions & 42 deletions baln/asrengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# PYTORCH_ENABLE_MPS_FALLBACK=1
# pretrained model path
# PRETRAINED = "openai/whisper-small"
# FILE = "./data/test.wav"
# FILE = "../talkbank-alignment/broken2/input/53.wav"
# # PRETRAINED = "openai/whisper-small"
# PRETRAINED = "talkbank/CHATWhisper-en-large-v1"
# # PRETRAINED = "openai/whisper-large-v2"
# # FILE = "./data/test.wav"
# FILE = "../talkbank-alignment/testing_playground_2/input/test.wav"
# # FILE = "../talkbank-alignment/broken2/input/53.wav"

@dataclass
class ASRAudioFile:
Expand Down Expand Up @@ -93,7 +96,9 @@ def __init__(self, model, base="openai/whisper-large-v2", language="english", ta
processor = WhisperProcessor.from_pretrained(base)

# force decoder IDs to create language
self.__decoder_ids = processor.get_decoder_prompt_ids(language=language, task="transcribe")
self.__decoder_ids = processor.get_decoder_prompt_ids(language=language,
task="transcribe")
self.__prompt_ids = processor.get_prompt_ids("um hello.")

# save the target sample rate
self.sample_rate = target_sample_rate
Expand Down Expand Up @@ -166,17 +171,23 @@ def __call__(self, data, segments):

words = self.pipe(data.cpu().numpy(),
batch_size=8,
generate_kwargs = {"forced_decoder_ids": self.__decoder_ids,
"repetition_penalty": 1.01
generate_kwargs = {
"forced_decoder_ids": self.__decoder_ids,
"repetition_penalty": 1.01,
"prompt_ids": self.__prompt_ids
})
# "do_sample": True,
# "temperature": 0.1
# })
})
# breakpoint()
# "temperature": 0,
#"temperature": 0.75,
# })
words = words["chunks"]
# to filter out the two word prompt
words = words["chunks"][2:]

# filter out the elements in the prompt, which has timestamp (0,0)
# words = list(filter(lambda x:x["timestamp"] != (0.0, 0.0), words))


for word in words:
groups.append({
Expand Down Expand Up @@ -218,38 +229,6 @@ def __call__(self, data, segments):
"speaker": current_speaker[0] if type(current_speaker) == tuple else current_speaker
})

return {
"monologues": turns
}


# mid_step = 0.1
# dia_cls = raw_dia[0]

# # create the segments
# groups = []
# cur_start = 0
# cur_spk = dia_cls[0]

# for indx, i in zip(secs, dia_cls):
# if i != cur_spk:
# # results is by 0.1 second steps
# groups.append({
# "type": "segment",
# "start": cur_start/10,
# "end": indx/10,
# "payload": cur_spk
# })
# cur_start = indx
# cur_spk = i

# groups

# raw_dia[:1000]

# # e = ASREngine(PRETRAINED, "english")
# # audio, segments = e.load(FILE, 2)
# # result = e(audio.all(), segments)
# # words = raw["chunks"]
return ({"monologues": turns})


4 changes: 2 additions & 2 deletions baln/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from multiprocessing import Process, freeze_support

VERSION="0.3.45"
NOTES="releasing new whisper"
VERSION="0.3.46"
NOTES="catalan bugs"

#################### OPTIONS ################################

Expand Down
8 changes: 4 additions & 4 deletions baln/retokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,10 @@ def retokenize(infile, outfile, utterance_engine, interactive=False, provider=AS
# we will now use UtteranceEngine to redo
# utterance tokenization
# chunk the passage into utterances
if lang == "en":
chunked_passage = utterance_engine(passage)
else:
chunked_passage = [i.replace("...", ".") for i in sent_tokenize(passage)]
# if lang == "en":
# chunked_passage = utterance_engine(passage)
# else:
chunked_passage = [i.replace("...", ".") for i in sent_tokenize(passage)]
# remove the end delimiters (again!) because
# all we case about here are the split utterances
# we will add "." later
Expand Down
2 changes: 2 additions & 0 deletions baln/ud.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def handler(word):
target = target.replace(',', '')
target = target.replace('\'', '')
target = target.replace('~', '')
target = target.replace('/100', '')
target = target.replace('/r', '')

# remove attachments
if "|" in target:
Expand Down
2 changes: 1 addition & 1 deletion meta.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% set name = "batchalign" %}
{% set version = "0.3.45" %}
{% set version = "0.3.46" %}

package:
name: {{ name }}
Expand Down

0 comments on commit 11cdc2f

Please sign in to comment.