diff --git a/learn_the_basics.rst b/learn_the_basics.rst index c6a2edba..67073f51 100644 --- a/learn_the_basics.rst +++ b/learn_the_basics.rst @@ -9,6 +9,11 @@ Learn the basics Transpiling Kornia functions to TensorFlow. + .. grid-item-card:: Transpiling Models from PyTorch to TensorFlow + :link: learn_the_basics/torch_to_tf_models.ipynb + + Transpiling PyTorch models to TensorFlow. + .. grid-item-card:: Trace Code :link: learn_the_basics/03_trace_code.ipynb @@ -29,6 +34,7 @@ Learn the basics :maxdepth: -1 learn_the_basics/torch_to_tf_functions.ipynb + learn_the_basics/torch_to_tf_models.ipynb learn_the_basics/03_trace_code.ipynb learn_the_basics/05_lazy_vs_eager.ipynb learn_the_basics/06_how_to_use_decorators.ipynb diff --git a/learn_the_basics/example_models.py b/learn_the_basics/example_models.py new file mode 100644 index 00000000..d6fdd783 --- /dev/null +++ b/learn_the_basics/example_models.py @@ -0,0 +1,16 @@ +import torch + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super(SimpleModel, self).__init__() + self.conv1 = torch.nn.Conv2d(1, 3, kernel_size=3) + self.relu = torch.nn.ReLU() + self.fc = torch.nn.Linear(3 * 26 * 26, 10) + + def forward(self, x): + x = self.conv1(x) + x = self.relu(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x diff --git a/learn_the_basics/torch_to_tf_models.ipynb b/learn_the_basics/torch_to_tf_models.ipynb new file mode 100644 index 00000000..246416f6 --- /dev/null +++ b/learn_the_basics/torch_to_tf_models.ipynb @@ -0,0 +1,160 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Transpiling Models from PyTorch to TensorFlow" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can install the dependencies required for this notebook by running the cell below ⬇️, or check out the [Get Started](https://ivy.dev/docs/overview/get_started.html) section of the docs to find out more about installing ivy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install ivy\n", + "!pip install torch\n", + "!pip install tensorflow" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we'll go through an example of how any model written in PyTorch can be converted, and used in, TensorFlow via `ivy.transpile`. First, lets import a simple torch model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from example_models import SimpleModel\n", + "\n", + "\"\"\"\n", + "This model is defined as follows:\n", + "\n", + "class SimpleModel(torch.nn.Module):\n", + " def __init__(self):\n", + " super(SimpleModel, self).__init__()\n", + " self.conv1 = torch.nn.Conv2d(1, 3, kernel_size=3)\n", + " self.relu = torch.nn.ReLU()\n", + " self.fc = torch.nn.Linear(3 * 26 * 26, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.conv1(x)\n", + " x = self.relu(x)\n", + " x = torch.flatten(x, 1)\n", + " x = self.fc(x)\n", + " return x\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we can convert the model to tensorflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ivy\n", + "\n", + "TFSimpleModel = ivy.transpile(SimpleModel, source=\"torch\", target=\"tensorflow\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can use the model with TensorFlow" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorShape([1, 10])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import tensorflow as tf\n", + "\n", + "tf_model = TFSimpleModel()\n", + "tf_model(tf.random.normal((1, 1, 28, 28))).shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also take advantage of TensorFlow-specific features, such as `tf.function`:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TensorShape([1, 10])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "compiled_model = tf.function(tf_model)\n", + "compiled_model(tf.random.normal((1, 1, 28, 28))).shape" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}