diff --git a/lavis/common/config.py b/lavis/common/config.py index 2264b0578..8916cb3a2 100644 --- a/lavis/common/config.py +++ b/lavis/common/config.py @@ -61,7 +61,10 @@ def build_model_config(config, **kwargs): model_cls = registry.get_model_class(model.arch) assert model_cls is not None, f"Model '{model.arch}' has not been registered." - model_type = kwargs.get("model.model_type", None) + model_type = None + if 'model' in kwargs: + if "model_type" in kwargs['model']: + model_type = kwargs['model']["model_type"] if not model_type: model_type = model.get("model_type", None) # else use the model type selected by user.