-
Notifications
You must be signed in to change notification settings - Fork 242
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update VGG model to be compatible with HF and add conversion scripts
- Loading branch information
Showing
10 changed files
with
265 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
|
||
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone | ||
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier | ||
|
||
backbone_cls = VGGBackbone | ||
|
||
|
||
REPEATS_BY_SIZE = { | ||
"vgg11": [1, 1, 2, 2, 2], | ||
"vgg13": [2, 2, 2, 2, 2], | ||
"vgg16": [2, 2, 3, 3, 3], | ||
"vgg19": [2, 2, 4, 4, 4], | ||
} | ||
|
||
|
||
def convert_backbone_config(timm_config): | ||
architecture = timm_config["architecture"] | ||
stackwise_num_repeats = REPEATS_BY_SIZE[architecture] | ||
return dict( | ||
stackwise_num_repeats=stackwise_num_repeats, | ||
stackwise_num_filters=[64, 128, 256, 512, 512], | ||
) | ||
|
||
|
||
def convert_conv2d( | ||
model, | ||
loader, | ||
keras_layer_name: str, | ||
hf_layer_name: str, | ||
): | ||
loader.port_weight( | ||
model.get_layer(keras_layer_name).kernel, | ||
hf_weight_key=f"{hf_layer_name}.weight", | ||
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), | ||
) | ||
loader.port_weight( | ||
model.get_layer(keras_layer_name).bias, | ||
hf_weight_key=f"{hf_layer_name}.bias", | ||
) | ||
|
||
|
||
def convert_weights( | ||
backbone: VGGBackbone, | ||
loader, | ||
timm_config: dict[Any], | ||
): | ||
architecture = timm_config["architecture"] | ||
stackwise_num_repeats = REPEATS_BY_SIZE[architecture] | ||
|
||
hf_index_to_keras_layer_name = {} | ||
layer_index = 0 | ||
for block_index, repeats_in_block in enumerate(stackwise_num_repeats): | ||
for repeat_index in range(repeats_in_block): | ||
hf_index = layer_index | ||
layer_index += 2 # Conv + activation layers. | ||
layer_name = f"block{block_index + 1}_conv{repeat_index + 1}" | ||
hf_index_to_keras_layer_name[hf_index] = layer_name | ||
layer_index += 1 # Pooling layer after blocks. | ||
|
||
for hf_index, keras_layer_name in hf_index_to_keras_layer_name.items(): | ||
convert_conv2d( | ||
backbone, loader, keras_layer_name, f"features.{hf_index}" | ||
) | ||
|
||
|
||
def convert_head( | ||
task: VGGImageClassifier, | ||
loader, | ||
timm_config: dict[Any], | ||
): | ||
convert_conv2d(task.head, loader, "fc1", "pre_logits.fc1") | ||
convert_conv2d(task.head, loader, "fc2", "pre_logits.fc2") | ||
|
||
loader.port_weight( | ||
task.head.get_layer("predictions").kernel, | ||
hf_weight_key="head.fc.weight", | ||
hook_fn=lambda x, _: np.transpose(np.squeeze(x)), | ||
) | ||
loader.port_weight( | ||
task.head.get_layer("predictions").bias, | ||
hf_weight_key="head.fc.bias", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
"""Loads an external VGG model and saves it in Keras format. | ||
Optionally uploads the model to Keras if the `--upload_uri` flag is passed. | ||
python tools/checkpoint_conversion/convert_vgg_checkpoints.py \ | ||
--preset vgg11 --upload_uri kaggle://kerashub/vgg/keras/vgg11 | ||
""" | ||
|
||
import os | ||
import shutil | ||
|
||
import keras | ||
import numpy as np | ||
import PIL | ||
import timm | ||
import torch | ||
from absl import app | ||
from absl import flags | ||
|
||
import keras_hub | ||
|
||
PRESET_MAP = { | ||
"vgg11": "timm/vgg11.tv_in1k", | ||
"vgg13": "timm/vgg13.tv_in1k", | ||
"vgg16": "timm/vgg16.tv_in1k", | ||
"vgg19": "timm/vgg19.tv_in1k", | ||
# TODO(jeffcarp): Add BN variants. | ||
} | ||
|
||
|
||
PRESET = flags.DEFINE_string( | ||
"preset", | ||
None, | ||
"Must be a valid `VGG` preset from KerasHub", | ||
required=True, | ||
) | ||
UPLOAD_URI = flags.DEFINE_string( | ||
"upload_uri", | ||
None, | ||
'Could be "kaggle://keras/{variant}/keras/{preset}_int8"', | ||
) | ||
|
||
|
||
def validate_output(keras_model, timm_model): | ||
file = keras.utils.get_file( | ||
origin=( | ||
"https://storage.googleapis.com/keras-cv/" | ||
"models/paligemma/cow_beach_1.png" | ||
) | ||
) | ||
image = PIL.Image.open(file) | ||
batch = np.array([image]) | ||
|
||
# Preprocess with Timm. | ||
data_config = timm.data.resolve_model_data_config(timm_model) | ||
data_config["crop_pct"] = 1.0 # Stop timm from cropping. | ||
transforms = timm.data.create_transform(**data_config, is_training=False) | ||
timm_preprocessed = transforms(image) | ||
timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) | ||
timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) | ||
|
||
# Preprocess with Keras. | ||
keras_preprocessed = keras_model.preprocessor(batch) | ||
|
||
# Call with Timm. Use the keras preprocessed image so we can keep modeling | ||
# and preprocessing comparisons independent. | ||
timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) | ||
timm_batch = torch.from_numpy(np.array(timm_batch)) | ||
timm_outputs = timm_model(timm_batch).detach().numpy() | ||
timm_label = np.argmax(timm_outputs[0]) | ||
|
||
# Call with Keras. | ||
keras_outputs = keras_model.predict(batch) | ||
keras_label = np.argmax(keras_outputs[0]) | ||
|
||
print("๐ถ Keras output:", keras_outputs[0, :10]) | ||
print("๐ถ TIMM output:", timm_outputs[0, :10]) | ||
print("๐ถ Keras label:", keras_label) | ||
print("๐ถ TIMM label:", timm_label) | ||
modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs)) | ||
print("๐ถ Modeling difference:", modeling_diff) | ||
preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) | ||
print("๐ถ Preprocessing difference:", preprocessing_diff) | ||
|
||
|
||
def main(_): | ||
preset = PRESET.value | ||
if os.path.exists(preset): | ||
shutil.rmtree(preset) | ||
os.makedirs(preset) | ||
|
||
timm_name = PRESET_MAP[preset] | ||
|
||
timm_model = timm.create_model(timm_name, pretrained=True) | ||
timm_model = timm_model.eval() | ||
print("โ Loaded TIMM model.") | ||
print(timm_model) | ||
|
||
keras_model = keras_hub.models.ImageClassifier.from_preset( | ||
"hf://" + timm_name, | ||
) | ||
print("โ Loaded KerasHub model.") | ||
|
||
keras_model.save_to_preset(f"./{preset}") | ||
print(f"๐ Preset saved to ./{preset}") | ||
|
||
validate_output(keras_model, timm_model) | ||
|
||
upload_uri = UPLOAD_URI.value | ||
if upload_uri: | ||
keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") | ||
print(f"๐ Preset uploaded to {upload_uri}") | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |