diff --git a/src/vak/models/get.py b/src/vak/models/get.py index be3603455..b9a137cf8 100644 --- a/src/vak/models/get.py +++ b/src/vak/models/get.py @@ -1,6 +1,7 @@ """Function that gets an instance of a model, given its name and a configuration as a dict.""" from __future__ import annotations +import inspect from typing import Callable from . import registry @@ -52,20 +53,29 @@ def get(name: str, f"Valid model names are: {registry.MODEL_NAMES}" ) from e - # still need to special case model logic here - if name in ('TweetyNet', 'TeenyTweetyNet', 'ED_TCN'): - num_input_channels = input_shape[-3] - num_freqbins = input_shape[-2] - config["network"].update( - num_classes=num_classes, - num_input_channels=num_input_channels, - num_freqbins=num_freqbins - ) - else: - model_names = list(all_models_dict.keys()) - raise ValueError( - f"Invalid model name: '{name}'.\nValid model names are: {model_names}" + model_family = registry.MODEL_FAMILY_FROM_NAME[name] + + if model_family == 'FrameClassificationModel': + # still need to special case model logic here + net_init_params = list( + inspect.signature( + model_class.definition.network.__init__ + ).parameters.keys() ) + if ('num_input_channels' in net_init_params) and ('num_freqbins' in net_init_params): + num_input_channels = input_shape[-3] + num_freqbins = input_shape[-2] + config["network"].update( + num_classes=num_classes, + num_input_channels=num_input_channels, + num_freqbins=num_freqbins + ) + else: + raise ValueError( + f"Detected that model with name '{name}' was family '{model_family}', but " + f"unable to determine network init arguments for model. Currently all models " + f"in this family must have networks with parameters ``num_input_channels`` and ``num_freqbins``" + ) model = model_class.from_config(config=config, labelmap=labelmap, post_tfm=post_tfm)