Skip to content

Commit

Permalink
Add dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Oct 2, 2024
1 parent b3d95a3 commit 7021003
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
9 changes: 9 additions & 0 deletions keras_hub/src/models/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
preprocessor=None,
pooling="avg",
activation=None,
dropout=0.0,
head_dtype=None,
**kwargs,
):
Expand All @@ -121,6 +122,11 @@ def __init__(
"Unknown `pooling` type. Polling should be either `'avg'` or "
f"`'max'`. Received: pooling={pooling}."
)
self.output_dropout = keras.layers.Dropout(
dropout,
dtype=head_dtype,
name="output_dropout",
)
self.output_dense = keras.layers.Dense(
num_classes,
activation=activation,
Expand All @@ -132,6 +138,7 @@ def __init__(
inputs = self.backbone.input
x = self.backbone(inputs)
x = self.pooler(x)
x = self.output_dropout(x)
outputs = self.output_dense(x)
super().__init__(
inputs=inputs,
Expand All @@ -143,6 +150,7 @@ def __init__(
self.num_classes = num_classes
self.activation = activation
self.pooling = pooling
self.dropout = dropout

def get_config(self):
# Backbone serialized in `super`
Expand All @@ -152,6 +160,7 @@ def get_config(self):
"num_classes": self.num_classes,
"pooling": self.pooling,
"activation": self.activation,
"dropout": self.dropout,
}
)
return config
Expand Down
9 changes: 9 additions & 0 deletions keras_hub/src/models/vgg/vgg_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
pooling="flatten",
pooling_hidden_dim=4096,
activation=None,
dropout=0.0,
head_dtype=None,
**kwargs,
):
Expand Down Expand Up @@ -140,6 +141,11 @@ def __init__(
"Unknown `pooling` type. Polling should be either `'avg'` or "
f"`'max'`. Received: pooling={pooling}."
)
self.output_dropout = keras.layers.Dropout(
dropout,
dtype=head_dtype,
name="output_dropout",
)
self.output_dense = keras.layers.Dense(
num_classes,
activation=activation,
Expand All @@ -151,6 +157,7 @@ def __init__(
inputs = self.backbone.input
x = self.backbone(inputs)
x = self.pooler(x)
x = self.output_dropout(x)
outputs = self.output_dense(x)
# Skip the parent class functional model.
Task.__init__(
Expand All @@ -165,6 +172,7 @@ def __init__(
self.activation = activation
self.pooling = pooling
self.pooling_hidden_dim = pooling_hidden_dim
self.dropout = dropout

def get_config(self):
# Backbone serialized in `super`
Expand All @@ -175,6 +183,7 @@ def get_config(self):
"pooling": self.pooling,
"activation": self.activation,
"pooling_hidden_dim": self.pooling_hidden_dim,
"dropout": self.dropout,
}
)
return config

0 comments on commit 7021003

Please sign in to comment.