Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RetinaNet] Image Converter and ObjectDetector #1906

Open
wants to merge 43 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
c1d7955
Rebased phase 1 changes
sineeli Sep 26, 2024
deaeac4
Rebased phase 1 changes
sineeli Sep 26, 2024
1cdd164
Merge branch 'sineeli/add-retinanet-phase-2' of https://github.com/si…
sineeli Sep 27, 2024
f90add8
nit
sineeli Sep 27, 2024
fb0c733
Merge remote-tracking branch 'upstream/master' into sineeli/add-retin…
sineeli Oct 2, 2024
6c26534
Retina Phase 2
sineeli Oct 3, 2024
baee6e2
nit
sineeli Oct 3, 2024
5ee905e
Expose Anchor Generator as layer, docstring correction and test corre…
sineeli Oct 4, 2024
84533d4
nit
sineeli Oct 4, 2024
b6ceb8f
Add missing args for prediction heads
sineeli Oct 4, 2024
4c7a28b
- Use FeaturePyramidBackbone cls for RetinaNet backbone.
sineeli Oct 8, 2024
3f915dc
fix decoding error
sineeli Oct 8, 2024
f0da549
- Add ground truth arg for RetinaNet model and remove source and targ…
sineeli Oct 8, 2024
05fdefe
nit
sineeli Oct 9, 2024
3b26d3a
Subclass Imageconverter and overload call method for object detection…
sineeli Oct 9, 2024
0df121a
Revert "Subclass Imageconverter and overload call method for object d…
sineeli Oct 9, 2024
8697240
add names to layers
sineeli Oct 9, 2024
394faf0
correct fpn coarser level as per torch retinanet model
sineeli Oct 9, 2024
33d81e9
nit
sineeli Oct 9, 2024
79502d9
Polish Prediction head and fpn layers to include flags and norm layers
sineeli Oct 9, 2024
72a02c4
nit
sineeli Oct 9, 2024
a28a033
nit
sineeli Oct 9, 2024
50686e0
add prior probability flag for prediction head to use it for classifi…
sineeli Oct 9, 2024
8dc5483
compute_shape seems redudant here and correct layers for channels_first
sineeli Oct 9, 2024
9f7d8ef
keep compute_output_shape for fpn
sineeli Oct 9, 2024
6801789
nit
sineeli Oct 10, 2024
7e57cf1
Change AnchorGen Implementation as per torch
sineeli Oct 10, 2024
8ac617c
correct the source format of anchors format
sineeli Oct 10, 2024
03efed5
use plain rescaling and normalization no resizing for od models as it…
sineeli Oct 11, 2024
5704950
use single bbox format for model
sineeli Oct 11, 2024
7c1d1de
- Add arg for encoding format
sineeli Oct 11, 2024
2414f00
make anchor generator optional
sineeli Oct 11, 2024
064c971
init as layers for anchor generator and label encoder and as one more…
sineeli Oct 11, 2024
4ff8f13
nit
sineeli Oct 11, 2024
c4f752d
- only consider levels from min level to backbone maxlevel fro featur…
sineeli Oct 12, 2024
bde84b9
nit
sineeli Oct 12, 2024
caacc99
nit
sineeli Oct 15, 2024
eb555ca
update resizing as per new keras3 resizing layer for bboxes
sineeli Oct 25, 2024
de8233e
Revert "update resizing as per new keras3 resizing layer for bboxes"
sineeli Oct 30, 2024
1ca10b9
Add TODO's for keras bounding box ops
sineeli Oct 30, 2024
5ec65fd
Use keras layers to rescale and normalize
sineeli Nov 4, 2024
dd00bdf
check with plain values
sineeli Nov 4, 2024
581d152
use convert_preprocessing_inputs function for basic operations as bac…
sineeli Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
from keras_hub.src.models.resnet.resnet_image_converter import (
ResNetImageConverter,
)
from keras_hub.src.models.retinanet.retinanet_image_converter import (
RetinaNetImageConverter,
)
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
Expand Down
11 changes: 11 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@
from keras_hub.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_hub.src.models.image_object_detector import ImageObjectDetector
from keras_hub.src.models.image_object_detector_preprocessor import (
ImageObjectDetectorPreprocessor,
)
from keras_hub.src.models.image_segmenter import ImageSegmenter
from keras_hub.src.models.image_segmenter_preprocessor import (
ImageSegmenterPreprocessor,
Expand Down Expand Up @@ -233,6 +237,13 @@
from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
ResNetImageClassifierPreprocessor,
)
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
from keras_hub.src.models.retinanet.retinanet_object_detector import (
RetinaNetObjectDetector,
)
from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import (
RetinaNetObjectDetectorPreprocessor,
)
from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_hub.src.models.roberta.roberta_masked_lm import RobertaMaskedLM
from keras_hub.src.models.roberta.roberta_masked_lm_preprocessor import (
Expand Down
84 changes: 84 additions & 0 deletions keras_hub/src/models/image_object_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.task import Task


@keras_hub_export("keras_hub.models.ImageObjectDetector")
class ImageObjectDetector(Task):
"""Base class for all image classification tasks.
sineeli marked this conversation as resolved.
Show resolved Hide resolved

`ImageObjectDetector` tasks wrap a `keras_hub.models.Backbone` and
sineeli marked this conversation as resolved.
Show resolved Hide resolved
a `keras_hub.models.Preprocessor` to create a model that can be used for
image classification. `ImageObjectDetector` tasks take an additional
sineeli marked this conversation as resolved.
Show resolved Hide resolved
`num_classes` argument, controlling the number of predicted output classes.

To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
labels where `x` is a string and `y` is dictionary with `boxes` and
`classes`.

All `ImageObjectDetector` tasks include a `from_preset()` constructor which
can be used to load a pre-trained config and weights.
"""

def compile(
self,
optimizer="auto",
box_loss="auto",
classification_loss="auto",
metrics=None,
**kwargs,
):
"""Configures the `ImageSegmenter` task for training.
sineeli marked this conversation as resolved.
Show resolved Hide resolved

The `ImageSegmenter` task extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
`metrics`. To override these defaults, pass any value
to these arguments during compilation.

Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for the given model and task. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
box_loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, where a
`keras.losses.Huber` loss will be
applied for the object detector task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
classification_loss: `"auto"`, a loss name, or a `keras.losses.Loss`
instance. Defaults to `"auto"`, where a
`keras.losses.BinaryFocalCrossentropy` loss will be
applied for the object detector task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
metrics: `a list of metrics to be evaluated by
the model during training and testing. Defaults to `None`.
See `keras.Model.compile` and `keras.metrics` for
more info on possible `metrics` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.Adam(5e-5)
if box_loss == "auto":
box_loss = keras.losses.Huber(reduction="sum")
if classification_loss == "auto":
activation = getattr(self, "activation", None)
activation = keras.activations.get(activation)
from_logits = activation != keras.activations.sigmoid
classification_loss = keras.losses.BinaryFocalCrossentropy(
from_logits=from_logits, reduction="sum"
)
if metrics is not None:
raise ValueError("User metrics not yet supported")

losses = {"box": box_loss, "classification": classification_loss}

super().compile(
optimizer=optimizer,
loss=losses,
metrics=metrics,
**kwargs,
)
117 changes: 117 additions & 0 deletions keras_hub/src/models/image_object_detector_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.bounding_box.converters import convert_format
from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.utils.tensor_utils import preprocessing_function


@keras_hub_export("keras_hub.models.ImageObjectDetectorPreprocessor")
class ImageObjectDetectorPreprocessor(Preprocessor):
"""Base class for object detector preprocessing layers.

`ImageObjectDetectorPreprocessor` tasks wraps a
`keras_hub.layers.Preprocessor` to create a preprocessing layer for
object detection tasks. It is intended to be paired with a
`keras_hub.models.ImageObjectDetector` task.

All `ImageObjectDetectorPreprocessor` take inputs three inputs, `x`, `y`, and
sineeli marked this conversation as resolved.
Show resolved Hide resolved
`sample_weight`. `x`, the first input, should always be included. It can
be a image or batch of images. See examples below. `y` and `sample_weight`
are optional inputs that will be passed through unaltered. Usually, `y` will
be the a dict of `{"boxes": Tensor(batch_size, num_boxes, 4),
"classes": (batch_size, num_boxes)}.

The layer will output either `x`, an `(x, y)` tuple if labels were provided,
sineeli marked this conversation as resolved.
Show resolved Hide resolved
or an `(x, y, sample_weight)` tuple if labels and sample weight were
provided. `x` will be the input images after all model preprocessing has
been applied.

All `ImageObjectDetectorPreprocessor` tasks include a `from_preset()`
constructor which can be used to load a pre-trained config and vocabularies.
You can call the `from_preset()` constructor directly on this base class, in
which case the correct class for your model will be automatically
instantiated.

Args:
image_converter: Preprocessing pipeline for images.
source_bounding_box_format: str. The format of the source bounding boxes.
supported formats include:
- `"rel_yxyx"`
- `"rel_xyxy"`
- `"rel_xywh"
Defaults to `"rel_yxyx"`.
target_bounding_box_format: str. TODO Add link to keras-core bounding
sineeli marked this conversation as resolved.
Show resolved Hide resolved
box formats page.


Examples.
```python
preprocessor = keras_hub.models.ImageObjectDetectorPreprocessor.from_preset(
"retinanet_resnet50",
)

# Resize a single image for resnet 50.
x = np.ones((512, 512, 3))
x = preprocessor(x)

# Resize a labeled image.
x, y = np.ones((512, 512, 3)), 1
x, y = preprocessor(x, y)

# Resize a batch of labeled images.
x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0]
x, y = preprocessor(x, y)

# Use a `tf.data.Dataset`.
ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(2)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
```
"""

def __init__(
self,
target_bounding_box_format,
source_bounding_box_format="rel_yxyx",
image_converter=None,
**kwargs,
):
super().__init__(**kwargs)
if "rel" not in source_bounding_box_format:
raise ValueError(
f"Only relative bounding box formats are supported "
sineeli marked this conversation as resolved.
Show resolved Hide resolved
f"but received source_bounding_box_format="
f"`{source_bounding_box_format}` "
f"please provide source bounding box format from one of these "
f"`rel_xyxy` or `rel_yxyx` or `rel_xywh`. Make sure provided "
f"ground truth bounding boxes are normalized/relative to image."
)
self.source_bounding_box_format = source_bounding_box_format
self.target_bounding_box_format = target_bounding_box_format
self.image_converter = image_converter

@preprocessing_function
def call(self, x, y=None, sample_weight=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case, the bounding box updates needs to be addressed.

if self.image_converter:
x = self.image_converter(x)

if y is not None:
y = convert_format(
y,
source=self.source_bounding_box_format,
target=self.target_bounding_box_format,
images=x,
)

return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)

def get_config(self):
config = super().get_config()
config.update(
{
"source_bounding_box_format": self.source_bounding_box_format,
"target_bounding_box_format": self.target_bounding_box_format,
}
)

return config
63 changes: 40 additions & 23 deletions keras_hub/src/models/retinanet/feature_pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def build(self, input_shapes):
)
)
self.lateral_batch_norm_layers[level].build(
(None, None, None, 256)
(None, None, None, self.num_filters)
if self.data_format == "channels_last"
else (None, 256, None, None)
else (None, self.num_filters, None, None)
)

# Build output layers
Expand All @@ -171,9 +171,9 @@ def build(self, input_shapes):
name=f"output_conv_{level}",
)
self.output_conv_layers[level].build(
(None, None, None, 256)
(None, None, None, self.num_filters)
if self.data_format == "channels_last"
else (None, 256, None, None)
else (None, self.num_filters, None, None)
)

# Build coarser layers
Expand All @@ -193,9 +193,9 @@ def build(self, input_shapes):
name=f"coarser_{level}",
)
self.output_conv_layers[level].build(
(None, None, None, 256)
(None, None, None, self.num_filters)
if self.data_format == "channels_last"
else (None, 256, None, None)
else (None, self.num_filters, None, None)
)

# Build batch norm layers
Expand All @@ -212,9 +212,9 @@ def build(self, input_shapes):
)
)
self.output_batch_norms[level].build(
(None, None, None, 256)
(None, None, None, self.num_filters)
if self.data_format == "channels_last"
else (None, 256, None, None)
else (None, self.num_filters, None, None)
)

# The same upsampling layer is used for all levels
Expand Down Expand Up @@ -320,34 +320,35 @@ def get_config(self):

def compute_output_shape(self, input_shapes):
output_shape = {}
print(input_shapes)
input_levels = [int(level[1]) for level in input_shapes]
backbone_max_level = min(max(input_levels), self.max_level)

for i in range(self.min_level, backbone_max_level + 1):
level = f"P{i}"
if self.data_format == "channels_last":
output_shape[level] = input_shapes[level][:-1] + (256,)
output_shape[level] = input_shapes[level][:-1] + (
self.num_filters,
)
else:
output_shape[level] = (
input_shapes[level][0],
256,
self.num_filters,
) + input_shapes[level][1:3]

intermediate_shape = input_shapes[f"P{backbone_max_level}"]
intermediate_shape = (
(
intermediate_shape[0],
intermediate_shape[1] // 2,
intermediate_shape[2] // 2,
256,
intermediate_shape[1] // 2 if intermediate_shape[1] else None,
intermediate_shape[2] // 2 if intermediate_shape[1] else None,
self.num_filters,
)
if self.data_format == "channels_last"
else (
intermediate_shape[0],
256,
intermediate_shape[1] // 2,
intermediate_shape[2] // 2,
self.num_filters,
intermediate_shape[1] // 2 if intermediate_shape[1] else None,
intermediate_shape[2] // 2 if intermediate_shape[1] else None,
)
)

Expand All @@ -357,16 +358,32 @@ def compute_output_shape(self, input_shapes):
intermediate_shape = (
(
intermediate_shape[0],
intermediate_shape[1] // 2,
intermediate_shape[2] // 2,
256,
(
intermediate_shape[1] // 2
if intermediate_shape[1]
else None
),
(
intermediate_shape[2] // 2
if intermediate_shape[1]
else None
),
self.num_filters,
)
if self.data_format == "channels_last"
else (
intermediate_shape[0],
256,
intermediate_shape[1] // 2,
intermediate_shape[2] // 2,
self.num_filters,
(
intermediate_shape[1] // 2
if intermediate_shape[1]
else None
),
(
intermediate_shape[2] // 2
if intermediate_shape[1]
else None
),
)
)

Expand Down
Loading
Loading