Skip to content

Commit

Permalink
add lite0 variant
Browse files Browse the repository at this point in the history
  • Loading branch information
pkgoogle committed Oct 31, 2024
1 parent 7e39d97 commit 6a120f7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
37 changes: 19 additions & 18 deletions keras_hub/src/utils/timm/convert_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
"width_coefficient": 1.0,
"depth_coefficient": 1.1,
},
"lite0": {
"width_coefficient": 1.0,
"depth_coefficient": 1.0,
"stackwise_squeeze_and_excite_ratios": [0] * 7,
"activation": "relu6",
},
}


Expand All @@ -31,15 +37,7 @@ def convert_backbone_config(timm_config):
"stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320],
"stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6],
"stackwise_strides": [1, 2, 2, 2, 1, 2, 1],
"stackwise_squeeze_and_excite_ratios": [
0.25,
0.25,
0.25,
0.25,
0.25,
0.25,
0.25,
],
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
"stackwise_block_types": ["v1"] * 7,
"min_depth": None,
"include_stem_padding": True,
Expand Down Expand Up @@ -145,6 +143,8 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
math.ceil(VARIANT_MAP[variant]["depth_coefficient"] * repeats)
)

se_ratio = VARIANT_MAP[variant]["stackwise_squeeze_and_excite_ratios"][stack_index]

for block_idx in range(repeats):

conv_pw_count = 0
Expand Down Expand Up @@ -184,15 +184,16 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
)
bn_count += 1

# Squeeze and Excite
port_conv2d(
keras_block_prefix + "se_reduce",
hf_block_prefix + "se.conv_reduce",
)
port_conv2d(
keras_block_prefix + "se_expand",
hf_block_prefix + "se.conv_expand",
)
if 0 < se_ratio <= 1:
# Squeeze and Excite
port_conv2d(
keras_block_prefix + "se_reduce",
hf_block_prefix + "se.conv_reduce",
)
port_conv2d(
keras_block_prefix + "se_expand",
hf_block_prefix + "se.conv_expand",
)

# Output/Projection
port_conv2d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
--preset efficientnet_b0_ra_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_b0_ra_imagenet
python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \
--preset efficientnet_b1_ft_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_b1_ft_imagenet
python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \
--preset efficientnet_lite0_ra_imagenet --upload_uri kaggle://kerashub/efficientnet/keras/efficientnet_lite0_ra_imagenet
"""

import os
Expand All @@ -23,6 +25,7 @@
PRESET_MAP = {
"efficientnet_b0_ra_imagenet": "timm/efficientnet_b0.ra_in1k",
"efficientnet_b1_ft_imagenet": "timm/efficientnet_b1.ft_in1k",
"efficientnet_lite0_ra_imagenet": "timm/efficientnet_lite0.ra_in1k",
}
FLAGS = flags.FLAGS

Expand Down

0 comments on commit 6a120f7

Please sign in to comment.