diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 6b85148caf..17cd5b77bc 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -43,6 +43,10 @@ from keras_hub.src.models.resnet.resnet_image_converter import ( ResNetImageConverter, ) +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.models.retinanet.retinanet_image_converter import ( + RetinaNetImageConverter, +) from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 371277465a..856e25ba92 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -167,6 +167,10 @@ from keras_hub.src.models.image_classifier_preprocessor import ( ImageClassifierPreprocessor, ) +from keras_hub.src.models.image_object_detector import ImageObjectDetector +from keras_hub.src.models.image_object_detector_preprocessor import ( + ImageObjectDetectorPreprocessor, +) from keras_hub.src.models.image_segmenter import ImageSegmenter from keras_hub.src.models.image_segmenter_preprocessor import ( ImageSegmenterPreprocessor, @@ -233,6 +237,13 @@ from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import ( ResNetImageClassifierPreprocessor, ) +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_object_detector import ( + RetinaNetObjectDetector, +) +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( + RetinaNetObjectDetectorPreprocessor, +) from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import ( diff --git a/keras_hub/src/bounding_box/__init__.py b/keras_hub/src/bounding_box/__init__.py index e69de29bb2..78f451fd0d 100644 --- a/keras_hub/src/bounding_box/__init__.py +++ b/keras_hub/src/bounding_box/__init__.py @@ -0,0 +1,2 @@ +# TODO: Once all bounding boxes are moved to keras repostory remove the +# bounding box folder. diff --git a/keras_hub/src/bounding_box/converters.py b/keras_hub/src/bounding_box/converters.py index 263cd6df33..92ef27c15d 100644 --- a/keras_hub/src/bounding_box/converters.py +++ b/keras_hub/src/bounding_box/converters.py @@ -20,29 +20,74 @@ class RequiresImagesException(Exception): ALL_AXES = 4 -def _encode_box_to_deltas( +def encode_box_to_deltas( anchors, boxes, - anchor_format: str, - box_format: str, + anchor_format, + box_format, + encoding_format="center_yxhw", variance=None, image_shape=None, ): - """Converts bounding_boxes from `center_yxhw` to delta format.""" + """Encodes bounding boxes relative to anchors as deltas. + + This function calculates the deltas that represent the difference between + bounding boxes and provided anchors. Deltas encode the offsets and scaling + factors to apply to anchors to obtain the target boxes. + + Boxes and anchors are first converted to the specified `encoding_format` + (defaulting to `center_yxhw`) for consistent delta representation. + + Args: + anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the + number of anchors. + boxes: `Tensors` Bounding boxes to encode. Boxes can be of be shape + `(B, N, 4)` or `(N, 4)`. + anchor_format: str. The format of the input `anchors` + (e.g., "xyxy", "xywh", etc.). + box_format: str. The format of the input `boxes` + (e.g., "xyxy", "xywh", etc.). + encoding_format: str. The intermediate format to which boxes and anchors + are converted before delta calculation. Defaults to "center_yxhw". + variance: `List[float]`. A 4-element array/tensor representing variance + factors to scale the box deltas. If provided, the calculated deltas + are divided by the variance. Defaults to None. + image_shape: `Tuple[int]`. The shape of the image (height, width, 3). + When using relative bounding box format for `box_format` the + `image_shape` is used for normalization. + Returns: + Encoded box deltas. The return type matches the `encode_format`. + + Raises: + ValueError: If `variance` is not None and its length is not 4. + ValueError: If `encoding_format` is not `"center_xywh"` or + `"center_yxhw"`. + + """ if variance is not None: variance = ops.convert_to_tensor(variance, "float32") var_len = variance.shape[-1] if var_len != 4: raise ValueError(f"`variance` must be length 4, got {variance}") + + if encoding_format not in ["center_xywh", "center_yxhw"]: + raise ValueError( + "`encoding_format` should be one of 'center_xywh' or 'center_yxhw', " + f"got {encoding_format}" + ) + encoded_anchors = convert_format( anchors, source=anchor_format, - target="center_yxhw", + target=encoding_format, image_shape=image_shape, ) boxes = convert_format( - boxes, source=box_format, target="center_yxhw", image_shape=image_shape + boxes, + source=box_format, + target=encoding_format, + image_shape=image_shape, ) anchor_dimensions = ops.maximum( encoded_anchors[..., 2:], keras.backend.epsilon() @@ -61,15 +106,54 @@ def _encode_box_to_deltas( return boxes_delta -def _decode_deltas_to_boxes( +def decode_deltas_to_boxes( anchors, boxes_delta, - anchor_format: str, - box_format: str, + anchor_format, + box_format, + encoded_format="center_yxhw", variance=None, image_shape=None, ): - """Converts bounding_boxes from delta format to `center_yxhw`.""" + """Converts bounding boxes from delta format to the specified `box_format`. + + This function decodes bounding box deltas relative to anchors to obtain the + final bounding box coordinates. The boxes are encoded in a specific + `encoded_format` (center_yxhw by default) during the decoding process. + This allows flexibility in how the deltas are applied to the anchors. + + Args: + anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level + indicies and values are corresponding anchor boxes. + The shape of the array/tensor should be `(N, 4)` where N is the + number of anchors. + boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas + must have the same type and structure as `anchors`. The + shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is + the number of boxes. + anchor_format: str. The format of the input `anchors`. + (e.g., `"xyxy"`, `"xywh"`, etc.) + box_format: str. The desired format for the output boxes. + (e.g., `"xyxy"`, `"xywh"`, etc.) + encoded_format: str. Raw output format from regression head. Defaults + to `"center_yxhw"`. + variance: `List[floats]`. A 4-element array/tensor representing + variance factors to scale the box deltas. If provided, the deltas + are multiplied by the variance before being applied to the anchors. + Defaults to None. + image_shape: The shape of the image (height, width). This is needed + if normalization to image size is required when converting between + formats. Defaults to None. + + Returns: + Decoded box coordinates. The return type matches the `box_format`. + + Raises: + ValueError: If `variance` is not None and its length is not 4. + ValueError: If `encoded_format` is not `"center_xywh"` or + `"center_yxhw"`. + + """ if variance is not None: variance = ops.convert_to_tensor(variance, "float32") var_len = variance.shape[-1] @@ -77,11 +161,17 @@ def _decode_deltas_to_boxes( if var_len != 4: raise ValueError(f"`variance` must be length 4, got {variance}") + if encoded_format not in ["center_xywh", "center_yxhw"]: + raise ValueError( + f"`encoded_format` should be 'center_xywh' or 'center_yxhw', " + f"but got '{encoded_format}'." + ) + def decode_single_level(anchor, box_delta): encoded_anchor = convert_format( anchor, source=anchor_format, - target="center_yxhw", + target=encoded_format, image_shape=image_shape, ) if variance is not None: @@ -97,7 +187,7 @@ def decode_single_level(anchor, box_delta): ) box = convert_format( box, - source="center_yxhw", + source=encoded_format, target=box_format, image_shape=image_shape, ) diff --git a/keras_hub/src/models/image_object_detector.py b/keras_hub/src/models/image_object_detector.py new file mode 100644 index 0000000000..aa8a54dc3e --- /dev/null +++ b/keras_hub/src/models/image_object_detector.py @@ -0,0 +1,87 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task + + +@keras_hub_export("keras_hub.models.ImageObjectDetector") +class ImageObjectDetector(Task): + """Base class for all image object detection tasks. + + The `ImageObjectDetector` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + object detection. `ImageObjectDetector` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is dictionary with `boxes` and + `classes`. + + All `ImageObjectDetector` tasks include a `from_preset()` constructor which + can be used to load a pre-trained config and weights. + """ + + def compile( + self, + optimizer="auto", + box_loss="auto", + classification_loss="auto", + metrics=None, + **kwargs, + ): + """Configures the `ImageObjectDetector` task for training. + + The `ImageObjectDetector` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + box_loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a + `keras.losses.Huber` loss will be + applied for the object detector task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + classification_loss: `"auto"`, a loss name, or a `keras.losses.Loss` + instance. Defaults to `"auto"`, where a + `keras.losses.BinaryFocalCrossentropy` loss will be + applied for the object detector task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `a list of metrics to be evaluated by + the model during training and testing. Defaults to `None`. + See `keras.Model.compile` and `keras.metrics` for + more info on possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.Adam(5e-5) + if box_loss == "auto": + box_loss = keras.losses.Huber(reduction="sum") + if classification_loss == "auto": + activation = getattr(self, "activation", None) + activation = keras.activations.get(activation) + from_logits = activation != keras.activations.sigmoid + classification_loss = keras.losses.BinaryFocalCrossentropy( + from_logits=from_logits, reduction="sum" + ) + if metrics is not None: + raise ValueError("User metrics not yet supported") + + losses = { + "bbox_regression": box_loss, + "cls_logits": classification_loss, + } + + super().compile( + optimizer=optimizer, + loss=losses, + metrics=metrics, + **kwargs, + ) diff --git a/keras_hub/src/models/image_object_detector_preprocessor.py b/keras_hub/src/models/image_object_detector_preprocessor.py new file mode 100644 index 0000000000..581a10d6d9 --- /dev/null +++ b/keras_hub/src/models/image_object_detector_preprocessor.py @@ -0,0 +1,57 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +@keras_hub_export("keras_hub.models.ImageObjectDetectorPreprocessor") +class ImageObjectDetectorPreprocessor(Preprocessor): + """Base class for object detector preprocessing layers. + + `ImageObjectDetectorPreprocessor` tasks wraps a + `keras_hub.layers.Preprocessor` to create a preprocessing layer for + object detection tasks. It is intended to be paired with a + `keras_hub.models.ImageObjectDetector` task. + + All `ImageObjectDetectorPreprocessor` take three inputs, `x`, `y`, and + `sample_weight`. `x`, the first input, should always be included. It can + be a image or batch of images. See examples below. `y` and `sample_weight` + are optional inputs that will be passed through unaltered. Usually, `y` will + be the a dict of `{"boxes": Tensor(batch_size, num_boxes, 4), + "classes": (batch_size, num_boxes)}. + + The layer will returns either `x`, an `(x, y)` tuple if labels were provided, + or an `(x, y, sample_weight)` tuple if labels and sample weight were + provided. `x` will be the input images after all model preprocessing has + been applied. + + All `ImageObjectDetectorPreprocessor` tasks include a `from_preset()` + constructor which can be used to load a pre-trained config and vocabularies. + You can call the `from_preset()` constructor directly on this base class, in + which case the correct class for your model will be automatically + instantiated. + + Args: + image_converter: Preprocessing pipeline for images. + + Examples. + ```python + preprocessor = keras_hub.models.ImageObjectDetectorPreprocessor.from_preset( + "retinanet_resnet50", + ) + """ + + def __init__( + self, + image_converter=None, + **kwargs, + ): + super().__init__(**kwargs) + self.image_converter = image_converter + + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + if self.image_converter: + x = self.image_converter(x) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/retinanet/anchor_generator.py b/keras_hub/src/models/retinanet/anchor_generator.py index bb46988926..a3c3800c49 100644 --- a/keras_hub/src/models/retinanet/anchor_generator.py +++ b/keras_hub/src/models/retinanet/anchor_generator.py @@ -3,9 +3,13 @@ import keras from keras import ops +from keras_hub.src.api_export import keras_hub_export + +# TODO: https://github.com/keras-team/keras-hub/issues/1965 from keras_hub.src.bounding_box.converters import convert_format +@keras_hub_export("keras_hub.layers.AnchorGenerator") class AnchorGenerator(keras.layers.Layer): """Generates anchor boxes for object detection tasks. @@ -81,6 +85,7 @@ def __init__( self.num_scales = num_scales self.aspect_ratios = aspect_ratios self.anchor_size = anchor_size + self.num_base_anchors = num_scales * len(aspect_ratios) self.built = True def call(self, inputs): @@ -92,60 +97,61 @@ def call(self, inputs): image_shape = tuple(image_shape) - multilevel_boxes = {} + multilevel_anchors = {} for level in range(self.min_level, self.max_level + 1): - boxes_l = [] # Calculate the feature map size for this level feat_size_y = math.ceil(image_shape[0] / 2**level) feat_size_x = math.ceil(image_shape[1] / 2**level) # Calculate the stride (step size) for this level - stride_y = ops.cast(image_shape[0] / feat_size_y, "float32") - stride_x = ops.cast(image_shape[1] / feat_size_x, "float32") + stride_y = image_shape[0] // feat_size_y + stride_x = image_shape[1] // feat_size_x # Generate anchor center points # Start from stride/2 to center anchors on pixels - cx = ops.arange(stride_x / 2, image_shape[1], stride_x) - cy = ops.arange(stride_y / 2, image_shape[0], stride_y) + cx = ops.arange(0, feat_size_x, dtype="float32") * stride_x + cy = ops.arange(0, feat_size_y, dtype="float32") * stride_y # Create a grid of anchor centers - cx_grid, cy_grid = ops.meshgrid(cx, cy) - - for scale in range(self.num_scales): - for aspect_ratio in self.aspect_ratios: - # Calculate the intermediate scale factor - intermidate_scale = 2 ** (scale / self.num_scales) - # Calculate the base anchor size for this level and scale - base_anchor_size = ( - self.anchor_size * 2**level * intermidate_scale - ) - # Adjust anchor dimensions based on aspect ratio - aspect_x = aspect_ratio**0.5 - aspect_y = aspect_ratio**-0.5 - half_anchor_size_x = base_anchor_size * aspect_x / 2.0 - half_anchor_size_y = base_anchor_size * aspect_y / 2.0 - - # Generate anchor boxes (y1, x1, y2, x2 format) - boxes = ops.stack( - [ - cy_grid - half_anchor_size_y, - cx_grid - half_anchor_size_x, - cy_grid + half_anchor_size_y, - cx_grid + half_anchor_size_x, - ], - axis=-1, - ) - boxes_l.append(boxes) - # Concat anchors on the same level to tensor shape HxWx(Ax4) - boxes_l = ops.concatenate(boxes_l, axis=-1) - boxes_l = ops.reshape(boxes_l, (-1, 4)) - # Convert to user defined - multilevel_boxes[f"P{level}"] = convert_format( - boxes_l, - source="yxyx", + cy_grid, cx_grid = ops.meshgrid(cy, cx, indexing="ij") + cy_grid = ops.reshape(cy_grid, (-1,)) + cx_grid = ops.reshape(cx_grid, (-1,)) + + shifts = ops.stack((cx_grid, cy_grid, cx_grid, cy_grid), axis=1) + sizes = [ + int( + 2**level * self.anchor_size * 2 ** (scale / self.num_scales) + ) + for scale in range(self.num_scales) + ] + + base_anchors = self.generate_base_anchors( + sizes=sizes, aspect_ratios=self.aspect_ratios + ) + shifts = ops.reshape(shifts, (-1, 1, 4)) + base_anchors = ops.reshape(base_anchors, (1, -1, 4)) + + anchors = shifts + base_anchors + anchors = ops.reshape(anchors, (-1, 4)) + multilevel_anchors[f"P{level}"] = convert_format( + anchors, + source="xyxy", target=self.bounding_box_format, ) - return multilevel_boxes + return multilevel_anchors + + def generate_base_anchors(self, sizes, aspect_ratios): + sizes = ops.convert_to_tensor(sizes, dtype="float32") + aspect_ratios = ops.convert_to_tensor(aspect_ratios) + h_ratios = ops.sqrt(aspect_ratios) + w_ratios = 1 / h_ratios + + ws = ops.reshape(w_ratios[:, None] * sizes[None, :], (-1,)) + hs = ops.reshape(h_ratios[:, None] * sizes[None, :], (-1,)) + + base_anchors = ops.stack([-1 * ws, -1 * hs, ws, hs], axis=1) / 2 + base_anchors = ops.round(base_anchors) + return base_anchors def compute_output_shape(self, input_shape): multilevel_boxes_shape = {} @@ -156,18 +162,11 @@ def compute_output_shape(self, input_shape): for i in range(self.min_level, self.max_level + 1): multilevel_boxes_shape[f"P{i}"] = ( - (image_height // 2 ** (i)) - * (image_width // 2 ** (i)) - * self.anchors_per_location, + int( + math.ceil(image_height / 2 ** (i)) + * math.ceil(image_width // 2 ** (i)) + * self.num_base_anchors + ), 4, ) return multilevel_boxes_shape - - @property - def anchors_per_location(self): - """ - The `anchors_per_location` property returns the number of anchors - generated per pixel location, which is equal to - `num_scales * len(aspect_ratios)`. - """ - return self.num_scales * len(self.aspect_ratios) diff --git a/keras_hub/src/models/retinanet/anchor_generator_test.py b/keras_hub/src/models/retinanet/anchor_generator_test.py index c843c32f27..0b71630843 100644 --- a/keras_hub/src/models/retinanet/anchor_generator_test.py +++ b/keras_hub/src/models/retinanet/anchor_generator_test.py @@ -2,7 +2,6 @@ from absl.testing import parameterized from keras import ops -from keras_hub.src.bounding_box.converters import convert_format from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.tests.test_case import TestCase @@ -18,7 +17,7 @@ def test_layer_behaviors(self): "max_level": 7, "num_scales": 3, "aspect_ratios": [0.5, 1.0, 2.0], - "anchor_size": 8, + "anchor_size": 4, }, input_data=np.random.uniform(size=images_shape), expected_output_shape={ @@ -40,58 +39,13 @@ def test_layer_behaviors(self): + ( { "P5": [ - [-16.0, -16.0, 48.0, 48.0], - [-16.0, 16.0, 48.0, 80.0], - [16.0, -16.0, 80.0, 48.0], - [16.0, 16.0, 80.0, 80.0], + [-32.0, -32.0, 32.0, 32.0], + [-32.0, 0, 32.0, 64.0], + [0.0, -32.0, 64.0, 32.0], + [0.0, 0.0, 64.0, 64.0], ] }, ), - # Multi scale anchor - ("xywh", 5, 6, 1, [1.0], 2.0, [64, 64]) - + ( - { - "P5": [ - [-16.0, -16.0, 48.0, 48.0], - [-16.0, 16.0, 48.0, 80.0], - [16.0, -16.0, 80.0, 48.0], - [16.0, 16.0, 80.0, 80.0], - ], - "P6": [[-32, -32, 96, 96]], - }, - ), - # Multi aspect ratio anchor - ("xyxy", 6, 6, 1, [1.0, 4.0, 0.25], 2.0, [64, 64]) - + ( - { - "P6": [ - [-32.0, -32.0, 96.0, 96.0], - [0.0, -96.0, 64.0, 160.0], - [-96.0, 0.0, 160.0, 64.0], - ] - }, - ), - # Intermidate scales - ("yxyx", 5, 5, 2, [1.0], 1.0, [32, 32]) - + ( - { - "P5": [ - [0.0, 0.0, 32.0, 32.0], - [ - 16 - 16 * 2**0.5, - 16 - 16 * 2**0.5, - 16 + 16 * 2**0.5, - 16 + 16 * 2**0.5, - ], - ] - }, - ), - # Non-square - ("xywh", 5, 5, 1, [1.0], 1.0, [64, 32]) - + ({"P5": [[0, 0, 32, 32], [32, 0, 64, 32]]},), - # Indivisible by 2^level - ("xyxy", 5, 5, 1, [1.0], 1.0, [40, 32]) - + ({"P5": [[-6, 0, 26, 32], [14, 0, 46, 32]]},), ) def test_anchor_generator( self, @@ -116,9 +70,4 @@ def test_anchor_generator( multilevel_boxes = anchor_generator(images) for key in expected_boxes: expected_boxes[key] = ops.convert_to_tensor(expected_boxes[key]) - expected_boxes[key] = convert_format( - expected_boxes[key], - source="yxyx", - target=bounding_box_format, - ) self.assertAllClose(expected_boxes, multilevel_boxes) diff --git a/keras_hub/src/models/retinanet/feature_pyramid.py b/keras_hub/src/models/retinanet/feature_pyramid.py index 5c0bbb906c..ea8b13af75 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid.py +++ b/keras_hub/src/models/retinanet/feature_pyramid.py @@ -1,5 +1,9 @@ +import math + import keras +from keras_hub.src.utils.keras_utils import standardize_data_format + class FeaturePyramid(keras.layers.Layer): """A Feature Pyramid Network (FPN) layer. @@ -37,14 +41,18 @@ class FeaturePyramid(keras.layers.Layer): Args: min_level: int. The minimum level of the feature pyramid. max_level: int. The maximum level of the feature pyramid. + use_p5: bool. If True, uses the output of the last layer (`P5` from + Feature Pyramid Network) as input for creating coarser convolution + layers (`P6`, `P7`). If False, uses the direct input `P5` + for creating coarser convolution layers. num_filters: int. The number of filters in each feature map. activation: string or `keras.activations`. The activation function to be used in network. Defaults to `"relu"`. - kernel_initializer: `str` or `keras.initializers` initializer. + kernel_initializer: `str` or `keras.initializers`. The kernel initializer for the convolution layers. Defaults to `"VarianceScaling"`. - bias_initializer: `str` or `keras.initializers` initializer. + bias_initializer: `str` or `keras.initializers`. The bias initializer for the convolution layers. Defaults to `"zeros"`. batch_norm_momentum: float. @@ -53,10 +61,10 @@ class FeaturePyramid(keras.layers.Layer): batch_norm_epsilon: float. The epsilon for the batch normalization layers. Defaults to `0.001`. - kernel_regularizer: `str` or `keras.regularizers` regularizer. + kernel_regularizer: `str` or `keras.regularizers`. The kernel regularizer for the convolution layers. Defaults to `None`. - bias_regularizer: `str` or `keras.regularizers` regularizer. + bias_regularizer: `str` or `keras.regularizers`. The bias regularizer for the convolution layers. Defaults to `None`. use_batch_norm: bool. Whether to use batch normalization. @@ -69,6 +77,7 @@ def __init__( self, min_level, max_level, + use_p5, num_filters=256, activation="relu", kernel_initializer="VarianceScaling", @@ -78,6 +87,7 @@ def __init__( kernel_regularizer=None, bias_regularizer=None, use_batch_norm=False, + data_format=None, **kwargs, ): super().__init__(**kwargs) @@ -89,6 +99,7 @@ def __init__( self.min_level = min_level self.max_level = max_level self.num_filters = num_filters + self.use_p5 = use_p5 self.activation = keras.activations.get(activation) self.kernel_initializer = keras.initializers.get(kernel_initializer) self.bias_initializer = keras.initializers.get(bias_initializer) @@ -103,8 +114,8 @@ def __init__( self.bias_regularizer = keras.regularizers.get(bias_regularizer) else: self.bias_regularizer = None - self.data_format = keras.backend.image_data_format() - self.batch_norm_axis = -1 if self.data_format == "channels_last" else 1 + self.data_format = standardize_data_format(data_format) + self.batch_norm_axis = -1 if data_format == "channels_last" else 1 def build(self, input_shapes): input_shapes = { @@ -117,7 +128,6 @@ def build(self, input_shapes): } input_levels = [int(level[1]) for level in input_shapes] backbone_max_level = min(max(input_levels), self.max_level) - # Build lateral layers self.lateral_conv_layers = {} for i in range(self.min_level, backbone_max_level + 1): @@ -134,7 +144,11 @@ def build(self, input_shapes): dtype=self.dtype_policy, name=f"lateral_conv_{level}", ) - self.lateral_conv_layers[level].build(input_shapes[level]) + self.lateral_conv_layers[level].build( + (None, None, None, input_shapes[level][-1]) + if self.data_format == "channels_last" + else (None, input_shapes[level][1], None, None) + ) self.lateral_batch_norm_layers = {} if self.use_batch_norm: @@ -149,9 +163,9 @@ def build(self, input_shapes): ) ) self.lateral_batch_norm_layers[level].build( - (None, None, None, 256) + (None, None, None, self.num_filters) if self.data_format == "channels_last" - else (None, 256, None, None) + else (None, self.num_filters, None, None) ) # Build output layers @@ -171,9 +185,9 @@ def build(self, input_shapes): name=f"output_conv_{level}", ) self.output_conv_layers[level].build( - (None, None, None, 256) + (None, None, None, self.num_filters) if self.data_format == "channels_last" - else (None, 256, None, None) + else (None, self.num_filters, None, None) ) # Build coarser layers @@ -192,11 +206,18 @@ def build(self, input_shapes): dtype=self.dtype_policy, name=f"coarser_{level}", ) - self.output_conv_layers[level].build( - (None, None, None, 256) - if self.data_format == "channels_last" - else (None, 256, None, None) - ) + if i == backbone_max_level + 1 and self.use_p5: + self.output_conv_layers[level].build( + (None, None, None, input_shapes[f"P{i-1}"][-1]) + if self.data_format == "channels_last" + else (None, input_shapes[f"P{i-1}"][1], None, None) + ) + else: + self.output_conv_layers[level].build( + (None, None, None, self.num_filters) + if self.data_format == "channels_last" + else (None, self.num_filters, None, None) + ) # Build batch norm layers self.output_batch_norms = {} @@ -212,9 +233,9 @@ def build(self, input_shapes): ) ) self.output_batch_norms[level].build( - (None, None, None, 256) + (None, None, None, self.num_filters) if self.data_format == "channels_last" - else (None, 256, None, None) + else (None, self.num_filters, None, None) ) # The same upsampling layer is used for all levels @@ -273,7 +294,11 @@ def call(self, inputs): for i in range(backbone_max_level + 1, self.max_level + 1): level = f"P{i}" - feats_in = output_features[f"P{i-1}"] + feats_in = ( + inputs[f"P{i-1}"] + if i == backbone_max_level + 1 and self.use_p5 + else output_features[f"P{i-1}"] + ) if i > backbone_max_level + 1: feats_in = self.activation(feats_in) output_features[level] = ( @@ -283,7 +308,10 @@ def call(self, inputs): if self.use_batch_norm else self.output_conv_layers[level](feats_in) ) - + output_features = { + f"P{i}": output_features[f"P{i}"] + for i in range(self.min_level, self.max_level + 1) + } return output_features def get_config(self): @@ -293,7 +321,9 @@ def get_config(self): "min_level": self.min_level, "max_level": self.max_level, "num_filters": self.num_filters, + "use_p5": self.use_p5, "use_batch_norm": self.use_batch_norm, + "data_format": self.data_format, "activation": keras.activations.serialize(self.activation), "kernel_initializer": keras.initializers.serialize( self.kernel_initializer @@ -320,34 +350,51 @@ def get_config(self): def compute_output_shape(self, input_shapes): output_shape = {} - print(input_shapes) input_levels = [int(level[1]) for level in input_shapes] backbone_max_level = min(max(input_levels), self.max_level) for i in range(self.min_level, backbone_max_level + 1): level = f"P{i}" if self.data_format == "channels_last": - output_shape[level] = input_shapes[level][:-1] + (256,) + output_shape[level] = input_shapes[level][:-1] + ( + self.num_filters, + ) else: output_shape[level] = ( input_shapes[level][0], - 256, + self.num_filters, ) + input_shapes[level][1:3] intermediate_shape = input_shapes[f"P{backbone_max_level}"] intermediate_shape = ( ( intermediate_shape[0], - intermediate_shape[1] // 2, - intermediate_shape[2] // 2, - 256, + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + self.num_filters, ) if self.data_format == "channels_last" else ( intermediate_shape[0], - 256, - intermediate_shape[1] // 2, - intermediate_shape[2] // 2, + self.num_filters, + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), ) ) @@ -357,16 +404,32 @@ def compute_output_shape(self, input_shapes): intermediate_shape = ( ( intermediate_shape[0], - intermediate_shape[1] // 2, - intermediate_shape[2] // 2, - 256, + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + self.num_filters, ) if self.data_format == "channels_last" else ( intermediate_shape[0], - 256, - intermediate_shape[1] // 2, - intermediate_shape[2] // 2, + self.num_filters, + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), + ( + int(math.ceil(intermediate_shape[1] / 2)) + if intermediate_shape[1] is not None + else None + ), ) ) diff --git a/keras_hub/src/models/retinanet/feature_pyramid_test.py b/keras_hub/src/models/retinanet/feature_pyramid_test.py index 728233c6ae..b9b62e62ac 100644 --- a/keras_hub/src/models/retinanet/feature_pyramid_test.py +++ b/keras_hub/src/models/retinanet/feature_pyramid_test.py @@ -18,6 +18,7 @@ def test_layer_behaviors(self): "batch_norm_epsilon": 0.0001, "kernel_initializer": "HeNormal", "bias_initializer": "Zeros", + "use_p5": False, }, input_data={ "P3": random.uniform(shape=(2, 64, 64, 4)), @@ -40,12 +41,14 @@ def test_layer_behaviors(self): "equal_resolutions", 3, 7, + False, {"P3": (2, 16, 16, 3), "P4": (2, 8, 8, 3), "P5": (2, 4, 4, 3)}, ), ( "different_resolutions", 2, 6, + True, { "P2": (2, 64, 128, 4), "P3": (2, 32, 64, 8), @@ -54,8 +57,14 @@ def test_layer_behaviors(self): }, ), ) - def test_layer_output_shapes(self, min_level, max_level, input_shapes): - layer = FeaturePyramid(min_level=min_level, max_level=max_level) + def test_layer_output_shapes( + self, min_level, max_level, use_p5, input_shapes + ): + layer = FeaturePyramid( + min_level=min_level, + max_level=max_level, + use_p5=use_p5, + ) inputs = { level: ops.ones(input_shapes[level]) for level in input_shapes diff --git a/keras_hub/src/models/retinanet/non_max_supression.py b/keras_hub/src/models/retinanet/non_max_supression.py index 9e52479f1d..5ca52b4dfc 100644 --- a/keras_hub/src/models/retinanet/non_max_supression.py +++ b/keras_hub/src/models/retinanet/non_max_supression.py @@ -3,6 +3,7 @@ import keras from keras import ops +# TODO: https://github.com/keras-team/keras-hub/issues/1965 from keras_hub.src.bounding_box import converters from keras_hub.src.bounding_box import utils from keras_hub.src.bounding_box import validate_format diff --git a/keras_hub/src/models/retinanet/prediction_head.py b/keras_hub/src/models/retinanet/prediction_head.py new file mode 100644 index 0000000000..007d4f32bd --- /dev/null +++ b/keras_hub/src/models/retinanet/prediction_head.py @@ -0,0 +1,192 @@ +import keras + +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class PredictionHead(keras.layers.Layer): + """A head for classification or bounding box regression predictions. + + Args: + output_filters: int. The umber of convolution filters in the final layer. + The number of output channels determines the prediction type: + - **Classification**: + `output_filters = num_anchors * num_classes` + Predicts class probabilities for each anchor. + - **Bounding Box Regression**: + `output_filters = num_anchors * 4` Predicts bounding box + offsets (x1, y1, x2, y2) for each anchor. + num_filters: int. The number of convolution filters to use in the base + layer. + num_conv_layers: int. The number of convolution layers before the final + layer. + use_prior_probability: bool. Set to True to use prior probability in the + bias initializer for the final convolution layer. + Defaults to `False`. + prior_probability: float. The prior probability value to use for + initializing the bias. Only used if `use_prior_probability` is + `True`. Defaults to `0.01`. + kernel_initializer: `str` or `keras.initializers`. The kernel + initializer for the convolution layers. Defaults to + `"random_normal"`. + bias_initializer: `str` or `keras.initializers`. The bias initializer + for the convolution layers. Defaults to `"zeros"`. + kernel_regularizer: `str` or `keras.regularizers`. The kernel + regularizer for the convolution layers. Defaults to `None`. + bias_regularizer: `str` or `keras.regularizers`. The bias regularizer + for the convolution layers. Defaults to `None`. + use_group_norm: bool. Whether to use Group Normalization after + the convolution layers. Defaults to `False`. + + Returns: + A function representing either the classification + or the box regression head depending on `output_filters`. + """ + + def __init__( + self, + output_filters, + num_filters, + num_conv_layers, + use_prior_probability=False, + prior_probability=0.01, + activation="relu", + kernel_initializer="random_normal", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + use_group_norm=False, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.output_filters = output_filters + self.num_filters = num_filters + self.num_conv_layers = num_conv_layers + self.use_prior_probability = use_prior_probability + self.prior_probability = prior_probability + self.activation = keras.activations.get(activation) + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + if kernel_regularizer is not None: + self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) + else: + self.kernel_regularizer = None + if bias_regularizer is not None: + self.bias_regularizer = keras.regularizers.get(bias_regularizer) + else: + self.bias_regularizer = None + self.use_group_norm = use_group_norm + self.data_format = standardize_data_format(data_format) + + def build(self, input_shape): + intermediate_shape = input_shape + self.conv_layers = [] + self.group_norm_layers = [] + for idx in range(self.num_conv_layers): + conv = keras.layers.Conv2D( + self.num_filters, + kernel_size=3, + padding="same", + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + use_bias=not self.use_group_norm, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + data_format=self.data_format, + dtype=self.dtype_policy, + name=f"conv2d_{idx}", + ) + conv.build(intermediate_shape) + self.conv_layers.append(conv) + intermediate_shape = ( + input_shape[:-1] + (self.num_filters,) + if self.data_format == "channels_last" + else (input_shape[0], self.num_filters) + (input_shape[1:-1]) + ) + if self.use_group_norm: + group_norm = keras.layers.GroupNormalization( + groups=32, + axis=-1 if self.data_format == "channels_last" else 1, + dtype=self.dtype_policy, + name=f"group_norm_{idx}", + ) + group_norm.build(intermediate_shape) + self.group_norm_layers.append(group_norm) + prior_probability = keras.initializers.Constant( + -1 + * keras.ops.log( + (1 - self.prior_probability) / self.prior_probability + ) + ) + self.prediction_layer = keras.layers.Conv2D( + self.output_filters, + kernel_size=3, + strides=1, + padding="same", + kernel_initializer=self.kernel_initializer, + bias_initializer=( + prior_probability + if self.use_prior_probability + else self.bias_initializer + ), + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + dtype=self.dtype_policy, + name="logits_layer", + ) + self.prediction_layer.build( + (None, None, None, self.num_filters) + if self.data_format == "channels_last" + else (None, self.num_filters, None, None) + ) + self.built = True + + def call(self, input): + x = input + for idx in range(self.num_conv_layers): + x = self.conv_layers[idx](x) + if self.use_group_norm: + x = self.group_norm_layers[idx](x) + x = self.activation(x) + + output = self.prediction_layer(x) + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "output_filters": self.output_filters, + "num_filters": self.num_filters, + "num_conv_layers": self.num_conv_layers, + "use_group_norm": self.use_group_norm, + "use_prior_probability": self.use_prior_probability, + "prior_probability": self.prior_probability, + "activation": keras.activations.serialize(self.activation), + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "kernel_regularizer": ( + keras.regularizers.serialize(self.kernel_regularizer) + if self.kernel_regularizer is not None + else None + ), + "bias_regularizer": ( + keras.regularizers.serialize(self.bias_regularizer) + if self.bias_regularizer is not None + else None + ), + } + ) + return config + + def compute_output_shape(self, input_shape): + return ( + input_shape[:-1] + (self.output_filters,) + if self.data_format == "channels_last" + else (input_shape[0],) + (self.output_filters,) + input_shape[1:-1] + ) diff --git a/keras_hub/src/models/retinanet/prediction_head_test.py b/keras_hub/src/models/retinanet/prediction_head_test.py new file mode 100644 index 0000000000..111c92ee7a --- /dev/null +++ b/keras_hub/src/models/retinanet/prediction_head_test.py @@ -0,0 +1,27 @@ +from absl.testing import parameterized +from keras import random + +from keras_hub.src.models.retinanet.prediction_head import PredictionHead +from keras_hub.src.tests.test_case import TestCase + + +class PredictionHeadTest(TestCase): + @parameterized.named_parameters( + ("without_group_normalization", False, 10), + ("with_group_normalization", True, 14), + ) + def test_layer_behaviors( + self, use_group_norm, expected_num_trainable_weights + ): + self.run_layer_test( + cls=PredictionHead, + init_kwargs={ + "output_filters": 9 * 4, # anchors_per_location * box length(4) + "num_filters": 256, + "num_conv_layers": 4, + "use_group_norm": use_group_norm, + }, + input_data=random.uniform(shape=(2, 64, 64, 256)), + expected_output_shape=(2, 64, 64, 36), + expected_num_trainable_weights=expected_num_trainable_weights, + ) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone.py b/keras_hub/src/models/retinanet/retinanet_backbone.py new file mode 100644 index 0000000000..c6ebff9ef2 --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_backbone.py @@ -0,0 +1,146 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone +from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.RetinaNetBackbone") +class RetinaNetBackbone(FeaturePyramidBackbone): + """RetinaNet Backbone. + + Combines a CNN backbone (e.g., ResNet, MobileNet) with a feature pyramid + network (FPN)to extract multi-scale features for object detection. + + Args: + image_encoder: `keras.Model`. The backbone model (e.g., ResNet50, + MobileNetV2) used to extract features from the input image. + It should have pyramid outputs (i.e., a dictionary mapping level + names like `"P2"`, `"P3"`, etc. to their corresponding feature + tensors). + min_level: int. The minimum level of the feature pyramid (e.g., 3). + This determines the coarsest level of features used. + max_level: int. The maximum level of the feature pyramid (e.g., 7). + This determines the finest level of features used. + use_p5: bool. Determines the input source for creating coarser + feature pyramid levels. If `True`, the output of the last backbone + layer (typically `'P5'` in an FPN) is used as input to create + higher-level feature maps (e.g., `'P6'`, `'P7'`) through + additional convolutional layers. If `False`, the original `'P5'` + feature map from the backbone is directly used as input for + creating the coarser levels, bypassing any further processing of + `'P5'` within the feature pyramid. Defaults to `False`. + use_fpn_batch_norm: bool. Whether to use batch normalization in the + feature pyramid network. Defaults to `False`. + image_shape: tuple. tuple. The shape of the input image (H, W, C). + The height and width can be `None` if they are variable. + data_format: str. The data format of the input image + (channels_first or channels_last). + dtype: str. The data type of the input image. + **kwargs: Additional keword arguments passed to the base class. + + Raises: + ValueError: If `min_level` is greater than `max_level`. + ValueError: If `backbone_max_level` is less than 5 and `max_level` is greater than or equal to 5. + """ + + def __init__( + self, + image_encoder, + min_level, + max_level, + use_p5, + use_fpn_batch_norm=False, + image_shape=(None, None, 3), + data_format=None, + dtype=None, + **kwargs, + ): + + # === Layers === + if min_level > max_level: + raise ValueError( + f"Minimum level ({min_level}) must be less than or equal to " + f"maximum level ({max_level})." + ) + + data_format = standardize_data_format(data_format) + input_levels = [ + int(level[1]) for level in image_encoder.pyramid_outputs + ] + backbone_max_level = min(max(input_levels), max_level) + + if backbone_max_level < 5 and max_level >= 5: + raise ValueError( + f"Backbone maximum level ({backbone_max_level}) is less than " + f"the desired maximum level ({max_level}). " + f"Please ensure that the backbone can generate features up to " + f"the specified maximum level." + ) + feature_extractor = keras.Model( + inputs=image_encoder.inputs, + outputs={ + f"P{level}": image_encoder.pyramid_outputs[f"P{level}"] + for level in range(min_level, backbone_max_level + 1) + }, + name="backbone", + ) + + feature_pyramid = FeaturePyramid( + min_level=min_level, + max_level=max_level, + use_p5=use_p5, + name="fpn", + dtype=dtype, + data_format=data_format, + use_batch_norm=use_fpn_batch_norm, + ) + + # === Functional model === + image_input = keras.layers.Input(image_shape, name="inputs") + feature_extractor_outputs = feature_extractor(image_input) + feature_pyramid_outputs = feature_pyramid(feature_extractor_outputs) + + super().__init__( + inputs=image_input, + outputs=feature_pyramid_outputs, + dtype=dtype, + **kwargs, + ) + + # === config === + self.min_level = min_level + self.max_level = max_level + self.use_p5 = use_p5 + self.use_fpn_batch_norm = use_fpn_batch_norm + self.image_encoder = image_encoder + self.feature_pyramid = feature_pyramid + self.image_shape = image_shape + self.pyramid_outputs = feature_pyramid_outputs + + def get_config(self): + config = super().get_config() + config.update( + { + "image_encoder": keras.layers.serialize(self.image_encoder), + "min_level": self.min_level, + "max_level": self.max_level, + "use_p5": self.use_p5, + "use_fpn_batch_norm": self.use_fpn_batch_norm, + "image_shape": self.image_shape, + } + ) + return config + + @classmethod + def from_config(cls, config): + config.update( + { + "image_encoder": keras.layers.deserialize( + config["image_encoder"] + ), + } + ) + + return super().from_config(config) diff --git a/keras_hub/src/models/retinanet/retinanet_backbone_test.py b/keras_hub/src/models/retinanet/retinanet_backbone_test.py new file mode 100644 index 0000000000..524374447b --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_backbone_test.py @@ -0,0 +1,63 @@ +import pytest +from keras import ops + +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class RetinaNetBackboneTest(TestCase): + def setUp(self): + resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 128, 256, 512], + "stackwise_num_blocks": [3, 4, 6, 3], + "stackwise_num_strides": [1, 2, 2, 2], + "block_type": "bottleneck_block", + "use_pre_activation": False, + } + image_encoder = ResNetBackbone(**resnet_kwargs) + + self.init_kwargs = { + "image_encoder": image_encoder, + "min_level": 3, + "max_level": 7, + "use_p5": True, + } + + self.input_size = 256 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + def test_backbone_basics_channels_first(self): + self.run_vision_backbone_test( + cls=RetinaNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape={ + "P3": (2, 32, 32, 256), + "P4": (2, 16, 16, 256), + "P5": (2, 8, 8, 256), + "P6": (2, 4, 4, 256), + "P7": (2, 2, 2, 256), + }, + expected_pyramid_output_keys=["P3", "P4", "P5", "P6", "P7"], + expected_pyramid_image_sizes=[ + (32, 32), + (16, 16), + (8, 8), + (4, 4), + (2, 2), + ], + run_mixed_precision_check=False, + run_quantization_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=RetinaNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/retinanet/retinanet_image_converter.py b/keras_hub/src/models/retinanet/retinanet_image_converter.py new file mode 100644 index 0000000000..1e137a94fd --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_image_converter.py @@ -0,0 +1,62 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.utils.tensor_utils import convert_preprocessing_inputs +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +@keras_hub_export("keras_hub.layers.RetinaNetImageConverter") +class RetinaNetImageConverter(ImageConverter): + backbone_cls = RetinaNetBackbone + + def __init__( + self, + scale=None, + offset=None, + norm_mean=[0.485, 0.456, 0.406], + norm_std=[0.229, 0.224, 0.225], + **kwargs + ): + super().__init__(**kwargs) + self.scale = scale + self.offset = offset + self.norm_mean = norm_mean + self.norm_std = norm_std + self.built = True + + @preprocessing_function + def call(self, inputs): + # TODO: https://github.com/keras-team/keras-hub/issues/1965 + x = inputs + # Rescaling Image + if self.scale is not None: + x = x * convert_preprocessing_inputs( + self._expand_non_channel_dims(self.scale, x) + ) + if self.offset is not None: + x = x + convert_preprocessing_inputs( + self._expand_non_channel_dims(self.offset, x) + ) + + # By default normalize using imagenet mean and std + if self.norm_mean: + x = x - convert_preprocessing_inputs( + self._expand_non_channel_dims(self.norm_mean, x) + ) + + if self.norm_std: + x = x / convert_preprocessing_inputs( + self._expand_non_channel_dims(self.norm_std, x) + ) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "norm_mean": self.norm_mean, + "norm_std": self.norm_std, + } + ) + return config diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder.py b/keras_hub/src/models/retinanet/retinanet_label_encoder.py index a5bf475b29..5ac0a3d690 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder.py @@ -1,9 +1,12 @@ +import math + import keras from keras import ops -from keras_hub.src.bounding_box.converters import _encode_box_to_deltas +# TODO: https://github.com/keras-team/keras-hub/issues/1965 +from keras_hub.src.bounding_box.converters import convert_format +from keras_hub.src.bounding_box.converters import encode_box_to_deltas from keras_hub.src.bounding_box.iou import compute_iou -from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.box_matcher import BoxMatcher from keras_hub.src.utils import tensor_utils @@ -24,17 +27,10 @@ class RetinaNetLabelEncoder(keras.layers.Layer): consistency during training, regardless of the input format. Args: - bounding_box_format: str. The format of bounding boxes of input dataset. - Refer TODO: Add link to Keras Core Docs. - min_level: int. Minimum level of the output feature pyramid. - max_level: int. Maximum level of the output feature pyramid. - num_scales: int. Number of intermediate scales added on each level. - For example, num_scales=2 adds one additional intermediate anchor - scale [2^0, 2^0.5] on each level. - aspect_ratios: List[float]. Aspect ratios of anchors added on - each level. Each number indicates the ratio of width to height. - anchor_size: float. Scale of size of the base anchor relative to the - feature stride 2^level. + anchor_generator: A `keras_hub.layers.AnchorGenerator`. + bounding_box_format: str. Ground truth format of bounding boxes. + encoding_format: str. The desired target encoding format for the boxes. + TODO: https://github.com/keras-team/keras-hub/issues/1907 positive_threshold: float. the threshold to set an anchor to positive match to gt box. Values above it are positive matches. Defaults to `0.5` @@ -43,7 +39,7 @@ class RetinaNetLabelEncoder(keras.layers.Layer): Defaults to `0.4` box_variance: List[float]. The scaling factors used to scale the bounding box targets. - Defaults to `[0.1, 0.1, 0.2, 0.2]`. + Defaults to `[1.0, 1.0, 1.0, 1.0]`. background_class: int. The class ID used for the background class, Defaults to `-1`. ignore_class: int. The class ID used for the ignore class, @@ -63,15 +59,12 @@ class RetinaNetLabelEncoder(keras.layers.Layer): def __init__( self, + anchor_generator, bounding_box_format, - min_level, - max_level, - num_scales, - aspect_ratios, - anchor_size, + encoding_format="center_yxhw", positive_threshold=0.5, negative_threshold=0.4, - box_variance=[0.1, 0.1, 0.2, 0.2], + box_variance=[1.0, 1.0, 1.0, 1.0], background_class=-1.0, ignore_class=-2.0, box_matcher_match_values=[-1, -2, 1], @@ -79,27 +72,15 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.anchor_generator = anchor_generator self.bounding_box_format = bounding_box_format - self.min_level = min_level - self.max_level = max_level - self.num_scales = num_scales - self.aspect_ratios = aspect_ratios - self.anchor_size = anchor_size + self.encoding_format = encoding_format self.positive_threshold = positive_threshold self.box_variance = box_variance self.negative_threshold = negative_threshold self.background_class = background_class self.ignore_class = ignore_class - self.anchor_generator = AnchorGenerator( - bounding_box_format=bounding_box_format, - min_level=min_level, - max_level=max_level, - num_scales=num_scales, - aspect_ratios=aspect_ratios, - anchor_size=anchor_size, - ) - self.box_matcher = BoxMatcher( thresholds=[negative_threshold, positive_threshold], match_values=box_matcher_match_values, @@ -116,7 +97,7 @@ def call(self, images, gt_boxes, gt_classes): images: A Tensor. The input images argument should be of shape `[B, H, W, C]` or `[B, C, H, W]`. gt_boxes: A Tensor with shape of `[B, num_boxes, 4]`. - gt_labels: A Tensor with shape of `[B, num_boxes, num_classes]` + gt_classes: A Tensor with shape of `[B, num_boxes, num_classes]` Returns: box_targets: A Tensor of shape `[batch_size, num_anchors, 4]` @@ -174,7 +155,12 @@ def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): Encoded boudning boxes in the format of `center_yxwh` and corresponding labels for each encoded bounding box. """ - + anchor_boxes = convert_format( + anchor_boxes, + source=self.anchor_generator.bounding_box_format, + target=self.bounding_box_format, + image_shape=image_shape, + ) iou_matrix = compute_iou( anchor_boxes, gt_boxes, @@ -193,11 +179,12 @@ def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4) ) - box_target = _encode_box_to_deltas( + box_targets = encode_box_to_deltas( anchors=anchor_boxes, boxes=matched_gt_boxes, anchor_format=self.bounding_box_format, box_format=self.bounding_box_format, + encoding_format=self.encoding_format, variance=self.box_variance, image_shape=image_shape, ) @@ -205,16 +192,16 @@ def _encode_sample(self, gt_boxes, gt_classes, anchor_boxes, image_shape): matched_gt_cls_ids = tensor_utils.target_gather( gt_classes, matched_gt_idx ) - cls_target = ops.where( + classs_targets = ops.where( ops.not_equal(positive_mask, 1.0), self.background_class, matched_gt_cls_ids, ) - cls_target = ops.where( - ops.equal(ignore_mask, 1.0), self.ignore_class, cls_target + classs_targets = ops.where( + ops.equal(ignore_mask, 1.0), self.ignore_class, classs_targets ) label = ops.concatenate( - [box_target, ops.cast(cls_target, box_target.dtype)], axis=-1 + [box_targets, ops.cast(classs_targets, box_targets.dtype)], axis=-1 ) # In the case that a box in the corner of an image matches with an all @@ -234,12 +221,11 @@ def get_config(self): config = super().get_config() config.update( { + "anchor_generator": keras.layers.serialize( + self.anchor_generator + ), "bounding_box_format": self.bounding_box_format, - "min_level": self.min_level, - "max_level": self.max_level, - "num_scales": self.num_scales, - "aspect_ratios": self.aspect_ratios, - "anchor_size": self.anchor_size, + "encoding_format": self.encoding_format, "positive_threshold": self.positive_threshold, "box_variance": self.box_variance, "negative_threshold": self.negative_threshold, @@ -249,6 +235,18 @@ def get_config(self): ) return config + @classmethod + def from_config(cls, config): + config.update( + { + "anchor_generator": keras.layers.deserialize( + config["anchor_generator"] + ), + } + ) + + return super().from_config(config) + def compute_output_shape( self, images_shape, gt_boxes_shape, gt_classes_shape ): @@ -258,10 +256,10 @@ def compute_output_shape( total_num_anchors = 0 for i in range(min_level, max_level + 1): - total_num_anchors += ( - (image_H // 2 ** (i)) - * (image_W // 2 ** (i)) - * self.anchor_generator.anchors_per_location + total_num_anchors += int( + math.ceil(image_H / 2 ** (i)) + * math.ceil(image_W / 2 ** (i)) + * self.anchor_generator.num_base_anchors ) return (batch_size, total_num_anchors, 4), ( diff --git a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py index de329685a8..d05bf5a99a 100644 --- a/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py +++ b/keras_hub/src/models/retinanet/retinanet_label_encoder_test.py @@ -1,6 +1,7 @@ import numpy as np from keras import ops +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator from keras_hub.src.models.retinanet.retinanet_label_encoder import ( RetinaNetLabelEncoder, ) @@ -8,20 +9,27 @@ class RetinaNetLabelEncoderTest(TestCase): + def setUp(self): + anchor_generator = AnchorGenerator( + bounding_box_format="xyxy", + min_level=3, + max_level=7, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=8, + ) + self.init_kwargs = { + "anchor_generator": anchor_generator, + "bounding_box_format": "xyxy", + } + def test_layer_behaviors(self): images_shape = (8, 128, 128, 3) boxes_shape = (8, 10, 4) classes_shape = (8, 10) self.run_layer_test( cls=RetinaNetLabelEncoder, - init_kwargs={ - "bounding_box_format": "xyxy", - "min_level": 3, - "max_level": 7, - "num_scales": 3, - "aspect_ratios": [0.5, 1.0, 2.0], - "anchor_size": 8, - }, + init_kwargs=self.init_kwargs, input_data={ "images": np.random.uniform(size=images_shape), "gt_boxes": np.random.uniform( @@ -48,12 +56,7 @@ def test_label_encoder_output_shapes(self): classes = np.random.uniform(size=classes_shape, low=0, high=5) encoder = RetinaNetLabelEncoder( - bounding_box_format="xyxy", - min_level=3, - max_level=7, - num_scales=3, - aspect_ratios=[0.5, 1.0, 2.0], - anchor_size=8, + **self.init_kwargs, ) box_targets, class_targets = encoder(images, boxes, classes) @@ -71,12 +74,7 @@ def test_all_negative_1(self): classes = -np.ones(shape=classes_shape, dtype="float32") encoder = RetinaNetLabelEncoder( - bounding_box_format="xyxy", - min_level=3, - max_level=7, - num_scales=3, - aspect_ratios=[0.5, 1.0, 2.0], - anchor_size=8, + **self.init_kwargs, ) box_targets, class_targets = encoder(images, boxes, classes) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector.py b/keras_hub/src/models/retinanet/retinanet_object_detector.py new file mode 100644 index 0000000000..92dbe293ca --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_object_detector.py @@ -0,0 +1,376 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export + +# TODO: https://github.com/keras-team/keras-hub/issues/1965 +from keras_hub.src.bounding_box.converters import convert_format +from keras_hub.src.bounding_box.converters import decode_deltas_to_boxes +from keras_hub.src.models.image_object_detector import ImageObjectDetector +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression +from keras_hub.src.models.retinanet.prediction_head import PredictionHead +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_label_encoder import ( + RetinaNetLabelEncoder, +) +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( + RetinaNetObjectDetectorPreprocessor, +) + + +@keras_hub_export("keras_hub.models.RetinaNetObjectDetector") +class RetinaNetObjectDetector(ImageObjectDetector): + """RetinaNet object detector model. + + This class implements the RetinaNet object detection architecture. + It consists of a feature extractor backbone, a feature pyramid network(FPN), + and two prediction heads (for classification and bounding box regression). + + Args: + backbone: `keras.Model`. A `keras.models.RetinaNetBackbone` class, + defining the backbone network architecture. Provides feature maps + for detection. + anchor_generator: A `keras_hub.layers.AnchorGenerator` instance. + Generates anchor boxes at different scales and aspect ratios + across the image. If None, a default `AnchorGenerator` is + created with the following parameters: + - `bounding_box_format`: Same as the model's + `bounding_box_format`. + - `min_level`: The backbone's `min_level`. + - `max_level`: The backbone's `max_level`. + - `num_scales`: 3. + - `aspect_ratios`: [0.5, 1.0, 2.0]. + - `anchor_size`: 4.0. + You can create a custom `AnchorGenerator` by instantiating the + `keras_hub.layers.AnchorGenerator` class and passing the desired + arguments. + num_classes: int. The number of object classes to be detected. + bounding_box_format: str. Dataset bounding box format (e.g., "xyxy", + "yxyx"). The supported formats are + refer TODO: https://github.com/keras-team/keras-hub/issues/1907 + label_encoder: Optional. A `RetinaNetLabelEncoder` instance. Encodes + ground truth boxes and classes into training targets. It matches + ground truth boxes to anchors based on IoU and encodes box + coordinates as offsets. If `None`, a default encoder is created. + See the + `keras_hub.src.models.retinanet.retinanet_label_encoder.RetinaNetLabelEncoder` + class for details. If None, a default encoder is created with + standard parameters. + - `anchor_generator`: Same as the model's. + - `bounding_box_format`: Same as the model's + `bounding_box_format`. + - `positive_threshold`: 0.5 + - `negative_threshold`: 0.4 + - `encoding_format`: "center_xywh" + - `box_variance`: [1.0, 1.0, 1.0, 1.0] + - `background_class`: -1 + - `ignore_class`: -2 + use_prediction_head_norm: bool. Whether to use Group Normalization after + the convolution layers in the prediction heads. Defaults to `False`. + classification_head_prior_probability: float. Prior probability for the + classification head (used for focal loss). Defaults to 0.01. + pre_logits_num_conv_layers: int. The number of convolutional layers in + the head before the logits layer. These convolutional layers are + applied before the final linear layer (logits) that produces the + output predictions (bounding box regressions, classification scores). + preprocessor: Optional. An instance of + `RetinaNetObjectDetectorPreprocessor`or a custom preprocessor. + Handles image preprocessing before feeding into the backbone. + activation: Optional. The activation function to be used in the + classification head. If None, sigmoid is used. + dtype: Optional. The data type for the prediction heads. Defaults to the + backbone's dtype policy. + prediction_decoder: Optional. A `keras.layers.Layer` instance + responsible for transforming RetinaNet predictions + (box regressions and classifications) into final bounding boxes and + classes with confidence scores. Defaults to a `NonMaxSuppression` + instance. + """ + + backbone_cls = RetinaNetBackbone + preprocessor_cls = RetinaNetObjectDetectorPreprocessor + + def __init__( + self, + backbone, + num_classes, + bounding_box_format, + anchor_generator=None, + label_encoder=None, + use_prediction_head_norm=False, + classification_head_prior_probability=0.01, + pre_logits_num_conv_layers=4, + preprocessor=None, + activation=None, + dtype=None, + prediction_decoder=None, + **kwargs, + ): + # === Layers === + image_input = keras.layers.Input(backbone.image_shape, name="images") + head_dtype = dtype or backbone.dtype_policy + + anchor_generator = anchor_generator or AnchorGenerator( + bounding_box_format, + min_level=backbone.min_level, + max_level=backbone.max_level, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=4, + ) + # As weights are ported from torch they use encoded format + # as "center_xywh" + label_encoder = label_encoder or RetinaNetLabelEncoder( + anchor_generator, + bounding_box_format=bounding_box_format, + encoding_format="center_xywh", + ) + + box_head = PredictionHead( + output_filters=anchor_generator.num_base_anchors * 4, + num_conv_layers=pre_logits_num_conv_layers, + num_filters=256, + use_group_norm=use_prediction_head_norm, + use_prior_probability=True, + prior_probability=classification_head_prior_probability, + dtype=head_dtype, + name="box_head", + ) + classification_head = PredictionHead( + output_filters=anchor_generator.num_base_anchors * num_classes, + num_conv_layers=pre_logits_num_conv_layers, + num_filters=256, + use_group_norm=use_prediction_head_norm, + dtype=head_dtype, + name="classification_head", + ) + + # === Functional Model === + feature_map = backbone(image_input) + + cls_pred = [] + box_pred = [] + + # Iterate through the feature pyramid levels (e.g., P3, P4, P5, P6, P7). + for level in feature_map: + box_pred.append( + keras.layers.Reshape((-1, 4), name=f"box_pred_{level}")( + box_head(feature_map[level]) + ) + ) + cls_pred.append( + keras.layers.Reshape( + (-1, num_classes), name=f"cls_pred_{level}" + )(classification_head(feature_map[level])) + ) + + # Concatenate predictions from all FPN levels. + cls_pred = keras.layers.Concatenate(axis=1, name="cls_logits")(cls_pred) + # box_pred is always in "center_xywh" delta-encoded no matter what + # format you pass in. + box_pred = keras.layers.Concatenate(axis=1, name="bbox_regression")( + box_pred + ) + + outputs = {"bbox_regression": box_pred, "cls_logits": cls_pred} + + super().__init__( + inputs=image_input, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.bounding_box_format = bounding_box_format + self.use_prediction_head_norm = use_prediction_head_norm + self.num_classes = num_classes + self.backbone = backbone + self.preprocessor = preprocessor + self.activation = activation + self.pre_logits_num_conv_layers = pre_logits_num_conv_layers + self.box_head = box_head + self.classification_head = classification_head + self.anchor_generator = anchor_generator + self.label_encoder = label_encoder + self._prediction_decoder = prediction_decoder or NonMaxSuppression( + from_logits=(activation != keras.activations.sigmoid), + bounding_box_format=bounding_box_format, + ) + + def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): + y_for_label_encoder = convert_format( + y, + source=self.bounding_box_format, + target=self.label_encoder.bounding_box_format, + images=x, + ) + + boxes, classes = self.label_encoder( + images=x, + gt_boxes=y_for_label_encoder["boxes"], + gt_classes=y_for_label_encoder["classes"], + ) + + box_pred = y_pred["bbox_regression"] + cls_pred = y_pred["cls_logits"] + + if boxes.shape[-1] != 4: + raise ValueError( + "boxes should have shape (None, None, 4). Got " + f"boxes.shape={tuple(boxes.shape)}" + ) + + if box_pred.shape[-1] != 4: + raise ValueError( + "box_pred should have shape (None, None, 4). Got " + f"box_pred.shape={tuple(box_pred.shape)}. Does your model's " + "`num_classes` parameter match your losses `num_classes` " + "parameter?" + ) + if cls_pred.shape[-1] != self.num_classes: + raise ValueError( + "cls_pred should have shape (None, None, 4). Got " + f"cls_pred.shape={tuple(cls_pred.shape)}. Does your model's " + "`num_classes` parameter match your losses `num_classes` " + "parameter?" + ) + + cls_labels = ops.one_hot( + ops.cast(classes, "int32"), self.num_classes, dtype="float32" + ) + positive_mask = ops.cast(ops.greater(classes, -1.0), dtype="float32") + normalizer = ops.sum(positive_mask) + cls_weights = ops.cast(ops.not_equal(classes, -2.0), dtype="float32") + cls_weights /= normalizer + box_weights = positive_mask / normalizer + + y_true = { + "bbox_regression": boxes, + "cls_logits": cls_labels, + } + sample_weights = { + "bbox_regression": box_weights, + "cls_logits": cls_weights, + } + zero_weight = { + "bbox_regression": ops.zeros_like(box_weights), + "cls_logits": ops.zeros_like(cls_weights), + } + + sample_weight = ops.cond( + normalizer == 0, + lambda: zero_weight, + lambda: sample_weights, + ) + return super().compute_loss( + x=x, y=y_true, y_pred=y_pred, sample_weight=sample_weight, **kwargs + ) + + def predict_step(self, *args): + outputs = super().predict_step(*args) + if isinstance(outputs, tuple): + return self.decode_predictions(outputs[0], args[-1]), outputs[1] + return self.decode_predictions(outputs, *args) + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + if prediction_decoder.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Expected `prediction_decoder` and `RetinaNet` to " + "use the same `bounding_box_format`, but got " + "`prediction_decoder.bounding_box_format=" + f"{prediction_decoder.bounding_box_format}`, and " + "`self.bounding_box_format=" + f"{self.bounding_box_format}`." + ) + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + self.make_train_function(force=True) + self.make_test_function(force=True) + + def decode_predictions(self, predictions, data): + box_pred = predictions["bbox_regression"] + cls_pred = predictions["cls_logits"] + # box_pred is on "center_yxhw" format, convert to target format. + if isinstance(data, list) or isinstance(data, tuple): + images, _ = data + else: + images = data + image_shape = ops.shape(images)[1:] + anchor_boxes = self.anchor_generator(images) + anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0) + box_pred = decode_deltas_to_boxes( + anchors=anchor_boxes, + boxes_delta=box_pred, + encoded_format="center_xywh", + anchor_format=self.anchor_generator.bounding_box_format, + box_format=self.bounding_box_format, + image_shape=image_shape, + ) + # box_pred is now in "self.bounding_box_format" format + box_pred = convert_format( + box_pred, + source=self.bounding_box_format, + target=self.prediction_decoder.bounding_box_format, + image_shape=image_shape, + ) + y_pred = self.prediction_decoder( + box_pred, cls_pred, image_shape=image_shape + ) + y_pred["boxes"] = convert_format( + y_pred["boxes"], + source=self.prediction_decoder.bounding_box_format, + target=self.bounding_box_format, + image_shape=image_shape, + ) + return y_pred + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "use_prediction_head_norm": self.use_prediction_head_norm, + "pre_logits_num_conv_layers": self.pre_logits_num_conv_layers, + "bounding_box_format": self.bounding_box_format, + "anchor_generator": keras.layers.serialize( + self.anchor_generator + ), + "label_encoder": keras.layers.serialize(self.label_encoder), + "prediction_decoder": keras.layers.serialize( + self._prediction_decoder + ), + } + ) + + return config + + @classmethod + def from_config(cls, config): + if "label_encoder" in config and isinstance( + config["label_encoder"], dict + ): + config["label_encoder"] = keras.layers.deserialize( + config["label_encoder"] + ) + + if "anchor_generator" in config and isinstance( + config["anchor_generator"], dict + ): + config["anchor_generator"] = keras.layers.deserialize( + config["anchor_generator"] + ) + + if "prediction_decoder" in config and isinstance( + config["prediction_decoder"], dict + ): + config["prediction_decoder"] = keras.layers.deserialize( + config["prediction_decoder"] + ) + + return super().from_config(config) diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py b/keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py new file mode 100644 index 0000000000..8bc6d1f796 --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_object_detector_preprocessor import ( + ImageObjectDetectorPreprocessor, +) +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_image_converter import ( + RetinaNetImageConverter, +) + + +@keras_hub_export("keras_hub.models.RetinaNetObjectDetectorPreprocessor") +class RetinaNetObjectDetectorPreprocessor(ImageObjectDetectorPreprocessor): + backbone_cls = RetinaNetBackbone + image_converter_cls = RetinaNetImageConverter diff --git a/keras_hub/src/models/retinanet/retinanet_object_detector_test.py b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py new file mode 100644 index 0000000000..995c839d23 --- /dev/null +++ b/keras_hub/src/models/retinanet/retinanet_object_detector_test.py @@ -0,0 +1,103 @@ +import numpy as np +import pytest + +from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone +from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator +from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone +from keras_hub.src.models.retinanet.retinanet_image_converter import ( + RetinaNetImageConverter, +) +from keras_hub.src.models.retinanet.retinanet_label_encoder import ( + RetinaNetLabelEncoder, +) +from keras_hub.src.models.retinanet.retinanet_object_detector import ( + RetinaNetObjectDetector, +) +from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( + RetinaNetObjectDetectorPreprocessor, +) +from keras_hub.src.tests.test_case import TestCase + + +class RetinaNetObjectDetectorTest(TestCase): + def setUp(self): + resnet_kwargs = { + "input_conv_filters": [64], + "input_conv_kernel_sizes": [7], + "stackwise_num_filters": [64, 64, 64], + "stackwise_num_blocks": [2, 2, 2], + "stackwise_num_strides": [1, 2, 2], + "image_shape": (None, None, 3), + "block_type": "bottleneck_block", + "use_pre_activation": False, + } + image_encoder = ResNetBackbone(**resnet_kwargs) + + retinanet_backbone_kwargs = { + "image_encoder": image_encoder, + "min_level": 3, + "max_level": 4, + "use_p5": False, + } + + feature_extractor = RetinaNetBackbone(**retinanet_backbone_kwargs) + anchor_generator = AnchorGenerator( + bounding_box_format="yxyx", + min_level=3, + max_level=4, + num_scales=3, + aspect_ratios=[0.5, 1.0, 2.0], + anchor_size=4, + ) + label_encoder = RetinaNetLabelEncoder( + bounding_box_format="yxyx", anchor_generator=anchor_generator + ) + + image_converter = RetinaNetImageConverter(scale=1 / 255.0) + + preprocessor = RetinaNetObjectDetectorPreprocessor( + image_converter=image_converter + ) + + self.init_kwargs = { + "backbone": feature_extractor, + "anchor_generator": anchor_generator, + "label_encoder": label_encoder, + "num_classes": 10, + "bounding_box_format": "yxyx", + "preprocessor": preprocessor, + } + + self.input_size = 512 + self.images = np.random.uniform( + low=0, high=255, size=(1, self.input_size, self.input_size, 3) + ).astype("float32") + self.labels = { + "boxes": np.array( + [[[20.0, 10.0, 12.0, 11.0], [30.0, 20.0, 40.0, 12.0]]] + ), + "classes": np.array([[0, 2]]), + } + + self.train_data = (self.images, self.labels) + + def test_detection_basics(self): + self.run_task_test( + cls=RetinaNetObjectDetector, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape={ + "boxes": (1, 100, 4), + "classes": (1, 100), + "confidence": (1, 100), + "num_detections": (1,), + }, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=RetinaNetObjectDetector, + init_kwargs=self.init_kwargs, + input_data=self.images, + )