Skip to content

Commit

Permalink
ovos-audio compatibility updates (#144)
Browse files Browse the repository at this point in the history
* Update ident handling and loosen ovos-audio dependency spec

* Troubleshooting playback_thread handling changes

* Troubleshooting playback_thread handling changes

* Troubleshooting Neon overrides

* Remove unused PIDLock

* Troubleshooting class signature breaking changes

* Remove unused `kwargs` param

* Update imports and inheritance testing in unit_tests.py

* Cleanup ident handling in tests

* Revert extra change from troubleshooting

---------

Co-authored-by: Daniel McKnight <[email protected]>
  • Loading branch information
NeonDaniel and NeonDaniel authored Oct 11, 2023
1 parent c0a2d8d commit ee4d469
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 51 deletions.
31 changes: 21 additions & 10 deletions neon_audio/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import ovos_plugin_manager.templates.tts

from threading import Event
from ovos_utils.log import LOG
from ovos_utils.log import LOG, log_deprecation
from neon_audio.tts import TTSFactory
from neon_utils.messagebus_utils import get_messagebus
from neon_utils.metrics_utils import Stopwatch
Expand Down Expand Up @@ -94,6 +94,8 @@ def __init__(self, ready_hook=on_ready, error_hook=on_error,
ovos_plugin_manager.templates.tts.check_for_signal = check_for_signal
ovos_plugin_manager.templates.tts.create_signal = create_signal

from neon_audio.tts.neon import NeonPlaybackThread
ovos_audio.service.PlaybackThread = NeonPlaybackThread
PlaybackService.__init__(self, ready_hook, error_hook, stopping_hook,
alive_hook, started_hook, watchdog, bus)
LOG.debug(f'Initialized tts={self._tts_hash} | '
Expand All @@ -107,30 +109,39 @@ def handle_speak(self, message):
if isinstance(message.context['destination'], str):
message.context['destination'] = [message.context['destination']]
if "audio" not in message.context['destination']:
LOG.warning("Adding audio to destination context")
log_deprecation("Adding audio to destination context", "2.0.0")
message.context['destination'].append('audio')

audio_finished = Event()

message.context.setdefault("timing", dict())
message.context["timing"].setdefault("speech_start", time())
ident = message.data.get('speak_ident') or message.context.get('ident')
if not ident:
LOG.warning(f"Ident missing for speak: {message.data}")

if message.context.get('ident'):
log_deprecation("ident context is deprecated. Use `session`",
"2.0.0")
if not message.context.get('session'):
LOG.info("No session context. Adding session from ident.")

speak_id = message.data.get('speak_ident') or \
message.context.get('ident') or message.data.get('ident')
message.context['speak_ident'] = speak_id
if not speak_id:
LOG.warning(f"`speak_ident` data missing: {message.data}")

def handle_finished(_):
audio_finished.set()
if ident:
self.bus.once(ident, handle_finished)
if speak_id:
self.bus.once(speak_id, handle_finished)
else:
audio_finished.set()

PlaybackService.handle_speak(self, message)
if not audio_finished.wait(self._playback_timeout):
LOG.warning(f"Playback not completed for {ident} within "
LOG.warning(f"Playback not completed for {speak_id} within "
f"{self._playback_timeout} seconds")
elif ident:
LOG.debug(f"Playback completed for: {ident}")
elif speak_id:
LOG.debug(f"Playback completed for: {speak_id}")

def handle_get_tts(self, message):
"""
Expand Down
36 changes: 22 additions & 14 deletions neon_audio/tts/neon.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from ovos_bus_client.message import Message
from ovos_plugin_manager.language import OVOSLangDetectionFactory,\
OVOSLangTranslationFactory
from ovos_plugin_manager.templates.tts import TTS, PlaybackThread
from ovos_plugin_manager.templates.tts import TTS
from ovos_utils.enclosure.api import EnclosureAPI

from neon_utils.file_utils import encode_file_to_base64_string
Expand All @@ -44,7 +44,7 @@
from neon_utils.signal_utils import create_signal, check_for_signal,\
init_signal_bus
from ovos_utils.log import LOG, log_deprecation

from ovos_audio.playback import PlaybackThread
from ovos_config.config import Configuration


Expand Down Expand Up @@ -132,8 +132,9 @@ def get_requested_tts_languages(msg) -> list:


class NeonPlaybackThread(PlaybackThread):
def __init__(self, queue):
PlaybackThread.__init__(self, queue)
def __init__(self, queue, bus=None):
LOG.info("Initializing NeonPlaybackThread")
PlaybackThread.__init__(self, queue, bus=bus)

def begin_audio(self, message=None):
# TODO: Mark signals for deprecation
Expand All @@ -151,8 +152,9 @@ def _play(self):
if not ident and len(self._now_playing) >= 5 and \
isinstance(self._now_playing[4], Message):
LOG.debug("Handling new style playback")
ident = self._now_playing[4].context.get('session',
{}).get('session_id')
ident = self._now_playing[4].context.get('ident') or \
self._now_playing[4].context.get('session',
{}).get('session_id')
super()._play()
LOG.info(f"Played {ident}")
self.bus.emit(Message(ident))
Expand Down Expand Up @@ -182,10 +184,13 @@ def _init_neon(base_engine, *args, **kwargs):
base_engine.lang = base_engine.lang or language_config.get("user",
"en-us")
try:
base_engine.lang_detector = \
OVOSLangDetectionFactory.create(language_config)
base_engine.translator = \
OVOSLangTranslationFactory.create(language_config)
if language_config.get('detection_module'):
# Prevent loading a detector if not configured
base_engine.lang_detector = \
OVOSLangDetectionFactory.create(language_config)
if language_config.get('translation_module'):
base_engine.translator = \
OVOSLangTranslationFactory.create(language_config)
except ValueError as e:
LOG.error(e)
base_engine.lang_detector = None
Expand All @@ -199,18 +204,21 @@ def _init_neon(base_engine, *args, **kwargs):
base_engine.cached_translations = cached_translations
return base_engine

def _init_playback(self):
def _init_playback(self, playback_thread: NeonPlaybackThread = None):
# shutdown any previous thread
if TTS.playback:
TTS.playback.shutdown()

if not isinstance(playback_thread, NeonPlaybackThread):
LOG.exception("Received invalid playback_thread")
playback_thread = None
init_signal_bus(self.bus)
TTS.playback = NeonPlaybackThread(TTS.queue)
TTS.playback = playback_thread or NeonPlaybackThread(TTS.queue)
TTS.playback.set_bus(self.bus)
TTS.playback.attach_tts(self)
if not TTS.playback.enclosure:
TTS.playback.enclosure = EnclosureAPI(self.bus)
TTS.playback.start()
if not TTS.playback.is_running:
TTS.playback.start()

def _get_tts(self, sentence: str, request: dict = None, **kwargs):
# TODO: Signature should be made to match ovos-audio
Expand Down
6 changes: 2 additions & 4 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
ovos-audio~=0.0.1,>=0.0.2a12,<0.0.2a14
# TODO: ovos-audio 0.0.2a14 introduces a breaking change around `ident` handling
ovos-audio~=0.0.1,>=0.0.2a14
ovos-utils==0.0.35
ovos-config~=0.0.10
phoneme-guesser~=0.1
ovos-plugin-manager~=0.0.22,>=0.0.24a5,<0.0.24a10
# TODO: ovos-plugin-manager 0.0.24a10 depends on an ovos-audio version newer than 0.0.2a14
ovos-plugin-manager~=0.0.22,>=0.0.24a5
neon-utils[network]~=1.6
click~=8.0
click-default-group~=1.2
Expand Down
33 changes: 11 additions & 22 deletions tests/api_method_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,48 +129,37 @@ def test_handle_speak(self):
mock_tts = Mock()
self.audio_service.execute_tts = mock_tts

# TODO: this destination handling should be deprecated
# 'audio' not in destination
message_invalid_destination = Message("speak",
{"utterance": "test"},
{"ident": "test",
"destination": ['invalid']})
self.audio_service.handle_speak(message_invalid_destination)
mock_tts.assert_called_with("test", "test", False)
session = {"session_id": "test_session"}

# 'audio' in destination
message_valid_destination = Message("speak",
{"utterance": "test1"},
{"ident": "test2",
"destination": ['invalid',
'audio']})
'audio'],
"session": session})
self.audio_service.handle_speak(message_valid_destination)
mock_tts.assert_called_with("test1", "test2", False)
mock_tts.assert_called_with("test1", "test_session", False,
message_valid_destination)

# str 'audio' destination
message_valid_destination = Message("speak",
{"utterance": "test5"},
{"ident": "test6",
"destination": 'audio'})
"destination": 'audio',
"session": session})
self.audio_service.handle_speak(message_valid_destination)
mock_tts.assert_called_with("test5", "test6", False)

# TODO: this destination handling should be deprecated
# no destination context
message_no_destination = Message("speak",
{"utterance": "test3"},
{"ident": "test4"})
self.audio_service.handle_speak(message_no_destination)
mock_tts.assert_called_with("test3", "test4", False)
mock_tts.assert_called_with("test5", "test_session", False,
message_valid_destination)

# Setup bus API handling
self.audio_service._playback_timeout = 60
self.audio_service._playback_timeout = 5
msg = None

def handle_tts(*args, **kwargs):
nonlocal msg
msg = dig_for_message()
ident = msg.data.get('speak_ident') or msg.data.get('ident')
ident = msg.context.get('speak_ident')
if ident:
self.bus.emit(Message(ident))

Expand Down
6 changes: 5 additions & 1 deletion tests/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from unittest.mock import Mock, patch
from click.testing import CliRunner
from ovos_bus_client import Message
from ovos_plugin_manager.templates.tts import PlaybackThread
# from ovos_plugin_manager.templates.tts import PlaybackThread
from ovos_utils.messagebus import FakeBus

sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
Expand Down Expand Up @@ -93,7 +93,11 @@ def test_class_init(self):

self.assertEqual(self.tts.voice, "default")
self.assertTrue(self.tts.queue.empty())

from ovos_audio.playback import PlaybackThread
from neon_audio.tts.neon import NeonPlaybackThread
self.assertIsInstance(self.tts.playback, PlaybackThread)
self.assertIsInstance(self.tts.playback, NeonPlaybackThread)

self.assertIsInstance(self.tts.spellings, dict)
self.assertEqual(self.tts.tts_name, "DummyTTS")
Expand Down

0 comments on commit ee4d469

Please sign in to comment.