diff --git a/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
new file mode 100644
index 000000000000..455b5cccd5d5
--- /dev/null
+++ b/examples/notebooks/beam-ml/bigtable_enrichment_transform.ipynb
@@ -0,0 +1,854 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "fFjof1NgAJwu",
+ "cellView": "form"
+ },
+ "outputs": [],
+ "source": [
+ "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
+ "\n",
+ "# Licensed to the Apache Software Foundation (ASF) under one\n",
+ "# or more contributor license agreements. See the NOTICE file\n",
+ "# distributed with this work for additional information\n",
+ "# regarding copyright ownership. The ASF licenses this file\n",
+ "# to you under the Apache License, Version 2.0 (the\n",
+ "# \"License\"); you may not use this file except in compliance\n",
+ "# with the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing,\n",
+ "# software distributed under the License is distributed on an\n",
+ "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+ "# KIND, either express or implied. See the License for the\n",
+ "# specific language governing permissions and limitations\n",
+ "# under the License"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "A8xNRyZMW1yK"
+ },
+ "source": [
+ "# Use Apache Beam and Bigtable to enrich data\n",
+ "\n",
+ "
\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HrCtxslBGK8Z"
+ },
+ "source": [
+ "This notebook shows how to enrich data by using the Apache Beam [enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) with [Bigtable](https://cloud.google.com/bigtable/docs/overview). The enrichment transform is a turnkey transform in Apache Beam that lets you enrich data by using a key-value lookup. This transform has the following features:\n",
+ "\n",
+ "- The transform has a built-in Apache Beam handler that interacts with Bigtable to get data to use in the enrichment.\n",
+ "- The enrichment transform uses client-side throttling to manage rate limiting the requests. The requests are exponentially backed off with a default retry strategy. You can configure rate limiting to suit your use case."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This notebook demonstrates the following ecommerce use case:\n",
+ "\n",
+ "A stream of online transaction from [Pub/Sub](https://cloud.google.com/pubsub/docs/guides) contains the following fields: `sale_id`, `product_id`, `customer_id`, `quantity`, and `price`. Additional customer demographic data is stored in a separate Bigtable cluster. The demographic data is used to enrich the event stream from Pub/Sub. Then, the enriched data is used to predict the next product to recommended to a customer."
+ ],
+ "metadata": {
+ "id": "ltn5zrBiGS9C"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gVCtGOKTHMm4"
+ },
+ "source": [
+ "## Before you begin\n",
+ "Set up your environment and download dependencies."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YDHPlMjZRuY0"
+ },
+ "source": [
+ "### Install Apache Beam\n",
+ "To use the enrichment transform with the built-in Bigtable handler, install the Apache Beam SDK version 2.54.0 or later."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jBakpNZnAhqk"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install torch\n",
+ "!pip install apache_beam[interactive,gcp]==2.54.0 --quiet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import datetime\n",
+ "import json\n",
+ "import math\n",
+ "\n",
+ "from typing import Any\n",
+ "from typing import Dict\n",
+ "\n",
+ "import torch\n",
+ "from google.cloud import pubsub_v1\n",
+ "from google.cloud.bigtable import Client\n",
+ "from google.cloud.bigtable import column_family\n",
+ "\n",
+ "import apache_beam as beam\n",
+ "import apache_beam.runners.interactive.interactive_beam as ib\n",
+ "from apache_beam.ml.inference.base import RunInference\n",
+ "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n",
+ "from apache_beam.options import pipeline_options\n",
+ "from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n",
+ "from apache_beam.transforms.enrichment import Enrichment\n",
+ "from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler"
+ ],
+ "metadata": {
+ "id": "SiJii48A2Rnb"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "X80jy3FqHjK4"
+ },
+ "source": [
+ "### Authenticate with Google Cloud\n",
+ "This notebook reads data from Pub/Sub and Bigtable. To use your Google Cloud account, authenticate this notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Kz9sccyGBqz3"
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import auth\n",
+ "auth.authenticate_user()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Replace ``, ``, and `` with the appropriate values for your setup. These fields are used with Bigtable."
+ ],
+ "metadata": {
+ "id": "nAmGgUMt48o9"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "wEXucyi2liij"
+ },
+ "outputs": [],
+ "source": [
+ "PROJECT_ID = \"\"\n",
+ "INSTANCE_ID = \"\"\n",
+ "TABLE_ID = \"\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Train the model\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "RpqZFfFfA_Dt"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Create sample data by using the format `[product_id, quantity, price, customer_id, customer_location, recommend_product_id]`."
+ ],
+ "metadata": {
+ "id": "8cUpV7mkB_xE"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "data = [\n",
+ " [3, 5, 127, 9, 'China', 7], [1, 6, 167, 5, 'Peru', 4], [5, 4, 91, 2, 'USA', 8], [7, 2, 52, 1, 'India', 4], [1, 8, 118, 3, 'UK', 8], [4, 6, 132, 8, 'Mexico', 2],\n",
+ " [6, 3, 154, 6, 'Brazil', 3], [4, 7, 163, 1, 'India', 7], [5, 2, 80, 4, 'Egypt', 9], [9, 4, 107, 7, 'Bangladesh', 1], [2, 9, 192, 8, 'Mexico', 4], [4, 5, 116, 5, 'Peru', 8],\n",
+ " [8, 1, 195, 1, 'India', 7], [8, 6, 153, 5, 'Peru', 1], [5, 3, 120, 6, 'Brazil', 2], [2, 7, 187, 7, 'Bangladesh', 4], [1, 8, 103, 6, 'Brazil', 8], [2, 9, 181, 1, 'India', 8],\n",
+ " [6, 5, 166, 3, 'UK', 5], [3, 4, 115, 8, 'Mexico', 1], [4, 7, 170, 4, 'Egypt', 2], [9, 3, 141, 7, 'Bangladesh', 3], [9, 3, 157, 1, 'India', 2], [7, 6, 128, 9, 'China', 1],\n",
+ " [1, 8, 102, 3, 'UK', 4], [5, 2, 107, 4, 'Egypt', 6], [6, 5, 164, 8, 'Mexico', 9], [4, 7, 188, 5, 'Peru', 1], [8, 1, 184, 1, 'India', 2], [8, 6, 198, 2, 'USA', 5],\n",
+ " [5, 3, 105, 6, 'Brazil', 7], [2, 7, 162, 7, 'Bangladesh', 7], [1, 8, 133, 9, 'China', 3], [2, 9, 173, 1, 'India', 7], [6, 5, 183, 5, 'Peru', 8], [3, 4, 191, 3, 'UK', 6],\n",
+ " [4, 7, 123, 2, 'USA', 5], [9, 3, 159, 8, 'Mexico', 2], [9, 3, 146, 4, 'Egypt', 8], [7, 6, 194, 1, 'India', 8], [3, 5, 112, 6, 'Brazil', 1], [4, 6, 101, 7, 'Bangladesh', 2],\n",
+ " [8, 1, 192, 4, 'Egypt', 4], [7, 2, 196, 5, 'Peru', 6], [9, 4, 124, 9, 'China', 7], [3, 4, 129, 5, 'Peru', 6], [6, 3, 151, 8, 'Mexico', 9], [5, 7, 114, 7, 'Bangladesh', 4],\n",
+ " [4, 7, 175, 6, 'Brazil', 5], [1, 8, 121, 1, 'India', 2], [4, 6, 187, 2, 'USA', 5], [6, 5, 144, 9, 'China', 9], [9, 4, 103, 5, 'Peru', 3], [5, 3, 84, 3, 'UK', 1],\n",
+ " [3, 5, 193, 2, 'USA', 4], [4, 7, 135, 1, 'India', 1], [7, 6, 148, 8, 'Mexico', 8], [1, 6, 160, 5, 'Peru', 7], [8, 6, 155, 6, 'Brazil', 9], [5, 7, 183, 7, 'Bangladesh', 2],\n",
+ " [2, 9, 125, 4, 'Egypt', 4], [6, 3, 111, 9, 'China', 9], [5, 2, 132, 3, 'UK', 3], [4, 5, 104, 7, 'Bangladesh', 7], [2, 7, 177, 8, 'Mexico', 7]]"
+ ],
+ "metadata": {
+ "id": "TpxDHGObBEsj"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "countries_to_id = {'India': 1, 'USA': 2, 'UK': 3, 'Egypt': 4, 'Peru': 5,\n",
+ " 'Brazil': 6, 'Bangladesh': 7, 'Mexico': 8, 'China': 9}"
+ ],
+ "metadata": {
+ "id": "bQt1cB4-CSBd"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Preprocess the data:\n",
+ "\n",
+ "1. Convert the lists to tensors.\n",
+ "2. Separate the features from the expected prediction."
+ ],
+ "metadata": {
+ "id": "Y0Duet4nCdN1"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "X = [torch.tensor(item[:4]+[countries_to_id[item[4]]], dtype=torch.float) for item in data]\n",
+ "Y = [torch.tensor(item[-1], dtype=torch.float) for item in data]"
+ ],
+ "metadata": {
+ "id": "7TT1O7sBCaZN"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Define a simple model that has five input features and predicts a single value."
+ ],
+ "metadata": {
+ "id": "q6wB_ZsXDjjd"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def build_model(n_inputs, n_outputs):\n",
+ " \"\"\"build_model builds and returns a model that takes\n",
+ " `n_inputs` features and predicts `n_outputs` value\"\"\"\n",
+ " return torch.nn.Sequential(\n",
+ " torch.nn.Linear(n_inputs, 8),\n",
+ " torch.nn.ReLU(),\n",
+ " torch.nn.Linear(8, 16),\n",
+ " torch.nn.ReLU(),\n",
+ " torch.nn.Linear(16, n_outputs))"
+ ],
+ "metadata": {
+ "id": "nphNfhUnESES"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Train the model."
+ ],
+ "metadata": {
+ "id": "_sBSzDllEmCz"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model = build_model(n_inputs=5, n_outputs=1)\n",
+ "\n",
+ "loss_fn = torch.nn.MSELoss()\n",
+ "optimizer = torch.optim.Adam(model.parameters())\n",
+ "\n",
+ "for epoch in range(1000):\n",
+ " print(f'Epoch {epoch}: ---')\n",
+ " optimizer.zero_grad()\n",
+ " for i in range(len(X)):\n",
+ " pred = model(X[i])\n",
+ " loss = loss_fn(pred, Y[i])\n",
+ " loss.backward()\n",
+ " optimizer.step()"
+ ],
+ "metadata": {
+ "id": "CaYrplaPDayp"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Save the model to the `STATE_DICT_PATH` variable."
+ ],
+ "metadata": {
+ "id": "_rJYv8fFFPYb"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "STATE_DICT_PATH = './model.pth'\n",
+ "\n",
+ "torch.save(model.state_dict(), STATE_DICT_PATH)"
+ ],
+ "metadata": {
+ "id": "W4t260o9FURP"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Set up the Bigtable table\n",
+ "\n",
+ "Create a sample Bigtable table for this notebook."
+ ],
+ "metadata": {
+ "id": "ouMQZ4sC4zuO"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Connect to the Bigtable instance. If you don't have admin access, then drop `admin=True`.\n",
+ "client = Client(project=PROJECT_ID, admin=True)\n",
+ "instance = client.instance(INSTANCE_ID)\n",
+ "\n",
+ "# Create a column family.\n",
+ "column_family_id = 'demograph'\n",
+ "max_versions_rule = column_family.MaxVersionsGCRule(2)\n",
+ "column_families = {column_family_id: max_versions_rule}\n",
+ "\n",
+ "# Create a table.\n",
+ "table = instance.table(TABLE_ID)\n",
+ "\n",
+ "# You need admin access to use `.exists()`. If you don't have the admin access, then\n",
+ "# comment out the if-else block.\n",
+ "if not table.exists():\n",
+ " table.create(column_families=column_families)\n",
+ "else:\n",
+ " print(\"Table %s already exists in %s:%s\" % (TABLE_ID, PROJECT_ID, INSTANCE_ID))"
+ ],
+ "metadata": {
+ "id": "E7Y4ipuL5kFD"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Add rows to the table for the enrichment example."
+ ],
+ "metadata": {
+ "id": "eQLkSg3p7WAm"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Define column names for the table.\n",
+ "customer_id = 'customer_id'\n",
+ "customer_name = 'customer_name'\n",
+ "customer_location = 'customer_location'\n",
+ "\n",
+ "# The following data is sample data to insert into Bigtable.\n",
+ "customers = [\n",
+ " {\n",
+ " 'customer_id': 1, 'customer_name': 'Sam', 'customer_location': 'India'\n",
+ " },\n",
+ " {\n",
+ " 'customer_id': 2, 'customer_name': 'John', 'customer_location': 'USA'\n",
+ " },\n",
+ " {\n",
+ " 'customer_id': 3, 'customer_name': 'Travis', 'customer_location': 'UK'\n",
+ " },\n",
+ "]\n",
+ "\n",
+ "for customer in customers:\n",
+ " row_key = str(customer[customer_id]).encode()\n",
+ " row = table.direct_row(row_key)\n",
+ " row.set_cell(\n",
+ " column_family_id,\n",
+ " customer_id.encode(),\n",
+ " str(customer[customer_id]),\n",
+ " timestamp=datetime.datetime.utcnow())\n",
+ " row.set_cell(\n",
+ " column_family_id,\n",
+ " customer_name.encode(),\n",
+ " customer[customer_name],\n",
+ " timestamp=datetime.datetime.utcnow())\n",
+ " row.set_cell(\n",
+ " column_family_id,\n",
+ " customer_location.encode(),\n",
+ " customer[customer_location],\n",
+ " timestamp=datetime.datetime.utcnow())\n",
+ " row.commit()\n",
+ " print('Inserted row for key: %s' % customer[customer_id])"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "LI6oYkZ97Vtu",
+ "outputId": "c72b28b5-8692-40f5-f8da-85622437d8f7"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Inserted row for key: 1\n",
+ "Inserted row for key: 2\n",
+ "Inserted row for key: 3\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Publish messages to Pub/Sub\n",
+ "\n",
+ "Use the Pub/Sub Python client to publish messages.\n"
+ ],
+ "metadata": {
+ "id": "pHODouJDwc60"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Replace with the name of your Pub/Sub topic.\n",
+ "TOPIC = \"\"\n",
+ "\n",
+ "# Replace with the subscription for your topic.\n",
+ "SUBSCRIPTION = \"\"\n"
+ ],
+ "metadata": {
+ "id": "QKCuwDioxw-f"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "messages = [\n",
+ " {'sale_id': i, 'customer_id': i, 'product_id': i, 'quantity': i, 'price': i*100}\n",
+ " for i in range(1,4)\n",
+ " ]\n",
+ "\n",
+ "publisher = pubsub_v1.PublisherClient()\n",
+ "topic_name = publisher.topic_path(PROJECT_ID, TOPIC)\n",
+ "\n",
+ "for message in messages:\n",
+ " data = json.dumps(message).encode('utf-8')\n",
+ " publish_future = publisher.publish(topic_name, data)"
+ ],
+ "metadata": {
+ "id": "MaCJwaPexPKZ"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Use the Bigtable enrichment handler\n",
+ "\n",
+ "The [`BigTableEnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler) is a built-in handler included in the Apache Beam SDK versions 2.54.0 and later."
+ ],
+ "metadata": {
+ "id": "zPSFEMm02omi"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "To establish a client for the Bigtable enrichment handler, replace ``, ``, and `` with the appropriate values for those fields. The `row_key` variable is the field name from the input row that contains the row key to use when querying Bigtable.\n",
+ "\n",
+ "To convert a `string` type to a `byte` type or a `byte` type to a `string` type from Bigtable, you can configure additional options, such as [`app_profile_id`](https://cloud.google.com/bigtable/docs/app-profiles), [`row_filter`](https://cloud.google.com/python/docs/reference/bigtable/latest/row-filters), and [`encoding`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.BigTableEnrichmentHandler:~:text=for%20more%20details.-,encoding,-(str)%20%E2%80%93%20encoding) type.\n",
+ "\n",
+ "The default `encoding` type is `utf-8`.\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "K41xhvmA5yQk"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "row_key = 'customer_id'"
+ ],
+ "metadata": {
+ "id": "3dB26jhI45gd"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "bigtable_handler = BigTableEnrichmentHandler(project_id=PROJECT_ID,\n",
+ " instance_id=INSTANCE_ID,\n",
+ " table_id=TABLE_ID,\n",
+ " row_key=row_key)"
+ ],
+ "metadata": {
+ "id": "cr1j_DHK4gA4"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The `BigTableEnrichmentHandler` returns the latest value from the table without its associated timestamp for the `row_key` that you provide. If you want to fetch the `timestamp` associated with the `row_key` value, then pass `include_timestamp=True` to the handler.\n",
+ "\n",
+ "**Note:** When exceptions occur, by default, the logging severity is set to warning ([`ExceptionLevel.WARN`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel.WARN)). To configure the severity to raise exceptions, set `exception_level` to [`ExceptionLevel.RAISE`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel.RAISE). To ignore exceptions, set `exception_level` to [`ExceptionLevel.QUIET`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigtable.html#apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel.QUIET)."
+ ],
+ "metadata": {
+ "id": "yFMcaf8i7TbI"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Use the enrichment transform\n",
+ "\n",
+ "To use the [enrichment transform](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.Enrichment), the [`EnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.EnrichmentSourceHandler) parameter is required. You can also use a configuration parameter to specify a `lambda` for a join function, a timeout, a throttler, and a repeater (retry strategy).\n",
+ "\n",
+ "\n",
+ "* `join_fn`: A lambda function that takes dictionaries as input and returns an enriched row (`Callable[[Dict[str, Any], Dict[str, Any]], beam.Row]`). The enriched row specifies how to join the data fetched from the API. Defaults to a [cross-join](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.cross_join).\n",
+ "* `timeout`: The number of seconds to wait for the request to be completed by the API before timing out. Defaults to 30 seconds.\n",
+ "* `throttler`: Specifies the throttling mechanism. The only supported option is default client-side adaptive throttling.\n",
+ "* `repeater`: Specifies the retry strategy when errors like `TooManyRequests` and `TimeoutException` occur. Defaults to [`ExponentialBackOffRepeater`](https://beam.apache.org/releases/pydoc/current/apache_beam.io.requestresponse.html#apache_beam.io.requestresponse.ExponentialBackOffRepeater).\n"
+ ],
+ "metadata": {
+ "id": "-Lvo8O2V-0Ey"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The following example demonstrates the code needed to add this transform to your pipeline.\n",
+ "\n",
+ "\n",
+ "```\n",
+ "with beam.Pipeline() as p:\n",
+ " output = (p\n",
+ " ...\n",
+ " | \"Enrich with BigTable\" >> Enrichment(bigtable_handler, timeout=10)\n",
+ " | \"RunInference\" >> RunInference(model_handler)\n",
+ " ...\n",
+ " )\n",
+ "```\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "xJTCfSmiV1kv"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "To make a prediction, use the following fields: `product_id`, `quantity`, `price`, `customer_id`, and `customer_location`. Retrieve the value of the `customer_location` field from Bigtable.\n",
+ "\n",
+ "Because the enrichment transform performs a [`cross_join`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment.html#apache_beam.transforms.enrichment.cross_join) by default, design the custom join to enrich the input data. This design ensures that the join includes only the specified fields."
+ ],
+ "metadata": {
+ "id": "F-xjiP_pHWZr"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def custom_join(left: Dict[str, Any], right: Dict[str, Any]):\n",
+ " enriched = {}\n",
+ " enriched['product_id'] = left['product_id']\n",
+ " enriched['quantity'] = left['quantity']\n",
+ " enriched['price'] = left['price']\n",
+ " enriched['customer_id'] = left['customer_id']\n",
+ " enriched['customer_location'] = right['demograph']['customer_location']\n",
+ " return beam.Row(**enriched)"
+ ],
+ "metadata": {
+ "id": "8LnCnEPNIPtg"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Use the `PyTorchModelHandlerTensor` interface to run inference\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "CX9Cqybu6scV"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Because the enrichment transform outputs data in the format `beam.Row`, to make it compatible with the [`PyTorchModelHandlerTensor`](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.pytorch_inference.html#apache_beam.ml.inference.pytorch_inference.PytorchModelHandlerTensor) interface, convert it to `torch.tensor`. Additionally, the enriched field `customer_location` is a `string` type, but the model requires a `float` type. Convert the `customer_location` field to a `float` type."
+ ],
+ "metadata": {
+ "id": "zy5Jl7_gLklX"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def convert_row_to_tensor(element: beam.Row):\n",
+ " row_dict = element._asdict()\n",
+ " row_dict['customer_location'] = countries_to_id[row_dict['customer_location']]\n",
+ " return torch.tensor(list(row_dict.values()), dtype=torch.float)"
+ ],
+ "metadata": {
+ "id": "KBKoB06nL4LF"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Initialize the model handler with the preprocessing function."
+ ],
+ "metadata": {
+ "id": "-tGHyB_vL3rJ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "model_handler = PytorchModelHandlerTensor(state_dict_path=STATE_DICT_PATH,\n",
+ " model_class=build_model,\n",
+ " model_params={'n_inputs':5, 'n_outputs':1}\n",
+ " ).with_preprocess_fn(convert_row_to_tensor)"
+ ],
+ "metadata": {
+ "id": "VqUUEwcU-r2e"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Define a `DoFn` to format the output."
+ ],
+ "metadata": {
+ "id": "vNHI4gVgNec2"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class PostProcessor(beam.DoFn):\n",
+ " def process(self, element, *args, **kwargs):\n",
+ " print('Customer %d who bought product %d is recommended to buy product %d' % (element.example[3], element.example[0], math.ceil(element.inference[0])))"
+ ],
+ "metadata": {
+ "id": "rkN-_Yf4Nlwy"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0a1zerXycQ0z"
+ },
+ "source": [
+ "## Run the pipeline\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Configure the pipeline to run in streaming mode."
+ ],
+ "metadata": {
+ "id": "WrwY0_gV_IDK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "options = pipeline_options.PipelineOptions()\n",
+ "options.view_as(pipeline_options.StandardOptions).streaming = True # Streaming mode is set True"
+ ],
+ "metadata": {
+ "id": "t0425sYBsYtB"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Pub/Sub sends the data in bytes. Convert the data to `beam.Row` objects by using a `DoFn`."
+ ],
+ "metadata": {
+ "id": "DBNijQDY_dRe"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class DecodeBytes(beam.DoFn):\n",
+ " \"\"\"\n",
+ " The DecodeBytes `DoFn` converts the data read from Pub/Sub to `beam.Row`.\n",
+ " First, decode the encoded string. Convert the output to\n",
+ " a `dict` with `json.loads()`, which is used to create a `beam.Row`.\n",
+ " \"\"\"\n",
+ " def process(self, element, *args, **kwargs):\n",
+ " element_dict = json.loads(element.decode('utf-8'))\n",
+ " yield beam.Row(**element_dict)"
+ ],
+ "metadata": {
+ "id": "sRw9iL8pKP5O"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Use the following code to run the pipeline.\n",
+ "\n",
+ "**Note:** Because this pipeline is a streaming pipeline, you need to manually stop the cell. If you don't stop the cell, the pipeline continues to run."
+ ],
+ "metadata": {
+ "id": "xofUJym-_GuB"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "St07XoibcQSb",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "outputId": "34e0a603-fb77-455c-e40b-d15b672edeb2"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/javascript": [
+ "\n",
+ " if (typeof window.interactive_beam_jquery == 'undefined') {\n",
+ " var jqueryScript = document.createElement('script');\n",
+ " jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
+ " jqueryScript.type = 'text/javascript';\n",
+ " jqueryScript.onload = function() {\n",
+ " var datatableScript = document.createElement('script');\n",
+ " datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
+ " datatableScript.type = 'text/javascript';\n",
+ " datatableScript.onload = function() {\n",
+ " window.interactive_beam_jquery = jQuery.noConflict(true);\n",
+ " window.interactive_beam_jquery(document).ready(function($){\n",
+ " \n",
+ " });\n",
+ " }\n",
+ " document.head.appendChild(datatableScript);\n",
+ " };\n",
+ " document.head.appendChild(jqueryScript);\n",
+ " } else {\n",
+ " window.interactive_beam_jquery(document).ready(function($){\n",
+ " \n",
+ " });\n",
+ " }"
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Customer 1 who bought product 1 is recommended to buy product 3\n",
+ "Customer 2 who bought product 2 is recommended to buy product 5\n",
+ "Customer 3 who bought product 3 is recommended to buy product 7\n"
+ ]
+ }
+ ],
+ "source": [
+ "with beam.Pipeline(options=options) as p:\n",
+ " _ = (p\n",
+ " | \"Read from Pub/Sub\" >> beam.io.ReadFromPubSub(subscription=SUBSCRIPTION)\n",
+ " | \"ConvertToRow\" >> beam.ParDo(DecodeBytes())\n",
+ " | \"Enrichment\" >> Enrichment(bigtable_handler, join_fn=custom_join)\n",
+ " | \"RunInference\" >> RunInference(model_handler)\n",
+ " | \"Format Output\" >> beam.ParDo(PostProcessor())\n",
+ " )\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file