Skip to content

Commit

Permalink
add keypoint_detection (keras-team#1589)
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin authored Nov 9, 2023
1 parent 10fb133 commit 331e91b
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 64 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 24 additions & 11 deletions examples/vision/ipynb/keypoint_detection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
"source": [
"# Keypoint Detection with Transfer Learning\n",
"\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)<br>\n",
"**Date created:** 2021/05/02<br>\n",
"**Last modified:** 2021/05/02<br>\n",
"**Last modified:** 2023/07/19<br>\n",
"**Description:** Training a keypoint detector with data augmentation and transfer learning."
]
},
Expand Down Expand Up @@ -119,9 +119,8 @@
},
"outputs": [],
"source": [
"from tensorflow.keras import layers\n",
"from tensorflow import keras\n",
"import tensorflow as tf\n",
"from keras import layers\n",
"import keras\n",
"\n",
"from imgaug.augmentables.kps import KeypointsOnImage\n",
"from imgaug.augmentables.kps import Keypoint\n",
Expand Down Expand Up @@ -277,6 +276,7 @@
"colours = [\"#\" + colour for colour in colours]\n",
"labels = keypoint_def[\"Name\"].values.tolist()\n",
"\n",
"\n",
"# Utility for reading an image and for getting its annotations.\n",
"def get_dog(name):\n",
" data = json_dict[name]\n",
Expand Down Expand Up @@ -311,6 +311,7 @@
},
"outputs": [],
"source": [
"\n",
"# Parts of this code come from here:\n",
"# https://github.com/benjiebob/StanfordExtra/blob/master/demo.ipynb\n",
"def visualize_keypoints(images, keypoints):\n",
Expand All @@ -326,7 +327,12 @@
" if isinstance(current_keypoint, KeypointsOnImage):\n",
" for idx, kp in enumerate(current_keypoint.keypoints):\n",
" ax_all.scatter(\n",
" [kp.x], [kp.y], c=colours[idx], marker=\"x\", s=50, linewidths=5\n",
" [kp.x],\n",
" [kp.y],\n",
" c=colours[idx],\n",
" marker=\"x\",\n",
" s=50,\n",
" linewidths=5,\n",
" )\n",
" else:\n",
" current_keypoint = np.array(current_keypoint)\n",
Expand Down Expand Up @@ -391,8 +397,9 @@
"outputs": [],
"source": [
"\n",
"class KeyPointsDataset(keras.utils.Sequence):\n",
" def __init__(self, image_keys, aug, batch_size=BATCH_SIZE, train=True):\n",
"class KeyPointsDataset(keras.utils.PyDataset):\n",
" def __init__(self, image_keys, aug, batch_size=BATCH_SIZE, train=True, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.image_keys = image_keys\n",
" self.aug = aug\n",
" self.batch_size = batch_size\n",
Expand Down Expand Up @@ -536,8 +543,12 @@
},
"outputs": [],
"source": [
"train_dataset = KeyPointsDataset(train_keys, train_aug)\n",
"validation_dataset = KeyPointsDataset(validation_keys, test_aug, train=False)\n",
"train_dataset = KeyPointsDataset(\n",
" train_keys, train_aug, workers=2, use_multiprocessing=True\n",
")\n",
"validation_dataset = KeyPointsDataset(\n",
" validation_keys, test_aug, train=False, workers=2, use_multiprocessing=True\n",
")\n",
"\n",
"print(f\"Total batches in training set: {len(train_dataset)}\")\n",
"print(f\"Total batches in validation set: {len(validation_dataset)}\")\n",
Expand Down Expand Up @@ -578,7 +589,9 @@
"def get_model():\n",
" # Load the pre-trained weights of MobileNetV2 and freeze the weights\n",
" backbone = keras.applications.MobileNetV2(\n",
" weights=\"imagenet\", include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3)\n",
" weights=\"imagenet\",\n",
" include_top=False,\n",
" input_shape=(IMG_SIZE, IMG_SIZE, 3),\n",
" )\n",
" backbone.trainable = False\n",
"\n",
Expand Down
34 changes: 22 additions & 12 deletions examples/vision/keypoint_detection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Title: Keypoint Detection with Transfer Learning
Author: [Sayak Paul](https://twitter.com/RisingSayak)
Author: [Sayak Paul](https://twitter.com/RisingSayak), converted to Keras 3 by [Muhammad Anas Raza](https://anasrz.com)
Date created: 2021/05/02
Last modified: 2021/05/02
Last modified: 2023/07/19
Description: Training a keypoint detector with data augmentation and transfer learning.
Accelerator: GPU
"""
Expand Down Expand Up @@ -56,10 +56,8 @@
"""
## Imports
"""

from tensorflow.keras import layers
from tensorflow import keras
import tensorflow as tf
from keras import layers
import keras

from imgaug.augmentables.kps import KeypointsOnImage
from imgaug.augmentables.kps import Keypoint
Expand Down Expand Up @@ -205,7 +203,12 @@ def visualize_keypoints(images, keypoints):
if isinstance(current_keypoint, KeypointsOnImage):
for idx, kp in enumerate(current_keypoint.keypoints):
ax_all.scatter(
[kp.x], [kp.y], c=colours[idx], marker="x", s=50, linewidths=5
[kp.x],
[kp.y],
c=colours[idx],
marker="x",
s=50,
linewidths=5,
)
else:
current_keypoint = np.array(current_keypoint)
Expand Down Expand Up @@ -251,8 +254,9 @@ def visualize_keypoints(images, keypoints):
"""


class KeyPointsDataset(keras.utils.Sequence):
def __init__(self, image_keys, aug, batch_size=BATCH_SIZE, train=True):
class KeyPointsDataset(keras.utils.PyDataset):
def __init__(self, image_keys, aug, batch_size=BATCH_SIZE, train=True, **kwargs):
super().__init__(**kwargs)
self.image_keys = image_keys
self.aug = aug
self.batch_size = batch_size
Expand Down Expand Up @@ -349,8 +353,12 @@ def __data_generation(self, image_keys_temp):
## Data generator investigation
"""

train_dataset = KeyPointsDataset(train_keys, train_aug)
validation_dataset = KeyPointsDataset(validation_keys, test_aug, train=False)
train_dataset = KeyPointsDataset(
train_keys, train_aug, workers=2, use_multiprocessing=True
)
validation_dataset = KeyPointsDataset(
validation_keys, test_aug, train=False, workers=2, use_multiprocessing=True
)

print(f"Total batches in training set: {len(train_dataset)}")
print(f"Total batches in validation set: {len(validation_dataset)}")
Expand All @@ -377,7 +385,9 @@ def __data_generation(self, image_keys_temp):
def get_model():
# Load the pre-trained weights of MobileNetV2 and freeze the weights
backbone = keras.applications.MobileNetV2(
weights="imagenet", include_top=False, input_shape=(IMG_SIZE, IMG_SIZE, 3)
weights="imagenet",
include_top=False,
input_shape=(IMG_SIZE, IMG_SIZE, 3),
)
backbone.trainable = False

Expand Down
Loading

0 comments on commit 331e91b

Please sign in to comment.