Skip to content

Commit

Permalink
#538 のコメントを反映 (#546)
Browse files Browse the repository at this point in the history
* Refactor: Python APIの変換系を移動

* Add: pydanticにバリデーターを追加

* Add: pyiにVoicevoxErrorを記述

Co-Authored-By: Qryxip <[email protected]>

* Add: to_zenkakuを通すように

* Update crates/voicevox_core_python_api/python/voicevox_core/_rust.pyi

Co-authored-by: Ryo Yamashita <[email protected]>

---------

Co-authored-by: Qryxip <[email protected]>
Co-authored-by: Ryo Yamashita <[email protected]>
  • Loading branch information
3 people authored Jul 28, 2023
1 parent cb8e83c commit f1dd63b
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 150 deletions.
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/user_dict/word.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl UserDictWord {
}

/// カタカナの文字列が発音として有効かどうかを判定する。
fn validate_pronunciation(pronunciation: &str) -> InvalidWordResult<()> {
pub fn validate_pronunciation(pronunciation: &str) -> InvalidWordResult<()> {
// 元実装:https://github.com/VOICEVOX/voicevox_engine/blob/39747666aa0895699e188f3fd03a0f448c9cf746/voicevox_engine/model.py#L190-L210
if !PRONUNCIATION_REGEX.is_match(pronunciation) {
return Err(InvalidWordError::InvalidPronunciation(
Expand Down Expand Up @@ -182,7 +182,7 @@ fn calculate_mora_count(pronunciation: &str, accent_type: usize) -> InvalidWordR
/// - "!"から"~"までの範囲の文字(数字やアルファベット)は、対応する全角文字に
/// - " "などの目に見えない文字は、まとめて全角スペース(0x3000)に
/// 変換する。
fn to_zenkaku(surface: &str) -> String {
pub fn to_zenkaku(surface: &str) -> String {
// 元実装:https://github.com/VOICEVOX/voicevox/blob/69898f5dd001d28d4de355a25766acb0e0833ec2/src/components/DictionaryManageDialog.vue#L379-L387
SPACE_REGEX
.replace_all(surface, "\u{3000}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def test_user_dict_load() -> None:

# ユーザー辞書のエクスポート
dict_c = voicevox_core.UserDict()
uuid_c=dict_c.add_word(
uuid_c = dict_c.add_word(
voicevox_core.UserDictWord(
surface="bar",
pronunciation="バー",
Expand All @@ -66,3 +66,12 @@ async def test_user_dict_load() -> None:
dict_a.remove_word(uuid_a)
assert uuid_a not in dict_a.words
assert uuid_c in dict_a.words

# 単語のバリデーション
with pytest.raises(voicevox_core.VoicevoxError):
dict_a.add_word(
voicevox_core.UserDictWord(
surface="",
pronunciation="カタカナ以外の文字",
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
OpenJtalk,
Synthesizer,
VoiceModel,
VoicevoxError,
UserDict,
supported_devices,
) # noqa: F401
Expand All @@ -26,6 +27,7 @@
"SpeakerMeta",
"SupportedDevices",
"Synthesizer",
"VoicevoxError",
"VoiceModel",
"supported_devices",
"UserDict",
Expand Down
11 changes: 11 additions & 0 deletions crates/voicevox_core_python_api/python/voicevox_core/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import pydantic

from ._rust import _validate_pronunciation, _to_zenkaku


@pydantic.dataclasses.dataclass
class StyleMeta:
Expand Down Expand Up @@ -89,3 +91,12 @@ class UserDictWord:
default=UserDictWordType.COMMON_NOUN
)
priority: int = dataclasses.field(default=5)

@pydantic.validator("pronunciation")
def validate_pronunciation(cls, v):
_validate_pronunciation(v)
return v

@pydantic.validator("surface")
def validate_surface(cls, v):
return _to_zenkaku(v)
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,11 @@ class UserDict:
インポートするユーザー辞書。
"""
...

class VoicevoxError(Exception):
"""VOICEVOXで発生したエラー。"""

...

def _validate_pronunciation(pronunciation: str) -> None: ...
def _to_zenkaku(text: str) -> str: ...
148 changes: 148 additions & 0 deletions crates/voicevox_core_python_api/src/convert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
use crate::VoicevoxError;
use std::{fmt::Display, future::Future, path::PathBuf};

use easy_ext::ext;
use pyo3::{types::PyList, FromPyObject as _, PyAny, PyObject, PyResult, Python, ToPyObject};
use serde::{de::DeserializeOwned, Serialize};
use serde_json::json;
use uuid::Uuid;
use voicevox_core::{
AccelerationMode, AccentPhraseModel, StyleId, UserDictWordType, VoiceModelMeta,
};

pub fn from_acceleration_mode(ob: &PyAny) -> PyResult<AccelerationMode> {
let py = ob.py();

let class = py.import("voicevox_core")?.getattr("AccelerationMode")?;
let mode = class.get_item(ob)?;

if mode.eq(class.getattr("AUTO")?)? {
Ok(AccelerationMode::Auto)
} else if mode.eq(class.getattr("CPU")?)? {
Ok(AccelerationMode::Cpu)
} else if mode.eq(class.getattr("GPU")?)? {
Ok(AccelerationMode::Gpu)
} else {
unreachable!("{} should be one of {{AUTO, CPU, GPU}}", mode.repr()?);
}
}

pub fn from_utf8_path(ob: &PyAny) -> PyResult<String> {
PathBuf::extract(ob)?
.into_os_string()
.into_string()
.map_err(|s| VoicevoxError::new_err(format!("{s:?} cannot be encoded to UTF-8")))
}

pub fn from_dataclass<T: DeserializeOwned>(ob: &PyAny) -> PyResult<T> {
let py = ob.py();

let ob = py.import("dataclasses")?.call_method1("asdict", (ob,))?;
let json = &py
.import("json")?
.call_method1("dumps", (ob,))?
.extract::<String>()?;
serde_json::from_str(json).into_py_result()
}

pub fn to_pydantic_voice_model_meta<'py>(
metas: &VoiceModelMeta,
py: Python<'py>,
) -> PyResult<Vec<&'py PyAny>> {
let class = py
.import("voicevox_core")?
.getattr("SpeakerMeta")?
.downcast()?;

metas
.iter()
.map(|m| to_pydantic_dataclass(m, class))
.collect::<PyResult<Vec<_>>>()
}

pub fn to_pydantic_dataclass(x: impl Serialize, class: &PyAny) -> PyResult<&PyAny> {
let py = class.py();

let x = serde_json::to_string(&x).into_py_result()?;
let x = py.import("json")?.call_method1("loads", (x,))?.downcast()?;
class.call((), Some(x))
}

pub fn modify_accent_phrases<'py, Fun, Fut>(
accent_phrases: &'py PyList,
speaker_id: StyleId,
py: Python<'py>,
method: Fun,
) -> PyResult<&'py PyAny>
where
Fun: FnOnce(Vec<AccentPhraseModel>, StyleId) -> Fut + Send + 'static,
Fut: Future<Output = voicevox_core::Result<Vec<AccentPhraseModel>>> + Send + 'static,
{
let rust_accent_phrases = accent_phrases
.iter()
.map(from_dataclass)
.collect::<PyResult<Vec<AccentPhraseModel>>>()?;
pyo3_asyncio::tokio::future_into_py_with_locals(
py,
pyo3_asyncio::tokio::get_current_locals(py)?,
async move {
let replaced_accent_phrases = method(rust_accent_phrases, speaker_id)
.await
.into_py_result()?;
Python::with_gil(|py| {
let replaced_accent_phrases = replaced_accent_phrases
.iter()
.map(move |accent_phrase| {
to_pydantic_dataclass(
accent_phrase,
py.import("voicevox_core")?.getattr("AccentPhrase")?,
)
})
.collect::<PyResult<Vec<_>>>()?;
let replaced_accent_phrases = PyList::new(py, replaced_accent_phrases);
Ok(replaced_accent_phrases.to_object(py))
})
},
)
}
pub fn to_rust_uuid(ob: &PyAny) -> PyResult<Uuid> {
let uuid = ob.getattr("hex")?.extract::<String>()?;
uuid.parse().into_py_result()
}
pub fn to_py_uuid(py: Python, uuid: Uuid) -> PyResult<PyObject> {
let uuid = uuid.hyphenated().to_string();
let uuid = py.import("uuid")?.call_method1("UUID", (uuid,))?;
Ok(uuid.to_object(py))
}
pub fn to_rust_user_dict_word(ob: &PyAny) -> PyResult<voicevox_core::UserDictWord> {
voicevox_core::UserDictWord::new(
ob.getattr("surface")?.extract()?,
ob.getattr("pronunciation")?.extract()?,
ob.getattr("accent_type")?.extract()?,
to_rust_word_type(ob.getattr("word_type")?.extract()?)?,
ob.getattr("priority")?.extract()?,
)
.into_py_result()
}
pub fn to_py_user_dict_word<'py>(
py: Python<'py>,
word: &voicevox_core::UserDictWord,
) -> PyResult<&'py PyAny> {
let class = py
.import("voicevox_core")?
.getattr("UserDictWord")?
.downcast()?;
to_pydantic_dataclass(word, class)
}
pub fn to_rust_word_type(word_type: &PyAny) -> PyResult<UserDictWordType> {
let name = word_type.getattr("name")?.extract::<String>()?;

serde_json::from_value::<UserDictWordType>(json!(name)).into_py_result()
}

#[ext]
pub impl<T, E: Display> Result<T, E> {
fn into_py_result(self) -> PyResult<T> {
self.map_err(|e| VoicevoxError::new_err(e.to_string()))
}
}
Loading

0 comments on commit f1dd63b

Please sign in to comment.