Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: InferenceDomainMapValuesのインスタンスをマクロで作る #852

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 14 additions & 68 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const_format = "0.2.33"
cstr = "0.2.12" # https://github.com/dtolnay/syn/issues/1502
derive-getters = "0.2.0"
derive-new = "0.5.9"
derive-syn-parse = "0.2.0"
derive_more = "0.99.17"
duct = "0.13.7"
duplicate = "1.0.0"
Expand Down
3 changes: 2 additions & 1 deletion crates/voicevox_core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod model_file;
pub(crate) mod runtimes;
pub(crate) mod session_set;

use std::{borrow::Cow, collections::BTreeSet, fmt::Debug};
use std::{borrow::Cow, collections::BTreeSet, fmt::Debug, ops::Index, sync::Arc};

use derive_new::new;
use duplicate::duplicate_item;
Expand Down Expand Up @@ -51,6 +51,7 @@ pub(crate) trait InferenceRuntime: 'static {
/// 共に扱われるべき推論操作の集合を示す。
pub(crate) trait InferenceDomain: Sized {
type Operation: InferenceOperation;
type Manifest: Index<Self::Operation, Output = Arc<str>>;

/// 対応する`StyleType`。
///
Expand Down
9 changes: 9 additions & 0 deletions crates/voicevox_core/src/infer/domains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ pub(crate) trait InferenceDomainMapValues {
impl<T> InferenceDomainMapValues for (T,) {
type Talk = T;
}

macro_rules! inference_domain_map_values {
(for<$arg:ident> $body:ty) => {
(::macros::substitute_type!(
$body where $arg = crate::infer::domains::TalkDomain as crate::infer::InferenceDomain
),)
};
}
pub(crate) use inference_domain_map_values;
3 changes: 2 additions & 1 deletion crates/voicevox_core/src/infer/domains/talk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use enum_map::Enum;
use macros::{InferenceInputSignature, InferenceOperation, InferenceOutputSignature};
use ndarray::{Array0, Array1, Array2};

use crate::StyleType;
use crate::{manifest::TalkManifest, StyleType};

use super::super::{
InferenceDomain, InferenceInputSignature as _, InferenceOutputSignature as _, OutputTensor,
Expand All @@ -14,6 +14,7 @@ pub(crate) enum TalkDomain {}

impl InferenceDomain for TalkDomain {
type Operation = TalkOperation;
type Manifest = TalkManifest;

fn style_types() -> &'static BTreeSet<StyleType> {
static STYLE_TYPES: LazyLock<BTreeSet<StyleType>> =
Expand Down
4 changes: 2 additions & 2 deletions crates/voicevox_core/src/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use serde::{de, Deserialize, Deserializer, Serialize};
use serde_with::{serde_as, DisplayFromStr};

use crate::{
infer::domains::{InferenceDomainMap, TalkOperation},
infer::domains::{inference_domain_map_values, InferenceDomainMap, TalkOperation},
StyleId, VoiceModelId,
};

Expand Down Expand Up @@ -79,7 +79,7 @@ pub struct Manifest {
domains: InferenceDomainMap<ManifestDomains>,
}

pub(crate) type ManifestDomains = (Option<TalkManifest>,);
pub(crate) type ManifestDomains = inference_domain_map_values!(for<D> Option<D::Manifest>);

#[derive(Deserialize, IndexForFields)]
#[cfg_attr(test, derive(Default))]
Expand Down
7 changes: 4 additions & 3 deletions crates/voicevox_core/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use itertools::iproduct;
use crate::{
error::{ErrorRepr, LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::{
domains::{InferenceDomainMap, TalkDomain, TalkOperation},
domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain},
session_set::{InferenceSessionCell, InferenceSessionSet},
InferenceDomain, InferenceInputSignature, InferenceRuntime, InferenceSessionOptions,
InferenceSignature,
Expand Down Expand Up @@ -338,10 +338,11 @@ impl InferenceDomainMap<ModelBytesWithInnerVoiceIdsByDomain> {
}
}

type SessionOptionsByDomain = (EnumMap<TalkOperation, InferenceSessionOptions>,);
type SessionOptionsByDomain =
inference_domain_map_values!(for<D> EnumMap<D::Operation, InferenceSessionOptions>);

type SessionSetsWithInnerVoiceIdsByDomain<R> =
(Option<(StyleIdToInnerVoiceId, InferenceSessionSet<R, TalkDomain>)>,);
inference_domain_map_values!(for<D> Option<(StyleIdToInnerVoiceId, InferenceSessionSet<R, D>)>);

#[cfg(test)]
mod tests {
Expand Down
11 changes: 6 additions & 5 deletions crates/voicevox_core/src/voice_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ use crate::{
asyncs::{Async, Mutex as _},
error::{LoadModelError, LoadModelErrorKind, LoadModelResult},
infer::{
domains::{InferenceDomainMap, TalkDomain, TalkOperation},
domains::{inference_domain_map_values, InferenceDomainMap, TalkDomain, TalkOperation},
InferenceDomain,
},
manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId, TalkManifest},
manifest::{Manifest, ManifestDomains, StyleIdToInnerVoiceId},
SpeakerMeta, StyleMeta, StyleType, VoiceModelMeta,
};

Expand All @@ -35,8 +35,9 @@ use crate::{
/// [`VoiceModelId`]: VoiceModelId
pub type RawVoiceModelId = Uuid;

pub(crate) type ModelBytesWithInnerVoiceIdsByDomain =
(Option<(StyleIdToInnerVoiceId, EnumMap<TalkOperation, Vec<u8>>)>,);
pub(crate) type ModelBytesWithInnerVoiceIdsByDomain = inference_domain_map_values!(
for<D> Option<(StyleIdToInnerVoiceId, EnumMap<D::Operation, Vec<u8>>)>
);

/// 音声モデルID。
#[derive(
Expand Down Expand Up @@ -251,7 +252,7 @@ impl<A: Async> Inner<A> {
}

type InferenceModelEntries<'manifest> =
(Option<InferenceModelEntry<TalkDomain, &'manifest TalkManifest>>,);
inference_domain_map_values!(for<D> Option<InferenceModelEntry<D, &'manifest D::Manifest>>);

struct InferenceModelEntry<D: InferenceDomain, M> {
indices: EnumMap<D::Operation, usize>,
Expand Down
3 changes: 2 additions & 1 deletion crates/voicevox_core_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ name = "macros"
proc-macro = true

[dependencies]
derive-syn-parse.workspace = true
indexmap.workspace = true
proc-macro2.workspace = true
quote.workspace = true
syn = { workspace = true, features = ["extra-traits", "full"] }
syn = { workspace = true, features = ["extra-traits", "full", "visit-mut"] }

[lints.rust]
unsafe_code = "forbid"
Expand Down
83 changes: 83 additions & 0 deletions crates/voicevox_core_macros/src/inference_domains.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use derive_syn_parse::Parse;
use quote::ToTokens as _;
use syn::{
parse_quote,
visit_mut::{self, VisitMut},
Path, PathArguments, PathSegment, Token, Type, TypePath,
};

pub(crate) fn substitute_type(input: Substitution) -> syn::Result<proc_macro2::TokenStream> {
let Substitution {
mut body,
arg,
replacement,
replacement_as,
..
} = input;

Substitute {
arg,
replacement,
replacement_as,
}
.visit_type_mut(&mut body);

return Ok(body.to_token_stream());

struct Substitute {
arg: syn::Ident,
replacement: Path,
replacement_as: Path,
}

impl VisitMut for Substitute {
fn visit_type_mut(&mut self, i: &mut Type) {
visit_mut::visit_type_mut(self, i);

let Type::Path(TypePath {
qself: None,
path:
Path {
leading_colon: None,
segments,
},
}) = i
else {
return;
};

match &mut *segments.iter_mut().collect::<Vec<_>>() {
[PathSegment {
ident,
arguments: PathArguments::None,
}] if *ident == self.arg => {
let replacement = self.replacement.clone();
*i = parse_quote!(#replacement);
}
[PathSegment {
ident: ident1,
arguments: PathArguments::None,
}, seg]
if *ident1 == self.arg =>
{
let replacement = self.replacement.clone();
let replacement_as = self.replacement_as.clone();
*i = parse_quote!(<#replacement as #replacement_as>::#seg);
}
_ => {}
}
}
}
}

/// `$body:ty where $arg:ident = $replacement:path as $replacement_as:path`
#[derive(Parse)]
pub(crate) struct Substitution {
body: Type,
_where_token: Token![where],
arg: syn::Ident,
_eq_token: Token![=],
replacement: Path,
_as_token: Token![as],
replacement_as: Path,
}
Loading
Loading