From 6a120f75410e5c5dee4a81da2c37d7a7a65c54fd Mon Sep 17 00:00:00 2001 From: Piseth Ky Date: Thu, 31 Oct 2024 16:02:59 -0700 Subject: [PATCH] add lite0 variant --- .../src/utils/timm/convert_efficientnet.py | 37 ++++++++++--------- .../convert_efficientnet_checkpoints.py | 3 ++ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/keras_hub/src/utils/timm/convert_efficientnet.py b/keras_hub/src/utils/timm/convert_efficientnet.py index 609c26d35..9c14b5f5b 100644 --- a/keras_hub/src/utils/timm/convert_efficientnet.py +++ b/keras_hub/src/utils/timm/convert_efficientnet.py @@ -18,6 +18,12 @@ "width_coefficient": 1.0, "depth_coefficient": 1.1, }, + "lite0": { + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "stackwise_squeeze_and_excite_ratios": [0] * 7, + "activation": "relu6", + }, } @@ -31,15 +37,7 @@ def convert_backbone_config(timm_config): "stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320], "stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6], "stackwise_strides": [1, 2, 2, 2, 1, 2, 1], - "stackwise_squeeze_and_excite_ratios": [ - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - 0.25, - ], + "stackwise_squeeze_and_excite_ratios": [0.25] * 7, "stackwise_block_types": ["v1"] * 7, "min_depth": None, "include_stem_padding": True, @@ -145,6 +143,8 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): math.ceil(VARIANT_MAP[variant]["depth_coefficient"] * repeats) ) + se_ratio = VARIANT_MAP[variant]["stackwise_squeeze_and_excite_ratios"][stack_index] + for block_idx in range(repeats): conv_pw_count = 0 @@ -184,15 +184,16 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): ) bn_count += 1 - # Squeeze and Excite - port_conv2d( - keras_block_prefix + "se_reduce", - hf_block_prefix + "se.conv_reduce", - ) - port_conv2d( - keras_block_prefix + "se_expand", - hf_block_prefix + "se.conv_expand", - ) + if 0 < se_ratio <= 1: + # Squeeze and Excite + port_conv2d( + keras_block_prefix + "se_reduce", + hf_block_prefix + "se.conv_reduce", + ) + port_conv2d( + keras_block_prefix + "se_expand", + hf_block_prefix + "se.conv_expand", + ) # Output/Projection port_conv2d( diff --git a/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py b/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py index 5790d6130..86e92633c 100644 --- a/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py +++ b/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py @@ -5,6 +5,8 @@ --preset efficientnet_b0_ra_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_b0_ra_imagenet python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ --preset efficientnet_b1_ft_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_b1_ft_imagenet +python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ + --preset efficientnet_lite0_ra_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_lite0_ra_imagenet """ import os @@ -23,6 +25,7 @@ PRESET_MAP = { "efficientnet_b0_ra_imagenet": "timm/efficientnet_b0.ra_in1k", "efficientnet_b1_ft_imagenet": "timm/efficientnet_b1.ft_in1k", + "efficientnet_lite0_ra_imagenet": "timm/efficientnet_lite0.ra_in1k", } FLAGS = flags.FLAGS