diff --git a/keras_cv/models/task.py b/keras_cv/models/task.py index ee2ed6ef1c..83684ebaec 100644 --- a/keras_cv/models/task.py +++ b/keras_cv/models/task.py @@ -127,13 +127,12 @@ def from_preset( ) return cls(backbone=backbone, **kwargs) - if input_shape is not None: - kwargs.update({"input_shape": input_shape}) # Task case. return load_from_preset( preset, load_weights=load_weights, + input_shape=input_shape, config_overrides=kwargs, ) diff --git a/keras_cv/utils/preset_utils.py b/keras_cv/utils/preset_utils.py index 3977e2bc9d..7b382b98c9 100644 --- a/keras_cv/utils/preset_utils.py +++ b/keras_cv/utils/preset_utils.py @@ -126,6 +126,7 @@ def save_to_preset( def load_from_preset( preset, load_weights=None, + input_shape=None, config_file="config.json", config_overrides={}, ): @@ -136,6 +137,8 @@ def load_from_preset( config = json.load(config_file) config["config"] = {**config["config"], **config_overrides} layer = keras.saving.deserialize_keras_object(config) + if input_shape is not None: + layer.build(input_shape) # Check load_weights flag does not violate preset config. if load_weights is True and config["weights"] is None: