From 411a4bf64e35c3254f676da0ee3faec49a26bd3f Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Fri, 20 Sep 2024 18:14:25 -0700 Subject: [PATCH] Expunge include_rescaling from backbones Since our models include built in preprocessing, it is much clearer for this rescaling to happen in the preprocessing layers. --- .../preprocessing/resizing_image_converter.py | 62 +++++++++++++++++-- .../resizing_image_converter_test.py | 51 ++++++++++++--- .../csp_darknet/csp_darknet_backbone.py | 12 +--- .../csp_darknet_image_classifier.py | 1 - .../csp_darknet_image_classifier_test.py | 1 - .../src/models/densenet/densenet_backbone.py | 12 +--- .../densenet/densenet_image_classifier.py | 1 - .../densenet_image_classifier_test.py | 1 - .../efficientnet/efficientnet_backbone.py | 17 +---- .../efficientnet_backbone_test.py | 8 --- .../mix_transformer_backbone.py | 12 +--- .../mix_transformer_backbone_test.py | 1 - .../mix_transformer_classifier.py | 1 - .../mix_transformer_classifier_test.py | 1 - .../models/mobilenet/mobilenet_backbone.py | 17 +---- .../mobilenet/mobilenet_backbone_test.py | 1 - .../mobilenet/mobilenet_image_classifier.py | 1 - .../mobilenet_image_classifier_test.py | 1 - .../src/models/pali_gemma/pali_gemma_vit.py | 3 + .../src/models/resnet/resnet_backbone.py | 22 +------ .../models/resnet/resnet_image_classifier.py | 1 - .../resnet/resnet_image_classifier_test.py | 3 +- keras_hub/src/models/resnet/resnet_presets.py | 12 ++-- keras_hub/src/models/vgg/vgg_backbone.py | 8 --- keras_hub/src/models/vgg/vgg_backbone_test.py | 1 - .../src/models/vgg/vgg_image_classifier.py | 1 - .../models/vgg/vgg_image_classifier_test.py | 1 - .../src/models/vit_det/vit_det_backbone.py | 9 --- .../models/vit_det/vit_det_backbone_test.py | 1 - keras_hub/src/utils/timm/convert_resnet.py | 8 --- keras_hub/src/utils/timm/preset_loader.py | 17 ++++- .../convert_resnet_checkpoints.py | 23 +++++-- 32 files changed, 154 insertions(+), 157 deletions(-) diff --git a/keras_hub/src/layers/preprocessing/resizing_image_converter.py b/keras_hub/src/layers/preprocessing/resizing_image_converter.py index cfce694b65..f5e044c886 100644 --- a/keras_hub/src/layers/preprocessing/resizing_image_converter.py +++ b/keras_hub/src/layers/preprocessing/resizing_image_converter.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import keras +from keras import ops from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.utils.keras_utils import standardize_data_format from keras_hub.src.utils.tensor_utils import preprocessing_function @@ -23,13 +25,23 @@ class ResizingImageConverter(ImageConverter): """An `ImageConverter` that simply resizes the input image. The `ResizingImageConverter` is a subclass of `ImageConverter` for models - that simply need to resize image tensors before using them for modeling. - The layer will take as input a raw image tensor (batched or unbatched) in the - channels last or channels first format, and output a resize tensor. + that need to resize (and optionally rescale) image tensors before using them + for modeling. The layer will take as input a raw image tensor (batched or + unbatched) in the channels last or channels first format, and output a + resize tensor. Args: - height: Integer, the height of the output shape. - width: Integer, the width of the output shape. + height: int, the height of the output shape. + width: int, the width of the output shape. + scale: float or `None`. If set, the image we be rescaled with a + `keras.layers.Rescaling` layer, multiplying the image by this + scale. + mean: tuples of floats per channel or `None`. If set, the image will be + normalized per channel by subtracting mean. + If set, also set `variance`. + variance: tuples of floats per channel or `None`. If set, the image will + be normalized per channel by dividing by `sqrt(variance)`. + If set, also set `mean`. crop_to_aspect_ratio: If `True`, resize the images without aspect ratio distortion. When the original aspect ratio differs from the target aspect ratio, the output image will be @@ -64,6 +76,9 @@ def __init__( self, height, width, + scale=None, + mean=None, + variance=None, crop_to_aspect_ratio=True, interpolation="bilinear", data_format=None, @@ -78,7 +93,26 @@ def __init__( crop_to_aspect_ratio=crop_to_aspect_ratio, interpolation=interpolation, data_format=data_format, + dtype=self.dtype_policy, + name="resizing", ) + if scale is not None: + self.rescaling = keras.layers.Rescaling( + scale=scale, + dtype=self.dtype_policy, + name="rescaling", + ) + else: + self.rescaling = None + if (mean is not None) != (variance is not None): + raise ValueError( + "Both `mean` and `variance` should be set or `None`. Received " + f"`mean={mean}`, `variance={variance}`." + ) + self.scale = scale + self.mean = mean + self.variance = variance + self.data_format = standardize_data_format(data_format) def image_size(self): """Returns the preprocessed size of a single image.""" @@ -86,7 +120,20 @@ def image_size(self): @preprocessing_function def call(self, inputs): - return self.resizing(inputs) + x = self.resizing(inputs) + if self.rescaling: + x = self.rescaling(x) + if self.mean is not None: + # Avoid `layers.Normalization` so this works batched and unbatched. + channels_first = self.data_format == "channels_first" + if len(ops.shape(inputs)) == 3: + broadcast_dims = (1, 2) if channels_first else (0, 1) + else: + broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2) + mean = ops.expand_dims(ops.array(self.mean), broadcast_dims) + std = ops.expand_dims(ops.sqrt(self.variance), broadcast_dims) + x = (x - mean) / std + return x def get_config(self): config = super().get_config() @@ -96,6 +143,9 @@ def get_config(self): "width": self.resizing.width, "interpolation": self.resizing.interpolation, "crop_to_aspect_ratio": self.resizing.crop_to_aspect_ratio, + "scale": self.scale, + "mean": self.mean, + "variance": self.variance, } ) return config diff --git a/keras_hub/src/layers/preprocessing/resizing_image_converter_test.py b/keras_hub/src/layers/preprocessing/resizing_image_converter_test.py index 857cf578a8..b54b0a0d94 100644 --- a/keras_hub/src/layers/preprocessing/resizing_image_converter_test.py +++ b/keras_hub/src/layers/preprocessing/resizing_image_converter_test.py @@ -22,22 +22,57 @@ class ResizingImageConverterTest(TestCase): + def test_resize_simple(self): + converter = ResizingImageConverter(height=4, width=4) + inputs = np.ones((10, 10, 3)) + outputs = converter(inputs) + self.assertAllClose(outputs, ops.ones((4, 4, 3))) + def test_resize_one(self): - converter = ResizingImageConverter(22, 22) - test_image = np.random.rand(10, 10, 3) * 255 - shape = ops.shape(converter(test_image)) - self.assertEqual(shape, (22, 22, 3)) + converter = ResizingImageConverter( + height=4, + width=4, + mean=(0.5, 0.7, 0.3), + variance=(0.25, 0.1, 0.5), + scale=1 / 255.0, + ) + inputs = np.ones((10, 10, 3)) * 128 + outputs = converter(inputs) + self.assertEqual(ops.shape(outputs), (4, 4, 3)) + self.assertAllClose(outputs[:, :, 0], np.ones((4, 4)) * 0.003922) + self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * -0.626255) + self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.285616) def test_resize_batch(self): - converter = ResizingImageConverter(12, 12) - test_batch = np.random.rand(4, 10, 20, 3) * 255 - shape = ops.shape(converter(test_batch)) - self.assertEqual(shape, (4, 12, 12, 3)) + converter = ResizingImageConverter( + height=4, + width=4, + mean=(0.5, 0.7, 0.3), + variance=(0.25, 0.1, 0.5), + scale=1 / 255.0, + ) + inputs = np.ones((2, 10, 10, 3)) * 128 + outputs = converter(inputs) + self.assertEqual(ops.shape(outputs), (2, 4, 4, 3)) + self.assertAllClose(outputs[:, :, :, 0], np.ones((2, 4, 4)) * 0.003922) + self.assertAllClose(outputs[:, :, :, 1], np.ones((2, 4, 4)) * -0.626255) + self.assertAllClose(outputs[:, :, :, 2], np.ones((2, 4, 4)) * 0.285616) + + def test_errors(self): + with self.assertRaises(ValueError): + ResizingImageConverter( + height=4, + width=4, + mean=(0.5, 0.7, 0.3), + ) def test_config(self): converter = ResizingImageConverter( width=12, height=20, + mean=(0.5, 0.7, 0.3), + variance=(0.25, 0.1, 0.5), + scale=1 / 255.0, crop_to_aspect_ratio=False, interpolation="nearest", ) diff --git a/keras_hub/src/models/csp_darknet/csp_darknet_backbone.py b/keras_hub/src/models/csp_darknet/csp_darknet_backbone.py index ab33823405..bf12c9d7f3 100644 --- a/keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +++ b/keras_hub/src/models/csp_darknet/csp_darknet_backbone.py @@ -31,9 +31,6 @@ class CSPDarkNetBackbone(FeaturePyramidBackbone): level in the model. stackwise_depth: A list of ints, the depth for each dark level in the model. - include_rescaling: boolean. If `True`, rescale the input using - `Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to - `True`. block_type: str. One of `"basic_block"` or `"depthwise_block"`. Use `"depthwise_block"` for depthwise conv block `"basic_block"` for basic conv block. @@ -55,7 +52,6 @@ class CSPDarkNetBackbone(FeaturePyramidBackbone): model = keras_hub.models.CSPDarkNetBackbone( stackwise_num_filters=[128, 256, 512, 1024], stackwise_depth=[3, 9, 9, 3], - include_rescaling=False, ) model(input_data) ``` @@ -65,7 +61,6 @@ def __init__( self, stackwise_num_filters, stackwise_depth, - include_rescaling=True, block_type="basic_block", image_shape=(None, None, 3), **kwargs, @@ -82,10 +77,7 @@ def __init__( base_channels = stackwise_num_filters[0] // 2 image_input = layers.Input(shape=image_shape) - x = image_input - if include_rescaling: - x = layers.Rescaling(scale=1 / 255.0)(x) - + x = image_input # Intermediate result. x = apply_focus(channel_axis, name="stem_focus")(x) x = apply_darknet_conv_block( base_channels, @@ -130,7 +122,6 @@ def __init__( # === Config === self.stackwise_num_filters = stackwise_num_filters self.stackwise_depth = stackwise_depth - self.include_rescaling = include_rescaling self.block_type = block_type self.image_shape = image_shape self.pyramid_outputs = pyramid_outputs @@ -141,7 +132,6 @@ def get_config(self): { "stackwise_num_filters": self.stackwise_num_filters, "stackwise_depth": self.stackwise_depth, - "include_rescaling": self.include_rescaling, "block_type": self.block_type, "image_shape": self.image_shape, } diff --git a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py b/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py index 28069e7d9f..4a7d4719e3 100644 --- a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +++ b/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py @@ -76,7 +76,6 @@ class CSPDarkNetImageClassifier(ImageClassifier): backbone = keras_hub.models.CSPDarkNetBackbone( stackwise_num_filters=[128, 256, 512, 1024], stackwise_depth=[3, 9, 9, 3], - include_rescaling=False, block_type="basic_block", image_shape = (224, 224, 3), ) diff --git a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier_test.py b/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier_test.py index f3735be2fe..c67685b763 100644 --- a/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier_test.py +++ b/keras_hub/src/models/csp_darknet/csp_darknet_image_classifier_test.py @@ -31,7 +31,6 @@ def setUp(self): self.backbone = CSPDarkNetBackbone( stackwise_num_filters=[2, 16, 16], stackwise_depth=[1, 3, 3, 1], - include_rescaling=False, block_type="basic_block", image_shape=(16, 16, 3), ) diff --git a/keras_hub/src/models/densenet/densenet_backbone.py b/keras_hub/src/models/densenet/densenet_backbone.py index 2b840011d1..8778f27d76 100644 --- a/keras_hub/src/models/densenet/densenet_backbone.py +++ b/keras_hub/src/models/densenet/densenet_backbone.py @@ -31,9 +31,6 @@ class DenseNetBackbone(FeaturePyramidBackbone): Args: stackwise_num_repeats: list of ints, number of repeated convolutional blocks per dense block. - include_rescaling: bool, whether to rescale the inputs. If set - to `True`, inputs will be passed through a `Rescaling(1/255.0)` - layer. Defaults to `True`. image_shape: optional shape tuple, defaults to (None, None, 3). compression_ratio: float, compression rate at transition layers, defaults to 0.5. @@ -51,7 +48,6 @@ class DenseNetBackbone(FeaturePyramidBackbone): # Randomly initialized backbone with a custom config model = keras_hub.models.DenseNetBackbone( stackwise_num_repeats=[6, 12, 24, 16], - include_rescaling=False, ) model(input_data) ``` @@ -60,7 +56,6 @@ class DenseNetBackbone(FeaturePyramidBackbone): def __init__( self, stackwise_num_repeats, - include_rescaling=True, image_shape=(None, None, 3), compression_ratio=0.5, growth_rate=32, @@ -71,10 +66,7 @@ def __init__( channel_axis = -1 if data_format == "channels_last" else 1 image_input = keras.layers.Input(shape=image_shape) - x = image_input - if include_rescaling: - x = keras.layers.Rescaling(1 / 255.0)(x) - + x = image_input # Intermediate result. x = keras.layers.Conv2D( 64, 7, @@ -124,7 +116,6 @@ def __init__( # === Config === self.stackwise_num_repeats = stackwise_num_repeats - self.include_rescaling = include_rescaling self.compression_ratio = compression_ratio self.growth_rate = growth_rate self.image_shape = image_shape @@ -135,7 +126,6 @@ def get_config(self): config.update( { "stackwise_num_repeats": self.stackwise_num_repeats, - "include_rescaling": self.include_rescaling, "compression_ratio": self.compression_ratio, "growth_rate": self.growth_rate, "image_shape": self.image_shape, diff --git a/keras_hub/src/models/densenet/densenet_image_classifier.py b/keras_hub/src/models/densenet/densenet_image_classifier.py index 6bd7bbbaa1..c727106f42 100644 --- a/keras_hub/src/models/densenet/densenet_image_classifier.py +++ b/keras_hub/src/models/densenet/densenet_image_classifier.py @@ -74,7 +74,6 @@ class DenseNetImageClassifier(ImageClassifier): backbone = keras_hub.models.DenseNetBackbone( stackwise_num_filters=[128, 256, 512, 1024], stackwise_depth=[3, 9, 9, 3], - include_rescaling=False, block_type="basic_block", image_shape = (224, 224, 3), ) diff --git a/keras_hub/src/models/densenet/densenet_image_classifier_test.py b/keras_hub/src/models/densenet/densenet_image_classifier_test.py index b4bb19d35a..da3fb20d1b 100644 --- a/keras_hub/src/models/densenet/densenet_image_classifier_test.py +++ b/keras_hub/src/models/densenet/densenet_image_classifier_test.py @@ -28,7 +28,6 @@ def setUp(self): self.labels = [0, 3] self.backbone = DenseNetBackbone( stackwise_num_repeats=[6, 12, 24, 16], - include_rescaling=True, compression_ratio=0.5, growth_rate=32, image_shape=(224, 224, 3), diff --git a/keras_hub/src/models/efficientnet/efficientnet_backbone.py b/keras_hub/src/models/efficientnet/efficientnet_backbone.py index 2cb7a82f8b..405ea7bce0 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_backbone.py +++ b/keras_hub/src/models/efficientnet/efficientnet_backbone.py @@ -67,8 +67,6 @@ class EfficientNetBackbone(FeaturePyramidBackbone): MBConvBlock, but instead of using a depthwise convolution and a 1x1 output convolution blocks fused blocks use a single 3x3 convolution block. - include_rescaling: bool, whether to rescale the inputs. If set to - True, inputs will be passed through a `Rescaling(1/255.0)` layer. min_depth: integer, minimum number of filters. Can be None and ignored if use_depth_divisor_as_min_depth is set to True. include_initial_padding: bool, whether to include initial zero padding @@ -96,7 +94,6 @@ class EfficientNetBackbone(FeaturePyramidBackbone): stackwise_block_types=[["fused"] * 3 + ["unfused"] * 3], width_coefficient=1.0, depth_coefficient=1.0, - include_rescaling=False, ) images = np.ones((1, 256, 256, 3)) outputs = efficientnet.predict(images) @@ -116,7 +113,6 @@ def __init__( stackwise_squeeze_and_excite_ratios, stackwise_strides, stackwise_block_types, - include_rescaling=True, dropout=0.2, depth_divisor=8, min_depth=8, @@ -129,14 +125,9 @@ def __init__( batch_norm_momentum=0.9, **kwargs, ): - img_input = keras.layers.Input(shape=input_shape) - - x = img_input - - if include_rescaling: - # Use common rescaling strategy across keras - x = keras.layers.Rescaling(scale=1.0 / 255.0)(x) + image_input = keras.layers.Input(shape=input_shape) + x = image_input # Intermediate result. if include_initial_padding: x = keras.layers.ZeroPadding2D( padding=self._correct_pad_downsample(x, 3), @@ -282,10 +273,9 @@ def __init__( curr_pyramid_level += 1 # Create model. - super().__init__(inputs=img_input, outputs=x, **kwargs) + super().__init__(inputs=image_input, outputs=x, **kwargs) # === Config === - self.include_rescaling = include_rescaling self.width_coefficient = width_coefficient self.depth_coefficient = depth_coefficient self.dropout = dropout @@ -313,7 +303,6 @@ def get_config(self): config = super().get_config() config.update( { - "include_rescaling": self.include_rescaling, "width_coefficient": self.width_coefficient, "depth_coefficient": self.depth_coefficient, "dropout": self.dropout, diff --git a/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py b/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py index aab9f6dc69..918bc8087d 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py +++ b/keras_hub/src/models/efficientnet/efficientnet_backbone_test.py @@ -42,7 +42,6 @@ def setUp(self): "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3, "width_coefficient": 1.0, "depth_coefficient": 1.0, - "include_rescaling": False, } self.input_data = keras.ops.ones(shape=(8, 224, 224, 3)) @@ -86,7 +85,6 @@ def test_valid_call_original_v1(self): ], "width_coefficient": 1.0, "depth_coefficient": 1.0, - "include_rescaling": False, "stackwise_block_types": ["v1"] * 7, "min_depth": None, "include_initial_padding": True, @@ -98,12 +96,6 @@ def test_valid_call_original_v1(self): model = EfficientNetBackbone(**original_v1_kwargs) model(self.input_data) - def test_valid_call_with_rescaling(self): - test_kwargs = self.init_kwargs.copy() - test_kwargs["include_rescaling"] = True - model = EfficientNetBackbone(**test_kwargs) - model(self.input_data) - def test_feature_pyramid_outputs(self): backbone = EfficientNetBackbone(**self.init_kwargs) model = keras.Model( diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_backbone.py b/keras_hub/src/models/mix_transformer/mix_transformer_backbone.py index 5127bd357b..6986be7c45 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +++ b/keras_hub/src/models/mix_transformer/mix_transformer_backbone.py @@ -36,7 +36,6 @@ def __init__( end_value, patch_sizes, strides, - include_rescaling=True, image_shape=(None, None, 3), hidden_dims=None, **kwargs, @@ -60,9 +59,6 @@ def __init__( value projections. If set to > 1, a `Conv2D` layer is used to reduce the length of the sequence. end_value: The end value of the sequence. - include_rescaling: bool, whether to rescale the inputs. If set - to `True`, inputs will be passed through a `Rescaling(1/255.0)` - layer. Defaults to `True`. image_shape: optional shape tuple, defaults to (None, None, 3). hidden_dims: the embedding dims per hierarchical layer, used as the levels of the feature pyramid. @@ -123,11 +119,7 @@ def __init__( # === Functional Model === image_input = keras.layers.Input(shape=image_shape) - x = image_input - - if include_rescaling: - x = keras.layers.Rescaling(scale=1 / 255)(x) - + x = image_input # Intermediate result. pyramid_outputs = {} for i in range(num_layers): # Compute new height/width after the `proj` @@ -151,7 +143,6 @@ def __init__( # === Config === self.depths = depths - self.include_rescaling = include_rescaling self.image_shape = image_shape self.hidden_dims = hidden_dims self.pyramid_outputs = pyramid_outputs @@ -167,7 +158,6 @@ def get_config(self): config.update( { "depths": self.depths, - "include_rescaling": self.include_rescaling, "hidden_dims": self.hidden_dims, "image_shape": self.image_shape, "num_layers": self.num_layers, diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py b/keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py index 9cab12b7bb..bab58103b4 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py +++ b/keras_hub/src/models/mix_transformer/mix_transformer_backbone_test.py @@ -25,7 +25,6 @@ class MiTBackboneTest(TestCase): def setUp(self): self.init_kwargs = { "depths": [2, 2], - "include_rescaling": True, "image_shape": (16, 16, 3), "hidden_dims": [4, 8], "num_layers": 2, diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py b/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py index c6ff3fba1e..7de8aea880 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +++ b/keras_hub/src/models/mix_transformer/mix_transformer_classifier.py @@ -76,7 +76,6 @@ class MiTImageClassifier(ImageClassifier): backbone = keras_hub.models.MiTBackbone( stackwise_num_filters=[128, 256, 512, 1024], stackwise_depth=[3, 9, 9, 3], - include_rescaling=False, block_type="basic_block", image_shape = (224, 224, 3), ) diff --git a/keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py b/keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py index e17071229a..1d5d4ec444 100644 --- a/keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py +++ b/keras_hub/src/models/mix_transformer/mix_transformer_classifier_test.py @@ -30,7 +30,6 @@ def setUp(self): self.labels = [0, 3] self.backbone = MiTBackbone( depths=[2, 2, 2, 2], - include_rescaling=True, image_shape=(16, 16, 3), hidden_dims=[4, 8], num_layers=2, diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index 27072ddf37..ff83364472 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -54,9 +54,6 @@ class MobileNetBackbone(Backbone): model. 0 if dont want to add Squeeze and Excite layer. stackwise_activation: list of activation functions, for each inverted residual block in the model. - include_rescaling: bool, whether to rescale the inputs. If set to True, - inputs will be passed through a `Rescaling(scale=1 / 255)` - layer. image_shape: optional shape tuple, defaults to (224, 224, 3). depth_multiplier: float, controls the width of the network. - If `depth_multiplier` < 1.0, proportionally decreases the number @@ -92,7 +89,6 @@ class MobileNetBackbone(Backbone): stackwise_num_strides=[2, 2, 1], stackwise_se_ratio=[0.25, None, 0.25], stackwise_activation=["relu", "relu6", "hard_swish"], - include_rescaling=False, output_num_filters=1280, input_activation='hard_swish', output_activation='hard_swish', @@ -111,7 +107,6 @@ def __init__( stackwise_num_strides, stackwise_se_ratio, stackwise_activation, - include_rescaling, output_num_filters, inverted_res_block, image_shape=(224, 224, 3), @@ -126,12 +121,8 @@ def __init__( -1 if keras.config.image_data_format() == "channels_last" else 1 ) - inputs = keras.layers.Input(shape=image_shape) - x = inputs - - if include_rescaling: - x = keras.layers.Rescaling(scale=1 / 255)(x) - + image_input = keras.layers.Input(shape=image_shape) + x = image_input # Intermediate result. input_num_filters = adjust_channels(input_num_filters) x = keras.layers.Conv2D( input_num_filters, @@ -195,7 +186,7 @@ def __init__( )(x) x = keras.layers.Activation(output_activation)(x) - super().__init__(inputs=inputs, outputs=x, **kwargs) + super().__init__(inputs=image_input, outputs=x, **kwargs) # === Config === self.stackwise_expansion = stackwise_expansion @@ -204,7 +195,6 @@ def __init__( self.stackwise_num_strides = stackwise_num_strides self.stackwise_se_ratio = stackwise_se_ratio self.stackwise_activation = stackwise_activation - self.include_rescaling = include_rescaling self.depth_multiplier = depth_multiplier self.input_num_filters = input_num_filters self.output_num_filters = output_num_filters @@ -223,7 +213,6 @@ def get_config(self): "stackwise_num_strides": self.stackwise_num_strides, "stackwise_se_ratio": self.stackwise_se_ratio, "stackwise_activation": self.stackwise_activation, - "include_rescaling": self.include_rescaling, "image_shape": self.image_shape, "depth_multiplier": self.depth_multiplier, "input_num_filters": self.input_num_filters, diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py index 32d1c27c47..cf49194c5c 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py @@ -28,7 +28,6 @@ def setUp(self): "stackwise_num_strides": [2, 2, 1], "stackwise_se_ratio": [0.25, None, 0.25], "stackwise_activation": ["relu", "relu", "hard_swish"], - "include_rescaling": False, "output_num_filters": 1280, "input_activation": "hard_swish", "output_activation": "hard_swish", diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index b744e7c40f..407feac11a 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -56,7 +56,6 @@ class MobileNetImageClassifier(ImageClassifier): stackwise_stride = [2, 2, 1], stackwise_se_ratio = [ 0.25, None, 0.25], stackwise_activation = ["relu", "relu", "hard_swish"], - include_rescaling = False, output_filter=1280, activation="hard_swish", inverted_res_block=True, diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py index 0fbcca7675..b16d1b92af 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -33,7 +33,6 @@ def setUp(self): stackwise_num_strides=[2, 2, 1], stackwise_se_ratio=[0.25, None, 0.25], stackwise_activation=["relu", "relu", "hard_swish"], - include_rescaling=False, output_num_filters=1280, input_activation="hard_swish", output_activation="hard_swish", diff --git a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py index e9da150c08..c47507703a 100644 --- a/keras_hub/src/models/pali_gemma/pali_gemma_vit.py +++ b/keras_hub/src/models/pali_gemma/pali_gemma_vit.py @@ -476,6 +476,9 @@ def __init__( shape=(image_size, image_size, 3), name="images" ) x = image_input # Intermediate result. + # TODO we have moved this rescaling to preprocessing layers for most + # models. We should consider removing it here, though it would break + # compatibility. if include_rescaling: rescaling = keras.layers.Rescaling( scale=1.0 / 127.5, offset=-1.0, name="rescaling" diff --git a/keras_hub/src/models/resnet/resnet_backbone.py b/keras_hub/src/models/resnet/resnet_backbone.py index 7f585ba1f2..638ccb8079 100644 --- a/keras_hub/src/models/resnet/resnet_backbone.py +++ b/keras_hub/src/models/resnet/resnet_backbone.py @@ -44,9 +44,6 @@ class ResNetBackbone(FeaturePyramidBackbone): additional pooling operation rather than performing downsampling within the convolutional layers themselves. - Note that `ResNetBackbone` expects the inputs to be images with a value - range of `[0, 255]` when `include_rescaling=True`. - Args: input_conv_filters: list of ints. The number of filters of the initial convolution(s). @@ -65,9 +62,6 @@ class ResNetBackbone(FeaturePyramidBackbone): variants. use_pre_activation: boolean. Whether to use pre-activation or not. `True` for ResNetV2, `False` for ResNet. - include_rescaling: boolean. If `True`, rescale the input using - `Rescaling` and `Normalization` layers. If `False`, do nothing. - Defaults to `True`. image_shape: tuple. The input shape without the batch size. Defaults to `(None, None, 3)`. pooling: `None` or str. Pooling mode for feature extraction. Defaults @@ -124,7 +118,6 @@ def __init__( stackwise_num_strides, block_type, use_pre_activation=False, - include_rescaling=True, image_shape=(None, None, 3), data_format=None, dtype=None, @@ -170,18 +163,7 @@ def __init__( # === Functional Model === image_input = layers.Input(shape=image_shape) - if include_rescaling: - x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input) - x = layers.Normalization( - axis=bn_axis, - mean=(0.485, 0.456, 0.406), - variance=(0.229**2, 0.224**2, 0.225**2), - dtype=dtype, - name="normalization", - )(x) - else: - x = image_input - + x = image_input # Intermediate result. # The padding between torch and tensorflow/jax differs when `strides>1`. # Therefore, we need to manually pad the tensor. x = layers.ZeroPadding2D( @@ -299,7 +281,6 @@ def __init__( self.stackwise_num_strides = stackwise_num_strides self.block_type = block_type self.use_pre_activation = use_pre_activation - self.include_rescaling = include_rescaling self.image_shape = image_shape self.pyramid_outputs = pyramid_outputs self.data_format = data_format @@ -315,7 +296,6 @@ def get_config(self): "stackwise_num_strides": self.stackwise_num_strides, "block_type": self.block_type, "use_pre_activation": self.use_pre_activation, - "include_rescaling": self.include_rescaling, "image_shape": self.image_shape, } ) diff --git a/keras_hub/src/models/resnet/resnet_image_classifier.py b/keras_hub/src/models/resnet/resnet_image_classifier.py index a7456cb85b..4440ef145c 100644 --- a/keras_hub/src/models/resnet/resnet_image_classifier.py +++ b/keras_hub/src/models/resnet/resnet_image_classifier.py @@ -85,7 +85,6 @@ class ResNetImageClassifier(ImageClassifier): stackwise_num_strides=[1, 2, 2], block_type="basic_block", use_pre_activation=True, - include_rescaling=False, pooling="avg", ) classifier = keras_hub.models.ResNetImageClassifier( diff --git a/keras_hub/src/models/resnet/resnet_image_classifier_test.py b/keras_hub/src/models/resnet/resnet_image_classifier_test.py index d9de3719ac..5706a7b5a6 100644 --- a/keras_hub/src/models/resnet/resnet_image_classifier_test.py +++ b/keras_hub/src/models/resnet/resnet_image_classifier_test.py @@ -34,7 +34,6 @@ def setUp(self): block_type="basic_block", use_pre_activation=True, image_shape=(16, 16, 3), - include_rescaling=False, ) self.init_kwargs = { "backbone": self.backbone, @@ -62,7 +61,7 @@ def test_head_dtype(self): @pytest.mark.large def test_smallest_preset(self): # Test that our forward pass is stable! - image_batch = self.load_test_image()[None, ...] + image_batch = self.load_test_image()[None, ...] / 255.0 self.run_preset_test( cls=ResNetImageClassifier, preset="resnet_18_imagenet", diff --git a/keras_hub/src/models/resnet/resnet_presets.py b/keras_hub/src/models/resnet/resnet_presets.py index 99e448f24d..7264558a7a 100644 --- a/keras_hub/src/models/resnet/resnet_presets.py +++ b/keras_hub/src/models/resnet/resnet_presets.py @@ -25,7 +25,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/2", + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/3", }, "resnet_50_imagenet": { "metadata": { @@ -38,7 +38,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/2", + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/3", }, "resnet_101_imagenet": { "metadata": { @@ -51,7 +51,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/2", + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/3", }, "resnet_152_imagenet": { "metadata": { @@ -64,7 +64,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/2", + "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/3", }, "resnet_v2_50_imagenet": { "metadata": { @@ -77,7 +77,7 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/2", + "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/3", }, "resnet_v2_101_imagenet": { "metadata": { @@ -90,6 +90,6 @@ "path": "resnet", "model_card": "https://arxiv.org/abs/2110.00476", }, - "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/2", + "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/3", }, } diff --git a/keras_hub/src/models/vgg/vgg_backbone.py b/keras_hub/src/models/vgg/vgg_backbone.py index 541b3600ef..771c45ce5e 100644 --- a/keras_hub/src/models/vgg/vgg_backbone.py +++ b/keras_hub/src/models/vgg/vgg_backbone.py @@ -33,8 +33,6 @@ class VGGBackbone(Backbone): stackwise_num_filters: list of ints, filter size for convolutional blocks per VGG block. For both VGG16 and VGG19 this is [ 64, 128, 256, 512, 512]. - include_rescaling: bool, whether to rescale the inputs. If set to - True, inputs will be passed through a `Rescaling(1/255.0)` layer. image_shape: tuple, optional shape tuple, defaults to (224, 224, 3). pooling: bool, Optional pooling mode for feature extraction when `include_top` is `False`. @@ -61,7 +59,6 @@ class VGGBackbone(Backbone): stackwise_num_repeats = [2, 2, 3, 3, 3], stackwise_num_filters = [64, 128, 256, 512, 512], image_shape = (224, 224, 3), - include_rescaling = False, pooling = "avg", ) model(input_data) @@ -72,7 +69,6 @@ def __init__( self, stackwise_num_repeats, stackwise_num_filters, - include_rescaling, image_shape=(224, 224, 3), pooling="avg", **kwargs, @@ -82,8 +78,6 @@ def __init__( img_input = keras.layers.Input(shape=image_shape) x = img_input - if include_rescaling: - x = layers.Rescaling(scale=1 / 255.0)(x) for stack_index in range(len(stackwise_num_repeats) - 1): x = apply_vgg_block( x=x, @@ -105,7 +99,6 @@ def __init__( # === Config === self.stackwise_num_repeats = stackwise_num_repeats self.stackwise_num_filters = stackwise_num_filters - self.include_rescaling = include_rescaling self.image_shape = image_shape self.pooling = pooling @@ -113,7 +106,6 @@ def get_config(self): return { "stackwise_num_repeats": self.stackwise_num_repeats, "stackwise_num_filters": self.stackwise_num_filters, - "include_rescaling": self.include_rescaling, "image_shape": self.image_shape, "pooling": self.pooling, } diff --git a/keras_hub/src/models/vgg/vgg_backbone_test.py b/keras_hub/src/models/vgg/vgg_backbone_test.py index 38f7d03606..76b279dc73 100644 --- a/keras_hub/src/models/vgg/vgg_backbone_test.py +++ b/keras_hub/src/models/vgg/vgg_backbone_test.py @@ -25,7 +25,6 @@ def setUp(self): "stackwise_num_repeats": [2, 3, 3], "stackwise_num_filters": [8, 64, 64], "image_shape": (16, 16, 3), - "include_rescaling": False, "pooling": "avg", } self.input_data = np.ones((2, 16, 16, 3), dtype="float32") diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index 6b9733c250..2e3c42285a 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -66,7 +66,6 @@ class VGGImageClassifier(ImageClassifier): stackwise_num_repeats = [2, 2, 3, 3, 3], stackwise_num_filters = [64, 128, 256, 512, 512], image_shape = (224, 224, 3), - include_rescaling = False, pooling = "avg", ) classifier = keras_hub.models.VGGImageClassifier( diff --git a/keras_hub/src/models/vgg/vgg_image_classifier_test.py b/keras_hub/src/models/vgg/vgg_image_classifier_test.py index 83ec811bbf..b62e56ae99 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier_test.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier_test.py @@ -28,7 +28,6 @@ def setUp(self): stackwise_num_repeats=[2, 4, 4], stackwise_num_filters=[2, 16, 16], image_shape=(4, 4, 3), - include_rescaling=False, pooling="max", ) self.init_kwargs = { diff --git a/keras_hub/src/models/vit_det/vit_det_backbone.py b/keras_hub/src/models/vit_det/vit_det_backbone.py index 0aed62fd11..b634f0936e 100644 --- a/keras_hub/src/models/vit_det/vit_det_backbone.py +++ b/keras_hub/src/models/vit_det/vit_det_backbone.py @@ -46,9 +46,6 @@ class ViTDetBackbone(Backbone): global attention. image_shape (tuple[int], optional): The size of the input image in `(H, W, C)` format. Defaults to `(1024, 1024, 3)`. - include_rescaling (bool, optional): Whether to rescale the inputs. If - set to `True`, inputs will be passed through a - `Rescaling(1/255.0)` layer. Defaults to `False`. patch_size (int, optional): the patch size to be supplied to the Patching layer to turn input images into a flattened sequence of patches. Defaults to `16`. @@ -96,7 +93,6 @@ def __init__( intermediate_dim, num_heads, global_attention_layer_indices, - include_rescaling=True, image_shape=(1024, 1024, 3), patch_size=16, num_output_channels=256, @@ -123,9 +119,6 @@ def __init__( ) img_size = img_input.shape[-3] x = img_input - if include_rescaling: - # Use common rescaling strategy across keras_cv - x = keras.layers.Rescaling(1.0 / 255.0)(x) # VITDet scales inputs based on the standard ImageNet mean/stddev. x = (x - ops.array([0.485, 0.456, 0.406], dtype=x.dtype)) / ( ops.array([0.229, 0.224, 0.225], dtype=x.dtype) @@ -179,14 +172,12 @@ def __init__( self.window_size = window_size self.global_attention_layer_indices = global_attention_layer_indices self.layer_norm_epsilon = layer_norm_epsilon - self.include_rescaling = include_rescaling def get_config(self): config = super().get_config() config.update( { "image_shape": self.image_shape, - "include_rescaling": self.include_rescaling, "patch_size": self.patch_size, "hidden_size": self.hidden_size, "num_layers": self.num_layers, diff --git a/keras_hub/src/models/vit_det/vit_det_backbone_test.py b/keras_hub/src/models/vit_det/vit_det_backbone_test.py index d8c1b2d24c..5bd3e0622b 100644 --- a/keras_hub/src/models/vit_det/vit_det_backbone_test.py +++ b/keras_hub/src/models/vit_det/vit_det_backbone_test.py @@ -22,7 +22,6 @@ class ViTDetBackboneTest(TestCase): def setUp(self): self.init_kwargs = { - "include_rescaling": True, "image_shape": (16, 16, 3), "patch_size": 2, "hidden_size": 4, diff --git a/keras_hub/src/utils/timm/convert_resnet.py b/keras_hub/src/utils/timm/convert_resnet.py index 8042d5f5f1..f5dc10e822 100644 --- a/keras_hub/src/utils/timm/convert_resnet.py +++ b/keras_hub/src/utils/timm/convert_resnet.py @@ -151,14 +151,6 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): if version == "v2": port_batch_normalization("post_bn", "norm") - # Rebuild normalization layer with pretrained mean & std - mean = timm_config["pretrained_cfg"]["mean"] - std = timm_config["pretrained_cfg"]["std"] - normalization_layer = backbone.get_layer("normalization") - normalization_layer.input_mean = mean - normalization_layer.input_variance = [s**2 for s in std] - normalization_layer.build(normalization_layer._build_input_shape) - def convert_head(task, loader, timm_config): v2 = "resnetv2_" in timm_config["architecture"] diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index 123cdf9674..0993a9a5d6 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -62,5 +62,20 @@ def load_image_converter(self, cls, **kwargs): pretrained_cfg = self.config.get("pretrained_cfg", None) if not pretrained_cfg or "input_size" not in pretrained_cfg: return None + # This assumes the same basic setup for all timm preprocessing, and that + # all our image conversion will be via a `ResizingImageConverter. We may + # need to extend this as we cover more model types. input_size = pretrained_cfg["input_size"] - return cls(width=input_size[1], height=input_size[2]) + mean = pretrained_cfg["mean"] + variance = [s**2 for s in pretrained_cfg["std"]] + interpolation = pretrained_cfg["interpolation"] + if interpolation not in ("bilinear", "nearest", "bicubic"): + interpolation = "bilinear" # Unsupported interpolation type. + return cls( + width=input_size[1], + height=input_size[2], + scale=1 / 255.0, + mean=mean, + variance=variance, + interpolation=interpolation, + ) diff --git a/tools/checkpoint_conversion/convert_resnet_checkpoints.py b/tools/checkpoint_conversion/convert_resnet_checkpoints.py index eae4554256..530d285f5b 100644 --- a/tools/checkpoint_conversion/convert_resnet_checkpoints.py +++ b/tools/checkpoint_conversion/convert_resnet_checkpoints.py @@ -75,21 +75,36 @@ def validate_output(keras_model, timm_model): image = PIL.Image.open(file) batch = np.array([image]) - # Call with Timm. - timm_batch = keras_model.preprocessor(batch) - timm_batch = keras.ops.transpose(timm_batch, axes=(0, 3, 1, 2)) / 255.0 + # Preprocess with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + data_config["crop_pct"] = 1.0 # Stop timm from cropping. + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_preprocessed = transforms(image) + timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) + timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) + + # Preprocess with Keras. + keras_preprocessed = keras_model.preprocessor(batch) + + # Call with Timm. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) timm_batch = torch.from_numpy(np.array(timm_batch)) timm_outputs = timm_model(timm_batch).detach().numpy() timm_label = np.argmax(timm_outputs[0]) + # Call with Keras. keras_outputs = keras_model.predict(batch) keras_label = np.argmax(keras_outputs[0]) print("🔶 Keras output:", keras_outputs[0, :10]) print("🔶 TIMM output:", timm_outputs[0, :10]) - print("🔶 Difference:", np.mean(np.abs(keras_outputs - timm_outputs))) print("🔶 Keras label:", keras_label) print("🔶 TIMM label:", timm_label) + modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) + print("🔶 Preprocessing difference:", preprocessing_diff) def main(_):