Skip to content

Commit

Permalink
Expunge include_rescaling from backbones
Browse files Browse the repository at this point in the history
Since our models include built in preprocessing, it is much clearer for
this rescaling to happen in the preprocessing layers.
  • Loading branch information
mattdangerw committed Sep 21, 2024
1 parent c4840fa commit 411a4bf
Show file tree
Hide file tree
Showing 32 changed files with 154 additions and 157 deletions.
62 changes: 56 additions & 6 deletions keras_hub/src/layers/preprocessing/resizing_image_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.utils.keras_utils import standardize_data_format
from keras_hub.src.utils.tensor_utils import preprocessing_function


Expand All @@ -23,13 +25,23 @@ class ResizingImageConverter(ImageConverter):
"""An `ImageConverter` that simply resizes the input image.
The `ResizingImageConverter` is a subclass of `ImageConverter` for models
that simply need to resize image tensors before using them for modeling.
The layer will take as input a raw image tensor (batched or unbatched) in the
channels last or channels first format, and output a resize tensor.
that need to resize (and optionally rescale) image tensors before using them
for modeling. The layer will take as input a raw image tensor (batched or
unbatched) in the channels last or channels first format, and output a
resize tensor.
Args:
height: Integer, the height of the output shape.
width: Integer, the width of the output shape.
height: int, the height of the output shape.
width: int, the width of the output shape.
scale: float or `None`. If set, the image we be rescaled with a
`keras.layers.Rescaling` layer, multiplying the image by this
scale.
mean: tuples of floats per channel or `None`. If set, the image will be
normalized per channel by subtracting mean.
If set, also set `variance`.
variance: tuples of floats per channel or `None`. If set, the image will
be normalized per channel by dividing by `sqrt(variance)`.
If set, also set `mean`.
crop_to_aspect_ratio: If `True`, resize the images without aspect
ratio distortion. When the original aspect ratio differs
from the target aspect ratio, the output image will be
Expand Down Expand Up @@ -64,6 +76,9 @@ def __init__(
self,
height,
width,
scale=None,
mean=None,
variance=None,
crop_to_aspect_ratio=True,
interpolation="bilinear",
data_format=None,
Expand All @@ -78,15 +93,47 @@ def __init__(
crop_to_aspect_ratio=crop_to_aspect_ratio,
interpolation=interpolation,
data_format=data_format,
dtype=self.dtype_policy,
name="resizing",
)
if scale is not None:
self.rescaling = keras.layers.Rescaling(
scale=scale,
dtype=self.dtype_policy,
name="rescaling",
)
else:
self.rescaling = None
if (mean is not None) != (variance is not None):
raise ValueError(
"Both `mean` and `variance` should be set or `None`. Received "
f"`mean={mean}`, `variance={variance}`."
)
self.scale = scale
self.mean = mean
self.variance = variance
self.data_format = standardize_data_format(data_format)

def image_size(self):
"""Returns the preprocessed size of a single image."""
return (self.resizing.height, self.resizing.width)

@preprocessing_function
def call(self, inputs):
return self.resizing(inputs)
x = self.resizing(inputs)
if self.rescaling:
x = self.rescaling(x)
if self.mean is not None:
# Avoid `layers.Normalization` so this works batched and unbatched.
channels_first = self.data_format == "channels_first"
if len(ops.shape(inputs)) == 3:
broadcast_dims = (1, 2) if channels_first else (0, 1)
else:
broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2)
mean = ops.expand_dims(ops.array(self.mean), broadcast_dims)
std = ops.expand_dims(ops.sqrt(self.variance), broadcast_dims)
x = (x - mean) / std
return x

def get_config(self):
config = super().get_config()
Expand All @@ -96,6 +143,9 @@ def get_config(self):
"width": self.resizing.width,
"interpolation": self.resizing.interpolation,
"crop_to_aspect_ratio": self.resizing.crop_to_aspect_ratio,
"scale": self.scale,
"mean": self.mean,
"variance": self.variance,
}
)
return config
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,57 @@


class ResizingImageConverterTest(TestCase):
def test_resize_simple(self):
converter = ResizingImageConverter(height=4, width=4)
inputs = np.ones((10, 10, 3))
outputs = converter(inputs)
self.assertAllClose(outputs, ops.ones((4, 4, 3)))

def test_resize_one(self):
converter = ResizingImageConverter(22, 22)
test_image = np.random.rand(10, 10, 3) * 255
shape = ops.shape(converter(test_image))
self.assertEqual(shape, (22, 22, 3))
converter = ResizingImageConverter(
height=4,
width=4,
mean=(0.5, 0.7, 0.3),
variance=(0.25, 0.1, 0.5),
scale=1 / 255.0,
)
inputs = np.ones((10, 10, 3)) * 128
outputs = converter(inputs)
self.assertEqual(ops.shape(outputs), (4, 4, 3))
self.assertAllClose(outputs[:, :, 0], np.ones((4, 4)) * 0.003922)
self.assertAllClose(outputs[:, :, 1], np.ones((4, 4)) * -0.626255)
self.assertAllClose(outputs[:, :, 2], np.ones((4, 4)) * 0.285616)

def test_resize_batch(self):
converter = ResizingImageConverter(12, 12)
test_batch = np.random.rand(4, 10, 20, 3) * 255
shape = ops.shape(converter(test_batch))
self.assertEqual(shape, (4, 12, 12, 3))
converter = ResizingImageConverter(
height=4,
width=4,
mean=(0.5, 0.7, 0.3),
variance=(0.25, 0.1, 0.5),
scale=1 / 255.0,
)
inputs = np.ones((2, 10, 10, 3)) * 128
outputs = converter(inputs)
self.assertEqual(ops.shape(outputs), (2, 4, 4, 3))
self.assertAllClose(outputs[:, :, :, 0], np.ones((2, 4, 4)) * 0.003922)
self.assertAllClose(outputs[:, :, :, 1], np.ones((2, 4, 4)) * -0.626255)
self.assertAllClose(outputs[:, :, :, 2], np.ones((2, 4, 4)) * 0.285616)

def test_errors(self):
with self.assertRaises(ValueError):
ResizingImageConverter(
height=4,
width=4,
mean=(0.5, 0.7, 0.3),
)

def test_config(self):
converter = ResizingImageConverter(
width=12,
height=20,
mean=(0.5, 0.7, 0.3),
variance=(0.25, 0.1, 0.5),
scale=1 / 255.0,
crop_to_aspect_ratio=False,
interpolation="nearest",
)
Expand Down
12 changes: 1 addition & 11 deletions keras_hub/src/models/csp_darknet/csp_darknet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ class CSPDarkNetBackbone(FeaturePyramidBackbone):
level in the model.
stackwise_depth: A list of ints, the depth for each dark level in the
model.
include_rescaling: boolean. If `True`, rescale the input using
`Rescaling(1 / 255.0)` layer. If `False`, do nothing. Defaults to
`True`.
block_type: str. One of `"basic_block"` or `"depthwise_block"`.
Use `"depthwise_block"` for depthwise conv block
`"basic_block"` for basic conv block.
Expand All @@ -55,7 +52,6 @@ class CSPDarkNetBackbone(FeaturePyramidBackbone):
model = keras_hub.models.CSPDarkNetBackbone(
stackwise_num_filters=[128, 256, 512, 1024],
stackwise_depth=[3, 9, 9, 3],
include_rescaling=False,
)
model(input_data)
```
Expand All @@ -65,7 +61,6 @@ def __init__(
self,
stackwise_num_filters,
stackwise_depth,
include_rescaling=True,
block_type="basic_block",
image_shape=(None, None, 3),
**kwargs,
Expand All @@ -82,10 +77,7 @@ def __init__(
base_channels = stackwise_num_filters[0] // 2

image_input = layers.Input(shape=image_shape)
x = image_input
if include_rescaling:
x = layers.Rescaling(scale=1 / 255.0)(x)

x = image_input # Intermediate result.
x = apply_focus(channel_axis, name="stem_focus")(x)
x = apply_darknet_conv_block(
base_channels,
Expand Down Expand Up @@ -130,7 +122,6 @@ def __init__(
# === Config ===
self.stackwise_num_filters = stackwise_num_filters
self.stackwise_depth = stackwise_depth
self.include_rescaling = include_rescaling
self.block_type = block_type
self.image_shape = image_shape
self.pyramid_outputs = pyramid_outputs
Expand All @@ -141,7 +132,6 @@ def get_config(self):
{
"stackwise_num_filters": self.stackwise_num_filters,
"stackwise_depth": self.stackwise_depth,
"include_rescaling": self.include_rescaling,
"block_type": self.block_type,
"image_shape": self.image_shape,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class CSPDarkNetImageClassifier(ImageClassifier):
backbone = keras_hub.models.CSPDarkNetBackbone(
stackwise_num_filters=[128, 256, 512, 1024],
stackwise_depth=[3, 9, 9, 3],
include_rescaling=False,
block_type="basic_block",
image_shape = (224, 224, 3),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def setUp(self):
self.backbone = CSPDarkNetBackbone(
stackwise_num_filters=[2, 16, 16],
stackwise_depth=[1, 3, 3, 1],
include_rescaling=False,
block_type="basic_block",
image_shape=(16, 16, 3),
)
Expand Down
12 changes: 1 addition & 11 deletions keras_hub/src/models/densenet/densenet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ class DenseNetBackbone(FeaturePyramidBackbone):
Args:
stackwise_num_repeats: list of ints, number of repeated convolutional
blocks per dense block.
include_rescaling: bool, whether to rescale the inputs. If set
to `True`, inputs will be passed through a `Rescaling(1/255.0)`
layer. Defaults to `True`.
image_shape: optional shape tuple, defaults to (None, None, 3).
compression_ratio: float, compression rate at transition layers,
defaults to 0.5.
Expand All @@ -51,7 +48,6 @@ class DenseNetBackbone(FeaturePyramidBackbone):
# Randomly initialized backbone with a custom config
model = keras_hub.models.DenseNetBackbone(
stackwise_num_repeats=[6, 12, 24, 16],
include_rescaling=False,
)
model(input_data)
```
Expand All @@ -60,7 +56,6 @@ class DenseNetBackbone(FeaturePyramidBackbone):
def __init__(
self,
stackwise_num_repeats,
include_rescaling=True,
image_shape=(None, None, 3),
compression_ratio=0.5,
growth_rate=32,
Expand All @@ -71,10 +66,7 @@ def __init__(
channel_axis = -1 if data_format == "channels_last" else 1
image_input = keras.layers.Input(shape=image_shape)

x = image_input
if include_rescaling:
x = keras.layers.Rescaling(1 / 255.0)(x)

x = image_input # Intermediate result.
x = keras.layers.Conv2D(
64,
7,
Expand Down Expand Up @@ -124,7 +116,6 @@ def __init__(

# === Config ===
self.stackwise_num_repeats = stackwise_num_repeats
self.include_rescaling = include_rescaling
self.compression_ratio = compression_ratio
self.growth_rate = growth_rate
self.image_shape = image_shape
Expand All @@ -135,7 +126,6 @@ def get_config(self):
config.update(
{
"stackwise_num_repeats": self.stackwise_num_repeats,
"include_rescaling": self.include_rescaling,
"compression_ratio": self.compression_ratio,
"growth_rate": self.growth_rate,
"image_shape": self.image_shape,
Expand Down
1 change: 0 additions & 1 deletion keras_hub/src/models/densenet/densenet_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class DenseNetImageClassifier(ImageClassifier):
backbone = keras_hub.models.DenseNetBackbone(
stackwise_num_filters=[128, 256, 512, 1024],
stackwise_depth=[3, 9, 9, 3],
include_rescaling=False,
block_type="basic_block",
image_shape = (224, 224, 3),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def setUp(self):
self.labels = [0, 3]
self.backbone = DenseNetBackbone(
stackwise_num_repeats=[6, 12, 24, 16],
include_rescaling=True,
compression_ratio=0.5,
growth_rate=32,
image_shape=(224, 224, 3),
Expand Down
17 changes: 3 additions & 14 deletions keras_hub/src/models/efficientnet/efficientnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
MBConvBlock, but instead of using a depthwise convolution and a 1x1
output convolution blocks fused blocks use a single 3x3 convolution
block.
include_rescaling: bool, whether to rescale the inputs. If set to
True, inputs will be passed through a `Rescaling(1/255.0)` layer.
min_depth: integer, minimum number of filters. Can be None and ignored
if use_depth_divisor_as_min_depth is set to True.
include_initial_padding: bool, whether to include initial zero padding
Expand Down Expand Up @@ -96,7 +94,6 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
stackwise_block_types=[["fused"] * 3 + ["unfused"] * 3],
width_coefficient=1.0,
depth_coefficient=1.0,
include_rescaling=False,
)
images = np.ones((1, 256, 256, 3))
outputs = efficientnet.predict(images)
Expand All @@ -116,7 +113,6 @@ def __init__(
stackwise_squeeze_and_excite_ratios,
stackwise_strides,
stackwise_block_types,
include_rescaling=True,
dropout=0.2,
depth_divisor=8,
min_depth=8,
Expand All @@ -129,14 +125,9 @@ def __init__(
batch_norm_momentum=0.9,
**kwargs,
):
img_input = keras.layers.Input(shape=input_shape)

x = img_input

if include_rescaling:
# Use common rescaling strategy across keras
x = keras.layers.Rescaling(scale=1.0 / 255.0)(x)
image_input = keras.layers.Input(shape=input_shape)

x = image_input # Intermediate result.
if include_initial_padding:
x = keras.layers.ZeroPadding2D(
padding=self._correct_pad_downsample(x, 3),
Expand Down Expand Up @@ -282,10 +273,9 @@ def __init__(
curr_pyramid_level += 1

# Create model.
super().__init__(inputs=img_input, outputs=x, **kwargs)
super().__init__(inputs=image_input, outputs=x, **kwargs)

# === Config ===
self.include_rescaling = include_rescaling
self.width_coefficient = width_coefficient
self.depth_coefficient = depth_coefficient
self.dropout = dropout
Expand Down Expand Up @@ -313,7 +303,6 @@ def get_config(self):
config = super().get_config()
config.update(
{
"include_rescaling": self.include_rescaling,
"width_coefficient": self.width_coefficient,
"depth_coefficient": self.depth_coefficient,
"dropout": self.dropout,
Expand Down
Loading

0 comments on commit 411a4bf

Please sign in to comment.