Skip to content

Commit

Permalink
Fix input shape build on load_preset
Browse files Browse the repository at this point in the history
  • Loading branch information
nkovela1 committed Dec 8, 2023
1 parent 876dcb7 commit 646ed18
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 1 addition & 2 deletions keras_cv/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,12 @@ def from_preset(
)
return cls(backbone=backbone, **kwargs)

if input_shape is not None:
kwargs.update({"input_shape": input_shape})

# Task case.
return load_from_preset(
preset,
load_weights=load_weights,
input_shape=input_shape,
config_overrides=kwargs,
)

Expand Down
3 changes: 3 additions & 0 deletions keras_cv/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def save_to_preset(
def load_from_preset(
preset,
load_weights=None,
input_shape=None,
config_file="config.json",
config_overrides={},
):
Expand All @@ -136,6 +137,8 @@ def load_from_preset(
config = json.load(config_file)
config["config"] = {**config["config"], **config_overrides}
layer = keras.saving.deserialize_keras_object(config)
if input_shape is not None:
layer.build(input_shape)

# Check load_weights flag does not violate preset config.
if load_weights is True and config["weights"] is None:
Expand Down

0 comments on commit 646ed18

Please sign in to comment.