From b3d95a32271205134f97e85d4dba5486055fd061 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 1 Oct 2024 23:21:38 -0700 Subject: [PATCH] Fix tests, address review --- .../layers/preprocessing/image_converter.py | 2 +- keras_hub/src/models/image_classifier.py | 17 +- keras_hub/src/models/task_test.py | 2 +- keras_hub/src/models/vgg/vgg_backbone.py | 14 +- .../src/models/vgg/vgg_image_classifier.py | 172 ++++++++++++++++++ .../models/vgg/vgg_image_classifier_test.py | 2 +- 6 files changed, 185 insertions(+), 24 deletions(-) diff --git a/keras_hub/src/layers/preprocessing/image_converter.py b/keras_hub/src/layers/preprocessing/image_converter.py index f5c425feda..2aff224609 100644 --- a/keras_hub/src/layers/preprocessing/image_converter.py +++ b/keras_hub/src/layers/preprocessing/image_converter.py @@ -1,7 +1,7 @@ import math -import numpy as np import keras +import numpy as np from keras import ops from keras_hub.src.api_export import keras_hub_export diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index 23945cf755..15183ad788 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -15,6 +15,8 @@ class ImageClassifier(Task): To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + All `ImageClassifier` tasks include a `from_preset()` constructor which can + be used to load a pre-trained config and weights. Args: backbone: A `keras_hub.models.Backbone` instance or a `keras.Model`. @@ -23,18 +25,13 @@ class ImageClassifier(Task): a `keras.Layer` instance, or a callable. If `None` no preprocessing will be applied to the inputs. pooling: `"avg"` or `"max"`. The type of pooling to apply on backbone - output. Default to average pooling. + output. Defaults to average pooling. activation: `None`, str, or callable. The activation function to use on the `Dense` layer. Set `activation=None` to return the output logits. Defaults to `"softmax"`. head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The dtype to use for the classification head's computations and weights. - To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` - where `x` is a tensor and `y` is a integer from `[0, num_classes)`. - All `ImageClassifier` tasks include a `from_preset()` constructor which can - be used to load a pre-trained config and weights. - Examples: Call `predict()` to run inference. @@ -109,11 +106,15 @@ def __init__( self.preprocessor = preprocessor if pooling == "avg": self.pooler = keras.layers.GlobalAveragePooling2D( - data_format, dtype=head_dtype + data_format, + dtype=head_dtype, + name="pooler", ) elif pooling == "max": self.pooler = keras.layers.GlobalMaxPooling2D( - data_format, dtype=head_dtype + data_format, + dtype=head_dtype, + name="pooler", ) else: raise ValueError( diff --git a/keras_hub/src/models/task_test.py b/keras_hub/src/models/task_test.py index b5b8dfd5dc..bf57c88912 100644 --- a/keras_hub/src/models/task_test.py +++ b/keras_hub/src/models/task_test.py @@ -8,8 +8,8 @@ from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM -from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.preprocessor import Preprocessor from keras_hub.src.models.task import Task from keras_hub.src.models.text_classifier import TextClassifier from keras_hub.src.tests.test_case import TestCase diff --git a/keras_hub/src/models/vgg/vgg_backbone.py b/keras_hub/src/models/vgg/vgg_backbone.py index 902f392962..cf2638146e 100644 --- a/keras_hub/src/models/vgg/vgg_backbone.py +++ b/keras_hub/src/models/vgg/vgg_backbone.py @@ -21,17 +21,6 @@ class VGGBackbone(Backbone): blocks per VGG block. For both VGG16 and VGG19 this is [ 64, 128, 256, 512, 512]. image_shape: tuple, optional shape tuple, defaults to (224, 224, 3). - pooling: bool, Optional pooling mode for feature extraction - when `include_top` is `False`. - - `None` means that the output of the model will be - the 4D tensor output of the - last convolutional block. - - `avg` means that global average pooling - will be applied to the output of the - last convolutional block, and thus - the output of the model will be a 2D tensor. - - `max` means that global max pooling will - be applied. Examples: ```python @@ -46,7 +35,6 @@ class VGGBackbone(Backbone): stackwise_num_repeats = [2, 2, 3, 3, 3], stackwise_num_filters = [64, 128, 256, 512, 512], image_shape = (224, 224, 3), - pooling = "avg", ) model(input_data) ``` @@ -56,7 +44,7 @@ def __init__( self, stackwise_num_repeats, stackwise_num_filters, - image_shape=(224, 224, 3), + image_shape=(None, None, 3), **kwargs, ): diff --git a/keras_hub/src/models/vgg/vgg_image_classifier.py b/keras_hub/src/models/vgg/vgg_image_classifier.py index 40adf69911..8f658c2814 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier.py @@ -1,8 +1,180 @@ +import keras + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.task import Task from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone @keras_hub_export("keras_hub.models.VGGImageClassifier") class VGGImageClassifier(ImageClassifier): + """VGG image classification task. + + `VGGImageClassifier` tasks wrap a `keras_hub.models.VGGBackbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + image classification. `VGGImageClassifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + + Not that unlike `keras_hub.model.ImageClassifier`, the `VGGImageClassifier` + allows and defaults to `pooling="flatten"`, when inputs are flatten and + passed through two intermediate dense layers before the final output + projection. + + Args: + backbone: A `keras_hub.models.VGGBackbone` instance or a `keras.Model`. + num_classes: int. The number of classes to predict. + preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. + pooling: `"flatten"`, `"avg"`, or `"max"`. The type of pooling to apply + on backbone output. The default is flatten to match the original + VGG implementation, where backbone inputs will be flattened and + passed through two dense layers with a `"relu"` activation. + pooling_hidden_dim: the output feature size of the pooling dense layers. + This only applies when `pooling="flatten"`. + activation: `None`, str, or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. + + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + classifier = keras_hub.models.VGGImageClassifier.from_preset( + "vgg_16_imagenet" + ) + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + labels = [0, 3] + classifier = keras_hub.models.VGGImageClassifier.from_preset( + "vgg_16_imagenet" + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_hub.models.VGGImageClassifier.from_preset( + "vgg_16_imagenet" + ) + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + labels = [0, 3] + model = keras_hub.models.VGGBackbone( + stackwise_num_repeats = [2, 2, 3, 3, 3], + stackwise_num_filters = [64, 128, 256, 512, 512], + image_shape = (224, 224, 3), + ) + classifier = keras_hub.models.VGGImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + backbone_cls = VGGBackbone + + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + pooling="flatten", + pooling_hidden_dim=4096, + activation=None, + head_dtype=None, + **kwargs, + ): + head_dtype = head_dtype or backbone.dtype_policy + data_format = getattr(backbone, "data_format", None) + + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + if pooling == "avg": + self.pooler = keras.layers.GlobalAveragePooling2D( + data_format, + dtype=head_dtype, + name="pooler", + ) + elif pooling == "max": + self.pooler = keras.layers.GlobalMaxPooling2D( + data_format, + dtype=head_dtype, + name="pooler", + ) + elif pooling == "flatten": + self.pooler = keras.Sequential( + [ + keras.layers.Flatten(name="flatten"), + keras.layers.Dense(pooling_hidden_dim, activation="relu"), + keras.layers.Dense(pooling_hidden_dim, activation="relu"), + ], + name="pooler", + ) + else: + raise ValueError( + "Unknown `pooling` type. Polling should be either `'avg'` or " + f"`'max'`. Received: pooling={pooling}." + ) + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + x = self.pooler(x) + outputs = self.output_dense(x) + # Skip the parent class functional model. + Task.__init__( + self, + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + self.pooling = pooling + self.pooling_hidden_dim = pooling_hidden_dim + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "pooling": self.pooling, + "activation": self.activation, + "pooling_hidden_dim": self.pooling_hidden_dim, + } + ) + return config diff --git a/keras_hub/src/models/vgg/vgg_image_classifier_test.py b/keras_hub/src/models/vgg/vgg_image_classifier_test.py index ce7b63bad4..6d95ddaac5 100644 --- a/keras_hub/src/models/vgg/vgg_image_classifier_test.py +++ b/keras_hub/src/models/vgg/vgg_image_classifier_test.py @@ -20,7 +20,7 @@ def setUp(self): "backbone": self.backbone, "num_classes": 2, "activation": "softmax", - "pooling": "max", + "pooling": "flatten", } self.train_data = ( self.images,