From 5dd3022297b9583d4ac3bea3846676e5eea15fdb Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Mon, 22 Jul 2024 10:55:34 +0100 Subject: [PATCH] Add tutorial for creating embeddings over AOI --- docs/_toc.yml | 6 +- .../create-embeddings-over-aoi.ipynb | 698 ++++++++++++++++++ 2 files changed, 702 insertions(+), 2 deletions(-) create mode 100644 docs/tutorials/create-embeddings-over-aoi.ipynb diff --git a/docs/_toc.yml b/docs/_toc.yml index 7fb8573a..9dc54363 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -28,14 +28,16 @@ parts: file: tutorials/reconstruction - caption: Finetune examples chapters: + - title: Finetunig on an embeddings database + file: finetune/finetune-on-embeddings - title: Segmentation using Chesapeake file: finetune/segment - title: Classification using Eurosat file: finetune/classify - title: Regression using Biomasters file: finetune/regression - - title: Finetunig on an embeddings database - file: finetune/finetune-on-embeddings + - title: Create embeddings over AOI + file: tutorials/create-embeddings-over-aoi.ipynb - caption: About Clay chapters: - title: GitHub diff --git a/docs/tutorials/create-embeddings-over-aoi.ipynb b/docs/tutorials/create-embeddings-over-aoi.ipynb new file mode 100644 index 00000000..6db9a2f9 --- /dev/null +++ b/docs/tutorials/create-embeddings-over-aoi.ipynb @@ -0,0 +1,698 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "yEwfOrRNwYmh" + }, + "source": [ + "# Embedding creation\n", + "\n", + "This tutorial shows how to create embeddings using Clay and store them in geoparquet.\n", + "Creating embeddings is useful for use in similarity seach applications, and when\n", + "training classification heads on top of the embeddings, as shown in the\n", + "[](bla) tutorial.\n", + "\n", + "Creating embeddings consists of three simple steps:\n", + "\n", + "1. Search for imagery to be used\n", + "2. Create chips dynamically from the source data with [stacchip](https://clay-foundation.github.io/stacchip/)\n", + "3. Pass chips to Clay and store the output as geoparquet\n", + "\n", + "Lets look at these one by one, but first ensure that stacchip is installed,\n", + "a library we are going to use to generate dynamic chips to pass to Clay." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "m5ioh-DTh6Kg", + "outputId": "5ccdfb4d-5db8-477f-eb3d-033715a09fa4" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: stacchip in /usr/local/lib/python3.10/dist-packages (0.1.35)\n", + "Requirement already satisfied: boto3>=1.29.0 in /usr/local/lib/python3.10/dist-packages (from stacchip) (1.34.145)\n", + "Requirement already satisfied: geoarrow-pyarrow>=0.1.2 in /usr/local/lib/python3.10/dist-packages (from stacchip) (0.1.2)\n", + "Requirement already satisfied: geopandas>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from stacchip) (1.0.1)\n", + "Requirement already satisfied: numpy<2.0,>=1.26.0 in /usr/local/lib/python3.10/dist-packages (from stacchip) (1.26.4)\n", + "Requirement already satisfied: planetary-computer>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from stacchip) (1.0.0)\n", + "Requirement already satisfied: pyarrow>=14.0.1 in /usr/local/lib/python3.10/dist-packages (from stacchip) (14.0.2)\n", + "Requirement already satisfied: pystac-client>=0.7.5 in /usr/local/lib/python3.10/dist-packages (from stacchip) (0.8.2)\n", + "Requirement already satisfied: pystac>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from stacchip) (1.10.1)\n", + "Requirement already satisfied: rasterio>=1.3.9 in /usr/local/lib/python3.10/dist-packages (from stacchip) (1.3.10)\n", + "Requirement already satisfied: rio-stac>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from stacchip) (0.9.0)\n", + "Requirement already satisfied: botocore<1.35.0,>=1.34.145 in /usr/local/lib/python3.10/dist-packages (from boto3>=1.29.0->stacchip) (1.34.145)\n", + "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from boto3>=1.29.0->stacchip) (1.0.1)\n", + "Requirement already satisfied: s3transfer<0.11.0,>=0.10.0 in /usr/local/lib/python3.10/dist-packages (from boto3>=1.29.0->stacchip) (0.10.2)\n", + "Requirement already satisfied: geoarrow-c in /usr/local/lib/python3.10/dist-packages (from geoarrow-pyarrow>=0.1.2->stacchip) (0.1.2)\n", + "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from geoarrow-pyarrow>=0.1.2->stacchip) (0.6)\n", + "Requirement already satisfied: pyogrio>=0.7.2 in /usr/local/lib/python3.10/dist-packages (from geopandas>=0.14.1->stacchip) (0.9.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from geopandas>=0.14.1->stacchip) (24.1)\n", + "Requirement already satisfied: pandas>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from geopandas>=0.14.1->stacchip) (2.0.3)\n", + "Requirement already satisfied: pyproj>=3.3.0 in /usr/local/lib/python3.10/dist-packages (from geopandas>=0.14.1->stacchip) (3.6.1)\n", + "Requirement already satisfied: shapely>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from geopandas>=0.14.1->stacchip) (2.0.5)\n", + "Requirement already satisfied: click>=7.1 in /usr/local/lib/python3.10/dist-packages (from planetary-computer>=1.0.0->stacchip) (8.1.7)\n", + "Requirement already satisfied: pydantic>=1.7.3 in /usr/local/lib/python3.10/dist-packages (from planetary-computer>=1.0.0->stacchip) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.5 in /usr/local/lib/python3.10/dist-packages (from planetary-computer>=1.0.0->stacchip) (2023.4)\n", + "Requirement already satisfied: requests>=2.25.1 in /usr/local/lib/python3.10/dist-packages (from planetary-computer>=1.0.0->stacchip) (2.31.0)\n", + "Requirement already satisfied: python-dotenv in /usr/local/lib/python3.10/dist-packages (from planetary-computer>=1.0.0->stacchip) (1.0.1)\n", + "Requirement already satisfied: python-dateutil>=2.7.0 in /usr/local/lib/python3.10/dist-packages (from pystac>=1.9.0->stacchip) (2.8.2)\n", + "Requirement already satisfied: affine in /usr/local/lib/python3.10/dist-packages (from rasterio>=1.3.9->stacchip) (2.4.0)\n", + "Requirement already satisfied: attrs in /usr/local/lib/python3.10/dist-packages (from rasterio>=1.3.9->stacchip) (23.2.0)\n", + "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from rasterio>=1.3.9->stacchip) (2024.7.4)\n", + "Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.10/dist-packages (from rasterio>=1.3.9->stacchip) (0.7.2)\n", + "Requirement already satisfied: snuggs>=1.4.1 in /usr/local/lib/python3.10/dist-packages (from rasterio>=1.3.9->stacchip) (1.4.7)\n", + "Requirement already satisfied: click-plugins in /usr/local/lib/python3.10/dist-packages (from rasterio>=1.3.9->stacchip) (1.1.1)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from rasterio>=1.3.9->stacchip) (67.7.2)\n", + "Requirement already satisfied: urllib3!=2.2.0,<3,>=1.25.4 in /usr/local/lib/python3.10/dist-packages (from botocore<1.35.0,>=1.34.145->boto3>=1.29.0->stacchip) (2.0.7)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.4.0->geopandas>=0.14.1->stacchip) (2024.1)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.7.3->planetary-computer>=1.0.0->stacchip) (0.7.0)\n", + "Requirement already satisfied: pydantic-core==2.20.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.7.3->planetary-computer>=1.0.0->stacchip) (2.20.1)\n", + "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.7.3->planetary-computer>=1.0.0->stacchip) (4.12.2)\n", + "Requirement already satisfied: jsonschema~=4.18 in /usr/local/lib/python3.10/dist-packages (from pystac>=1.9.0->stacchip) (4.19.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7.0->pystac>=1.9.0->stacchip) (1.16.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.1->planetary-computer>=1.0.0->stacchip) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.25.1->planetary-computer>=1.0.0->stacchip) (3.7)\n", + "Requirement already satisfied: pyparsing>=2.1.6 in /usr/local/lib/python3.10/dist-packages (from snuggs>=1.4.1->rasterio>=1.3.9->stacchip) (3.1.2)\n", + "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema~=4.18->pystac>=1.9.0->stacchip) (2023.12.1)\n", + "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema~=4.18->pystac>=1.9.0->stacchip) (0.35.1)\n", + "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema~=4.18->pystac>=1.9.0->stacchip) (0.19.0)\n" + ] + } + ], + "source": [ + "! pip install stacchip" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "id": "v8Mvx05TXyYS" + }, + "outputs": [], + "source": [ + "import math\n", + "\n", + "import geopandas as gpd\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pystac_client\n", + "import torch\n", + "import yaml\n", + "from box import Box\n", + "from matplotlib import pyplot as plt\n", + "from rasterio.enums import Resampling\n", + "from shapely import Point\n", + "from torchvision.transforms import v2\n", + "import numpy as np\n", + "import math\n", + "import geoarrow.pyarrow as ga\n", + "import numpy as np\n", + "import pyarrow as pa\n", + "import pyarrow.parquet as pq\n", + "\n", + "import pystac_client\n", + "from stacchip.indexer import Sentinel2Indexer\n", + "from stacchip.chipper import Chipper\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import requests" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J8O7mQSoyK_-" + }, + "source": [ + "### Note: This notebook requires CUDA\n", + "\n", + "This is because we are using the Clay encoder from a [torchscript](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html), which was compiled using CUDA." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "0qx3gkZiTSpe" + }, + "outputs": [], + "source": [ + "if not torch.cuda.is_available():\n", + " raise ValueError(\"The compiled version of Clay needs CUDA\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZZLc4ePFw1LN" + }, + "source": [ + "## Find data for AOI\n", + "\n", + "The first step is to find STAC items of imagery that we want to use\n", + "to create embeddings. In this example we are going to use\n", + "[Earth Genome's composite dataset](https://medium.com/earthrisemedia/announcing-public-access-to-our-global-cloud-free-imagery-archive-bb21311abb69)\n", + "which comes with a great STAC catalog.\n", + "\n", + "We are also going to create embeddings along time so that we have multiple\n", + "embeddings for the same location at different moments in time." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "id": "ec1SxsCbU8y5" + }, + "outputs": [], + "source": [ + "# Point over Monchique Portugal\n", + "lat, lon = 37.30939, -8.57207\n", + "\n", + "# Dates of a large forest fire\n", + "start = \"2018-07-01\"\n", + "end = \"2018-09-01\"" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "StAZ4Fp8VHSj", + "outputId": "0f9cd0ec-fcee-48d2-9f86-ed794e865687" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/pystac_client/item_search.py:851: FutureWarning: get_all_items() is deprecated, use item_collection() instead.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 12 items\n" + ] + } + ], + "source": [ + "# Optimize GDAL settings for cloud optimized reading\n", + "os.environ[\"GDAL_DISABLE_READDIR_ON_OPEN\"] = \"EMPTY_DIR\"\n", + "os.environ[\"AWS_REQUEST_PAYER\"] = \"requester\"\n", + "\n", + "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", + "COLLECTION = \"sentinel-2-l2a\"\n", + "\n", + "# Search the catalogue\n", + "catalog = pystac_client.Client.open(STAC_API)\n", + "search = catalog.search(\n", + " collections=[COLLECTION],\n", + " datetime=f\"{start}/{end}\",\n", + " bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),\n", + " max_items=100,\n", + " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", + ")\n", + "\n", + "all_items = search.get_all_items()\n", + "\n", + "# Reduce to one per date (there might be some duplicates\n", + "# based on the location)\n", + "items = []\n", + "dates = []\n", + "for item in all_items:\n", + " if item.datetime.date() not in dates:\n", + " items.append(item)\n", + " dates.append(item.datetime.date())\n", + "\n", + "print(f\"Found {len(items)} items\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "azUDzttEy9fp" + }, + "source": [ + "To speed up processing in this example, we limit the number of chips to 3 per Sentinel-2 scene. Remove this limit in a real use case." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3qyq0XYZVIk4", + "outputId": "cdbb1507-cf19-4ad7-ab7e-9005c8bb264f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n", + "Working on \n" + ] + } + ], + "source": [ + "chips = []\n", + "datetimes = []\n", + "bboxs = []\n", + "chip_ids = []\n", + "item_ids = []\n", + "\n", + "for item in items:\n", + " print(f\"Working on {item}\")\n", + "\n", + " # Index the chips in the item\n", + " indexer = Sentinel2Indexer(item)\n", + "\n", + " # Instanciate the chipper\n", + " chipper = Chipper(indexer, assets=[\"red\", \"green\", \"blue\", \"nir\", \"scl\"])\n", + "\n", + " # Get first chip for the \"image\" asset key\n", + " for idx, (x, y, chip) in enumerate(chipper):\n", + " if idx > 2:\n", + " break\n", + " del chip[\"scl\"]\n", + " chips.append(chip)\n", + " datetimes.append(item.datetime)\n", + " bboxs.append(indexer.get_chip_bbox(x, y))\n", + " chip_ids.append((x, y))\n", + " item_ids.append(item.id)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LefH0VduWs53", + "outputId": "88aa9fb9-b219-41e3-f51d-9235f2df40f5" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(36, 4, 256, 256)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pixels = np.array([np.array(list(chip.values())).squeeze() for chip in chips])\n", + "pixels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "id": "Ql-2Cw4mXlW0" + }, + "outputs": [], + "source": [ + "# Extract mean, std, and wavelengths from metadata\n", + "platform = \"sentinel-2-l2a\"\n", + "# Retrieve the file content from the URL\n", + "\n", + "url = (\n", + " \"https://raw.githubusercontent.com/Clay-foundation/model/main/configs/metadata.yaml\"\n", + ")\n", + "response = requests.get(url, allow_redirects=True)\n", + "\n", + "# Convert bytes to string\n", + "content = response.content.decode(\"utf-8\")\n", + "\n", + "# Load the yaml\n", + "content = yaml.safe_load(content)\n", + "\n", + "metadata = Box(content)\n", + "mean = []\n", + "std = []\n", + "waves = []\n", + "# Use the band names to get the correct values in the correct order.\n", + "for band in chips[0].keys():\n", + " mean.append(metadata[platform].bands.mean[band])\n", + " std.append(metadata[platform].bands.std[band])\n", + " waves.append(metadata[platform].bands.wavelength[band])\n", + "\n", + "# Prepare the normalization transform function using the mean and std values.\n", + "transform = v2.Compose(\n", + " [\n", + " v2.Normalize(mean=mean, std=std),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "id": "X0PDwp_8VL-c" + }, + "outputs": [], + "source": [ + "def normalize_timestamp(date):\n", + " week = date.isocalendar().week * 2 * np.pi / 52\n", + " hour = date.hour * 2 * np.pi / 24\n", + "\n", + " return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))\n", + "\n", + "\n", + "times = [normalize_timestamp(dat) for dat in datetimes]\n", + "week_norm = [dat[0] for dat in times]\n", + "hour_norm = [dat[1] for dat in times]\n", + "\n", + "\n", + "# Prep lat/lon embedding using the\n", + "def normalize_latlon(lat, lon):\n", + " lat = lat * np.pi / 180\n", + " lon = lon * np.pi / 180\n", + "\n", + " return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))\n", + "\n", + "\n", + "latlons = [normalize_latlon(lat, lon)] * len(times)\n", + "lat_norm = [dat[0] for dat in latlons]\n", + "lon_norm = [dat[1] for dat in latlons]\n", + "\n", + "# Prep gsd\n", + "gsd = [10]\n", + "\n", + "# Normalize pixels\n", + "pixels = transform(pixels)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "id": "_qGIb7MSbRz-" + }, + "outputs": [], + "source": [ + "datacube = (\n", + " torch.tensor(pixels, dtype=torch.float32, device=\"cuda\"),\n", + " torch.tensor(np.hstack((week_norm, hour_norm)), dtype=torch.float32, device=\"cuda\"),\n", + " torch.tensor(np.hstack((lat_norm, lon_norm)), dtype=torch.float32, device=\"cuda\"),\n", + " torch.tensor(waves, dtype=torch.float32, device=\"cuda\"),\n", + " torch.tensor(gsd, dtype=torch.float32, device=\"cuda\"),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "38HpZBD3eqcr", + "outputId": "1f39b47a-f906-4b51-8a2b-211347de672c" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[torch.Size([36, 4, 256, 256]),\n", + " torch.Size([36, 4]),\n", + " torch.Size([36, 4]),\n", + " torch.Size([4]),\n", + " torch.Size([1])]" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[dat.shape for dat in datacube]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rp9ahz_Cw761" + }, + "source": [ + "## Generate embeddings using the Clay encoder\n", + "\n", + "We are going to download the compiled verision of the Clay\n", + "encouder, which has been prepared using torchscript." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LedqWjYlUEE7", + "outputId": "b7cc4d23-b2c9-44fd-c07c-229d27cbe87a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2024-07-22 09:35:43-- https://huggingface.co/made-with-clay/Clay/resolve/main/clay-v1-encoder.pt\n", + "Resolving huggingface.co (huggingface.co)... 18.239.50.49, 18.239.50.103, 18.239.50.80, ...\n", + "Connecting to huggingface.co (huggingface.co)|18.239.50.49|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://cdn-lfs-us-1.huggingface.co/repos/9e/5f/9e5f70717de49e5e8fb94cc66c7c40e24e6800ae6dbf377099154c19eafdc5f6/6efe1d94fde51e88de4d2d6df699fb9f055a57ea8f1bc31c7a25fb1b7796f5ad?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27clay-v1-encoder.pt%3B+filename%3D%22clay-v1-encoder.pt%22%3B&Expires=1721900143&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMTkwMDE0M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzllLzVmLzllNWY3MDcxN2RlNDllNWU4ZmI5NGNjNjZjN2M0MGUyNGU2ODAwYWU2ZGJmMzc3MDk5MTU0YzE5ZWFmZGM1ZjYvNmVmZTFkOTRmZGU1MWU4OGRlNGQyZDZkZjY5OWZiOWYwNTVhNTdlYThmMWJjMzFjN2EyNWZiMWI3Nzk2ZjVhZD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=PzjbwOuAz8uOpJVTYnObx52HmTQyDuPl7KO31set-bAVJMY504ilUdU7LfA2Bfv1e3ER%7EJXq%7ESkk-ICmeRCf059ZBlAUuA5blZz3f2zS0ipBrrM9deTAFOwsU8U5iE9ZBJwBUXYhpJliqypqRsMjfNEpwjaMubC7FffBXhj81w6P8M8yz%7EM1oBStQCBEPOLOwdtJNWBZx%7E0uhFO6Exh3CwJUyKC%7EMgqZDgbYowxajScahX2mz7NyIiSVcijTURuvugaR1kiQW6ruwBnxpm3Vcb0kMf4VxjWrt%7EP4j8sb04t6j-rIY%7EV-dB7w1oh5KgMWLuPSe4hbqIXmPPuNPaHTcg__&Key-Pair-Id=K24J24Z295AEI9 [following]\n", + "--2024-07-22 09:35:43-- https://cdn-lfs-us-1.huggingface.co/repos/9e/5f/9e5f70717de49e5e8fb94cc66c7c40e24e6800ae6dbf377099154c19eafdc5f6/6efe1d94fde51e88de4d2d6df699fb9f055a57ea8f1bc31c7a25fb1b7796f5ad?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27clay-v1-encoder.pt%3B+filename%3D%22clay-v1-encoder.pt%22%3B&Expires=1721900143&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMTkwMDE0M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzllLzVmLzllNWY3MDcxN2RlNDllNWU4ZmI5NGNjNjZjN2M0MGUyNGU2ODAwYWU2ZGJmMzc3MDk5MTU0YzE5ZWFmZGM1ZjYvNmVmZTFkOTRmZGU1MWU4OGRlNGQyZDZkZjY5OWZiOWYwNTVhNTdlYThmMWJjMzFjN2EyNWZiMWI3Nzk2ZjVhZD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=PzjbwOuAz8uOpJVTYnObx52HmTQyDuPl7KO31set-bAVJMY504ilUdU7LfA2Bfv1e3ER%7EJXq%7ESkk-ICmeRCf059ZBlAUuA5blZz3f2zS0ipBrrM9deTAFOwsU8U5iE9ZBJwBUXYhpJliqypqRsMjfNEpwjaMubC7FffBXhj81w6P8M8yz%7EM1oBStQCBEPOLOwdtJNWBZx%7E0uhFO6Exh3CwJUyKC%7EMgqZDgbYowxajScahX2mz7NyIiSVcijTURuvugaR1kiQW6ruwBnxpm3Vcb0kMf4VxjWrt%7EP4j8sb04t6j-rIY%7EV-dB7w1oh5KgMWLuPSe4hbqIXmPPuNPaHTcg__&Key-Pair-Id=K24J24Z295AEI9\n", + "Resolving cdn-lfs-us-1.huggingface.co (cdn-lfs-us-1.huggingface.co)... 18.239.94.84, 18.239.94.40, 18.239.94.3, ...\n", + "Connecting to cdn-lfs-us-1.huggingface.co (cdn-lfs-us-1.huggingface.co)|18.239.94.84|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 369855883 (353M) [binary/octet-stream]\n", + "Saving to: ‘clay-v1-encoder.pt.1’\n", + "\n", + "clay-v1-encoder.pt. 100%[===================>] 352.72M 30.2MB/s in 5.5s \n", + "\n", + "2024-07-22 09:35:49 (64.2 MB/s) - ‘clay-v1-encoder.pt.1’ saved [369855883/369855883]\n", + "\n" + ] + } + ], + "source": [ + "!wget https://huggingface.co/made-with-clay/Clay/resolve/main/clay-v1-encoder.pt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3Nj-EVMUyBVv" + }, + "source": [ + "Load the packaged encoder using pytorch." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "EJ12oU88Ta0h", + "outputId": "170d8118-49ad-4964-d13e-9c58f08b91e4" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py:58: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer\n", + " getattr_node = gm.graph.get_attr(lifted_node)\n", + "/usr/local/lib/python3.10/dist-packages/torch/fx/graph.py:1460: UserWarning: Node _lifted_tensor_constant0_1 target _lifted_tensor_constant0 _lifted_tensor_constant0 of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target\n", + " warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '\n" + ] + } + ], + "source": [ + "clay_encoder = torch.export.load(\"clay-v1-encoder.pt\").module()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mTbYwSafhEI8" + }, + "source": [ + "Run the encoder and extract the class embedding, which is the\n", + "main embedding vector that can be used for image classification\n", + "or similarity search." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b6v9uhP_WGe7", + "outputId": "97d34ff1-2fcc-44e1-fbd0-7c54fad5c197" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([36, 768])" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run the clay encoder\n", + "with torch.no_grad():\n", + " unmsk_patch, unmsk_idx, msk_idx, msk_matrix = clay_encoder(*datacube)\n", + "# Get class embeddings\n", + "cls_embedding = unmsk_patch[:, 0, :]\n", + "# Print shape of class embeddings\n", + "cls_embedding.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r_q1pvlGxdqc" + }, + "source": [ + "## Store the results in a geoparquet table\n", + "\n", + "We create a table containing the embeddings, bounding box, the STAC item ID, the datetime of the image capture, and the chip x and y ids. Then we save that data to disk." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5CkXQwnNhfdn", + "outputId": "106b5d8c-cf2d-4e51-c9eb-e87cdbb2b394" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "pyarrow.Table\n", + "datetimes: timestamp[us, tz=UTC]\n", + "chip_ids: list\n", + " child 0, item: int64\n", + "item_ids: string\n", + "emeddings: list\n", + " child 0, item: float\n", + "geometry: extension>\n", + "----\n", + "datetimes: [[2018-08-28 11:30:56.771000,2018-08-28 11:30:56.771000,2018-08-28 11:30:56.771000,2018-08-23 11:30:50.574000,2018-08-23 11:30:50.574000,...,2018-07-09 11:24:55.535000,2018-07-09 11:24:55.535000,2018-07-04 11:30:35.271000,2018-07-04 11:30:35.271000,2018-07-04 11:30:35.271000]]\n", + "chip_ids: [[[0,0],[1,0],...,[1,0],[2,0]]]\n", + "item_ids: [[\"S2A_29SNB_20180828_1_L2A\",\"S2A_29SNB_20180828_1_L2A\",\"S2A_29SNB_20180828_1_L2A\",\"S2B_29SNB_20180823_1_L2A\",\"S2B_29SNB_20180823_1_L2A\",...,\"S2A_29SNB_20180709_0_L2A\",\"S2A_29SNB_20180709_0_L2A\",\"S2B_29SNB_20180704_0_L2A\",\"S2B_29SNB_20180704_0_L2A\",\"S2B_29SNB_20180704_0_L2A\"]]\n", + "emeddings: [[[-0.14773352,0.08466569,0.13797817,0.11150878,0.06517958,...,0.03668152,-0.092160314,0.025934448,-0.124962896,-0.034070194],[-0.14430065,0.08585757,0.1383917,0.109635465,0.065273784,...,0.03711327,-0.09153647,0.02631683,-0.12422915,-0.03333639],...,[-0.09626352,0.062443335,0.24817088,0.012715766,0.04309366,...,0.011770157,-0.037860356,0.027813742,-0.11962944,-0.022464497],[-0.10004054,0.06320579,0.24851678,0.012129238,0.04335017,...,0.011444268,-0.037332665,0.027787171,-0.12139091,-0.02108902]]]\n", + "geometry: [[[ -- is_valid: all not null\n", + " -- child 0 type: double\n", + "[-8.825403979293151,-8.825730459265694,-9.000227209792856,-9.000227635454767,-8.825403979293151]\n", + " -- child 1 type: double\n", + "[37.947460030545635,37.809019655564406,37.809148556380286,37.947589571562965,37.947460030545635]],[ -- is_valid: all not null\n", + " -- child 0 type: double\n", + "[-8.650582567535476,-8.651235936821893,-8.825730459265694,-8.825403979293151,-8.650582567535476]\n", + " -- child 1 type: double\n", + "[37.94707073614538,37.80863228507305,37.809019655564406,37.947460030545635,37.94707073614538]],...,[ -- is_valid: all not null\n", + " -- child 0 type: double\n", + "[-8.650582567535476,-8.651235936821893,-8.825730459265694,-8.825403979293151,-8.650582567535476]\n", + " -- child 1 type: double\n", + "[37.94707073614538,37.80863228507305,37.809019655564406,37.947460030545635,37.94707073614538]],[ -- is_valid: all not null\n", + " -- child 0 type: double\n", + "[-8.475765647330832,-8.476745873271028,-8.651235936821893,-8.650582567535476,-8.475765647330832]\n", + " -- child 1 type: double\n", + "[37.94642170369997,37.80798646012822,37.80863228507305,37.94707073614538,37.94642170369997]]]]" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Write data to pyarrow table\n", + "index = {\n", + " \"datetimes\": datetimes,\n", + " \"chip_ids\": chip_ids,\n", + " \"item_ids\": item_ids,\n", + " \"emeddings\": [np.ascontiguousarray(dat) for dat in cls_embedding.cpu().numpy()],\n", + " \"geometry\": ga.as_geoarrow([dat.wkt for dat in bboxs]),\n", + "}\n", + "table = pa.table(index)\n", + "table" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "id": "VrmZd3vIDnVl" + }, + "outputs": [], + "source": [ + "pq.write_table(table, \"clay_embeddings.parquet\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}