diff --git a/crates/voicevox_core_python_api/python/test/test_pseudo_raii_for_synthesizer.py b/crates/voicevox_core_python_api/python/test/test_pseudo_raii_for_synthesizer.py new file mode 100644 index 000000000..165770dab --- /dev/null +++ b/crates/voicevox_core_python_api/python/test/test_pseudo_raii_for_synthesizer.py @@ -0,0 +1,45 @@ +""" +``Synthesizer`` について、(広義の)RAIIができることをテストする。 +""" + +import conftest +import pytest +import pytest_asyncio +from voicevox_core import OpenJtalk, Synthesizer, VoicevoxError + + +def test_enter_returns_workable_self(synthesizer: Synthesizer) -> None: + with synthesizer as ctx: + assert ctx is synthesizer + _ = synthesizer.metas + + +def test_closing_multiple_times_is_allowed(synthesizer: Synthesizer) -> None: + with synthesizer: + with synthesizer: + pass + synthesizer.close() + synthesizer.close() + + +def test_access_after_close_denied(synthesizer: Synthesizer) -> None: + synthesizer.close() + with pytest.raises(VoicevoxError, match="^The `Synthesizer` is closed$"): + _ = synthesizer.metas + + +def test_access_after_exit_denied(synthesizer: Synthesizer) -> None: + with synthesizer: + pass + with pytest.raises(VoicevoxError, match="^The `Synthesizer` is closed$"): + _ = synthesizer.metas + + +@pytest_asyncio.fixture +async def synthesizer(open_jtalk: OpenJtalk) -> Synthesizer: + return await Synthesizer.new_with_initialize(open_jtalk) + + +@pytest.fixture(scope="module") +def open_jtalk() -> OpenJtalk: + return OpenJtalk(conftest.open_jtalk_dic_dir) diff --git a/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi b/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi index 345f483be..66625d230 100644 --- a/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi +++ b/crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi @@ -90,6 +90,8 @@ class Synthesizer: """ ... def __repr__(self) -> str: ... + def __enter__(self) -> "Synthesizer": ... + def __exit__(self, exc_type, exc_value, traceback) -> None: ... @property def is_gpu_mode(self) -> bool: """ハードウェアアクセラレーションがGPUモードかどうか。""" @@ -219,6 +221,7 @@ class Synthesizer: :returns: WAVデータ。 """ ... + def close(self) -> None: ... class UserDict: """ユーザー辞書。 diff --git a/crates/voicevox_core_python_api/src/lib.rs b/crates/voicevox_core_python_api/src/lib.rs index 09cea260e..40b9e31bd 100644 --- a/crates/voicevox_core_python_api/src/lib.rs +++ b/crates/voicevox_core_python_api/src/lib.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{marker::PhantomData, sync::Arc}; mod convert; use convert::*; @@ -9,7 +9,7 @@ use pyo3::{ exceptions::PyException, pyclass, pyfunction, pymethods, pymodule, types::{IntoPyDict as _, PyBytes, PyDict, PyList, PyModule}, - wrap_pyfunction, PyAny, PyObject, PyResult, Python, ToPyObject, + wrap_pyfunction, PyAny, PyObject, PyRef, PyResult, PyTypeInfo, Python, ToPyObject, }; use tokio::{runtime::Runtime, sync::Mutex}; use uuid::Uuid; @@ -114,7 +114,7 @@ impl OpenJtalk { #[pyclass] struct Synthesizer { - synthesizer: Arc>, + synthesizer: Closable>, Self>, } #[pymethods] @@ -143,9 +143,10 @@ impl Synthesizer { }, ) .await - .into_py_result()?; + .into_py_result()? + .into(); Ok(Self { - synthesizer: Arc::new(Mutex::new(synthesizer)), + synthesizer: Closable::new(Arc::new(synthesizer)), }) }) } @@ -154,14 +155,30 @@ impl Synthesizer { "Synthesizer { .. }" } + fn __enter__(slf: PyRef<'_, Self>) -> PyResult> { + slf.synthesizer.get()?; + Ok(slf) + } + + fn __exit__( + &mut self, + #[allow(unused_variables)] exc_type: &PyAny, + #[allow(unused_variables)] exc_value: &PyAny, + #[allow(unused_variables)] traceback: &PyAny, + ) { + self.close(); + } + #[getter] - fn is_gpu_mode(&self) -> bool { - RUNTIME.block_on(self.synthesizer.lock()).is_gpu_mode() + fn is_gpu_mode(&self) -> PyResult { + let synthesizer = self.synthesizer.get()?; + Ok(RUNTIME.block_on(synthesizer.lock()).is_gpu_mode()) } #[getter] - fn metas<'py>(&self, py: Python<'py>) -> Vec<&'py PyAny> { - to_pydantic_voice_model_meta(RUNTIME.block_on(self.synthesizer.lock()).metas(), py).unwrap() + fn metas<'py>(&self, py: Python<'py>) -> PyResult> { + let synthesizer = self.synthesizer.get()?; + to_pydantic_voice_model_meta(RUNTIME.block_on(synthesizer.lock()).metas(), py) } fn load_voice_model<'py>( @@ -170,7 +187,7 @@ impl Synthesizer { py: Python<'py>, ) -> PyResult<&'py PyAny> { let model: VoiceModel = model.extract()?; - let synthesizer = self.synthesizer.clone(); + let synthesizer = self.synthesizer.get()?.clone(); pyo3_asyncio::tokio::future_into_py(py, async move { synthesizer .lock() @@ -183,15 +200,15 @@ impl Synthesizer { fn unload_voice_model(&mut self, voice_model_id: &str) -> PyResult<()> { RUNTIME - .block_on(self.synthesizer.lock()) + .block_on(self.synthesizer.get()?.lock()) .unload_voice_model(&VoiceModelId::new(voice_model_id.to_string())) .into_py_result() } - fn is_loaded_voice_model(&self, voice_model_id: &str) -> bool { - RUNTIME - .block_on(self.synthesizer.lock()) - .is_loaded_voice_model(&VoiceModelId::new(voice_model_id.to_string())) + fn is_loaded_voice_model(&self, voice_model_id: &str) -> PyResult { + Ok(RUNTIME + .block_on(self.synthesizer.get()?.lock()) + .is_loaded_voice_model(&VoiceModelId::new(voice_model_id.to_string()))) } #[pyo3(signature=(text,style_id,kana = AudioQueryOptions::default().kana))] @@ -202,7 +219,7 @@ impl Synthesizer { kana: bool, py: Python<'py>, ) -> PyResult<&'py PyAny> { - let synthesizer = self.synthesizer.clone(); + let synthesizer = self.synthesizer.get()?.clone(); let text = text.to_owned(); pyo3_asyncio::tokio::future_into_py_with_locals( py, @@ -232,7 +249,7 @@ impl Synthesizer { kana: bool, py: Python<'py>, ) -> PyResult<&'py PyAny> { - let synthesizer = self.synthesizer.clone(); + let synthesizer = self.synthesizer.get()?.clone(); let text = text.to_owned(); pyo3_asyncio::tokio::future_into_py_with_locals( py, @@ -267,7 +284,7 @@ impl Synthesizer { style_id: u32, py: Python<'py>, ) -> PyResult<&'py PyAny> { - let synthesizer = self.synthesizer.clone(); + let synthesizer = self.synthesizer.get()?.clone(); modify_accent_phrases( accent_phrases, StyleId::new(style_id), @@ -282,7 +299,7 @@ impl Synthesizer { style_id: u32, py: Python<'py>, ) -> PyResult<&'py PyAny> { - let synthesizer = self.synthesizer.clone(); + let synthesizer = self.synthesizer.get()?.clone(); modify_accent_phrases( accent_phrases, StyleId::new(style_id), @@ -297,7 +314,7 @@ impl Synthesizer { style_id: u32, py: Python<'py>, ) -> PyResult<&'py PyAny> { - let synthesizer = self.synthesizer.clone(); + let synthesizer = self.synthesizer.get()?.clone(); modify_accent_phrases( accent_phrases, StyleId::new(style_id), @@ -314,7 +331,7 @@ impl Synthesizer { enable_interrogative_upspeak: bool, py: Python<'py>, ) -> PyResult<&'py PyAny> { - let synthesizer = self.synthesizer.clone(); + let synthesizer = self.synthesizer.get()?.clone(); pyo3_asyncio::tokio::future_into_py_with_locals( py, pyo3_asyncio::tokio::get_current_locals(py)?, @@ -355,7 +372,7 @@ impl Synthesizer { kana, enable_interrogative_upspeak, }; - let synthesizer = self.synthesizer.clone(); + let synthesizer = self.synthesizer.get()?.clone(); let text = text.to_owned(); pyo3_asyncio::tokio::future_into_py_with_locals( py, @@ -371,6 +388,52 @@ impl Synthesizer { }, ) } + + fn close(&mut self) { + self.synthesizer.close() + } +} + +struct Closable { + content: MaybeClosed, + marker: PhantomData, +} + +enum MaybeClosed { + Open(T), + Closed, +} + +impl Closable { + fn new(content: T) -> Self { + Self { + content: MaybeClosed::Open(content), + marker: PhantomData, + } + } + + fn get(&self) -> PyResult<&T> { + match &self.content { + MaybeClosed::Open(content) => Ok(content), + MaybeClosed::Closed => Err(VoicevoxError::new_err(format!( + "The `{}` is closed", + C::NAME, + ))), + } + } + + fn close(&mut self) { + if matches!(self.content, MaybeClosed::Open(_)) { + debug!("Closing a {}", C::NAME); + } + self.content = MaybeClosed::Closed; + } +} + +impl Drop for Closable { + fn drop(&mut self) { + self.close(); + } } #[pyfunction] @@ -451,9 +514,3 @@ impl UserDict { Ok(words.into_py_dict(py)) } } - -impl Drop for Synthesizer { - fn drop(&mut self) { - debug!("Destructing a VoicevoxCore"); - } -}