From db228b1438af6124ba1aed3712c4527cfff2c445 Mon Sep 17 00:00:00 2001 From: Ryo Yamashita Date: Thu, 10 Oct 2024 01:09:43 +0900 Subject: [PATCH] =?UTF-8?q?refactor:=20`InferenceDomainMapValues`=E3=81=AE?= =?UTF-8?q?=E3=82=A4=E3=83=B3=E3=82=B9=E3=82=BF=E3=83=B3=E3=82=B9=E3=82=92?= =?UTF-8?q?=E3=83=9E=E3=82=AF=E3=83=AD=E3=81=A7=E4=BD=9C=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #737 に向け。また #851 の後にdecode.onnx入りのVVMに対応するときも同様に 役に立つはず。 --- Cargo.lock | 82 ++++-------------- Cargo.toml | 1 + crates/voicevox_core/src/infer.rs | 3 +- crates/voicevox_core/src/infer/domains.rs | 9 ++ .../voicevox_core/src/infer/domains/talk.rs | 3 +- crates/voicevox_core/src/manifest.rs | 4 +- crates/voicevox_core/src/status.rs | 7 +- crates/voicevox_core/src/voice_model.rs | 11 +-- crates/voicevox_core_macros/Cargo.toml | 3 +- .../src/inference_domains.rs | 83 +++++++++++++++++++ crates/voicevox_core_macros/src/lib.rs | 22 +++++ 11 files changed, 147 insertions(+), 81 deletions(-) create mode 100644 crates/voicevox_core_macros/src/inference_domains.rs diff --git a/Cargo.lock b/Cargo.lock index 1dd3d5d53..95113d18f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -366,7 +366,7 @@ dependencies = [ "bitflags 2.5.0", "cexpr", "clang-sys", - "itertools 0.11.0", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -1035,6 +1035,17 @@ dependencies = [ "syn 1.0.102", ] +[[package]] +name = "derive-syn-parse" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d65d7ce8132b7c0e54497a4d9a55a1c2a0912a0d786cf894472ba818fba45762" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "derive_builder" version = "0.20.0" @@ -1881,15 +1892,6 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - [[package]] name = "itertools" version = "0.12.1" @@ -2003,7 +2005,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-targets 0.48.0", + "windows-targets 0.52.6", ] [[package]] @@ -4435,6 +4437,7 @@ dependencies = [ name = "voicevox_core_macros" version = "0.0.0" dependencies = [ + "derive-syn-parse", "indexmap 2.6.0", "proc-macro2", "quote", @@ -4724,21 +4727,6 @@ dependencies = [ "windows_x86_64_msvc 0.42.2", ] -[[package]] -name = "windows-targets" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" -dependencies = [ - "windows_aarch64_gnullvm 0.48.0", - "windows_aarch64_msvc 0.48.0", - "windows_i686_gnu 0.48.0", - "windows_i686_msvc 0.48.0", - "windows_x86_64_gnu 0.48.0", - "windows_x86_64_gnullvm 0.48.0", - "windows_x86_64_msvc 0.48.0", -] - [[package]] name = "windows-targets" version = "0.52.6" @@ -4761,12 +4749,6 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" - [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" @@ -4785,12 +4767,6 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" - [[package]] name = "windows_aarch64_msvc" version = "0.52.6" @@ -4809,12 +4785,6 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" -[[package]] -name = "windows_i686_gnu" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" - [[package]] name = "windows_i686_gnu" version = "0.52.6" @@ -4839,12 +4809,6 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" -[[package]] -name = "windows_i686_msvc" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" - [[package]] name = "windows_i686_msvc" version = "0.52.6" @@ -4863,12 +4827,6 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" - [[package]] name = "windows_x86_64_gnu" version = "0.52.6" @@ -4881,12 +4839,6 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" - [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" @@ -4905,12 +4857,6 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" - [[package]] name = "windows_x86_64_msvc" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index 69ecbb58a..a4260b25e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ const_format = "0.2.33" cstr = "0.2.12" # https://github.com/dtolnay/syn/issues/1502 derive-getters = "0.2.0" derive-new = "0.5.9" +derive-syn-parse = "0.2.0" derive_more = "0.99.17" duct = "0.13.7" duplicate = "1.0.0" diff --git a/crates/voicevox_core/src/infer.rs b/crates/voicevox_core/src/infer.rs index 0dc322049..e827ddd7d 100644 --- a/crates/voicevox_core/src/infer.rs +++ b/crates/voicevox_core/src/infer.rs @@ -3,7 +3,7 @@ mod model_file; pub(crate) mod runtimes; pub(crate) mod session_set; -use std::{borrow::Cow, collections::BTreeSet, fmt::Debug}; +use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, ops::Index, sync::Arc}; use derive_new::new; use duplicate::duplicate_item; @@ -51,6 +51,7 @@ pub(crate) trait InferenceRuntime: 'static { /// 共に扱われるべき推論操作の集合を示す。 pub(crate) trait InferenceDomain: Sized { type Operation: InferenceOperation; + type Manifest: Index>; /// 対応する`StyleType`。 /// diff --git a/crates/voicevox_core/src/infer/domains.rs b/crates/voicevox_core/src/infer/domains.rs index 72e1e0886..54589b6e4 100644 --- a/crates/voicevox_core/src/infer/domains.rs +++ b/crates/voicevox_core/src/infer/domains.rs @@ -63,3 +63,12 @@ pub(crate) trait InferenceDomainMapValues { impl InferenceDomainMapValues for (T,) { type Talk = T; } + +macro_rules! inference_domain_map_values { + (for<$arg:ident> $body:ty) => { + (::macros::substitute_type!( + $body where $arg = crate::infer::domains::TalkDomain as crate::infer::InferenceDomain + ),) + }; +} +pub(crate) use inference_domain_map_values; diff --git a/crates/voicevox_core/src/infer/domains/talk.rs b/crates/voicevox_core/src/infer/domains/talk.rs index b2470c124..b7f7c1470 100644 --- a/crates/voicevox_core/src/infer/domains/talk.rs +++ b/crates/voicevox_core/src/infer/domains/talk.rs @@ -4,7 +4,7 @@ use enum_map::Enum; use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature}; use ndarray::{Array0, Array1, Array2}; -use crate::StyleType; +use crate::{manifest::TalkManifest, StyleType}; use super::super::{ InferenceDomain, InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor, @@ -14,6 +14,7 @@ pub(crate) enum TalkDomain {} impl InferenceDomain for TalkDomain { type Operation = TalkOperation; + type Manifest = TalkManifest; fn style_types() -> &'static BTreeSet { static STYLE_TYPES: LazyLock> = diff --git a/crates/voicevox_core/src/manifest.rs b/crates/voicevox_core/src/manifest.rs index 203fc76a9..465d64314 100644 --- a/crates/voicevox_core/src/manifest.rs +++ b/crates/voicevox_core/src/manifest.rs @@ -12,7 +12,7 @@ use serde::{de, Deserialize, Deserializer, Serialize}; use serde_with::{serde_as, DisplayFromStr}; use crate::{ - infer::domains::{InferenceDomainMap, TalkOperation}, + infer::domains::{inference_domain_map_values, InferenceDomainMap, TalkOperation}, StyleId, VoiceModelId, }; @@ -79,7 +79,7 @@ pub struct Manifest { domains: InferenceDomainMap, } -pub(crate) type ManifestDomains = (Option,); +pub(crate) type ManifestDomains = inference_domain_map_values!(for Option); #[derive(Deserialize, IndexForFields)] #[cfg_attr(test, derive(Default))] diff --git a/crates/voicevox_core/src/status.rs b/crates/voicevox_core/src/status.rs index 95a1f6c0c..44c7fe22e 100644 --- a/crates/voicevox_core/src/status.rs +++ b/crates/voicevox_core/src/status.rs @@ -9,7 +9,7 @@ use itertools::iproduct; use crate::{ error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult}, infer::{ - domains::{InferenceDomainMap, TalkDomain, TalkOperation}, + domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain}, session_set::{InferenceSessionCell, InferenceSessionSet}, InferenceDomain, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions, InferenceSignature, @@ -338,10 +338,11 @@ impl InferenceDomainMap { } } -type SessionOptionsByDomain = (EnumMap,); +type SessionOptionsByDomain = + inference_domain_map_values!(for EnumMap); type SessionSetsWithInnerVoiceIdsByDomain = - (Option<(StyleIdToInnerVoiceId, InferenceSessionSet)>,); + inference_domain_map_values!(for Option<(StyleIdToInnerVoiceId, InferenceSessionSet)>); #[cfg(test)] mod tests { diff --git a/crates/voicevox_core/src/voice_model.rs b/crates/voicevox_core/src/voice_model.rs index 7f39b22cb..e6b08e634 100644 --- a/crates/voicevox_core/src/voice_model.rs +++ b/crates/voicevox_core/src/voice_model.rs @@ -23,10 +23,10 @@ use crate::{ asyncs::{Async, Mutex as _}, error::{LoadModelError, LoadModelErrorKind, LoadModelResult}, infer::{ - domains::{InferenceDomainMap, TalkDomain, TalkOperation}, + domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain, TalkOperation}, InferenceDomain, }, - manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId, TalkManifest}, + manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId}, SpeakerMeta, StyleMeta, StyleType, VoiceModelMeta, }; @@ -35,8 +35,9 @@ use crate::{ /// [`VoiceModelId`]: VoiceModelId pub type RawVoiceModelId = Uuid; -pub(crate) type ModelBytesWithInnerVoiceIdsByDomain = - (Option<(StyleIdToInnerVoiceId, EnumMap>)>,); +pub(crate) type ModelBytesWithInnerVoiceIdsByDomain = inference_domain_map_values!( + for Option<(StyleIdToInnerVoiceId, EnumMap>)> +); /// 音声モデルID。 #[derive( @@ -251,7 +252,7 @@ impl Inner { } type InferenceModelEntries<'manifest> = - (Option>,); + inference_domain_map_values!(for Option>); struct InferenceModelEntry { indices: EnumMap, diff --git a/crates/voicevox_core_macros/Cargo.toml b/crates/voicevox_core_macros/Cargo.toml index 1be131d65..f0613f291 100644 --- a/crates/voicevox_core_macros/Cargo.toml +++ b/crates/voicevox_core_macros/Cargo.toml @@ -10,10 +10,11 @@ name = "macros" proc-macro = true [dependencies] +derive-syn-parse.workspace = true indexmap.workspace = true proc-macro2.workspace = true quote.workspace = true -syn = { workspace = true, features = ["extra-traits", "full"] } +syn = { workspace = true, features = ["extra-traits", "full", "visit-mut"] } [lints.rust] unsafe_code = "forbid" diff --git a/crates/voicevox_core_macros/src/inference_domains.rs b/crates/voicevox_core_macros/src/inference_domains.rs new file mode 100644 index 000000000..18b67fa95 --- /dev/null +++ b/crates/voicevox_core_macros/src/inference_domains.rs @@ -0,0 +1,83 @@ +use derive_syn_parse::Parse; +use quote::ToTokens as _; +use syn::{ + parse_quote, + visit_mut::{self, VisitMut}, + Path, PathArguments, PathSegment, Token, Type, TypePath, +}; + +pub(crate) fn substitute_type(input: Substitution) -> syn::Result { + let Substitution { + mut body, + arg, + replacement, + replacement_as, + .. + } = input; + + Substitute { + arg, + replacement, + replacement_as, + } + .visit_type_mut(&mut body); + + return Ok(body.to_token_stream()); + + struct Substitute { + arg: syn::Ident, + replacement: Path, + replacement_as: Path, + } + + impl VisitMut for Substitute { + fn visit_type_mut(&mut self, i: &mut Type) { + visit_mut::visit_type_mut(self, i); + + let Type::Path(TypePath { + qself: None, + path: + Path { + leading_colon: None, + segments, + }, + }) = i + else { + return; + }; + + match &mut *segments.iter_mut().collect::>() { + [PathSegment { + ident, + arguments: PathArguments::None, + }] if *ident == self.arg => { + let replacement = self.replacement.clone(); + *i = parse_quote!(#replacement); + } + [PathSegment { + ident: ident1, + arguments: PathArguments::None, + }, seg] + if *ident1 == self.arg => + { + let replacement = self.replacement.clone(); + let replacement_as = self.replacement_as.clone(); + *i = parse_quote!(<#replacement as #replacement_as>::#seg); + } + _ => {} + } + } + } +} + +/// `$body:ty where $arg:ident = $replacement:path as $replacement_as:path` +#[derive(Parse)] +pub(crate) struct Substitution { + body: Type, + _where_token: Token![where], + arg: syn::Ident, + _eq_token: Token![=], + replacement: Path, + _as_token: Token![as], + replacement_as: Path, +} diff --git a/crates/voicevox_core_macros/src/lib.rs b/crates/voicevox_core_macros/src/lib.rs index ff0b83037..933d51373 100644 --- a/crates/voicevox_core_macros/src/lib.rs +++ b/crates/voicevox_core_macros/src/lib.rs @@ -2,6 +2,7 @@ mod extract; mod inference_domain; +mod inference_domains; mod manifest; use syn::parse_macro_input; @@ -131,6 +132,27 @@ pub fn derive_index_for_fields(input: proc_macro::TokenStream) -> proc_macro::To from_syn(manifest::derive_index_for_fields(input)) } +/// # Example +/// +/// ``` +/// type ManifestDomains = +/// (substitute_type!(Option where D = TalkDomain as InferenceDomain),); +/// ``` +/// +/// ↓ +/// +/// ``` +/// type ManifestDomains = (Option<::Manifest>,); +/// // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +/// // T ← +/// ``` +#[cfg(not(doctest))] +#[proc_macro] +pub fn substitute_type(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input); + from_syn(inference_domains::substitute_type(input)) +} + fn from_syn(result: syn::Result) -> proc_macro::TokenStream { result.unwrap_or_else(|e| e.to_compile_error()).into() }