diff --git a/guides/customizing_saving_and_serialization.py b/guides/customizing_saving_and_serialization.py index 7986d7e8dc..a5e7083cc7 100644 --- a/guides/customizing_saving_and_serialization.py +++ b/guides/customizing_saving_and_serialization.py @@ -4,6 +4,7 @@ Date created: 2023/03/15 Last modified: 2023/03/15 Description: A more advanced guide on customizing saving for your layers and models. +Accelerator: None """ """ @@ -78,7 +79,9 @@ def save_own_variables(self, store): class LayerWithCustomVariables(keras.layers.Dense): def __init__(self, units, **kwargs): super().__init__(units, **kwargs) - self.stored_variables = tf.Variable(np.random.random((10,)), name="special_arr", dtype=tf.float32) + self.stored_variables = tf.Variable( + np.random.random((10,)), name="special_arr", dtype=tf.float32 + ) def save_own_variables(self, store): super().save_own_variables(store) @@ -98,9 +101,7 @@ def call(self, inputs): return super().call(inputs) * self.stored_variables -model = keras.Sequential([ - LayerWithCustomVariables(1) -]) +model = keras.Sequential([LayerWithCustomVariables(1)]) ref_input = np.random.random((8, 10)) ref_output = np.random.random((8,)) @@ -112,7 +113,7 @@ def call(self, inputs): np.testing.assert_allclose( model.layers[0].stored_variables.numpy(), - restored_model.layers[0].stored_variables.numpy() + restored_model.layers[0].stored_variables.numpy(), ) """ @@ -149,9 +150,9 @@ def load_assets(self, inner_path): self.vocab = text.replace("", "little") -model = keras.Sequential([ - LayerWithCustomAssets(vocab='Mary had a lamb.', units=5) -]) +model = keras.Sequential( + [LayerWithCustomAssets(vocab="Mary had a lamb.", units=5)] +) x = np.random.random((10, 10)) y = model(x) @@ -338,4 +339,4 @@ def compile_from_config(self, config): - `get_compile_config` and `compile_from_config` save and restore the model's compiled states. -""" \ No newline at end of file +""" diff --git a/guides/ipynb/customizing_saving_and_serialization.ipynb b/guides/ipynb/customizing_saving_and_serialization.ipynb new file mode 100644 index 0000000000..e472965ba0 --- /dev/null +++ b/guides/ipynb/customizing_saving_and_serialization.ipynb @@ -0,0 +1,472 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "# Customizing Saving and Serialization\n", + "\n", + "**Author:** Neel Kovelamudi
\n", + "**Date created:** 2023/03/15
\n", + "**Last modified:** 2023/03/15
\n", + "**Description:** A more advanced guide on customizing saving for your layers and models." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Introduction\n", + "\n", + "This guide covers advanced methods that can be customized in Keras saving. For most\n", + "users, the methods outlined in the primary\n", + "[Serialize, save, and export guide](https://keras.io/guides/serialization_and_saving)\n", + "are sufficient." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### APIs\n", + "We will cover the following APIs:\n", + "\n", + "- `save_assets()` and `load_assets()`\n", + "- `save_own_variables()` and `load_own_variables()`\n", + "- `get_build_config()` and `build_from_config()`\n", + "- `get_compile_config()` and `compile_from_config()`\n", + "\n", + "When restoring a model, these get executed in the following order:\n", + "\n", + "- `build_from_config()`\n", + "- `compile_from_config()`\n", + "- `load_own_variables()`\n", + "- `load_assets()`" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import keras" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## State saving customization\n", + "\n", + "These methods determine how the state of your model's layers is saved when calling\n", + "`model.save()`. You can override them to take full control of the state saving process." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "\n", + "### `save_own_variables()` and `load_own_variables()`\n", + "\n", + "These methods save and load the state variables of the layer when `model.save()` and\n", + "`keras.models.load_model()` are called, respectively. By default, the state variables\n", + "saved and loaded are the weights of the layer (both trainable and non-trainable). Here is\n", + "the default implementation of `save_own_variables()`:\n", + "\n", + "```python\n", + "def save_own_variables(self, store):\n", + " all_vars = self._trainable_weights + self._non_trainable_weights\n", + " for i, v in enumerate(all_vars):\n", + " store[f\"{i}\"] = v.numpy()\n", + "```\n", + "\n", + "The store used by these methods is a dictionary that can be populated with the layer\n", + "variables. Let's take a look at an example customizing this.\n", + "\n", + "**Example:**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.utils.register_keras_serializable(package=\"my_custom_package\")\n", + "class LayerWithCustomVariables(keras.layers.Dense):\n", + " def __init__(self, units, **kwargs):\n", + " super().__init__(units, **kwargs)\n", + " self.stored_variables = tf.Variable(\n", + " np.random.random((10,)), name=\"special_arr\", dtype=tf.float32\n", + " )\n", + "\n", + " def save_own_variables(self, store):\n", + " super().save_own_variables(store)\n", + " # Stores the value of the `tf.Variable` upon saving\n", + " store[\"variables\"] = self.stored_variables.numpy()\n", + "\n", + " def load_own_variables(self, store):\n", + " # Assigns the value of the `tf.Variable` upon loading\n", + " self.stored_variables.assign(store[\"variables\"])\n", + " # Load the remaining weights\n", + " for i, v in enumerate(self.weights):\n", + " v.assign(store[f\"{i}\"])\n", + " # Note: You must specify how all variables (including layer weights)\n", + " # are loaded in `load_own_variables.`\n", + "\n", + " def call(self, inputs):\n", + " return super().call(inputs) * self.stored_variables\n", + "\n", + "\n", + "model = keras.Sequential([LayerWithCustomVariables(1)])\n", + "\n", + "ref_input = np.random.random((8, 10))\n", + "ref_output = np.random.random((8,))\n", + "model.compile(optimizer=\"adam\", loss=\"mean_squared_error\")\n", + "model.fit(ref_input, ref_output)\n", + "\n", + "model.save(\"custom_vars_model.keras\")\n", + "restored_model = keras.models.load_model(\"custom_vars_model.keras\")\n", + "\n", + "np.testing.assert_allclose(\n", + " model.layers[0].stored_variables.numpy(),\n", + " restored_model.layers[0].stored_variables.numpy(),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### `save_assets()` and `load_assets()`\n", + "\n", + "These methods can be added to your model class definition to store and load any\n", + "additional information that your model needs.\n", + "\n", + "For example, NLP domain layers such as TextVectorization layers and IndexLookup layers\n", + "may need to store their associated vocabulary (or lookup table) in a text file upon\n", + "saving.\n", + "\n", + "Let's take at the basics of this workflow with a simple file `assets.txt`.\n", + "\n", + "**Example:**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.saving.register_keras_serializable(package=\"my_custom_package\")\n", + "class LayerWithCustomAssets(keras.layers.Dense):\n", + " def __init__(self, vocab=None, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.vocab = vocab\n", + "\n", + " def save_assets(self, inner_path):\n", + " # Writes the vocab (sentence) to text file at save time.\n", + " with open(os.path.join(inner_path, \"vocabulary.txt\"), \"w\") as f:\n", + " f.write(self.vocab)\n", + "\n", + " def load_assets(self, inner_path):\n", + " # Reads the vocab (sentence) from text file at load time.\n", + " with open(os.path.join(inner_path, \"vocabulary.txt\"), \"r\") as f:\n", + " text = f.read()\n", + " self.vocab = text.replace(\"\", \"little\")\n", + "\n", + "\n", + "model = keras.Sequential(\n", + " [LayerWithCustomAssets(vocab=\"Mary had a lamb.\", units=5)]\n", + ")\n", + "\n", + "x = np.random.random((10, 10))\n", + "y = model(x)\n", + "\n", + "model.save(\"custom_assets_model.keras\")\n", + "restored_model = keras.models.load_model(\"custom_assets_model.keras\")\n", + "\n", + "np.testing.assert_string_equal(\n", + " restored_model.layers[0].vocab, \"Mary had a little lamb.\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## `build` and `compile` saving customization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### `get_build_config()` and `build_from_config()`\n", + "\n", + "These methods work together to save the layer's built states and restore them upon\n", + "loading.\n", + "\n", + "By default, this only includes a build config dictionary with the layer's input shape,\n", + "but overriding these methods can be used to include further Variables and Lookup Tables\n", + "that can be useful to restore for your built model.\n", + "\n", + "**Example:**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.saving.register_keras_serializable(package=\"my_custom_package\")\n", + "class LayerWithCustomBuild(keras.layers.Layer):\n", + " def __init__(self, units=32, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.units = units\n", + "\n", + " def call(self, inputs):\n", + " return tf.matmul(inputs, self.w) + self.b\n", + "\n", + " def get_config(self):\n", + " return dict(units=self.units, **super().get_config())\n", + "\n", + " def build(self, input_shape, layer_init):\n", + " # Note the customization in overriding `build()` adds an extra argument.\n", + " # Therefore, we will need to manually call build with `layer_init` argument\n", + " # before the first execution of `call()`.\n", + " super().build(input_shape)\n", + " self.w = self.add_weight(\n", + " shape=(input_shape[-1], self.units),\n", + " initializer=layer_init,\n", + " trainable=True,\n", + " )\n", + " self.b = self.add_weight(\n", + " shape=(self.units,),\n", + " initializer=layer_init,\n", + " trainable=True,\n", + " )\n", + " self.layer_init = layer_init\n", + "\n", + " def get_build_config(self):\n", + " build_config = super().get_build_config() # only gives `input_shape`\n", + " build_config.update(\n", + " {\"layer_init\": self.layer_init} # Stores our initializer for `build()`\n", + " )\n", + " return build_config\n", + "\n", + " def build_from_config(self, config):\n", + " # Calls `build()` with the parameters at loading time\n", + " self.build(config[\"input_shape\"], config[\"layer_init\"])\n", + "\n", + "\n", + "custom_layer = LayerWithCustomBuild(units=16)\n", + "custom_layer.build(input_shape=(8,), layer_init=\"random_normal\")\n", + "\n", + "model = keras.Sequential(\n", + " [\n", + " custom_layer,\n", + " keras.layers.Dense(1, activation=\"sigmoid\"),\n", + " ]\n", + ")\n", + "\n", + "x = np.random.random((16, 8))\n", + "y = model(x)\n", + "\n", + "model.save(\"custom_build_model.keras\")\n", + "restored_model = keras.models.load_model(\"custom_build_model.keras\")\n", + "\n", + "np.testing.assert_equal(restored_model.layers[0].layer_init, \"random_normal\")\n", + "np.testing.assert_equal(restored_model.built, True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### `get_compile_config()` and `compile_from_config()`\n", + "\n", + "These methods work together to save the information with which the model was compiled\n", + "(optimizers, losses, etc.) and restore and re-compile the model with this information.\n", + "\n", + "Overriding these methods can be useful for compiling the restored model with custom\n", + "optimizers, custom losses, etc., as these will need to be deserialized prior to calling\n", + "`model.compile` in `compile_from_config()`.\n", + "\n", + "Let's take a look at an example of this.\n", + "\n", + "**Example:**" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "@keras.saving.register_keras_serializable(package=\"my_custom_package\")\n", + "def small_square_sum_loss(y_true, y_pred):\n", + " loss = tf.math.squared_difference(y_pred, y_true)\n", + " loss = loss / 10.0\n", + " loss = tf.reduce_sum(loss, axis=1)\n", + " return loss\n", + "\n", + "\n", + "@keras.saving.register_keras_serializable(package=\"my_custom_package\")\n", + "def mean_pred(y_true, y_pred):\n", + " return tf.reduce_mean(y_pred)\n", + "\n", + "\n", + "@keras.saving.register_keras_serializable(package=\"my_custom_package\")\n", + "class ModelWithCustomCompile(keras.Model):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.dense1 = keras.layers.Dense(8, activation=\"relu\")\n", + " self.dense2 = keras.layers.Dense(4, activation=\"softmax\")\n", + "\n", + " def call(self, inputs):\n", + " x = self.dense1(inputs)\n", + " return self.dense2(x)\n", + "\n", + " def compile(self, optimizer, loss_fn, metrics):\n", + " super().compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)\n", + " self.model_optimizer = optimizer\n", + " self.loss_fn = loss_fn\n", + " self.loss_metrics = metrics\n", + "\n", + " def get_compile_config(self):\n", + " # These parameters will be serialized at saving time.\n", + " return {\n", + " \"model_optimizer\": self.model_optimizer,\n", + " \"loss_fn\": self.loss_fn,\n", + " \"metric\": self.loss_metrics,\n", + " }\n", + "\n", + " def compile_from_config(self, config):\n", + " # Deserializes the compile parameters (important, since many are custom)\n", + " optimizer = keras.utils.deserialize_keras_object(config[\"model_optimizer\"])\n", + " loss_fn = keras.utils.deserialize_keras_object(config[\"loss_fn\"])\n", + " metrics = keras.utils.deserialize_keras_object(config[\"metric\"])\n", + "\n", + " # Calls compile with the deserialized parameters\n", + " self.compile(optimizer=optimizer, loss_fn=loss_fn, metrics=metrics)\n", + "\n", + "\n", + "model = ModelWithCustomCompile()\n", + "model.compile(\n", + " optimizer=\"SGD\", loss_fn=small_square_sum_loss, metrics=[\"accuracy\", mean_pred]\n", + ")\n", + "\n", + "x = np.random.random((4, 8))\n", + "y = np.random.random((4,))\n", + "\n", + "model.fit(x, y)\n", + "\n", + "model.save(\"custom_compile_model.keras\")\n", + "restored_model = keras.models.load_model(\"custom_compile_model.keras\")\n", + "\n", + "np.testing.assert_equal(model.model_optimizer, restored_model.model_optimizer)\n", + "np.testing.assert_equal(model.loss_fn, restored_model.loss_fn)\n", + "np.testing.assert_equal(model.loss_metrics, restored_model.loss_metrics)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Conclusion\n", + "\n", + "Using the methods learned in this tutorial allows for a wide variety of use cases,\n", + "allowing the saving and loading of complex models with exotic assets and state\n", + "elements. To recap:\n", + "\n", + "- `save_own_variables` and `load_own_variables` determine how your states are saved\n", + "and loaded.\n", + "- `save_assets` and `load_assets` can be added to store and load any additional\n", + "information your model needs.\n", + "- `get_build_config` and `build_from_config` save and restore the model's built\n", + "states.\n", + "- `get_compile_config` and `compile_from_config` save and restore the model's\n", + "compiled states." + ] + } + ], + "metadata": { + "accelerator": "None", + "colab": { + "collapsed_sections": [], + "name": "customizing_saving_and_serialization", + "private_outputs": false, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/guides/md/customizing_saving_and_serialization.md b/guides/md/customizing_saving_and_serialization.md new file mode 100644 index 0000000000..f8272ea7d6 --- /dev/null +++ b/guides/md/customizing_saving_and_serialization.md @@ -0,0 +1,355 @@ +# Customizing Saving and Serialization + +**Author:** Neel Kovelamudi
+**Date created:** 2023/03/15
+**Last modified:** 2023/03/15
+**Description:** A more advanced guide on customizing saving for your layers and models. + + + [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/customizing_saving_and_serialization.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/customizing_saving_and_serialization.py) + + + +--- +## Introduction + +This guide covers advanced methods that can be customized in Keras saving. For most +users, the methods outlined in the primary +[Serialize, save, and export guide](https://keras.io/guides/serialization_and_saving) +are sufficient. + +### APIs +We will cover the following APIs: + +- `save_assets()` and `load_assets()` +- `save_own_variables()` and `load_own_variables()` +- `get_build_config()` and `build_from_config()` +- `get_compile_config()` and `compile_from_config()` + +When restoring a model, these get executed in the following order: + +- `build_from_config()` +- `compile_from_config()` +- `load_own_variables()` +- `load_assets()` + +--- +## Setup + + +```python +import os +import numpy as np +import tensorflow as tf +import keras +``` + +--- +## State saving customization + +These methods determine how the state of your model's layers is saved when calling +`model.save()`. You can override them to take full control of the state saving process. + + +### `save_own_variables()` and `load_own_variables()` + +These methods save and load the state variables of the layer when `model.save()` and +`keras.models.load_model()` are called, respectively. By default, the state variables +saved and loaded are the weights of the layer (both trainable and non-trainable). Here is +the default implementation of `save_own_variables()`: + +```python +def save_own_variables(self, store): + all_vars = self._trainable_weights + self._non_trainable_weights + for i, v in enumerate(all_vars): + store[f"{i}"] = v.numpy() +``` + +The store used by these methods is a dictionary that can be populated with the layer +variables. Let's take a look at an example customizing this. + +**Example:** + + +```python + +@keras.utils.register_keras_serializable(package="my_custom_package") +class LayerWithCustomVariables(keras.layers.Dense): + def __init__(self, units, **kwargs): + super().__init__(units, **kwargs) + self.stored_variables = tf.Variable( + np.random.random((10,)), name="special_arr", dtype=tf.float32 + ) + + def save_own_variables(self, store): + super().save_own_variables(store) + # Stores the value of the `tf.Variable` upon saving + store["variables"] = self.stored_variables.numpy() + + def load_own_variables(self, store): + # Assigns the value of the `tf.Variable` upon loading + self.stored_variables.assign(store["variables"]) + # Load the remaining weights + for i, v in enumerate(self.weights): + v.assign(store[f"{i}"]) + # Note: You must specify how all variables (including layer weights) + # are loaded in `load_own_variables.` + + def call(self, inputs): + return super().call(inputs) * self.stored_variables + + +model = keras.Sequential([LayerWithCustomVariables(1)]) + +ref_input = np.random.random((8, 10)) +ref_output = np.random.random((8,)) +model.compile(optimizer="adam", loss="mean_squared_error") +model.fit(ref_input, ref_output) + +model.save("custom_vars_model.keras") +restored_model = keras.models.load_model("custom_vars_model.keras") + +np.testing.assert_allclose( + model.layers[0].stored_variables.numpy(), + restored_model.layers[0].stored_variables.numpy(), +) +``` + +
+``` +1/1 [==============================] - 0s 418ms/step - loss: 0.7132 + +``` +
+### `save_assets()` and `load_assets()` + +These methods can be added to your model class definition to store and load any +additional information that your model needs. + +For example, NLP domain layers such as TextVectorization layers and IndexLookup layers +may need to store their associated vocabulary (or lookup table) in a text file upon +saving. + +Let's take at the basics of this workflow with a simple file `assets.txt`. + +**Example:** + + +```python + +@keras.saving.register_keras_serializable(package="my_custom_package") +class LayerWithCustomAssets(keras.layers.Dense): + def __init__(self, vocab=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.vocab = vocab + + def save_assets(self, inner_path): + # Writes the vocab (sentence) to text file at save time. + with open(os.path.join(inner_path, "vocabulary.txt"), "w") as f: + f.write(self.vocab) + + def load_assets(self, inner_path): + # Reads the vocab (sentence) from text file at load time. + with open(os.path.join(inner_path, "vocabulary.txt"), "r") as f: + text = f.read() + self.vocab = text.replace("", "little") + + +model = keras.Sequential( + [LayerWithCustomAssets(vocab="Mary had a lamb.", units=5)] +) + +x = np.random.random((10, 10)) +y = model(x) + +model.save("custom_assets_model.keras") +restored_model = keras.models.load_model("custom_assets_model.keras") + +np.testing.assert_string_equal( + restored_model.layers[0].vocab, "Mary had a little lamb." +) +``` + +--- +## `build` and `compile` saving customization + +### `get_build_config()` and `build_from_config()` + +These methods work together to save the layer's built states and restore them upon +loading. + +By default, this only includes a build config dictionary with the layer's input shape, +but overriding these methods can be used to include further Variables and Lookup Tables +that can be useful to restore for your built model. + +**Example:** + + +```python + +@keras.saving.register_keras_serializable(package="my_custom_package") +class LayerWithCustomBuild(keras.layers.Layer): + def __init__(self, units=32, **kwargs): + super().__init__(**kwargs) + self.units = units + + def call(self, inputs): + return tf.matmul(inputs, self.w) + self.b + + def get_config(self): + return dict(units=self.units, **super().get_config()) + + def build(self, input_shape, layer_init): + # Note the customization in overriding `build()` adds an extra argument. + # Therefore, we will need to manually call build with `layer_init` argument + # before the first execution of `call()`. + super().build(input_shape) + self.w = self.add_weight( + shape=(input_shape[-1], self.units), + initializer=layer_init, + trainable=True, + ) + self.b = self.add_weight( + shape=(self.units,), + initializer=layer_init, + trainable=True, + ) + self.layer_init = layer_init + + def get_build_config(self): + build_config = super().get_build_config() # only gives `input_shape` + build_config.update( + {"layer_init": self.layer_init} # Stores our initializer for `build()` + ) + return build_config + + def build_from_config(self, config): + # Calls `build()` with the parameters at loading time + self.build(config["input_shape"], config["layer_init"]) + + +custom_layer = LayerWithCustomBuild(units=16) +custom_layer.build(input_shape=(8,), layer_init="random_normal") + +model = keras.Sequential( + [ + custom_layer, + keras.layers.Dense(1, activation="sigmoid"), + ] +) + +x = np.random.random((16, 8)) +y = model(x) + +model.save("custom_build_model.keras") +restored_model = keras.models.load_model("custom_build_model.keras") + +np.testing.assert_equal(restored_model.layers[0].layer_init, "random_normal") +np.testing.assert_equal(restored_model.built, True) +``` + +### `get_compile_config()` and `compile_from_config()` + +These methods work together to save the information with which the model was compiled +(optimizers, losses, etc.) and restore and re-compile the model with this information. + +Overriding these methods can be useful for compiling the restored model with custom +optimizers, custom losses, etc., as these will need to be deserialized prior to calling +`model.compile` in `compile_from_config()`. + +Let's take a look at an example of this. + +**Example:** + + +```python + +@keras.saving.register_keras_serializable(package="my_custom_package") +def small_square_sum_loss(y_true, y_pred): + loss = tf.math.squared_difference(y_pred, y_true) + loss = loss / 10.0 + loss = tf.reduce_sum(loss, axis=1) + return loss + + +@keras.saving.register_keras_serializable(package="my_custom_package") +def mean_pred(y_true, y_pred): + return tf.reduce_mean(y_pred) + + +@keras.saving.register_keras_serializable(package="my_custom_package") +class ModelWithCustomCompile(keras.Model): + def __init__(self): + super().__init__() + self.dense1 = keras.layers.Dense(8, activation="relu") + self.dense2 = keras.layers.Dense(4, activation="softmax") + + def call(self, inputs): + x = self.dense1(inputs) + return self.dense2(x) + + def compile(self, optimizer, loss_fn, metrics): + super().compile(optimizer=optimizer, loss=loss_fn, metrics=metrics) + self.model_optimizer = optimizer + self.loss_fn = loss_fn + self.loss_metrics = metrics + + def get_compile_config(self): + # These parameters will be serialized at saving time. + return { + "model_optimizer": self.model_optimizer, + "loss_fn": self.loss_fn, + "metric": self.loss_metrics, + } + + def compile_from_config(self, config): + # Deserializes the compile parameters (important, since many are custom) + optimizer = keras.utils.deserialize_keras_object(config["model_optimizer"]) + loss_fn = keras.utils.deserialize_keras_object(config["loss_fn"]) + metrics = keras.utils.deserialize_keras_object(config["metric"]) + + # Calls compile with the deserialized parameters + self.compile(optimizer=optimizer, loss_fn=loss_fn, metrics=metrics) + + +model = ModelWithCustomCompile() +model.compile( + optimizer="SGD", loss_fn=small_square_sum_loss, metrics=["accuracy", mean_pred] +) + +x = np.random.random((4, 8)) +y = np.random.random((4,)) + +model.fit(x, y) + +model.save("custom_compile_model.keras") +restored_model = keras.models.load_model("custom_compile_model.keras") + +np.testing.assert_equal(model.model_optimizer, restored_model.model_optimizer) +np.testing.assert_equal(model.loss_fn, restored_model.loss_fn) +np.testing.assert_equal(model.loss_metrics, restored_model.loss_metrics) +``` + +
+``` +1/1 [==============================] - 0s 386ms/step - loss: 0.0704 - accuracy: 0.0000e+00 - mean_pred: 0.2500 + +WARNING:absl:Skipping variable loading for optimizer 'SGD', because it has 1 variables whereas the saved optimizer has 5 variables. + +``` +
+--- +## Conclusion + +Using the methods learned in this tutorial allows for a wide variety of use cases, +allowing the saving and loading of complex models with exotic assets and state +elements. To recap: + +- `save_own_variables` and `load_own_variables` determine how your states are saved +and loaded. +- `save_assets` and `load_assets` can be added to store and load any additional +information your model needs. +- `get_build_config` and `build_from_config` save and restore the model's built +states. +- `get_compile_config` and `compile_from_config` save and restore the model's +compiled states.