diff --git a/.bazelignore b/.bazelignore index 75ea7dff..ad88eee7 100644 --- a/.bazelignore +++ b/.bazelignore @@ -1,3 +1,4 @@ bazel-kv-server tools/wasm_example/ google_internal/piper/ +node_modules diff --git a/.bazelrc b/.bazelrc index e3d6c35d..717b8ec1 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,9 +1,8 @@ build --announce_rc build --verbose_failures build --compilation_mode=opt -build --output_filter='^//((?!(third_party):).)*$'` +build --output_filter='^//((?!(third_party):).)*$' build --color=yes -build --@io_bazel_rules_docker//transitions:enable=false build --workspace_status_command="bash tools/get_workspace_status" build --copt=-Werror=thread-safety build --config=clang diff --git a/CHANGELOG.md b/CHANGELOG.md index 76b12c7c..29b3684e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,14 +2,97 @@ All notable changes to this project will be documented in this file. See [commit-and-tag-version](https://github.com/absolute-version/commit-and-tag-version) for commit guidelines. -## 0.17.1 (2024-08-26) +## 1.0.0 (2024-10-14) + + +### ⚠ BREAKING CHANGES + +* GA release + +### Features + +* Add 64 bit int sets support to key value cache +* Add CBOR conversion for v2 objects +* Add CBOR support to multi-partition flows in V2 +* add CORS headers for envoy config +* Add data loading support for uint64 sets +* Add documentation for uint64 sets +* Add internal lookup rpc for uint64 sets +* Add parameter notifier to get parameter update notification +* Add partition-level metadata to UDF execution metadata +* Add runSetQueryUInt64 udf hook +* Add support for reading and writing uint64 sets to csv files +* Add uint64 bitset wrapper. +* CBOR conversion for Compresion Group +* CborDecodeToProto implementation +* Convert http ContentType header to a custom header in GCP +* Download pre-built aws-otel-collector.rpm +* Encode cbor content as bytestring and add partitionOutputs to CBOR converter +* Fix release script +* Flag to control chaffing for sharding for nonprod +* GA release +* Implement CBOR for validator +* Implement internal GetUInt64ValueSet functionality +* Implement InternalRunSetQueryUInt64 rpc (local lookup) +* Implement InternalRunSetQueryUInt64 rpc (sharded lookup) +* multiple partition support +* Pass partition level metadata to UDF +* Process v2 padded requests +* Put server logs in the response DebugInfo for consented requests +* Refactor cache logic for bitsets into it's own class +* Set up AWS terraform resources for logging verbosity parameter notification +* Start parameter notifier to get logging verbosity updates +* Support dataVersion field in PA partition output +* Support set operations for 64 bit int sets +* Update AWS sqs cleanup function to clean up sqs for parameter updates +* Update common repo and set the verbosity level for PS_VLOG with new API +* Update v2 contract +* Update v2 headers +* Upgrade common repo to 9c5c93e +* Upgrade rules_oci to 2.0 and deprecate rules_docker +* Use proper ohttp media types for encryption +* When using the wrong inline set type in query, resolve the result ### Bug Fixes +* Add missing include directive +* Add missing internal testing parameters +* Allow CORS OPTIONS for preflight +* Correct fork logic +* Correct output_filter typo +* Destroy terraform before doing perfgate exporting * Enable a second kv on aws deployment. * fix AppMesh health check. +* logMessage should us PS LOGS +* Make AL2023 work. +* Remove "k" from ReceivedLowLatencyNotificationsCount metric name +* Remove version from header +* Rename BUILD to BUILD.bazel * Resolve proxy subnet resources collision issue. +* Response partition id should come from the request +* Temporary GCP V2 HTTP envoy fix +* Update common repo to pick up the server crash fix +* Update V2 handler and docs with proper ohttp response label. +* Upgrade builders version to 0.69.0 +* Use specified release branch to cut release. +* V2 should not return error status on UDF failure + + +### Dependencies + +* **deps:** Upgrade build-system to 0.66.1 +* **deps:** Upgrade data-plane-shared-libraries to 144264c 2024-07-31 + + +### Documentation + +* Add aws update-function-code lambda update command to the AWS deployment doc +* Add readme doc for diagnostic tool +* Add screenshot for gcp server prod log location +* Update docs to use docker compose instead of docker-compose +* Update gcp deployment doc about console logging +* Update playbook ## 0.17.0 (2024-07-08) @@ -89,6 +172,11 @@ All notable changes to this project will be documented in this file. See [commit * Use aws_platform bazel config * Use local_{platform,instance} bazel configs +### Image digests and PCR0s + +GCP: sha256:d09d5a6d340a8829df03213b71b74d4b431e4d5a138525c77269c347a367b004 +AWS: {"PCR0":"1e28ac4b72600ea40d61e1756e14f453a3d923a1bf94c360ae48d9777bff0714923d9322ed380823591859e357d2f825"} + ## 0.16.0 (2024-04-05) diff --git a/README.md b/README.md index fcf6a733..e41c9fa8 100644 --- a/README.md +++ b/README.md @@ -11,47 +11,27 @@ --- -# ![Privacy Sandbox Logo](docs/assets/privacy_sandbox_logo.png) FLEDGE Key/Value service +# ![Privacy Sandbox Logo](docs/assets/privacy_sandbox_logo.png) Protected Auction Key/Value service -# Background - -FLEDGE API is a proposal to serve remarketing and other custom-audience ads without third-party -cookies. FLEDGE executes the ad auction between the buyers (DSP) and the sellers (SSP) locally, and -receives real-time signals from the FLEDGE K/V servers. To learn more about +# State of the project -- FLEDGE for the Web: [explainer](https://developer.chrome.com/en/docs/privacy-sandbox/fledge/) - and the [developer guide](https://developer.chrome.com/blog/fledge-api/). -- FLEDGE on Android: - [design proposal](https://developer.android.com/design-for-safety/privacy-sandbox/fledge) and - the - [developer guide](https://developer.android.com/design-for-safety/privacy-sandbox/guides/fledge). +The current codebase represents the implementation of the TEE-based Key/Value service by Privacy +Sandbox. -When the auction is executed, separate -[FLEDGE K/V servers](https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md) -are queried for the buyers and sellers. When a buyer is making a bid, the DSP K/V server can be -queried to receive real-time information to help determine the bid. To help the seller pick an -auction winner, the SSP K/V server can be queried to receive any information about the creative to -help score the ad. +For +[Protected Audience](https://developers.google.com/privacy-sandbox/private-advertising/protected-audience), +the service can be used as a BYOS KV server. Soon it can be used to communicate with Chrome and the +Bidding and Auction services using +[V2 protocol](https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md). -# State of the project +For +[Protected App Signals](https://developers.google.com/privacy-sandbox/private-advertising/protected-audience/android/protected-app-signals), +the service should be used as the ad retrieval server. -The current codebase represents the initial implementation and setup of the Key/Value server. It can -be integrated with Chrome and Android with the +It can be integrated with Chrome and Android with the [Privacy Sandbox unified origin trial](https://developer.chrome.com/blog/expanding-privacy-sandbox-testing/) and [Privacy Sandbox on Android Developer Preview](https://developer.android.com/design-for-safety/privacy-sandbox/program-overview). -Our goal is to present the foundation of the project in a publicly visible way for early feedback. -This feedback will help us shape the future versions. - -The implementation, and in particular the APIs, are in rapid development and may change as new -versions are released. The query API conforms to the -[API explainer](https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md). At the -moment, to load data, instead of calling the mutation API, you would place the data as files into a -location that can be directly read by the server. See more details in the -[data loading guide](/docs/data_loading/loading_data.md). - -Currently, this service can be deployed to 1 region of your choice. Multi-region configuration is up -to the service owner to configure. ## Current features @@ -120,6 +100,7 @@ products. +
@@ -193,6 +174,7 @@ The implementation supports live traffic at scale
+ @@ -270,14 +252,7 @@ The implementation supports live traffic at scale ## Breaking changes -While we make efforts to not introduce breaking changes, we expect that to happen occasionally. - -The release version follows the `[major change]-[minor change]-[patch]` scheme. All 0.x.x versions -may contain breaking changes without notice. Refer to the [release changelog](/CHANGELOG.md) for the -details of the breaking changes. - -At GA the version will become 1.0.0, we will establish additional channels for announcing breaking -changes and major version will always be incremented for breaking changes. +Backward-incompatible changes are expected to be rare and will result in a major version change. # Key documents @@ -304,8 +279,8 @@ changes and major version will always be incremented for breaking changes. Contributions are welcome, and we will publish more detailed guidelines soon. In the meantime, if you are interested, -[open a new Issue](https://github.com/privacysandbox/fledge-key-value-service/issues) in the GitHub -repository. +[open a new Issue](https://github.com/privacysandbox/protected-auction-key-value-service/issues) in +the GitHub repository. # Feedback diff --git a/WORKSPACE b/WORKSPACE index 2813868a..2b824476 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -7,17 +7,19 @@ local_repository( path = "testing/functionaltest-system", ) -load("//builders/bazel:deps.bzl", "python_deps") +load("//builders/bazel:deps.bzl", "python_deps", "python_register_toolchains") -python_deps("//builders/bazel") +python_deps() + +python_register_toolchains("//builders/bazel") http_archive( name = "google_privacysandbox_servers_common", - # commit 34445c1 2024-07-01 - sha256 = "ce300bc178b1eedd88d7545b89d1d672b3b9bfb62c138ab3f4a845f159436285", - strip_prefix = "data-plane-shared-libraries-37522d6ac55c8592060f636d68f50feddcb9598a", + # commit cc49da3 2024-10-09 + sha256 = "7a0337420161304c7429c727b1f82394bc27e1e2586d2da30e6d6100ba92b437", + strip_prefix = "data-plane-shared-libraries-158593616a63df924af1cb689f3915b8d32e9db1", urls = [ - "https://github.com/privacysandbox/data-plane-shared-libraries/archive/37522d6ac55c8592060f636d68f50feddcb9598a.zip", + "https://github.com/privacysandbox/data-plane-shared-libraries/archive/158593616a63df924af1cb689f3915b8d32e9db1.zip", ], ) @@ -51,28 +53,10 @@ load( cpp_repositories() -http_archive( - name = "io_bazel_rules_docker", - sha256 = "b1e80761a8a8243d03ebca8845e9cc1ba6c82ce7c5179ce2b295cd36f7e394bf", - urls = ["https://github.com/bazelbuild/rules_docker/releases/download/v0.25.0/rules_docker-v0.25.0.tar.gz"], -) - -load("@io_bazel_rules_docker//repositories:repositories.bzl", container_repositories = "repositories") - -container_repositories() - -load("@io_bazel_rules_docker//repositories:deps.bzl", io_bazel_rules_docker_deps = "deps") - -io_bazel_rules_docker_deps() - load("//third_party_deps:container_deps.bzl", "container_deps") container_deps() -load("@io_bazel_rules_docker//go:image.bzl", go_image_repos = "repositories") - -go_image_repos() - # googleapis http_archive( name = "com_google_googleapis", # master branch from 26.04.2022 @@ -88,6 +72,16 @@ http_archive( urls = ["https://github.com/google/distributed_point_functions/archive/45da5f54836c38b73a1392e846c9db999c548711.tar.gz"], ) +http_archive( + name = "libcbor", + build_file = "//third_party_deps:libcbor.BUILD", + patch_args = ["-p1"], + patches = ["//third_party_deps:libcbor.patch"], + sha256 = "9fec8ce3071d5c7da8cda397fab5f0a17a60ca6cbaba6503a09a47056a53a4d7", + strip_prefix = "libcbor-0.10.2/src", + urls = ["https://github.com/PJK/libcbor/archive/refs/tags/v0.10.2.zip"], +) + # Dependencies for Flex/Bison build rules http_archive( name = "rules_m4", @@ -132,6 +126,15 @@ latency_benchmark_install_deps() word2vec_install_deps() +http_archive( + name = "io_bazel_rules_go", + sha256 = "16e9fca53ed6bd4ff4ad76facc9b7b651a89db1689a2877d6fd7b82aa824e366", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.34.0/rules_go-v0.34.0.zip", + "https://github.com/bazelbuild/rules_go/releases/download/v0.34.0/rules_go-v0.34.0.zip", + ], +) + # Use nogo to run `go vet` with bazel load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") diff --git a/builders/.pre-commit-config.yaml b/builders/.pre-commit-config.yaml index e1ad672d..3f853aab 100644 --- a/builders/.pre-commit-config.yaml +++ b/builders/.pre-commit-config.yaml @@ -47,7 +47,7 @@ repos: - id: shellcheck - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.4 + rev: v18.1.5 hooks: - id: clang-format types_or: diff --git a/builders/.profiler.bazelrc b/builders/.profiler.bazelrc new file mode 100644 index 00000000..42524d63 --- /dev/null +++ b/builders/.profiler.bazelrc @@ -0,0 +1,5 @@ +build:profiler --compilation_mode=opt +build:profiler --dynamic_mode=off +build:profiler --copt=-gmlt +build:profiler --copt=-fno-omit-frame-pointer +build:profiler --strip=never diff --git a/builders/CHANGELOG.md b/builders/CHANGELOG.md index 3ced1112..af695dc8 100644 --- a/builders/CHANGELOG.md +++ b/builders/CHANGELOG.md @@ -2,6 +2,112 @@ All notable changes to this project will be documented in this file. See [commit-and-tag-version](https://github.com/absolute-version/commit-and-tag-version) for commit guidelines. +## 0.69.0 (2024-09-15) + + +### Features + +* Pin build-debian ubuntu base image from 20.04 to focal-20240530 + +## 0.68.1 (2024-08-21) + +### Bug Fixes + +* Fix load bazel_tools import for Python deps + +## 0.68.0 (2024-08-21) + + +### Features + +* **deps:** Split python deps and registering toolchains +* **deps:** Update rules_python to 0.35.0 + +## 0.67.0 (2024-07-31) + + +### Bug Fixes + +* Add EXTRA_CBUILD_ARGS to tools/bazel-* scripts + + +### Dependencies + +* **deps:** Update buildozer to 6.1.1 +* **deps:** Upgrade amazonlinux2023 to 5.20240722.0 + +## 0.66.1 (2024-06-24) + + +### Bug Fixes + +* Add --compilation_mode=opt to build:profiler config + +## 0.66.0 (2024-06-20) + + +### Features + +* Add cpu-profiler flags to cbuild +* Add profiler config in .profiler.bazelrc + +## 0.65.1 (2024-06-04) + + +### Bug Fixes + +* Support multiple etc files in a single image + +## 0.65.0 (2024-06-04) + + +### Features + +* Add DOCKER_NETWORK env var for test-tools + +## 0.64.1 (2024-05-29) + + +### Bug Fixes + +* Support container reuse when --cmd not specified +* Use find to identify bazel symlinks + +## 0.64.0 (2024-05-27) + + +### Features + +* Support cmd-profiler mode with/without --cmd + + +### Bug Fixes + +* cbuild should find container with exact name match +* Ensure normalize-bazel-symlinks is in the workspace dir + + +### Dependencies + +* **deps:** Upgrade clang-format pre-commit hook + +## 0.63.0 (2024-05-26) + + +### Features + +* Support cmd-profiler mode with/without --cmd + + +### Bug Fixes + +* Ensure normalize-bazel-symlinks is in the workspace dir + + +### Dependencies + +* **deps:** Upgrade clang-format pre-commit hook + ## 0.62.0 (2024-05-10) diff --git a/builders/bazel/deps.bzl b/builders/bazel/deps.bzl index 0fe6743b..9e46f9ca 100644 --- a/builders/bazel/deps.bzl +++ b/builders/bazel/deps.bzl @@ -15,21 +15,24 @@ """Load definitions for use in WORKSPACE files.""" load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") -def python_deps(bazel_package): - """Load rules_python and register container-based python toolchain +def python_deps(): + """Load rules_python. Use python_register_toolchains to also resgister container-based python toolchain.""" + maybe( + http_archive, + name = "rules_python", + sha256 = "be04b635c7be4604be1ef20542e9870af3c49778ce841ee2d92fcb42f9d9516a", + strip_prefix = "rules_python-0.35.0", + url = "https://github.com/bazelbuild/rules_python/releases/download/0.35.0/rules_python-0.35.0.tar.gz", + ) + +def python_register_toolchains(bazel_package): + """Register container-based python toolchain. Note: the bazel_package arg will depend on the import/submodule location in your workspace Args: bazel_package: repo-relative bazel package to builders/bazel/BUILD eg. "//builders/bazel" """ - http_archive( - name = "rules_python", - sha256 = "0a8003b044294d7840ac7d9d73eef05d6ceb682d7516781a4ec62eeb34702578", - strip_prefix = "rules_python-0.24.0", - urls = [ - "https://github.com/bazelbuild/rules_python/releases/download/0.24.0/rules_python-0.24.0.tar.gz", - ], - ) native.register_toolchains("{}:py_toolchain".format(bazel_package)) diff --git a/builders/images/build-amazonlinux2023/Dockerfile b/builders/images/build-amazonlinux2023/Dockerfile index 268bbcba..c7d3509b 100644 --- a/builders/images/build-amazonlinux2023/Dockerfile +++ b/builders/images/build-amazonlinux2023/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM amazonlinux:2023.4.20240416.0 +FROM amazonlinux:2023.5.20240722.0 COPY /install_apps install_golang_apps install_go.sh generate_system_bazelrc .bazelversion /scripts/ COPY get_workspace_mount /usr/local/bin diff --git a/builders/images/build-amazonlinux2023/install_apps b/builders/images/build-amazonlinux2023/install_apps index 442021c4..16378fa5 100755 --- a/builders/images/build-amazonlinux2023/install_apps +++ b/builders/images/build-amazonlinux2023/install_apps @@ -47,8 +47,8 @@ function install_python() { function install_nitro() { dnf install -y \ - "aws-nitro-enclaves-cli-1.2.*" \ - "aws-nitro-enclaves-cli-devel-1.2.*" + "aws-nitro-enclaves-cli-1.3.*" \ + "aws-nitro-enclaves-cli-devel-1.3.*" } function install_gcc() { diff --git a/builders/images/build-debian/Dockerfile b/builders/images/build-debian/Dockerfile index 735d9b66..4b615819 100644 --- a/builders/images/build-debian/Dockerfile +++ b/builders/images/build-debian/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASE_IMAGE=ubuntu:20.04 +ARG BASE_IMAGE=ubuntu:focal-20240530 # ignore this hadolint error as BASE_IMAGE contains an image tag # hadolint ignore=DL3006 diff --git a/builders/images/presubmit/.bazelversion b/builders/images/presubmit/.bazelversion new file mode 120000 index 00000000..8f79a5af --- /dev/null +++ b/builders/images/presubmit/.bazelversion @@ -0,0 +1 @@ +../../etc/.bazelversion \ No newline at end of file diff --git a/builders/images/presubmit/Dockerfile b/builders/images/presubmit/Dockerfile index 53ee3778..827e2584 100644 --- a/builders/images/presubmit/Dockerfile +++ b/builders/images/presubmit/Dockerfile @@ -14,7 +14,7 @@ FROM ubuntu:24.04 -COPY install_apps install_go.sh .pre-commit-config.yaml /scripts/ +COPY install_apps install_go.sh install_golang_apps .bazelversion .pre-commit-config.yaml /scripts/ COPY gitconfig /etc ARG PRE_COMMIT_VENV_DIR=/usr/pre-commit-venv @@ -28,6 +28,7 @@ ENV BUILD_ARCH="${TARGETARCH}" \ RUN \ chmod 644 /etc/gitconfig && \ /usr/bin/env -v PRE_COMMIT_VENV_DIR=${PRE_COMMIT_VENV_DIR} /scripts/install_apps && \ + /scripts/install_golang_apps && \ rm -rf /scripts ENV PATH="${PATH}:/usr/local/go/bin" diff --git a/builders/images/presubmit/install_golang_apps b/builders/images/presubmit/install_golang_apps new file mode 120000 index 00000000..acc9d5a3 --- /dev/null +++ b/builders/images/presubmit/install_golang_apps @@ -0,0 +1 @@ +../install_golang_apps \ No newline at end of file diff --git a/builders/images/release/install_release_apps b/builders/images/release/install_release_apps index e6c12253..dc99134f 100755 --- a/builders/images/release/install_release_apps +++ b/builders/images/release/install_release_apps @@ -5,4 +5,4 @@ npm install --global commit-and-tag-version@10.1.0 # Install the GitHub CLI tool (https://cli.github.com/) apk add github-cli -GOBIN=/usr/local/go/bin go install github.com/bazelbuild/buildtools/buildozer@6.0.1 +GOBIN=/usr/local/go/bin go install github.com/bazelbuild/buildtools/buildozer@6.1.1 diff --git a/builders/tests/data/hashes/build-amazonlinux2023 b/builders/tests/data/hashes/build-amazonlinux2023 index 5fe991bf..7d1e615d 100644 --- a/builders/tests/data/hashes/build-amazonlinux2023 +++ b/builders/tests/data/hashes/build-amazonlinux2023 @@ -1 +1 @@ -8d01333fe93d2ac2102dd8360a58717724b7b594d51fe4e412ec20aae181efce +59a82d2db8173784b0b49959c9f82ead6c2e6da78a6be21cdc78520aa43741e3 diff --git a/builders/tests/data/hashes/build-debian b/builders/tests/data/hashes/build-debian index 57095aed..b0a607dd 100644 --- a/builders/tests/data/hashes/build-debian +++ b/builders/tests/data/hashes/build-debian @@ -1 +1 @@ -c194dafd287978093f8fe6e16e981fb22028e37345e20a4d7ca84caa43f0d4c0 +83fab12505490f9ed41e5d8747f3c8844f6aee8740ad05c47eca61fd7b42a8d1 diff --git a/builders/tests/data/hashes/presubmit b/builders/tests/data/hashes/presubmit index b02b21b0..a3e4b6ff 100644 --- a/builders/tests/data/hashes/presubmit +++ b/builders/tests/data/hashes/presubmit @@ -1 +1 @@ -afaf1932764d07d480c4e833e6b08877f069abae87401bdac4782277c535a298 +560a5a1726e7b6fd2a507f72f2563eb3938d153a7d4b4aada575a7fe772873b0 diff --git a/builders/tests/data/hashes/release b/builders/tests/data/hashes/release index e5218ba2..affaf218 100644 --- a/builders/tests/data/hashes/release +++ b/builders/tests/data/hashes/release @@ -1 +1 @@ -d60fb40a53b1704f7ac353d0d036f49eac10bfd08ccb19f9b436acf8bdf2cb79 +398787e442bb10bcf7383bc3beec85ab27fb0145f80a23d8ee0eeb4992e5cb81 diff --git a/builders/tools/bazel-debian b/builders/tools/bazel-debian index 843b3b6e..c32cf7b9 100755 --- a/builders/tools/bazel-debian +++ b/builders/tools/bazel-debian @@ -19,6 +19,7 @@ # BAZEL_STARTUP_ARGS Additional startup arguments to pass to bazel invocations # BAZEL_EXTRA_ARGS Additional command arguments to pass to bazel invocations # EXTRA_DOCKER_RUN_ARGS Additional arguments to pass to docker run invocations +# EXTRA_CBUILD_ARGS Additional arguments to pass to tools/cbuild set -o pipefail set -o errexit @@ -63,7 +64,8 @@ declare -a APP_ARGS declare -r -a ARGLIST=("$@") partition_array ARGLIST BAZEL_ARGS APP_ARGS -"${CBUILD}" --seccomp-unconfined --image "${IMAGE}" --cmd " +# shellcheck disable=SC2086 +"${CBUILD}" ${EXTRA_CBUILD_ARGS} --seccomp-unconfined --image "${IMAGE}" --cmd " printf 'bazel output_base: [%s]\n' \"\$(bazel info output_base 2>/dev/null)\" bazel ${BAZEL_STARTUP_ARGS} ${BAZEL_ARGS[*]@Q} ${BAZEL_EXTRA_ARGS} ${APP_ARGS[*]@Q} " diff --git a/builders/tools/cbuild b/builders/tools/cbuild index 17da1336..31ee0fcb 100755 --- a/builders/tools/cbuild +++ b/builders/tools/cbuild @@ -36,7 +36,6 @@ function usage() { usage: $0 --cmd bash command string to execute within the docker container - --cmd-profiler enable profiler for the command --image Image name for the build runtime. Valid names: USAGE @@ -57,6 +56,11 @@ USAGE --seccomp-unconfined Run docker container without a seccomp profile --verbose Enable verbose output + Profiler flags: + --cmd-profiler enable profiler for the command + --cpu-profiler-signal unix signal to use to trigger profiler output. Default: ${CPU_PROFILER_SIGNAL} + --cpu-profiler-filename path for the cpu profiler output. Default: ${CPU_PROFILER_FILENAME} + Environment variables (all optional): WORKSPACE Full path to the workspace (repo root) WORKSPACE_MOUNT Full path to the workspace on the host filesystem @@ -79,6 +83,8 @@ DOCKER_NETWORK="${DOCKER_NETWORK:-bridge}" declare -i DOCKER_SECCOMP_UNCONFINED=0 declare -i KEEP_CONTAINER_RUNNING=0 declare LONG_RUNNING_CONTAINER_TIMEOUT=8h +declare CPU_PROFILER_FILENAME=benchmark.prof +declare -i CPU_PROFILER_SIGNAL=12 while [[ $# -gt 0 ]]; do case "$1" in @@ -90,6 +96,14 @@ while [[ $# -gt 0 ]]; do WITH_CMD_PROFILER=1 shift ;; + --cpu-profiler-filename) + CPU_PROFILER_FILENAME="$2" + shift 2 || usage + ;; + --cpu-profiler-signal) + CPU_PROFILER_SIGNAL=$2 + shift 2 || usage + ;; --env) ENV_VARS+=("$2") shift 2 || usage @@ -199,7 +213,8 @@ if [[ ${WITH_CMD_PROFILER} -eq 1 ]]; then fi DOCKER_RUN_ARGS+=( "--env=CMD_PROFILER=LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libprofiler.so" - "--env=CPUPROFILE=benchmark.prof" + "--env=CPUPROFILE=${CPU_PROFILER_FILENAME}" + "--env=CPUPROFILESIGNAL=${CPU_PROFILER_SIGNAL}" ) fi @@ -252,8 +267,8 @@ function running_container_for() { declare -r name="$1" declare -a docker_args=( container ls - --filter "name=${name}" - --format "{{print .Names}}" + "--filter=name=^${name}$" + "--format={{print .Names}}" ) local -r exited="$(docker "${docker_args[@]}" --all --filter "status=exited")" if [[ -n ${exited} ]]; then @@ -269,13 +284,7 @@ function long_running_container() { local -r docker_running_container="$(running_container_for "${container_name}")" if [[ -z ${docker_running_container} ]]; then printf "starting a new container [%s]\n" "${container_name}" &>/dev/stderr - if [[ -z ${CMD} ]]; then - # shellcheck disable=SC2068 - docker run \ - ${DOCKER_RUN_ARGS[@]} \ - "${DOCKER_EXEC_RUN_ARGS[@]}" \ - "${IMAGE_TAGGED}" - else + if [[ -n ${CMD} ]]; then # shellcheck disable=SC2068 docker run \ ${DOCKER_RUN_ARGS[@]} \ @@ -293,11 +302,26 @@ timeout ${LONG_RUNNING_CONTAINER_TIMEOUT} tail --pid=\${pid} -f /dev/null } if [[ ${KEEP_CONTAINER_RUNNING} -eq 1 ]]; then - DOCKER_RUNNING_CONTAINER="$(long_running_container "${DOCKER_CONTAINER_NAME}")" - docker exec \ - "${DOCKER_EXEC_RUN_ARGS[@]}" \ - "${DOCKER_RUNNING_CONTAINER}" \ - /bin/bash -c "${CMD}" + if [[ -z ${CMD} ]]; then + # shellcheck disable=SC2068 + docker run \ + ${DOCKER_RUN_ARGS[@]} \ + "${DOCKER_EXEC_RUN_ARGS[@]}" \ + "${IMAGE_TAGGED}" + else + DOCKER_RUNNING_CONTAINER="$(long_running_container "${DOCKER_CONTAINER_NAME}")" + if [[ ${WITH_CMD_PROFILER} -eq 1 ]]; then + docker exec \ + "${DOCKER_EXEC_RUN_ARGS[@]}" \ + "${DOCKER_RUNNING_CONTAINER}" \ + /bin/bash -c "'${TOOLS_RELDIR}'/normalize-bazel-symlinks; env \${CMD_PROFILER} ${CMD:-/bin/sh}" + else + docker exec \ + "${DOCKER_EXEC_RUN_ARGS[@]}" \ + "${DOCKER_RUNNING_CONTAINER}" \ + /bin/bash -c "${CMD:-/bin/sh}" + fi + fi else if [[ -z ${CMD} ]]; then # shellcheck disable=SC2068 @@ -319,6 +343,6 @@ else ${DOCKER_RUN_ARGS[@]} \ "${DOCKER_EXEC_RUN_ARGS[@]}" \ "${IMAGE_TAGGED}" \ - --login -c "$CMD" + --login -c "${CMD}" fi fi diff --git a/builders/tools/get-builder-image-tagged b/builders/tools/get-builder-image-tagged index 35371aea..bc14958a 100755 --- a/builders/tools/get-builder-image-tagged +++ b/builders/tools/get-builder-image-tagged @@ -199,7 +199,7 @@ function _tar_for_dir() { # shellcheck disable=SC2012 ls -A -1 "${FILEPATH}" "${ETC_DIR}" | sort | uniq -d ls -A -1 "${WORKSPACE}" - } | sort | uniq -d)" + } | sort | uniq -d | tr '\n' ' ')" # create a deterministic tarball of the collected files docker run \ --rm \ diff --git a/builders/tools/normalize-bazel-symlinks b/builders/tools/normalize-bazel-symlinks index 8506ac96..05e3e7ef 100755 --- a/builders/tools/normalize-bazel-symlinks +++ b/builders/tools/normalize-bazel-symlinks @@ -43,14 +43,11 @@ if [[ -f /.dockerenv ]]; then _normalize_fn=normalize_symlink_docker fi -declare -a -r LINK_DIRS=( - bazel-bin - bazel-out - bazel-testlogs - bazel-workspace -) -for link in "${LINK_DIRS[@]}"; do - if [[ -L ${link} ]]; then - ${_normalize_fn} "${link}" - fi +source "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")"/builder.sh +cd "${WORKSPACE}" || true + +declare -a links +mapfile -t links < <(find . -maxdepth 1 -type l -name "bazel-*" -exec basename {} \;) +for link in "${links[@]}"; do + ${_normalize_fn} "${link}" done diff --git a/builders/tools/test-tool b/builders/tools/test-tool index 6a275370..908249db 100755 --- a/builders/tools/test-tool +++ b/builders/tools/test-tool @@ -13,7 +13,9 @@ # limitations under the License. # environment variables (all optional): -# WORKSPACE repo root directory, must be an absolute path +# WORKSPACE repo root directory, must be an absolute path +# DOCKER_NETWORK docker run --network arg, defaults to "host", set to +# blank to avoid setting --network set -o errexit @@ -57,6 +59,9 @@ readonly REL_PWD WORKSPACE_MOUNT="$(builder::get_docker_workspace_mount)" readonly WORKSPACE_MOUNT +# respect an empty DOCKER_NETWORK value +DOCKER_NETWORK=${DOCKER_NETWORK-host} + declare -a DOCKER_RUN_ARGS=( "--rm" "--interactive" @@ -64,6 +69,11 @@ declare -a DOCKER_RUN_ARGS=( "--volume=${WORKSPACE_MOUNT}:/src/workspace" "--workdir=/src/workspace/${REL_PWD}" ) +if [[ -n ${DOCKER_NETWORK} ]]; then + DOCKER_RUN_ARGS+=( + "--network=${DOCKER_NETWORK}" + ) +fi if [[ -n ${EXTRA_DOCKER_RUN_ARGS} ]]; then # shellcheck disable=SC2207 DOCKER_RUN_ARGS+=( diff --git a/builders/version.txt b/builders/version.txt index 7e9253a3..acdcb836 100644 --- a/builders/version.txt +++ b/builders/version.txt @@ -1 +1 @@ -0.62.0 \ No newline at end of file +0.69.0 \ No newline at end of file diff --git a/components/aws/BUILD.bazel b/components/aws/BUILD.bazel index f35d8299..6bf9b9db 100644 --- a/components/aws/BUILD.bazel +++ b/components/aws/BUILD.bazel @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push") load( - "@io_bazel_rules_docker//container:container.bzl", - "container_image", - "container_layer", - "container_push", + "@rules_pkg//pkg:mappings.bzl", + "pkg_attributes", + "pkg_files", ) +load("@rules_pkg//pkg:tar.bzl", "pkg_tar") load("@rules_python//python:defs.bzl", "py_library", "py_test") py_library( @@ -35,35 +36,58 @@ py_test( ], ) -container_layer( - name = "lambda_binary_layer", - directory = "/var/task", # ${LAMBDA_TASK_ROOT} - files = [ +pkg_files( + name = "lambda_binaries", + srcs = [ "sqs_cleanup.py", "sqs_cleanup_manager.py", ], + attributes = pkg_attributes(mode = "0555"), + prefix = "/var/task", # ${LAMBDA_TASK_ROOT} ) -container_image( - name = "sqs_lambda", +pkg_tar( + name = "lambda_binary_tar", + srcs = [":lambda_binaries"], +) + +oci_image( + name = "sqs_lambda_image", base = select({ - "@platforms//cpu:arm64": "@aws-lambda-python-arm64//image", - "@platforms//cpu:x86_64": "@aws-lambda-python-amd64//image", + "@platforms//cpu:arm64": "@aws-lambda-python-arm64", + "@platforms//cpu:x86_64": "@aws-lambda-python-amd64", }), cmd = ["sqs_cleanup.handler"], - layers = [ - ":lambda_binary_layer", + tars = [ + ":lambda_binary_tar", ], +) + +oci_load( + name = "sqs_lambda", + image = ":sqs_lambda_image", + repo_tags = ["bazel/components/aws:sqs_lambda"], +) + +filegroup( + name = "sqs_lambda_tarball_file", + srcs = [":sqs_lambda"], + output_group = "tarball", +) + +genrule( + name = "sqs_lambda_tarball", + srcs = [":sqs_lambda_tarball_file"], + outs = ["sqs_lambda.tar"], + cmd = "cp $< $@", visibility = ["//production/packaging:__subpackages__"], ) -container_push( +oci_push( name = "sqs_lambda_push_aws_ecr", - format = "Docker", - image = ":sqs_lambda", - registry = "$${AWS_ECR}", - repository = "sqs_lambda", - tag = "latest", + image = ":sqs_lambda_image", + remote_tags = ["latest"], + repository = "$${AWS_ECR}/sqs_lambda", ) exports_files( diff --git a/components/aws/sqs_cleanup.py b/components/aws/sqs_cleanup.py index 4717cbff..713a6c87 100644 --- a/components/aws/sqs_cleanup.py +++ b/components/aws/sqs_cleanup.py @@ -44,9 +44,27 @@ def handler(event, context): deleted_realtime_queues, deleted_realtime_subscriptions = find_and_cleanup( realtime_sns_topic, realtime_queue_prefix, timeout_secs ) + logging_verbosity_updates_sns_topic = event.get( + "logging_verbosity_updates_sns_topic" + ) + parameter_queue_prefix = event.get("parameter_queue_prefix") + if logging_verbosity_updates_sns_topic is None: + raise Exception("no logging verbosity updates topic") + if parameter_queue_prefix is None: + raise Exception("no parameter queue prefix") + + ( + deleted_logging_verbosity_parameter_queues, + deleted_logging_verbosity_parameter_subscriptions, + ) = find_and_cleanup( + logging_verbosity_updates_sns_topic, parameter_queue_prefix, timeout_secs + ) + return { "deleted_queues": deleted_queues, "deleted_subscriptions": deleted_subscriptions, "deleted_realtime_queues": deleted_realtime_queues, "deleted_realtime_subscriptions": deleted_realtime_subscriptions, + "deleted_logging_verbosity_parameter_queues": deleted_logging_verbosity_parameter_queues, + "deleted_logging_verbosity_parameter_subscriptions": deleted_logging_verbosity_parameter_subscriptions, } diff --git a/components/cloud_config/instance_client_gcp.cc b/components/cloud_config/instance_client_gcp.cc index 5ad5f919..b6d58ffc 100644 --- a/components/cloud_config/instance_client_gcp.cc +++ b/components/cloud_config/instance_client_gcp.cc @@ -250,9 +250,7 @@ class GcpInstanceClient : public InstanceClient { const ExecutionResult& result, const GetInstanceDetailsByResourceNameResponse& response) { if (result.Successful()) { - // TODO(b/342614468): Temporarily turn off this vlog until - // verbosity setting API in the common repo is fixed - // PS_VLOG(2, log_context_) << response.DebugString(); + PS_VLOG(2, log_context_) << response.DebugString(); instance_id_ = std::string{response.instance_details().instance_id()}; environment_ = diff --git a/components/cloud_config/parameter_client_local.cc b/components/cloud_config/parameter_client_local.cc index c051a649..377bb63c 100644 --- a/components/cloud_config/parameter_client_local.cc +++ b/components/cloud_config/parameter_client_local.cc @@ -55,6 +55,9 @@ ABSL_FLAG(std::string, data_loading_file_format, "possible values."); ABSL_FLAG(std::int32_t, logging_verbosity_level, 0, "Loggging verbosity level."); +ABSL_FLAG(std::int32_t, logging_verbosity_backup_poll_frequency_secs, 300, + "Loggging verbosity level back up poll frequency in seconds."); + ABSL_FLAG(absl::Duration, udf_timeout, absl::Seconds(5), "Timeout for one UDF invocation"); ABSL_FLAG(absl::Duration, udf_update_timeout, absl::Seconds(30), @@ -132,6 +135,9 @@ class LocalParameterClient : public ParameterClient { absl::GetFlag(FLAGS_udf_num_workers)}); int32_t_flag_values_.insert({"kv-server-local-logging-verbosity-level", absl::GetFlag(FLAGS_logging_verbosity_level)}); + int32_t_flag_values_.insert( + {"kv-server-local-logging-verbosity-backup-poll-frequency-secs", + absl::GetFlag(FLAGS_logging_verbosity_backup_poll_frequency_secs)}); int32_t_flag_values_.insert( {"kv-server-local-udf-timeout-millis", absl::ToInt64Milliseconds(absl::GetFlag(FLAGS_udf_timeout))}); diff --git a/components/cloud_config/parameter_update/BUILD.bazel b/components/cloud_config/parameter_update/BUILD.bazel new file mode 100644 index 00000000..747283e8 --- /dev/null +++ b/components/cloud_config/parameter_update/BUILD.bazel @@ -0,0 +1,44 @@ +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +package(default_visibility = [ + "//components:__subpackages__", + "//tools:__subpackages__", +]) + +cc_library( + name = "parameter_notifier", + srcs = select({ + "//:aws_platform": ["parameter_notifier_aws.cc"], + "//:gcp_platform": ["parameter_notifier_gcp.cc"], + "//:local_platform": ["parameter_notifier_local.cc"], + }) + ["parameter_notifier.cc"], + hdrs = [ + "parameter_notifier.h", + ], + deps = [ + "//components/data/common:change_notifier", + "//components/data/common:thread_manager", + "//components/util:sleepfor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", + ], +) + +cc_test( + name = "parameter_notifier_test", + size = "small", + srcs = select({ + "//:aws_platform": ["parameter_notifier_test_aws.cc"], + "//:gcp_platform": ["parameter_notifier_test_gcp.cc"], + "//:local_platform": ["parameter_notifier_test_local.cc"], + }), + deps = [ + ":parameter_notifier", + "//components/data/common:mocks", + "//components/util:sleepfor_mock", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/components/cloud_config/parameter_update/parameter_notifier.cc b/components/cloud_config/parameter_update/parameter_notifier.cc new file mode 100644 index 00000000..984a3320 --- /dev/null +++ b/components/cloud_config/parameter_update/parameter_notifier.cc @@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/cloud_config/parameter_update/parameter_notifier.h" + +#include +#include + +namespace kv_server { + +using privacy_sandbox::server_common::ExpiringFlag; + +absl::Status ParameterNotifier::Stop() { + absl::Status status = sleep_for_->Stop(); + status.Update(thread_manager_->Stop()); + return status; +} + +bool ParameterNotifier::IsRunning() const { + return thread_manager_->IsRunning(); +} + +absl::StatusOr ParameterNotifier::ShouldGetParameter( + ExpiringFlag& expiring_flag) { + if (!expiring_flag.Get()) { + PS_VLOG(5, log_context_) + << "Backup poll on parameter update " << parameter_name_; + return true; + } + absl::StatusOr notification = + WaitForNotification(expiring_flag.GetTimeRemaining(), + [this]() { return thread_manager_->ShouldStop(); }); + + if (absl::IsDeadlineExceeded(notification.status())) { + // Deadline exceeded while waiting, trigger backup poll + PS_VLOG(5, log_context_) + << "Backup poll on parameter update " << parameter_name_; + return true; + } + if (!notification.ok()) { + return notification.status(); + } + return true; +} + +} // namespace kv_server diff --git a/components/cloud_config/parameter_update/parameter_notifier.h b/components/cloud_config/parameter_update/parameter_notifier.h new file mode 100644 index 00000000..e0d3e556 --- /dev/null +++ b/components/cloud_config/parameter_update/parameter_notifier.h @@ -0,0 +1,162 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_CLOUD_CONFIG_PARAMETER_UPDATE_PARAMETER_NOTIFIER_H_ +#define COMPONENTS_CLOUD_CONFIG_PARAMETER_UPDATE_PARAMETER_NOTIFIER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "components/data/common/change_notifier.h" +#include "components/data/common/notifier_metadata.h" +#include "components/data/common/thread_manager.h" +#include "components/util/sleepfor.h" +#include "src/logger/request_context_logger.h" + +namespace kv_server { + +// The ParameterNotifier watches the cloud pubsub notification on the +// changes of the value for a given parameter. When a notification is received +// or the poll period deadline is reached, this class executes the get parameter +// value callback to retrieve the updated parameter value, and executes the +// apply parameter value callback to use the retrieved parameter to do some +// operation (e.g. update logging verbosity with the updated verbosity level +// parameter value) +class ParameterNotifier { + public: + explicit ParameterNotifier( + std::unique_ptr notifier, std::string parameter_name, + const absl::Duration poll_frequency, std::unique_ptr sleep_for, + privacy_sandbox::server_common::SteadyClock& clock, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : notifier_(std::move(notifier)), + parameter_name_(std::move(parameter_name)), + thread_manager_( + ThreadManager::Create("Parameter Notifier " + parameter_name_)), + poll_frequency_(poll_frequency), + sleep_for_(std::move(sleep_for)), + clock_(clock), + log_context_(log_context) {} + virtual ~ParameterNotifier() = default; + // Starts watching the parameter updates. The ParamType is the data type of + // the value for the given parameter name. The data type can be int, bool or + // string etc, depending on how the callbacks are defined. + template + absl::Status Start( + std::function(std::string_view param_name)> + get_param_callback, + std::function apply_param_callback); + + // Blocks until `IsRunning` is False. + virtual absl::Status Stop(); + + // Returns False before calling `Start` or after `Stop` is + // successful. + virtual bool IsRunning() const; + + static absl::StatusOr> Create( + NotifierMetadata notifier_metadata, std::string parameter_name, + const absl::Duration poll_frequency, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); + + private: + template + // Starts thread for watching the parameter updates + void Watch( + std::function(std::string_view param_name)> + get_param_callback, + std::function apply_param_callback); + // Gets notification from pubsub, returns error status or the notification + // message + absl::StatusOr WaitForNotification( + absl::Duration wait_duration, + const std::function& should_stop_callback); + absl::StatusOr ShouldGetParameter( + privacy_sandbox::server_common::ExpiringFlag& expiring_flag); + std::unique_ptr notifier_; + const std::string parameter_name_; + std::unique_ptr thread_manager_; + const absl::Duration poll_frequency_; + std::unique_ptr sleep_for_; + privacy_sandbox::server_common::SteadyClock& clock_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; +}; + +template +absl::Status ParameterNotifier::Start( + std::function(std::string_view)> + get_param_callback, + std::function apply_param_callback) { + return thread_manager_->Start( + [this, get_param_callback = std::move(get_param_callback), + apply_param_callback = std::move(apply_param_callback)]() { + Watch(std::move(get_param_callback), std::move(apply_param_callback)); + }); +} + +template +void ParameterNotifier::Watch( + std::function(std::string_view)> + get_param_callback, + std::function apply_param_callback) { + PS_LOG(INFO, log_context_) + << "Started to watch " << parameter_name_ << " parameter update"; + privacy_sandbox::server_common::ExpiringFlag expiring_flag(clock_); + uint32_t sequential_failures = 0; + while (!thread_manager_->ShouldStop()) { + const absl::StatusOr should_get_parameter = + ShouldGetParameter(expiring_flag); + if (!should_get_parameter.ok()) { + ++sequential_failures; + const absl::Duration backoff_time = + std::min(expiring_flag.GetTimeRemaining(), + ExponentialBackoffForRetry(sequential_failures)); + PS_LOG(ERROR, log_context_) + << "Failed to get parameter update notifications: " << parameter_name_ + << ", " << should_get_parameter.status() << ". Waiting for " + << backoff_time; + if (!sleep_for_->Duration(backoff_time)) { + PS_LOG(ERROR, log_context_) + << "Failed to sleep for " << backoff_time << ". SleepFor invalid."; + } + continue; + } + sequential_failures = 0; + if (!*should_get_parameter) { + continue; + } + expiring_flag.Set(poll_frequency_); + auto param_result = get_param_callback(parameter_name_); + if (param_result.ok()) { + apply_param_callback(std::move(*param_result)); + PS_VLOG(5, log_context_) << "Applied the callback on the parameter"; + } else { + PS_LOG(ERROR, log_context_) << "Failed to get parameter value for " + << parameter_name_ << param_result.status(); + } + } +} + +} // namespace kv_server + +#endif // COMPONENTS_CLOUD_CONFIG_PARAMETER_UPDATE_PARAMETER_NOTIFIER_H_ diff --git a/components/cloud_config/parameter_update/parameter_notifier_aws.cc b/components/cloud_config/parameter_update/parameter_notifier_aws.cc new file mode 100644 index 00000000..06b814d5 --- /dev/null +++ b/components/cloud_config/parameter_update/parameter_notifier_aws.cc @@ -0,0 +1,52 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/cloud_config/parameter_update/parameter_notifier.h" + +namespace kv_server { + +using privacy_sandbox::server_common::SteadyClock; + +absl::StatusOr ParameterNotifier::WaitForNotification( + absl::Duration wait_duration, + const std::function& should_stop_callback) { + absl::StatusOr> changes = + notifier_->GetNotifications(wait_duration, should_stop_callback); + if (!changes.ok()) { + return changes.status(); + } + if ((*changes).empty()) { + return absl::DataLossError("Empty message in the notification"); + } + // return the last element + PS_VLOG(5, log_context_) << "Received notification for parameter update"; + return std::string((*changes).back()); +} + +absl::StatusOr> ParameterNotifier::Create( + NotifierMetadata notifier_metadata, std::string parameter_name, + const absl::Duration poll_frequency, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + auto cloud_notifier_metadata = + std::get(notifier_metadata); + cloud_notifier_metadata.queue_prefix = "ParameterNotifier_"; + PS_ASSIGN_OR_RETURN( + auto notifier, + ChangeNotifier::Create(std::move(cloud_notifier_metadata), log_context)); + return std::make_unique( + std::move(notifier), std::move(parameter_name), poll_frequency, + std::make_unique(), SteadyClock::RealClock(), log_context); +} + +} // namespace kv_server diff --git a/components/cloud_config/parameter_update/parameter_notifier_gcp.cc b/components/cloud_config/parameter_update/parameter_notifier_gcp.cc new file mode 100644 index 00000000..62597722 --- /dev/null +++ b/components/cloud_config/parameter_update/parameter_notifier_gcp.cc @@ -0,0 +1,45 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/cloud_config/parameter_update/parameter_notifier.h" + +namespace kv_server { + +using privacy_sandbox::server_common::SteadyClock; + +absl::StatusOr ParameterNotifier::WaitForNotification( + absl::Duration wait_duration, + const std::function& should_stop_callback) { + // TODO(b/356110894): Use change_notifier_gcp to get notifications from gcp + // pubsub once it is ready + sleep_for_->Duration(wait_duration); + return absl::DeadlineExceededError( + "Trigger backup poll before GCP change notifier is " + "implemented."); +} + +absl::StatusOr> ParameterNotifier::Create( + NotifierMetadata notifier_metadata, std::string parameter_name, + const absl::Duration poll_frequency, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + PS_ASSIGN_OR_RETURN( + auto notifier, + ChangeNotifier::Create(std::get(notifier_metadata), + log_context)); + return std::make_unique( + std::move(notifier), std::move(parameter_name), poll_frequency, + std::make_unique(), SteadyClock::RealClock(), log_context); +} + +} // namespace kv_server diff --git a/components/cloud_config/parameter_update/parameter_notifier_local.cc b/components/cloud_config/parameter_update/parameter_notifier_local.cc new file mode 100644 index 00000000..59d71712 --- /dev/null +++ b/components/cloud_config/parameter_update/parameter_notifier_local.cc @@ -0,0 +1,38 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/cloud_config/parameter_update/parameter_notifier.h" + +namespace kv_server { + +using privacy_sandbox::server_common::SteadyClock; + +absl::StatusOr ParameterNotifier::WaitForNotification( + absl::Duration wait_duration, + const std::function& should_stop_callback) { + sleep_for_->Duration(wait_duration); + return absl::DeadlineExceededError( + "Parameter pubsub notification" + "does not support local platform"); +} + +absl::StatusOr> ParameterNotifier::Create( + NotifierMetadata notifier_metadata, std::string parameter_name, + const absl::Duration poll_frequency, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique( + nullptr, std::move(parameter_name), poll_frequency, + std::make_unique(), SteadyClock::RealClock(), log_context); +} +} // namespace kv_server diff --git a/components/cloud_config/parameter_update/parameter_notifier_test_aws.cc b/components/cloud_config/parameter_update/parameter_notifier_test_aws.cc new file mode 100644 index 00000000..f99df5cc --- /dev/null +++ b/components/cloud_config/parameter_update/parameter_notifier_test_aws.cc @@ -0,0 +1,182 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/synchronization/notification.h" +#include "components/cloud_config/parameter_update/parameter_notifier.h" +#include "components/data/common/mocks.h" +#include "components/util/sleepfor_mock.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +using privacy_sandbox::server_common::SimulatedSteadyClock; +using testing::_; +using testing::AllOf; +using testing::Field; +using testing::Return; + +class ParameterNotifierAWSTest : public ::testing::Test { + protected: + void SetUp() override { + std::unique_ptr mock_change_notifier = + std::make_unique(); + change_notifier_ = mock_change_notifier.get(); + std::unique_ptr mock_sleep_for = + std::make_unique(); + sleep_for_ = mock_sleep_for.get(); + notifier_ = std::make_unique( + std::move(mock_change_notifier), parameter_name_, poll_frequency_, + std::move(mock_sleep_for), sim_clock_, + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); + } + std::unique_ptr notifier_; + MockChangeNotifier* change_notifier_; + MockSleepFor* sleep_for_; + SimulatedSteadyClock sim_clock_; + absl::Duration poll_frequency_ = absl::Minutes(5); + std::string parameter_name_ = "test_parameter"; +}; + +TEST_F(ParameterNotifierAWSTest, NotRunning) { + ASSERT_FALSE(notifier_->IsRunning()); +} + +TEST_F(ParameterNotifierAWSTest, StartsAndStops) { + absl::Status status = notifier_->Start( + [](std::string_view param_name) { return "test_value"; }, + [](std::string param_value) {}); + ASSERT_TRUE(status.ok()); + EXPECT_TRUE(notifier_->IsRunning()); + status = notifier_->Stop(); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(notifier_->IsRunning()); +} + +TEST_F(ParameterNotifierAWSTest, NotifiesWithParameterUpdateIncludeFailures) { + std::string param_update_triggerred_by_notification = "value_n"; + std::string param_update_triggerred_by_backpoll = "value_b"; + EXPECT_CALL(*change_notifier_, GetNotifications(_, _)) + .WillOnce(Return(absl::InvalidArgumentError("stuff"))) + .WillOnce(Return(absl::InvalidArgumentError("stuff"))) + .WillOnce(Return(std::vector({"pubsub_update"}))) + .WillOnce(Return(absl::DeadlineExceededError("no message"))) + .WillRepeatedly(Return(std::vector())); + absl::Notification finished; + testing::MockFunction( + std::string_view param_name)> + get_parameter_callback; + EXPECT_CALL(get_parameter_callback, Call) + .Times(3) + // called from initial poll + .WillOnce([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return param_update_triggerred_by_backpoll; + }) + // called when notification is received + .WillOnce([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return param_update_triggerred_by_notification; + }) + // called when there is no message during the notification wait period + .WillOnce([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return param_update_triggerred_by_backpoll; + }); + testing::MockFunction apply_parameter_callback; + EXPECT_CALL(apply_parameter_callback, Call) + .Times(3) + .WillOnce([&](std::string param_value) { + EXPECT_EQ(param_value, param_update_triggerred_by_backpoll); + }) + .WillOnce([&](std::string param_value) { + EXPECT_EQ(param_value, param_update_triggerred_by_notification); + }) + .WillOnce([&](std::string param_value) { + EXPECT_EQ(param_value, param_update_triggerred_by_backpoll); + finished.Notify(); + }); + EXPECT_CALL(*sleep_for_, Duration(_)).WillRepeatedly(Return(true)); + absl::Status status = + notifier_->Start(get_parameter_callback.AsStdFunction(), + apply_parameter_callback.AsStdFunction()); + ASSERT_TRUE(status.ok()); + EXPECT_TRUE(notifier_->IsRunning()); + finished.WaitForNotification(); + status = notifier_->Stop(); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(notifier_->IsRunning()); +} + +TEST_F(ParameterNotifierAWSTest, BackupPollOnly) { + std::string param_update_triggerred_by_backpoll_1 = "value_1"; + std::string param_update_triggerred_by_backpoll_2 = "value_2"; + std::string param_update_triggerred_by_backpoll_3 = "value_3"; + absl::Notification finished; + testing::MockFunction( + std::string_view param_name)> + get_parameter_callback; + EXPECT_CALL(get_parameter_callback, Call) + .Times(3) + // called from initial poll + .WillOnce([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return param_update_triggerred_by_backpoll_1; + }) + // called due to expiring flag + .WillOnce([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return param_update_triggerred_by_backpoll_2; + }) + // called due to timeout + .WillOnce([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return param_update_triggerred_by_backpoll_3; + }); + testing::MockFunction apply_parameter_callback; + EXPECT_CALL(apply_parameter_callback, Call) + .Times(3) + .WillOnce([&](std::string param_value) { + EXPECT_EQ(param_value, param_update_triggerred_by_backpoll_1); + }) + .WillOnce([&](std::string param_value) { + sim_clock_.AdvanceTime(poll_frequency_ + absl::Seconds(1)); + EXPECT_EQ(param_value, param_update_triggerred_by_backpoll_2); + }) + .WillOnce([&](std::string param_value) { + EXPECT_EQ(param_value, param_update_triggerred_by_backpoll_3); + finished.Notify(); + }); + EXPECT_CALL(*change_notifier_, GetNotifications(_, _)) + .WillOnce(Return(absl::DeadlineExceededError("time out"))) + .WillRepeatedly(Return(std::vector())); + absl::Status status = + notifier_->Start(get_parameter_callback.AsStdFunction(), + apply_parameter_callback.AsStdFunction()); + ASSERT_TRUE(status.ok()); + EXPECT_TRUE(notifier_->IsRunning()); + finished.WaitForNotification(); + status = notifier_->Stop(); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(notifier_->IsRunning()); +} +} // namespace + +} // namespace kv_server diff --git a/components/cloud_config/parameter_update/parameter_notifier_test_gcp.cc b/components/cloud_config/parameter_update/parameter_notifier_test_gcp.cc new file mode 100644 index 00000000..8f63e345 --- /dev/null +++ b/components/cloud_config/parameter_update/parameter_notifier_test_gcp.cc @@ -0,0 +1,127 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/synchronization/notification.h" +#include "components/cloud_config/parameter_update/parameter_notifier.h" +#include "components/data/common/mocks.h" +#include "components/util/sleepfor_mock.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +using privacy_sandbox::server_common::SimulatedSteadyClock; +using testing::_; +using testing::AllOf; +using testing::Field; +using testing::Return; + +// TODO(b/356110894) Remove or combine this test with AWS test once +// change_notifier_gcp is ready to use +class ParameterNotifierGCPTest : public ::testing::Test { + protected: + void SetUp() override { + std::unique_ptr mock_change_notifier = + std::make_unique(); + change_notifier_ = mock_change_notifier.get(); + std::unique_ptr mock_sleep_for = + std::make_unique(); + sleep_for_ = mock_sleep_for.get(); + notifier_ = std::make_unique( + std::move(mock_change_notifier), parameter_name_, poll_frequency_, + std::move(mock_sleep_for), sim_clock_, + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); + } + std::unique_ptr notifier_; + MockChangeNotifier* change_notifier_; + MockSleepFor* sleep_for_; + SimulatedSteadyClock sim_clock_; + absl::Duration poll_frequency_ = absl::Minutes(5); + std::string parameter_name_ = "test_parameter"; +}; + +TEST_F(ParameterNotifierGCPTest, NotRunning) { + ASSERT_FALSE(notifier_->IsRunning()); +} + +TEST_F(ParameterNotifierGCPTest, StartsAndStops) { + absl::Status status = notifier_->Start( + [](std::string_view param_name) { return "test_value"; }, + [](std::string param_value) {}); + ASSERT_TRUE(status.ok()); + EXPECT_TRUE(notifier_->IsRunning()); + status = notifier_->Stop(); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(notifier_->IsRunning()); +} + +TEST_F(ParameterNotifierGCPTest, BackupPollOnly) { + std::string param_update_triggerred_by_backpoll_1 = "value_1"; + std::string param_update_triggerred_by_backpoll_2 = "value_2"; + std::string param_update_triggerred_by_backpoll_3 = "value_3"; + absl::Notification finished; + testing::MockFunction( + std::string_view param_name)> + get_parameter_callback; + EXPECT_CALL(get_parameter_callback, Call) + // called from initial poll + .WillOnce([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return param_update_triggerred_by_backpoll_1; + }) + .WillOnce([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return param_update_triggerred_by_backpoll_2; + }) + .WillOnce([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return param_update_triggerred_by_backpoll_3; + }) + .WillRepeatedly([&](std::string_view param_name) { + EXPECT_EQ(param_name, parameter_name_); + return ""; + }); + testing::MockFunction apply_parameter_callback; + EXPECT_CALL(apply_parameter_callback, Call) + .WillOnce([&](std::string param_value) { + EXPECT_EQ(param_value, param_update_triggerred_by_backpoll_1); + }) + .WillOnce([&](std::string param_value) { + EXPECT_EQ(param_value, param_update_triggerred_by_backpoll_2); + }) + .WillOnce([&](std::string param_value) { + EXPECT_EQ(param_value, param_update_triggerred_by_backpoll_3); + finished.Notify(); + }) + .WillRepeatedly( + [&](std::string param_value) { EXPECT_EQ(param_value, ""); }); + EXPECT_CALL(*sleep_for_, Duration(_)).WillRepeatedly(Return(true)); + absl::Status status = + notifier_->Start(get_parameter_callback.AsStdFunction(), + apply_parameter_callback.AsStdFunction()); + ASSERT_TRUE(status.ok()); + EXPECT_TRUE(notifier_->IsRunning()); + finished.WaitForNotification(); + status = notifier_->Stop(); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(notifier_->IsRunning()); +} +} // namespace +} // namespace kv_server diff --git a/components/cloud_config/parameter_update/parameter_notifier_test_local.cc b/components/cloud_config/parameter_update/parameter_notifier_test_local.cc new file mode 100644 index 00000000..6a1ed7a7 --- /dev/null +++ b/components/cloud_config/parameter_update/parameter_notifier_test_local.cc @@ -0,0 +1,69 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/synchronization/notification.h" +#include "components/cloud_config/parameter_update/parameter_notifier.h" +#include "components/data/common/mocks.h" +#include "components/util/sleepfor_mock.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +using privacy_sandbox::server_common::SimulatedSteadyClock; +using testing::_; +using testing::AllOf; +using testing::Field; +using testing::Return; + +// The runtime parameter notification does not support local platform, +// we only need to test start and stop here. + +class ParameterNotifierLocalTest : public ::testing::Test { + protected: + void SetUp() override { + notifier_ = std::make_unique( + nullptr, std::move(parameter_name_), poll_frequency_, + std::make_unique(), sim_clock_, + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); + } + std::unique_ptr notifier_; + MockSleepFor* sleep_for_; + SimulatedSteadyClock sim_clock_; + absl::Duration poll_frequency_ = absl::Minutes(5); + std::string parameter_name_ = "test_parameter"; +}; + +TEST_F(ParameterNotifierLocalTest, NotRunningSmokeTest) { + ASSERT_FALSE(notifier_->IsRunning()); +} + +TEST_F(ParameterNotifierLocalTest, StartsAndStopsSmokeTest) { + absl::Status status = notifier_->Start( + [](std::string_view param_name) { return "test_value"; }, + [](std::string param_value) {}); + ASSERT_TRUE(status.ok()); + EXPECT_TRUE(notifier_->IsRunning()); + status = notifier_->Stop(); + ASSERT_TRUE(status.ok()); + EXPECT_FALSE(notifier_->IsRunning()); +} +} // namespace +} // namespace kv_server diff --git a/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc b/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc index 69bc4e8f..dd9863bd 100644 --- a/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc +++ b/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc @@ -55,14 +55,7 @@ class MockSqsClient : public ::Aws::SQS::SQSClient { // https://docs.aws.amazon.com/AmazonS3/latest/userguide/notification-content-structure.html class BlobStorageChangeNotifierS3Test : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } void CreateRequiredSqsCallExpectations() { static const std::string mock_sqs_url("mock sqs url"); EXPECT_CALL(mock_message_service_, IsSetupComplete) diff --git a/components/data/blob_storage/blob_storage_client_gcp_test.cc b/components/data/blob_storage/blob_storage_client_gcp_test.cc index ac4b4900..6f6256ff 100644 --- a/components/data/blob_storage/blob_storage_client_gcp_test.cc +++ b/components/data/blob_storage/blob_storage_client_gcp_test.cc @@ -51,14 +51,7 @@ class GcpBlobStorageClientTest : public ::testing::Test { protected: PlatformInitializer initializer_; privacy_sandbox::server_common::log::NoOpContext no_op_context_; - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } }; TEST_F(GcpBlobStorageClientTest, DeleteBlobSucceeds) { diff --git a/components/data/blob_storage/blob_storage_client_s3_test.cc b/components/data/blob_storage/blob_storage_client_s3_test.cc index 938810b8..17b215f1 100644 --- a/components/data/blob_storage/blob_storage_client_s3_test.cc +++ b/components/data/blob_storage/blob_storage_client_s3_test.cc @@ -55,14 +55,7 @@ class MockS3Client : public ::Aws::S3::S3Client { class BlobStorageClientS3Test : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } privacy_sandbox::server_common::log::NoOpContext no_op_context_; private: diff --git a/components/data/blob_storage/seeking_input_streambuf_test.cc b/components/data/blob_storage/seeking_input_streambuf_test.cc index 845f6995..4ec47c36 100644 --- a/components/data/blob_storage/seeking_input_streambuf_test.cc +++ b/components/data/blob_storage/seeking_input_streambuf_test.cc @@ -66,14 +66,7 @@ class StringBlobInputStreambuf : public SeekingInputStreambuf { class SeekingInputStreambufTest : public testing::TestWithParam { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } StringBlobInputStreambuf CreateStringBlobStreambuf(std::string_view blob) { TelemetryProvider::Init("test", "test"); return StringBlobInputStreambuf(blob, GetParam()); diff --git a/components/data/common/change_notifier_aws_test.cc b/components/data/common/change_notifier_aws_test.cc index 8850a1ef..cb2d6542 100644 --- a/components/data/common/change_notifier_aws_test.cc +++ b/components/data/common/change_notifier_aws_test.cc @@ -55,14 +55,7 @@ class MockSqsClient : public ::Aws::SQS::SQSClient { class ChangeNotifierAwsTest : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } private: PlatformInitializer initializer_; diff --git a/components/data/converters/BUILD.bazel b/components/data/converters/BUILD.bazel new file mode 100644 index 00000000..ed98b203 --- /dev/null +++ b/components/data/converters/BUILD.bazel @@ -0,0 +1,80 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +cc_library( + name = "scoped_cbor", + hdrs = [ + "scoped_cbor.h", + ], + deps = [ + "@libcbor//:cbor", + ], +) + +cc_library( + name = "cbor_converter_utils", + srcs = [ + "cbor_converter_utils.cc", + ], + hdrs = [ + "cbor_converter_utils.h", + ], + deps = [ + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@libcbor//:cbor", + ], +) + +cc_library( + name = "cbor_converter", + srcs = [ + "cbor_converter.cc", + ], + hdrs = [ + "cbor_converter.h", + ], + visibility = [ + "//components/data_server:__subpackages__", + "//components/tools:__subpackages__", + "//infrastructure:__subpackages__", + ], + deps = [ + ":cbor_converter_utils", + ":scoped_cbor", + "//public/applications/pa:api_overlay_cc_proto", + "//public/applications/pa:response_utils", + "//public/query/v2:get_values_v2_cc_proto", + "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", + "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", + "@libcbor//:cbor", + "@nlohmann_json//:lib", + ], +) + +cc_test( + name = "cbor_converter_test", + size = "small", + srcs = ["cbor_converter_test.cc"], + deps = [ + ":cbor_converter", + "//public/test_util:proto_matcher", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@nlohmann_json//:lib", + ], +) diff --git a/components/data/converters/cbor_converter.cc b/components/data/converters/cbor_converter.cc new file mode 100644 index 00000000..29370ae5 --- /dev/null +++ b/components/data/converters/cbor_converter.cc @@ -0,0 +1,314 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data/converters/cbor_converter.h" + +#include +#include +#include + +#include "absl/log/log.h" +#include "components/data/converters/cbor_converter_utils.h" +#include "components/data/converters/scoped_cbor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/json_util.h" +#include "nlohmann/json.hpp" +#include "public/applications/pa/response_utils.h" +#include "public/query/v2/get_values_v2.pb.h" +#include "src/util/status_macro/status_macros.h" + +#include "cbor.h" + +namespace kv_server { + +namespace { +inline constexpr char kCompressionGroups[] = "compressionGroups"; +inline constexpr char kCompressionGroupId[] = "compressionGroupId"; +inline constexpr char kTtlMs[] = "ttlMs"; +inline constexpr char kContent[] = "content"; + +inline constexpr char kPartitionOutputs[] = "partitionOutputs"; +inline constexpr char kPartitionId[] = "id"; +inline constexpr char kKeyGroupOutputs[] = "keyGroupOutputs"; +inline constexpr char kTags[] = "tags"; +inline constexpr char kKeyValues[] = "keyValues"; +inline constexpr char kValue[] = "value"; + +absl::StatusOr EncodeCompressionGroup( + v2::CompressionGroup& compression_group) { + const int compressionGroupKeysNumber = 3; + auto* cbor_internal = cbor_new_definite_map(compressionGroupKeysNumber); + if (compression_group.has_ttl_ms()) { + PS_RETURN_IF_ERROR( + CborSerializeUInt(kTtlMs, compression_group.ttl_ms(), *cbor_internal)); + } + PS_RETURN_IF_ERROR(CborSerializeByteString( + kContent, std::move(compression_group.content()), *cbor_internal)); + PS_RETURN_IF_ERROR(CborSerializeUInt(kCompressionGroupId, + compression_group.compression_group_id(), + *cbor_internal)); + + return cbor_internal; +} + +absl::StatusOr EncodeCompressionGroups( + google::protobuf::RepeatedPtrField& + compression_groups) { + cbor_item_t* serialized_compression_groups = + cbor_new_definite_array(compression_groups.size()); + for (auto& compression_group : compression_groups) { + PS_ASSIGN_OR_RETURN(auto* serialized_compression_group, + EncodeCompressionGroup(compression_group)); + if (!cbor_array_push(serialized_compression_groups, + cbor_move(serialized_compression_group))) { + return absl::InternalError(absl::StrCat("Failed to serialize ", + kCompressionGroups, " to CBOR. ", + compression_group)); + } + } + + return serialized_compression_groups; +} + +absl::StatusOr EncodeKeyGroupOutput( + application_pa::KeyGroupOutput& key_group_output) { + const int keyGroupOutputKeysNumber = 2; + auto* cbor_internal = cbor_new_definite_map(keyGroupOutputKeysNumber); + // tags + cbor_item_t* serialized_tags = + cbor_new_definite_array(key_group_output.tags().size()); + + for (auto& tag : key_group_output.tags()) { + if (!cbor_array_push(serialized_tags, cbor_move(cbor_build_stringn( + tag.data(), tag.size())))) { + return absl::InternalError(absl::StrCat("Failed to serialize ", kTags, + " to CBOR. ", key_group_output)); + } + } + struct cbor_pair serialized_serialized_tags_pair = { + .key = cbor_move(cbor_build_stringn(kTags, sizeof(kTags) - 1)), + .value = serialized_tags, + }; + if (!cbor_map_add(cbor_internal, serialized_serialized_tags_pair)) { + return absl::InternalError(absl::StrCat("Failed to serialize ", kTags, + " to CBOR. ", key_group_output)); + } + // key_values + cbor_item_t* serialized_key_values = + cbor_new_definite_map(key_group_output.key_values().size()); + std::vector> kv_vector; + for (auto&& [key, value] : *(key_group_output.mutable_key_values())) { + std::string value_str = std::move(value.mutable_value()->string_value()); + auto* cbor_internal_value = cbor_new_definite_map(1); + struct cbor_pair serialized_value_pair = { + .key = cbor_move(cbor_build_stringn(kValue, sizeof(kValue) - 1)), + .value = + cbor_move(cbor_build_stringn(value_str.c_str(), value_str.size())), + }; + + if (!cbor_map_add(cbor_internal_value, serialized_value_pair)) { + return absl::InternalError(absl::StrCat("Failed to serialize ", kValue, + " to CBOR. ", key_group_output)); + } + struct cbor_pair serialized_key_value_pair = { + .key = cbor_move(cbor_build_stringn(key.c_str(), key.size())), + .value = cbor_internal_value, + }; + kv_vector.emplace_back(key, serialized_key_value_pair); + } + // Following the chromium implementation, we only need to check that + // the length and lexicographic order of the plaintext string + // https://chromium.googlesource.com/chromium/src/components/cbor/+/10d0a11b998d2cca774189ba26159ad4e1eacb7f/values.h#59 + // https://chromium.googlesource.com/chromium/src/components/cbor/+/10d0a11b998d2cca774189ba26159ad4e1eacb7f/values.cc#109 + std::sort(kv_vector.begin(), kv_vector.end(), [](auto& left, auto& right) { + const auto left_size = left.first.size(); + const auto& left_str = left.first; + const auto right_size = right.first.size(); + const auto& right_str = right.first; + return std::tie(left_size, left_str) < std::tie(right_size, right_str); + }); + for (auto&& [key, serialized_key_value_pair] : kv_vector) { + if (!cbor_map_add(serialized_key_values, serialized_key_value_pair)) { + return absl::InternalError(absl::StrCat( + "Failed to serialize ", kKeyValues, " to CBOR. ", key_group_output)); + } + } + struct cbor_pair serialized_key_values_pair = { + .key = cbor_move(cbor_build_stringn(kKeyValues, sizeof(kKeyValues) - 1)), + .value = serialized_key_values, + }; + if (!cbor_map_add(cbor_internal, serialized_key_values_pair)) { + return absl::InternalError(absl::StrCat("Failed to serialize ", kKeyValues, + " to CBOR. ", key_group_output)); + } + return cbor_internal; +} + +absl::StatusOr EncodePartitionOutput( + application_pa::PartitionOutput& partition_output) { + const int partitionKeysNumber = 2; + auto* cbor_internal = cbor_new_definite_map(partitionKeysNumber); + PS_RETURN_IF_ERROR( + CborSerializeUInt(kPartitionId, partition_output.id(), *cbor_internal)); + cbor_item_t* serialized_key_group_outputs = + cbor_new_definite_array(partition_output.key_group_outputs().size()); + for (auto& key_group_output : + *(partition_output.mutable_key_group_outputs())) { + PS_ASSIGN_OR_RETURN(auto* serialized_key_group_output, + EncodeKeyGroupOutput(key_group_output)); + if (!cbor_array_push(serialized_key_group_outputs, + cbor_move(serialized_key_group_output))) { + return absl::InternalError(absl::StrCat("Failed to serialize ", + kPartitionOutputs, " to CBOR", + partition_output)); + } + } + struct cbor_pair serialized_key_group_outputs_pair = { + .key = cbor_move( + cbor_build_stringn(kKeyGroupOutputs, sizeof(kKeyGroupOutputs) - 1)), + .value = serialized_key_group_outputs, + }; + if (!cbor_map_add(cbor_internal, serialized_key_group_outputs_pair)) { + return absl::InternalError(absl::StrCat("Failed to serialize ", + kKeyGroupOutputs, " to CBOR. ", + partition_output)); + } + return cbor_internal; +} + +absl::Status EncodePartitionOutputs( + google::protobuf::RepeatedPtrField& + partition_outputs, + cbor_item_t* serialized_partition_outputs) { + for (auto& partition_output : partition_outputs) { + PS_ASSIGN_OR_RETURN(auto* serialized_partition_output, + EncodePartitionOutput(partition_output)); + if (!cbor_array_push(serialized_partition_outputs, + cbor_move(serialized_partition_output))) { + return absl::InternalError(absl::StrCat("Failed to serialize ", + kPartitionOutputs, " to CBOR. ", + partition_output)); + } + } + return absl::OkStatus(); +} + +} // namespace +absl::StatusOr V2GetValuesResponseCborEncode( + v2::GetValuesResponse& response) { + if (response.has_single_partition()) { + return absl::InvalidArgumentError( + "single_partition is not supported for cbor content type"); + } + const int getValuesResponseKeysNumber = 1; + ScopedCbor root(cbor_new_definite_map(getValuesResponseKeysNumber)); + PS_ASSIGN_OR_RETURN( + auto* compression_groups, + EncodeCompressionGroups(*(response.mutable_compression_groups()))); + struct cbor_pair serialized_compression_groups = { + .key = cbor_move(cbor_build_stringn(kCompressionGroups, + sizeof(kCompressionGroups) - 1)), + .value = compression_groups, + }; + auto* cbor_internal = root.get(); + if (!cbor_map_add(cbor_internal, serialized_compression_groups)) { + return absl::InternalError(absl::StrCat( + "Failed to serialize ", kCompressionGroups, " to CBOR. ", response)); + } + return GetCborSerializedResult(*cbor_internal); +} + +absl::StatusOr V2GetValuesRequestJsonStringCborEncode( + std::string_view serialized_json) { + nlohmann::json json_req = nlohmann::json::parse(serialized_json, nullptr, + /*allow_exceptions=*/false, + /*ignore_comments=*/true); + if (json_req.is_discarded()) { + return absl::InternalError(absl::StrCat( + "Unable to parse json req from string: ", serialized_json)); + } + std::vector cbor_vec = nlohmann::json::to_cbor(json_req); + return std::string(cbor_vec.begin(), cbor_vec.end()); +} + +absl::StatusOr V2GetValuesRequestProtoToCborEncode( + const v2::GetValuesRequest& proto_req) { + std::string json_req_string; + if (const auto json_status = google::protobuf::json::MessageToJsonString( + proto_req, &json_req_string); + !json_status.ok()) { + return absl::InternalError(absl::StrCat( + "Unable to convert proto request to json string: ", proto_req)); + } + return V2GetValuesRequestJsonStringCborEncode(json_req_string); +} + +absl::StatusOr PartitionOutputsCborEncode( + google::protobuf::RepeatedPtrField& + partition_outputs) { + ScopedCbor root(cbor_new_definite_array(partition_outputs.size())); + auto* cbor_internal = root.get(); + PS_RETURN_IF_ERROR(EncodePartitionOutputs(partition_outputs, cbor_internal)); + return GetCborSerializedResult(*cbor_internal); +} + +absl::StatusOr GetPartitionOutputsInJson( + const nlohmann::json& content_json) { + std::vector content_cbor = nlohmann::json::to_cbor(content_json); + std::string content_cbor_string = + std::string(content_cbor.begin(), content_cbor.end()); + struct cbor_load_result result; + cbor_item_t* cbor_bytestring = cbor_load( + reinterpret_cast(content_cbor_string.data()), + content_cbor_string.size(), &result); + auto partition_output_cbor = cbor_bytestring_handle(cbor_bytestring); + auto cbor_bytestring_len = cbor_bytestring_length(cbor_bytestring); + return nlohmann::json::from_cbor(std::vector( + partition_output_cbor, partition_output_cbor + cbor_bytestring_len)); +} + +absl::StatusOr V2CompressionGroupCborEncode( + application_pa::V2CompressionGroup& comp_group) { + const int getCompressionGroupKeysNumber = 1; + ScopedCbor root(cbor_new_definite_map(getCompressionGroupKeysNumber)); + cbor_item_t* partition_outputs = + cbor_new_definite_array(comp_group.partition_outputs().size()); + PS_RETURN_IF_ERROR(EncodePartitionOutputs( + *(comp_group.mutable_partition_outputs()), partition_outputs)); + struct cbor_pair serialized_partition_outputs = { + .key = cbor_move( + cbor_build_stringn(kPartitionOutputs, sizeof(kPartitionOutputs) - 1)), + .value = partition_outputs, + }; + auto* cbor_internal = root.get(); + if (!cbor_map_add(cbor_internal, serialized_partition_outputs)) { + return absl::InternalError(absl::StrCat( + "Failed to serialize ", kPartitionOutputs, " to CBOR. ", comp_group)); + } + return GetCborSerializedResult(*cbor_internal); +} + +absl::Status CborDecodeToNonBytesProto(std::string_view cbor_raw, + google::protobuf::Message& message) { + // TODO(b/353537363): Skip intermediate JSON conversion step + nlohmann::json json_from_cbor = nlohmann::json::from_cbor( + cbor_raw, /*strict=*/true, /*allow_exceptions=*/false); + if (json_from_cbor.is_discarded()) { + return absl::InternalError("Failed to convert raw CBOR buffer to JSON"); + } + return google::protobuf::util::JsonStringToMessage(json_from_cbor.dump(), + &message); +} + +} // namespace kv_server diff --git a/components/data/converters/cbor_converter.h b/components/data/converters/cbor_converter.h new file mode 100644 index 00000000..b545c39b --- /dev/null +++ b/components/data/converters/cbor_converter.h @@ -0,0 +1,53 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_CONVERTER_H +#define COMPONENTS_DATA_CONVERTER_H + +#include + +#include "absl/status/statusor.h" +#include "nlohmann/json.hpp" +#include "public/applications/pa/api_overlay.pb.h" +#include "public/query/v2/get_values_v2.pb.h" + +namespace kv_server { + +absl::StatusOr V2GetValuesResponseCborEncode( + v2::GetValuesResponse& response); + +absl::StatusOr V2CompressionGroupCborEncode( + application_pa::V2CompressionGroup& comp_group); + +absl::StatusOr V2GetValuesRequestJsonStringCborEncode( + std::string_view serialized_json); + +absl::StatusOr V2GetValuesRequestProtoToCborEncode( + const v2::GetValuesRequest& proto_req); + +absl::StatusOr PartitionOutputsCborEncode( + google::protobuf::RepeatedPtrField& + partition_outputs); + +absl::StatusOr GetPartitionOutputsInJson( + const nlohmann::json& content_json); + +// Converts a CBOR serialized string to a proto that does not contain a `bytes` +// field. Will return error if the proto contains `bytes`. +absl::Status CborDecodeToNonBytesProto(std::string_view cbor_raw, + google::protobuf::Message& message); +} // namespace kv_server +#endif // COMPONENTS_DATA_CONVERTER_H diff --git a/components/data/converters/cbor_converter_test.cc b/components/data/converters/cbor_converter_test.cc new file mode 100644 index 00000000..3a34100a --- /dev/null +++ b/components/data/converters/cbor_converter_test.cc @@ -0,0 +1,549 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data/converters/cbor_converter.h" + +#include + +#include "gmock/gmock.h" +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "nlohmann/json.hpp" +#include "public/query/v2/get_values_v2.pb.h" +#include "public/test_util/proto_matcher.h" + +namespace kv_server { +namespace { + +using json = nlohmann::json; +using ordered_json = nlohmann::ordered_json; +using google::protobuf::TextFormat; + +TEST(CborConverterTest, V2GetValuesResponseCborEncodeSuccess) { + // "abc" -> [97,98,99] as byte array + ordered_json json_etalon = nlohmann::ordered_json::parse(R"({ + "compressionGroups": [ + { + "ttlMs": 2, + "content": {"bytes":[97,98,99],"subtype":null}, + "compressionGroupId": 1 + } + ] + })"); + v2::GetValuesResponse response; + TextFormat::ParseFromString( + R"pb( + compression_groups { compression_group_id: 1 ttl_ms: 2 content: "abc" } + )pb", + &response); + absl::StatusOr cbor_encoded_proto_maybe = + V2GetValuesResponseCborEncode(response); + ASSERT_TRUE(cbor_encoded_proto_maybe.ok()) + << cbor_encoded_proto_maybe.status(); + EXPECT_EQ( + json_etalon.dump(), + nlohmann::ordered_json::from_cbor(*cbor_encoded_proto_maybe).dump()); +} + +TEST(CborConverterTest, V2GetValuesResponseCborEncode_SinglePartition_Failure) { + v2::GetValuesResponse response; + TextFormat::ParseFromString( + R"pb( + single_partition {} + )pb", + &response); + absl::StatusOr cbor_encoded_proto_maybe = + V2GetValuesResponseCborEncode(response); + ASSERT_FALSE(cbor_encoded_proto_maybe.ok()) + << cbor_encoded_proto_maybe.status(); +} + +TEST(CborConverterTest, V2GetValuesResponseCborEncodeArrayMsSuccess) { + ordered_json json_etalon = nlohmann::ordered_json::parse(R"({ + "compressionGroups": [ + { + "content": {"bytes":[97,98,99], "subtype":null }, + "compressionGroupId": 1 + + }, + { + "ttlMs": 2, + "content": {"bytes":[97,98,99,100], "subtype":null }, + "compressionGroupId": 2 + } + ] + })"); + + v2::GetValuesResponse response; + TextFormat::ParseFromString( + R"pb( + compression_groups { compression_group_id: 1 content: "abc" } + compression_groups { compression_group_id: 2 ttl_ms: 2 content: "abcd" } + )pb", + &response); + absl::StatusOr cbor_encoded_proto_maybe = + V2GetValuesResponseCborEncode(response); + ASSERT_TRUE(cbor_encoded_proto_maybe.ok()) + << cbor_encoded_proto_maybe.status(); + EXPECT_EQ(json_etalon.dump(), + ordered_json::from_cbor(*cbor_encoded_proto_maybe).dump()); +} + +TEST(CborConverterTest, V2CompressionGroupCborEncodeSuccess) { + ordered_json json_etalon = ordered_json::parse(R"( + { + "partitionOutputs": [ + { + "id": 0, + "keyGroupOutputs": [ + { + "tags": [ + "custom", + "keys" + ], + "keyValues": { + "hello": { + "value": "world" + } + } + }, + { + "tags": [ + "structured", + "groupNames" + ], + "keyValues": { + "hello": { + "value": "world" + } + } + } + ] + }, + { + "id": 1, + "keyGroupOutputs": [ + { + "tags": [ + "custom", + "keys" + ], + "keyValues": { + "hello2": { + "value": "world2" + } + } + } + ] + } + ] + } +)"); + + application_pa::V2CompressionGroup compression_group; + TextFormat::ParseFromString( + R"pb(partition_outputs { + id: 0 + key_group_outputs { + tags: "custom" + tags: "keys" + key_values { + key: "hello" + value { value { string_value: "world" } } + } + } + key_group_outputs { + tags: "structured" + tags: "groupNames" + key_values { + key: "hello" + value { value { string_value: "world" } } + } + } + } + partition_outputs { + id: 1 + key_group_outputs { + tags: "custom" + tags: "keys" + key_values { + key: "hello2" + value { value { string_value: "world2" } } + } + } + })pb", + &compression_group); + absl::StatusOr cbor_encoded_proto_maybe = + V2CompressionGroupCborEncode(compression_group); + ASSERT_TRUE(cbor_encoded_proto_maybe.ok()); + EXPECT_EQ(json_etalon.dump(), + ordered_json::from_cbor(*cbor_encoded_proto_maybe).dump()); +} + +TEST(CborConverterTest, + V2CompressionGroupEmptyKeyGroupOutputsCborEncodeSuccess) { + json json_etalon = R"( + { + "partitionOutputs": [ + { + "id": 0, + "keyGroupOutputs": [] + } + ] + } +)"_json; + + application_pa::V2CompressionGroup compression_group; + TextFormat::ParseFromString(R"pb(partition_outputs { id: 0 })pb", + &compression_group); + absl::StatusOr cbor_encoded_proto_maybe = + V2CompressionGroupCborEncode(compression_group); + ASSERT_TRUE(cbor_encoded_proto_maybe.ok()); + ASSERT_TRUE(json_etalon == json::from_cbor(*cbor_encoded_proto_maybe)); +} + +TEST(CborConverterTest, V2GetValuesRequestJsonStringCborEncodeSuccess) { + std::string json_request = R"({ + "partitions": [ + { + "id": 0, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "structured", + "groupNames" + ], + "data": [ + "hello" + ] + } + ] + } + ] + })"; + absl::StatusOr cbor_encoded_request_maybe = + V2GetValuesRequestJsonStringCborEncode(json_request); + ASSERT_TRUE(cbor_encoded_request_maybe.ok()) + << cbor_encoded_request_maybe.status(); + EXPECT_EQ(nlohmann::json::parse(json_request), + nlohmann::json::from_cbor(*cbor_encoded_request_maybe)); +} + +TEST(CborConverterTest, + V2GetValuesRequestJsonStringCborEncodeInvalidJsonFails) { + std::string json_request = R"({ + "partitions": [ + { + "id": 0, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "structured", + "groupNames" + ], + "data": [ + "hello" + ] + } + ] + } + ], + })"; + absl::StatusOr cbor_encoded_request_maybe = + V2GetValuesRequestJsonStringCborEncode(json_request); + ASSERT_FALSE(cbor_encoded_request_maybe.ok()) + << cbor_encoded_request_maybe.status(); +} + +TEST(CborConverterTest, V2GetValuesRequestProtoToCborEncodeSuccess) { + v2::GetValuesRequest request; + TextFormat::ParseFromString( + R"pb(partitions { + id: 0 + compression_group_id: 1 + arguments { data { string_value: "hi" } } + + })pb", + &request); + absl::StatusOr cbor_encoded_request_maybe = + V2GetValuesRequestProtoToCborEncode(request); + ASSERT_TRUE(cbor_encoded_request_maybe.ok()) + << cbor_encoded_request_maybe.status(); +} + +TEST(CborConverterTest, CborDecodeToNonBytesProtoSuccess) { + v2::GetValuesRequest expected; + TextFormat::ParseFromString(R"pb( + client_version: "version1" + metadata { + fields { + key: "foo" + value { string_value: "bar1" } + } + } + partitions { + id: 1 + compression_group_id: 1 + metadata { + fields { + key: "partition_metadata" + value { string_value: "bar2" } + } + } + arguments { + tags { + values { string_value: "tag1" } + values { string_value: "tag2" } + } + + data { string_value: "bar4" } + } + } + )pb", + &expected); + nlohmann::json json_message = R"( + { + "clientVersion": "version1", + "metadata": { + "foo": "bar1" + }, + "partitions": [ + { + "id": 1, + "compressionGroupId": 1, + "metadata": { + "partition_metadata": "bar2" + }, + "arguments": { + "tags": [ + "tag1", + "tag2" + ], + "data": "bar4" + } + } + ] +} +)"_json; + ::kv_server::v2::GetValuesRequest actual; + std::vector v = json::to_cbor(json_message); + std::string cbor_raw(v.begin(), v.end()); + const auto status = CborDecodeToNonBytesProto(cbor_raw, actual); + ASSERT_TRUE(status.ok()); + EXPECT_THAT(actual, EqualsProto(expected)); +} + +TEST(CborConverterTest, CborDecodeToNonBytesProtoGetValuesResponseFailure) { + json json_etalon = R"({ + "compressionGroups": [ + { + "compressionGroupId": 1, + "content": {"bytes":[97,98,99],"subtype":null}, + "ttlMs": 2 + } + ] + })"_json; + v2::GetValuesResponse actual; + std::vector v = json::to_cbor(json_etalon); + std::string cbor_raw(v.begin(), v.end()); + const auto status = CborDecodeToNonBytesProto(cbor_raw, actual); + ASSERT_FALSE(status.ok()) << status; +} + +TEST(CborConverterTest, CborDecodeToNonBytesProtoFailure) { + nlohmann::json json_message = R"( + { + "clientVersion": "version1", + "metadata": { + "foo": "bar1" + }, + "partitions": [ + { + "id": 1, + "compressionGroupId": 1, + "metadata": { + "partition_metadata": "bar2" + }, + "arguments": { + "tags": [ + "tag1", + "tag2" + ], + "data": "bar4" + } + } + ] +} +)"_json; + ::kv_server::v2::GetValuesRequest actual; + std::vector v = json::to_cbor(json_message); + std::string cbor_raw(v.begin(), --v.end()); + const auto status = CborDecodeToNonBytesProto(cbor_raw, actual); + ASSERT_FALSE(status.ok()); +} + +TEST(CborConverterTest, PartitionOutputsCborEncodeSuccess) { + ordered_json json_etalon = ordered_json::parse(R"( + [ + { + "id": 0, + "keyGroupOutputs": [ + { + "tags": [ + "custom", + "keys" + ], + "keyValues": { + "hello": { + "value": "world" + } + } + } + ] + }, + { + "id": 1, + "keyGroupOutputs": [ + { + "tags": [ + "custom", + "keys" + ], + "keyValues": { + "hello2": { + "value": "world2" + } + } + } + ] + } + ] +)"); + + application_pa::V2CompressionGroup compression_group; + TextFormat::ParseFromString( + R"pb(partition_outputs { + id: 0 + key_group_outputs { + tags: "custom" + tags: "keys" + key_values { + key: "hello" + value { value { string_value: "world" } } + } + } + } + partition_outputs { + id: 1 + key_group_outputs { + tags: "custom" + tags: "keys" + key_values { + key: "hello2" + value { value { string_value: "world2" } } + } + } + })pb", + &compression_group); + absl::StatusOr cbor_encoded_proto_maybe = + PartitionOutputsCborEncode( + *compression_group.mutable_partition_outputs()); + ASSERT_TRUE(cbor_encoded_proto_maybe.ok()) + << cbor_encoded_proto_maybe.status(); + EXPECT_EQ(json_etalon, ordered_json::from_cbor(*cbor_encoded_proto_maybe)); +} + +TEST(CborConverterTest, PartitionOutputsCborEncodeEmptyKeyGroupOutputsSuccess) { + ordered_json json_etalon = nlohmann::ordered_json::parse(R"( + [ + { + "id": 0, + "keyGroupOutputs": [] + } + ] +)"); + + application_pa::V2CompressionGroup compression_group; + TextFormat::ParseFromString(R"pb(partition_outputs { id: 0 })pb", + &compression_group); + absl::StatusOr cbor_encoded_proto_maybe = + PartitionOutputsCborEncode( + *compression_group.mutable_partition_outputs()); + ASSERT_TRUE(cbor_encoded_proto_maybe.ok()); + ASSERT_TRUE(json_etalon == + ordered_json::from_cbor(*cbor_encoded_proto_maybe)); +} + +TEST(CborConverterTest, PartitionOutputsCborEncodeKeyValueMapOrderSuccess) { + ordered_json json_etalon = ordered_json::parse(R"( + [ + { + "id": 0, + "keyGroupOutputs": [ + { + "tags": [ + "custom", + "keys" + ], + "keyValues": { + "a": { + "value": "first" + }, + "b": { + "value": "second" + }, + "ab": { + "value": "third" + } + } + } + ] + } + ] +)"); + + application_pa::V2CompressionGroup compression_group; + TextFormat::ParseFromString( + R"pb(partition_outputs { + id: 0 + key_group_outputs { + tags: "custom" + tags: "keys" + key_values { + key: "b" + value { value { string_value: "second" } } + } + key_values { + key: "a" + value { value { string_value: "first" } } + } + key_values { + key: "ab" + value { value { string_value: "third" } } + } + } + })pb", + &compression_group); + absl::StatusOr cbor_encoded_proto_maybe = + PartitionOutputsCborEncode( + *compression_group.mutable_partition_outputs()); + ASSERT_TRUE(cbor_encoded_proto_maybe.ok()) + << cbor_encoded_proto_maybe.status(); + EXPECT_EQ(json_etalon, ordered_json::from_cbor(*cbor_encoded_proto_maybe)); +} + +} // namespace +} // namespace kv_server diff --git a/components/data/converters/cbor_converter_utils.cc b/components/data/converters/cbor_converter_utils.cc new file mode 100644 index 00000000..e941182f --- /dev/null +++ b/components/data/converters/cbor_converter_utils.cc @@ -0,0 +1,86 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data/converters/cbor_converter_utils.h" + +#include "absl/strings/str_cat.h" + +// TODO(b/355941387): move common methods to the common repo +namespace kv_server { +cbor_item_t* cbor_build_uint(uint32_t input) { + if (input <= 255) { + return cbor_build_uint8(input); + } else if (input <= 65535) { + return cbor_build_uint16(input); + } + return cbor_build_uint32(input); +} + +absl::Status CborSerializeUInt(absl::string_view key, uint32_t value, + cbor_item_t& root) { + struct cbor_pair kv = { + .key = cbor_move(cbor_build_stringn(key.data(), key.size())), + .value = cbor_build_uint(value)}; + if (!cbor_map_add(&root, kv)) { + return absl::InternalError( + absl::StrCat("Failed to serialize ", key, " to CBOR")); + } + return absl::OkStatus(); +} + +absl::Status CborSerializeString(absl::string_view key, absl::string_view value, + cbor_item_t& root) { + struct cbor_pair kv = { + .key = cbor_move(cbor_build_stringn(key.data(), key.size())), + .value = cbor_move(cbor_build_stringn(value.data(), value.size()))}; + if (!cbor_map_add(&root, kv)) { + return absl::InternalError( + absl::StrCat("Failed to serialize ", key, " to CBOR")); + } + + return absl::OkStatus(); +} + +absl::Status CborSerializeByteString(absl::string_view key, + absl::string_view value, + cbor_item_t& root) { + struct cbor_pair kv = { + .key = cbor_move(cbor_build_stringn(key.data(), key.size())), + .value = cbor_move(cbor_build_bytestring( + reinterpret_cast(value.data()), value.size()))}; + if (!cbor_map_add(&root, kv)) { + return absl::InternalError( + absl::StrCat("Failed to serialize ", key, " to CBOR")); + } + + return absl::OkStatus(); +} + +absl::StatusOr GetCborSerializedResult( + cbor_item_t& cbor_data_root) { + const size_t cbor_serialized_data_size = + cbor_serialized_size(&cbor_data_root); + if (!cbor_serialized_data_size) { + return absl::InternalError("Failed to serialize to CBOR (too large!)"); + } + std::string byte_string; + byte_string.resize(cbor_serialized_data_size); + if (cbor_serialize(&cbor_data_root, + reinterpret_cast(byte_string.data()), + cbor_serialized_data_size) == 0) { + return absl::InternalError("Failed to serialize to CBOR"); + } + return byte_string; +} +} // namespace kv_server diff --git a/components/data/converters/cbor_converter_utils.h b/components/data/converters/cbor_converter_utils.h new file mode 100644 index 00000000..978fe9be --- /dev/null +++ b/components/data/converters/cbor_converter_utils.h @@ -0,0 +1,43 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_CONVERTER_UTILS_H +#define COMPONENTS_DATA_CONVERTER_UTILS_H + +#include + +#include "absl/status/statusor.h" + +#include "cbor.h" + +namespace kv_server { +cbor_item_t* cbor_build_uint(uint32_t input); + +absl::Status CborSerializeUInt(absl::string_view key, uint32_t value, + cbor_item_t& root); + +absl::Status CborSerializeString(absl::string_view key, absl::string_view value, + cbor_item_t& root); + +absl::Status CborSerializeByteString(absl::string_view key, + absl::string_view value, + cbor_item_t& root); + +absl::StatusOr GetCborSerializedResult( + cbor_item_t& cbor_data_root); + +} // namespace kv_server +#endif // COMPONENTS_DATA_CONVERTER_UTILS_H diff --git a/components/data/converters/scoped_cbor.h b/components/data/converters/scoped_cbor.h new file mode 100644 index 00000000..2d76508f --- /dev/null +++ b/components/data/converters/scoped_cbor.h @@ -0,0 +1,64 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SERVICES_COMMON_UTIL_SCOPED_CBOR_H_ +#define SERVICES_COMMON_UTIL_SCOPED_CBOR_H_ + +#include "cbor.h" + +// TODO(b/355941387): move to the common repo +namespace kv_server { + +// Wrapper class for managing the ref-counted CBOR objects. +class ScopedCbor { + public: + ScopedCbor() {} + + // Expects a pointer to an object that is ref-incremented at the + // time of creation (e.g. an object created with cbor_build_string). + explicit ScopedCbor(cbor_item_t* val) { ptr_ = val; } + virtual ~ScopedCbor() { + if (ptr_) { + cbor_decref(&ptr_); + } + } + ScopedCbor& operator=(cbor_item_t* val) { + if (ptr_) { + cbor_decref(&ptr_); + } + ptr_ = val; + return *this; + } + + cbor_item_t* operator->() const { return ptr_; } + cbor_item_t* operator*() const { return ptr_; } + cbor_item_t* get() const { return ptr_; } + cbor_item_t* release() { + cbor_item_t* tmp = ptr_; + ptr_ = nullptr; + return tmp; + } + + explicit operator bool() const { return ptr_ != nullptr; } + bool operator!() const { return ptr_ == nullptr; } + + private: + cbor_item_t* ptr_ = nullptr; +}; + +} // namespace kv_server + +#endif // SERVICES_COMMON_UTIL_SCOPED_CBOR_H_ diff --git a/components/data/realtime/delta_file_record_change_notifier_aws_test.cc b/components/data/realtime/delta_file_record_change_notifier_aws_test.cc index e2f91460..4bd036eb 100644 --- a/components/data/realtime/delta_file_record_change_notifier_aws_test.cc +++ b/components/data/realtime/delta_file_record_change_notifier_aws_test.cc @@ -56,12 +56,7 @@ constexpr std::string_view kX64EncodedMessage = class DeltaFileRecordChangeNotifierAwsTest : public ::testing::Test { protected: void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); + kv_server::InitMetricsContextMap(); mock_change_notifier_ = std::make_unique(); } std::unique_ptr mock_change_notifier_; diff --git a/components/data/realtime/realtime_notifier_aws_test.cc b/components/data/realtime/realtime_notifier_aws_test.cc index 238acdc1..f09bc0d5 100644 --- a/components/data/realtime/realtime_notifier_aws_test.cc +++ b/components/data/realtime/realtime_notifier_aws_test.cc @@ -36,14 +36,7 @@ using privacy_sandbox::server_common::GetTracer; class RealtimeNotifierAwsTest : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } std::unique_ptr change_notifier_ = std::make_unique(); std::unique_ptr mock_sleep_for_ = diff --git a/components/data/realtime/realtime_notifier_gcp_test.cc b/components/data/realtime/realtime_notifier_gcp_test.cc index fa6e5603..76e5ae33 100644 --- a/components/data/realtime/realtime_notifier_gcp_test.cc +++ b/components/data/realtime/realtime_notifier_gcp_test.cc @@ -46,14 +46,7 @@ using testing::Return; class RealtimeNotifierGcpTest : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } std::unique_ptr mock_sleep_for_ = std::make_unique(); std::shared_ptr mock_ = diff --git a/components/data/realtime/realtime_thread_pool_manager_aws_test.cc b/components/data/realtime/realtime_thread_pool_manager_aws_test.cc index 3ee12dc2..38eb0b34 100644 --- a/components/data/realtime/realtime_thread_pool_manager_aws_test.cc +++ b/components/data/realtime/realtime_thread_pool_manager_aws_test.cc @@ -34,14 +34,7 @@ using testing::Return; class RealtimeThreadPoolNotifierAwsTest : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } int32_t thread_number_ = 4; }; diff --git a/components/data/realtime/realtime_thread_pool_manager_gcp_test.cc b/components/data/realtime/realtime_thread_pool_manager_gcp_test.cc index b56e4282..bb48b085 100644 --- a/components/data/realtime/realtime_thread_pool_manager_gcp_test.cc +++ b/components/data/realtime/realtime_thread_pool_manager_gcp_test.cc @@ -39,14 +39,7 @@ using testing::Return; class RealtimeThreadPoolNotifierGcpTest : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } int32_t thread_number_ = 4; std::unique_ptr mock_sleep_for_ = std::make_unique(); diff --git a/components/data_server/cache/BUILD.bazel b/components/data_server/cache/BUILD.bazel index 7def83d5..c6812308 100644 --- a/components/data_server/cache/BUILD.bazel +++ b/components/data_server/cache/BUILD.bazel @@ -20,9 +20,9 @@ package(default_visibility = [ ]) cc_library( - name = "uint32_value_set", - srcs = ["uint32_value_set.cc"], - hdrs = ["uint32_value_set.h"], + name = "uint_value_set", + srcs = ["uint_value_set.cc"], + hdrs = ["uint_value_set.h"], deps = [ "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -32,13 +32,13 @@ cc_library( ) cc_test( - name = "uint32_value_set_test", + name = "uint_value_set_test", size = "small", srcs = [ - "uint32_value_set_test.cc", + "uint_value_set_test.cc", ], deps = [ - ":uint32_value_set", + ":uint_value_set", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@roaring_bitmap//:c_roaring", @@ -54,7 +54,7 @@ cc_library( "get_key_value_set_result.h", ], deps = [ - ":uint32_value_set", + ":uint_value_set", "//components/container:thread_safe_hash_map", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -74,6 +74,38 @@ cc_library( ], ) +cc_library( + name = "uint_value_set_cache", + hdrs = [ + "uint_value_set_cache.h", + ], + deps = [ + ":get_key_value_set_result_impl", + ":uint_value_set", + "//components/container:thread_safe_hash_map", + "//components/util:request_context", + "//public:base_types_cc_proto", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/log", + "@google_privacysandbox_servers_common//src/telemetry:telemetry_provider", + ], +) + +cc_test( + name = "uint_value_set_cache_test", + size = "small", + srcs = [ + "uint_value_set_cache_test.cc", + ], + deps = [ + ":uint_value_set_cache", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@roaring_bitmap//:c_roaring", + ], +) + cc_library( name = "key_value_cache", srcs = [ @@ -85,7 +117,8 @@ cc_library( deps = [ ":cache", ":get_key_value_set_result_impl", - ":uint32_value_set", + ":uint_value_set", + ":uint_value_set_cache", "//components/container:thread_safe_hash_map", "//public:base_types_cc_proto", "@com_google_absl//absl/base", @@ -120,7 +153,7 @@ cc_library( hdrs = ["mocks.h"], deps = [ ":cache", - ":uint32_value_set", + ":uint_value_set", "//components/container:thread_safe_hash_map", "@com_google_googletest//:gtest", ], diff --git a/components/data_server/cache/cache.h b/components/data_server/cache/cache.h index 06081599..31f510e5 100644 --- a/components/data_server/cache/cache.h +++ b/components/data_server/cache/cache.h @@ -49,6 +49,10 @@ class Cache { const RequestContext& request_context, const absl::flat_hash_set& key_set) const = 0; + virtual std::unique_ptr GetUInt64ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const = 0; + // Inserts or updates the key with the new value for a given prefix virtual void UpdateKeyValue( privacy_sandbox::server_common::log::PSLogContext& log_context, @@ -69,6 +73,11 @@ class Cache { std::string_view key, absl::Span value_set, int64_t logical_commit_time, std::string_view prefix = "") = 0; + virtual void UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") = 0; + // Deletes a particular (key, value) pair for a given prefix. virtual void DeleteKey( privacy_sandbox::server_common::log::PSLogContext& log_context, @@ -91,6 +100,11 @@ class Cache { std::string_view key, absl::Span value_set, int64_t logical_commit_time, std::string_view prefix = "") = 0; + virtual void DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") = 0; + // Removes the values that were deleted before the specified // logical_commit_time for a given prefix. virtual void RemoveDeletedKeys( diff --git a/components/data_server/cache/get_key_value_set_result.h b/components/data_server/cache/get_key_value_set_result.h index f19aed41..b98246fe 100644 --- a/components/data_server/cache/get_key_value_set_result.h +++ b/components/data_server/cache/get_key_value_set_result.h @@ -22,7 +22,7 @@ #include "absl/container/flat_hash_set.h" #include "components/container/thread_safe_hash_map.h" -#include "components/data_server/cache/uint32_value_set.h" +#include "components/data_server/cache/uint_value_set.h" namespace kv_server { // Class that holds the data retrieved from cache lookup and read locks for @@ -36,6 +36,8 @@ class GetKeyValueSetResult { std::string_view key) const = 0; virtual const UInt32ValueSet* GetUInt32ValueSet( std::string_view key) const = 0; + virtual const UInt64ValueSet* GetUInt64ValueSet( + std::string_view key) const = 0; private: // Adds key, value_set to the result data map, mantains the lock on `key` @@ -43,14 +45,20 @@ class GetKeyValueSetResult { virtual void AddKeyValueSet( std::string_view key, absl::flat_hash_set value_set, std::unique_ptr key_lock) = 0; - virtual void AddUInt32ValueSet( + virtual void AddUIntValueSet( std::string_view key, ThreadSafeHashMap::ConstLockedNodePtr value_set_node) = 0; + virtual void AddUIntValueSet( + std::string_view key, + ThreadSafeHashMap::ConstLockedNodePtr + value_set_node) = 0; static std::unique_ptr Create(); friend class KeyValueCache; + template + friend class UIntValueSetCache; }; } // namespace kv_server diff --git a/components/data_server/cache/get_key_value_set_result_impl.cc b/components/data_server/cache/get_key_value_set_result_impl.cc index 8e3b44ce..ac5385d3 100644 --- a/components/data_server/cache/get_key_value_set_result_impl.cc +++ b/components/data_server/cache/get_key_value_set_result_impl.cc @@ -27,6 +27,8 @@ namespace { using UInt32ValueSetNodePtr = ThreadSafeHashMap::ConstLockedNodePtr; +using UInt64ValueSetNodePtr = + ThreadSafeHashMap::ConstLockedNodePtr; // Class that holds the data retrieved from cache lookup and read locks for // the lookup keys @@ -51,8 +53,16 @@ class GetKeyValueSetResultImpl : public GetKeyValueSetResult { } const UInt32ValueSet* GetUInt32ValueSet(std::string_view key) const override { - if (auto iter = uin32t_sets_map_.find(key); - iter != uin32t_sets_map_.end() && iter->second.is_present()) { + if (auto iter = uint32_sets_map_.find(key); + iter != uint32_sets_map_.end() && iter->second.is_present()) { + return iter->second.value(); + } + return nullptr; + } + + const UInt64ValueSet* GetUInt64ValueSet(std::string_view key) const override { + if (auto iter = uint64_sets_map_.find(key); + iter != uint64_sets_map_.end() && iter->second.is_present()) { return iter->second.value(); } return nullptr; @@ -68,15 +78,21 @@ class GetKeyValueSetResultImpl : public GetKeyValueSetResult { data_map_.emplace(key, std::move(value_set)); } - void AddUInt32ValueSet(std::string_view key, - UInt32ValueSetNodePtr value_set_ptr) override { - uin32t_sets_map_.emplace(key, std::move(value_set_ptr)); + void AddUIntValueSet(std::string_view key, + UInt32ValueSetNodePtr value_set_ptr) override { + uint32_sets_map_.emplace(key, std::move(value_set_ptr)); + } + + void AddUIntValueSet(std::string_view key, + UInt64ValueSetNodePtr value_set_ptr) override { + uint64_sets_map_.emplace(key, std::move(value_set_ptr)); } std::vector> read_locks_; absl::flat_hash_map> data_map_; - absl::flat_hash_map uin32t_sets_map_; + absl::flat_hash_map uint32_sets_map_; + absl::flat_hash_map uint64_sets_map_; }; } // namespace diff --git a/components/data_server/cache/key_value_cache.cc b/components/data_server/cache/key_value_cache.cc index fb2cf22a..5072c2f5 100644 --- a/components/data_server/cache/key_value_cache.cc +++ b/components/data_server/cache/key_value_cache.cc @@ -95,9 +95,31 @@ std::unique_ptr KeyValueCache::GetUInt32ValueSet( ScopeLatencyMetricsRecorder latency_recorder(request_context.GetInternalLookupMetricsContext()); - auto result = GetKeyValueSetResult::Create(); + auto result = uint32_sets_cache_.GetValueSet(request_context, key_set); + for (const auto& key : key_set) { + if (result->GetUInt32ValueSet(key) != nullptr) { + LogCacheAccessMetrics(request_context, kKeyValueSetCacheHit); + } else { + LogCacheAccessMetrics(request_context, kKeyValueSetCacheMiss); + } + } + return result; +} + +// Looks up and returns int64 value set result for the given key set. +std::unique_ptr KeyValueCache::GetUInt64ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const { + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetInternalLookupMetricsContext()); + auto result = uint64_sets_cache_.GetValueSet(request_context, key_set); for (const auto& key : key_set) { - result->AddUInt32ValueSet(key, uint32_sets_map_.CGet(key)); + if (result->GetUInt32ValueSet(key) != nullptr) { + LogCacheAccessMetrics(request_context, kKeyValueSetCacheHit); + } else { + LogCacheAccessMetrics(request_context, kKeyValueSetCacheMiss); + } } return result; } @@ -231,18 +253,23 @@ void KeyValueCache::UpdateKeyValueSet( ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); - if (auto prefix_max_time_node = - uint32_sets_max_cleanup_commit_time_map_.CGet(prefix); - prefix_max_time_node.is_present() && - logical_commit_time <= *prefix_max_time_node.value()) { - return; // Skip old updates. - } - auto cached_set_node = uint32_sets_map_.Get(key); - if (!cached_set_node.is_present()) { - auto result = uint32_sets_map_.PutIfAbsent(key, UInt32ValueSet()); - cached_set_node = std::move(result.first); - } - cached_set_node.value()->Add(value_set, logical_commit_time); + PS_VLOG(9, log_context) << "Received update for [" << key << "] at " + << logical_commit_time; + uint32_sets_cache_.UpdateSetValues(log_context, key, value_set, + logical_commit_time, prefix); +} + +void KeyValueCache::UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix) { + ScopeLatencyMetricsRecorder + latency_recorder(KVServerContextMap()->SafeMetric()); + PS_VLOG(9, log_context) << "Received update for [" << key << "] at " + << logical_commit_time; + uint64_sets_cache_.UpdateSetValues(log_context, key, value_set, + logical_commit_time, prefix); } void KeyValueCache::DeleteKey( @@ -346,32 +373,23 @@ void KeyValueCache::DeleteValuesInSet( ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); - if (auto prefix_max_time_node = - uint32_sets_max_cleanup_commit_time_map_.CGet(prefix); - prefix_max_time_node.is_present() && - logical_commit_time <= *prefix_max_time_node.value()) { - return; // Skip old deletes. - } - { - auto cached_set_node = uint32_sets_map_.Get(key); - if (!cached_set_node.is_present()) { - auto result = uint32_sets_map_.PutIfAbsent(key, UInt32ValueSet()); - cached_set_node = std::move(result.first); - } - cached_set_node.value()->Remove(value_set, logical_commit_time); - } - { - // Mark set as having deleted elements. - auto prefix_deleted_sets_node = deleted_uint32_sets_map_.Get(prefix); - if (!prefix_deleted_sets_node.is_present()) { - auto result = deleted_uint32_sets_map_.PutIfAbsent( - prefix, absl::btree_map>()); - prefix_deleted_sets_node = std::move(result.first); - } - auto* commit_time_sets = - &(*prefix_deleted_sets_node.value())[logical_commit_time]; - commit_time_sets->insert(std::string(key)); - } + PS_VLOG(9, log_context) << "Received delete for [" << key << "] at " + << logical_commit_time; + uint32_sets_cache_.DeleteSetValues(log_context, key, value_set, + logical_commit_time, prefix); +} + +void KeyValueCache::DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix) { + ScopeLatencyMetricsRecorder + latency_recorder(KVServerContextMap()->SafeMetric()); + PS_VLOG(9, log_context) << "Received delete for [" << key << "] at " + << logical_commit_time; + uint64_sets_cache_.DeleteSetValues(log_context, key, value_set, + logical_commit_time, prefix); } void KeyValueCache::RemoveDeletedKeys( @@ -382,7 +400,7 @@ void KeyValueCache::RemoveDeletedKeys( latency_recorder(KVServerContextMap()->SafeMetric()); CleanUpKeyValueMap(log_context, logical_commit_time, prefix); CleanUpKeyValueSetMap(log_context, logical_commit_time, prefix); - CleanUpUInt32SetMap(log_context, logical_commit_time, prefix); + CleanUpUIntSetMaps(log_context, logical_commit_time, prefix); } void KeyValueCache::CleanUpKeyValueMap( @@ -473,47 +491,17 @@ void KeyValueCache::CleanUpKeyValueSetMap( } } -void KeyValueCache::CleanUpUInt32SetMap( +void KeyValueCache::CleanUpUIntSetMaps( privacy_sandbox::server_common::log::PSLogContext& log_context, int64_t logical_commit_time, std::string_view prefix) { ScopeLatencyMetricsRecorder + kCleanUpUIntSetMapLatency> latency_recorder(KVServerContextMap()->SafeMetric()); - { - if (auto max_cleanup_time_node = - uint32_sets_max_cleanup_commit_time_map_.PutIfAbsent( - prefix, logical_commit_time); - *max_cleanup_time_node.first.value() < logical_commit_time) { - *max_cleanup_time_node.first.value() = logical_commit_time; - } - } - absl::flat_hash_set cleanup_sets; - { - auto prefix_deleted_sets_node = deleted_uint32_sets_map_.Get(prefix); - if (!prefix_deleted_sets_node.is_present()) { - return; // nothing to cleanup for this prefix. - } - absl::flat_hash_set cleanup_commit_times; - for (const auto& [commit_time, deleted_sets] : - *prefix_deleted_sets_node.value()) { - if (commit_time > logical_commit_time) { - break; - } - cleanup_commit_times.insert(commit_time); - cleanup_sets.insert(deleted_sets.begin(), deleted_sets.end()); - } - for (auto commit_time : cleanup_commit_times) { - prefix_deleted_sets_node.value()->erase( - prefix_deleted_sets_node.value()->find(commit_time)); - } - } - { - for (const auto& set : cleanup_sets) { - if (auto set_node = uint32_sets_map_.Get(set); set_node.is_present()) { - set_node.value()->Cleanup(logical_commit_time); - } - } - } + PS_VLOG(9, log_context) + << "Cleaning up uint set maps with a new cutoff timestamp: " + << logical_commit_time; + uint32_sets_cache_.CleanUpValueSets(log_context, logical_commit_time); + uint64_sets_cache_.CleanUpValueSets(log_context, logical_commit_time); } void KeyValueCache::LogCacheAccessMetrics( diff --git a/components/data_server/cache/key_value_cache.h b/components/data_server/cache/key_value_cache.h index 61e08c91..9e05638c 100644 --- a/components/data_server/cache/key_value_cache.h +++ b/components/data_server/cache/key_value_cache.h @@ -26,10 +26,10 @@ #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "components/container/thread_safe_hash_map.h" #include "components/data_server/cache/cache.h" #include "components/data_server/cache/get_key_value_set_result.h" -#include "components/data_server/cache/uint32_value_set.h" +#include "components/data_server/cache/uint_value_set.h" +#include "components/data_server/cache/uint_value_set_cache.h" namespace kv_server { // In-memory datastore. @@ -51,6 +51,10 @@ class KeyValueCache : public Cache { const RequestContext& request_context, const absl::flat_hash_set& key_set) const override; + std::unique_ptr GetUInt64ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const override; + // Inserts or updates the key with the new value for a given prefix void UpdateKeyValue( privacy_sandbox::server_common::log::PSLogContext& log_context, @@ -71,6 +75,11 @@ class KeyValueCache : public Cache { std::string_view key, absl::Span value_set, int64_t logical_commit_time, std::string_view prefix = "") override; + void UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") override; + // Deletes a particular (key, value) pair for a given prefix. void DeleteKey(privacy_sandbox::server_common::log::PSLogContext& log_context, std::string_view key, int64_t logical_commit_time, @@ -92,6 +101,11 @@ class KeyValueCache : public Cache { std::string_view key, absl::Span value_set, int64_t logical_commit_time, std::string_view prefix = "") override; + void DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") override; + // Removes the values that were deleted before the specified // logical_commit_time for a given prefix. // TODO: b/267182790 -- Cache cleanup should be done periodically from a @@ -127,6 +141,25 @@ class KeyValueCache : public Cache { : last_logical_commit_time(logical_commit_time), is_deleted(deleted) {} }; + // Removes deleted keys from key-value map for a given prefix + void CleanUpKeyValueMap( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix); + + // Removes deleted key-values from key-value_set map for a given prefix + void CleanUpKeyValueSetMap( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix); + + void CleanUpUIntSetMaps( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix); + + // Logs cache access metrics for cache hit or miss counts. The cache access + // event name is defined in server_definition.h file + void LogCacheAccessMetrics(const RequestContext& request_context, + std::string_view cache_access_event) const; + // mutex for key value map; mutable absl::Mutex mutex_; // mutex for key value set map; @@ -178,32 +211,8 @@ class KeyValueCache : public Cache { absl::flat_hash_map>>> deleted_set_nodes_map_ ABSL_GUARDED_BY(set_map_mutex_); - // Maps set key to its int32_t value set. - ThreadSafeHashMap uint32_sets_map_; - ThreadSafeHashMap - uint32_sets_max_cleanup_commit_time_map_; - ThreadSafeHashMap>> - deleted_uint32_sets_map_; - - // Removes deleted keys from key-value map for a given prefix - void CleanUpKeyValueMap( - privacy_sandbox::server_common::log::PSLogContext& log_context, - int64_t logical_commit_time, std::string_view prefix); - - // Removes deleted key-values from key-value_set map for a given prefix - void CleanUpKeyValueSetMap( - privacy_sandbox::server_common::log::PSLogContext& log_context, - int64_t logical_commit_time, std::string_view prefix); - - void CleanUpUInt32SetMap( - privacy_sandbox::server_common::log::PSLogContext& log_context, - int64_t logical_commit_time, std::string_view prefix); - - // Logs cache access metrics for cache hit or miss counts. The cache access - // event name is defined in server_definition.h file - void LogCacheAccessMetrics(const RequestContext& request_context, - std::string_view cache_access_event) const; + UIntValueSetCache uint32_sets_cache_; + UIntValueSetCache uint64_sets_cache_; friend class KeyValueCacheTestPeer; }; diff --git a/components/data_server/cache/key_value_cache_test.cc b/components/data_server/cache/key_value_cache_test.cc index be8620f2..83a264a9 100644 --- a/components/data_server/cache/key_value_cache_test.cc +++ b/components/data_server/cache/key_value_cache_test.cc @@ -51,16 +51,6 @@ class KeyValueCacheTestPeer { return c.map_; } - static const auto& ReadUint32Nodes(KeyValueCache& cache, - std::string_view prefix = "") { - return cache.uint32_sets_map_; - } - - static const auto& ReadDeletedUint32Nodes(KeyValueCache& cache, - std::string_view prefix = "") { - return cache.deleted_uint32_sets_map_; - } - static int GetDeletedSetNodesMapSize(const KeyValueCache& c, std::string prefix = "") { absl::MutexLock lock(&c.set_map_mutex_); @@ -1423,30 +1413,6 @@ TEST_F(CacheTest, VerifyCleaningUpUInt32Sets) { EXPECT_THAT(set->GetRemovedValues(), UnorderedElementsAreArray(delete_values)); } - const auto& deleted_nodes = - KeyValueCacheTestPeer::ReadDeletedUint32Nodes(*cache); - { - auto prefix_nodes = deleted_nodes.CGet(""); - ASSERT_TRUE(prefix_nodes.is_present()); - auto iter = prefix_nodes.value()->find(2); - ASSERT_NE(iter, prefix_nodes.value()->end()); - EXPECT_THAT(iter->first, 2); - EXPECT_TRUE(iter->second.contains("set1")); - } - { - cache->RemoveDeletedKeys(safe_path_log_context_, 3); - auto prefix_nodes = deleted_nodes.CGet(""); - ASSERT_TRUE(prefix_nodes.is_present()); - auto iter = prefix_nodes.value()->find(2); - ASSERT_EQ(iter, prefix_nodes.value()->end()); - } - { - auto result = cache->GetUInt32ValueSet(request_context, keys); - auto* set = result->GetUInt32ValueSet("set1"); - ASSERT_TRUE(set != nullptr); - EXPECT_THAT(set->GetValues(), UnorderedElementsAre(3, 4, 5)); - EXPECT_TRUE(set->GetRemovedValues().empty()); - } } } // namespace diff --git a/components/data_server/cache/mocks.h b/components/data_server/cache/mocks.h index 5661b6a7..17d892d5 100644 --- a/components/data_server/cache/mocks.h +++ b/components/data_server/cache/mocks.h @@ -20,7 +20,7 @@ #include "components/container/thread_safe_hash_map.h" #include "components/data_server/cache/cache.h" -#include "components/data_server/cache/uint32_value_set.h" +#include "components/data_server/cache/uint_value_set.h" #include "gmock/gmock.h" namespace kv_server { @@ -44,6 +44,10 @@ class MockCache : public Cache { (const RequestContext&, const absl::flat_hash_set&), (const, override)); + MOCK_METHOD((std::unique_ptr), GetUInt64ValueSet, + (const RequestContext&, + const absl::flat_hash_set&), + (const, override)); MOCK_METHOD(void, UpdateKeyValue, (privacy_sandbox::server_common::log::PSLogContext&, std::string_view, std::string_view, int64_t, std::string_view), @@ -58,6 +62,11 @@ class MockCache : public Cache { std::string_view, absl::Span, int64_t, std::string_view), (override)); + MOCK_METHOD(void, UpdateKeyValueSet, + (privacy_sandbox::server_common::log::PSLogContext&, + std::string_view, absl::Span, int64_t, + std::string_view), + (override)); MOCK_METHOD(void, DeleteValuesInSet, (privacy_sandbox::server_common::log::PSLogContext&, std::string_view, absl::Span, int64_t, @@ -68,6 +77,11 @@ class MockCache : public Cache { std::string_view, absl::Span, int64_t, std::string_view), (override)); + MOCK_METHOD(void, DeleteValuesInSet, + (privacy_sandbox::server_common::log::PSLogContext&, + std::string_view, absl::Span, int64_t, + std::string_view), + (override)); MOCK_METHOD(void, DeleteKey, (privacy_sandbox::server_common::log::PSLogContext&, std::string_view, int64_t, std::string_view), @@ -87,12 +101,19 @@ class MockGetKeyValueSetResult : public GetKeyValueSetResult { std::unique_ptr), (override)); MOCK_METHOD((const UInt32ValueSet*), GetUInt32ValueSet, (std::string_view), - (const override)); + (const, override)); MOCK_METHOD( - void, AddUInt32ValueSet, + void, AddUIntValueSet, (std::string_view, (ThreadSafeHashMap::ConstLockedNodePtr)), (override)); + MOCK_METHOD((const UInt64ValueSet*), GetUInt64ValueSet, (std::string_view), + (const, override)); + MOCK_METHOD( + void, AddUIntValueSet, + (std::string_view, + (ThreadSafeHashMap::ConstLockedNodePtr)), + (override)); }; } // namespace kv_server diff --git a/components/data_server/cache/noop_key_value_cache.h b/components/data_server/cache/noop_key_value_cache.h index aa79f7e1..ad26e313 100644 --- a/components/data_server/cache/noop_key_value_cache.h +++ b/components/data_server/cache/noop_key_value_cache.h @@ -39,6 +39,11 @@ class NoOpKeyValueCache : public Cache { const absl::flat_hash_set& key_set) const override { return std::make_unique(); } + std::unique_ptr GetUInt64ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const override { + return std::make_unique(); + } void UpdateKeyValue( privacy_sandbox::server_common::log::PSLogContext& log_context, std::string_view key, std::string_view value, int64_t logical_commit_time, @@ -51,6 +56,10 @@ class NoOpKeyValueCache : public Cache { privacy_sandbox::server_common::log::PSLogContext& log_context, std::string_view key, absl::Span value_set, int64_t logical_commit_time, std::string_view prefix = "") override {} + void UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") override {} void DeleteKey(privacy_sandbox::server_common::log::PSLogContext& log_context, std::string_view key, int64_t logical_commit_time, std::string_view prefix) override {} @@ -62,6 +71,10 @@ class NoOpKeyValueCache : public Cache { privacy_sandbox::server_common::log::PSLogContext& log_context, std::string_view key, absl::Span value_set, int64_t logical_commit_time, std::string_view prefix = "") override {} + void DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") override {} void RemoveDeletedKeys( privacy_sandbox::server_common::log::PSLogContext& log_context, int64_t logical_commit_time, std::string_view prefix) override {} @@ -79,13 +92,21 @@ class NoOpKeyValueCache : public Cache { std::string_view key) const override { return nullptr; } + const UInt64ValueSet* GetUInt64ValueSet( + std::string_view key) const override { + return nullptr; + } void AddKeyValueSet( std::string_view key, absl::flat_hash_set value_set, std::unique_ptr key_lock) override {} - void AddUInt32ValueSet( + void AddUIntValueSet( std::string_view key, ThreadSafeHashMap::ConstLockedNodePtr value_set_node) override {} + void AddUIntValueSet( + std::string_view key, + ThreadSafeHashMap::ConstLockedNodePtr + value_set_node) override {} }; }; diff --git a/components/data_server/cache/uint32_value_set.cc b/components/data_server/cache/uint32_value_set.cc deleted file mode 100644 index 9f6e02a4..00000000 --- a/components/data_server/cache/uint32_value_set.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "components/data_server/cache/uint32_value_set.h" - -#include - -namespace kv_server { - -absl::flat_hash_set UInt32ValueSet::GetValues() const { - absl::flat_hash_set values; - values.reserve(values_bitset_.cardinality()); - for (const auto& [value, metadata] : values_metadata_) { - if (!metadata.is_deleted) { - values.insert(value); - } - } - return values; -} - -const roaring::Roaring& UInt32ValueSet::GetValuesBitSet() const { - return values_bitset_; -} - -absl::flat_hash_set UInt32ValueSet::GetRemovedValues() const { - absl::flat_hash_set removed_values; - for (const auto& [_, values] : deleted_values_) { - for (auto value : values) { - removed_values.insert(value); - } - } - return removed_values; -} - -void UInt32ValueSet::AddOrRemove(absl::Span values, - int64_t logical_commit_time, bool is_deleted) { - for (auto value : values) { - auto* metadata = &values_metadata_[value]; - if (metadata->logical_commit_time >= logical_commit_time) { - continue; - } - metadata->logical_commit_time = logical_commit_time; - metadata->is_deleted = is_deleted; - if (is_deleted) { - values_bitset_.remove(value); - deleted_values_[logical_commit_time].insert(value); - } else { - values_bitset_.add(value); - deleted_values_[logical_commit_time].erase(value); - } - } - values_bitset_.runOptimize(); -} - -void UInt32ValueSet::Add(absl::Span values, - int64_t logical_commit_time) { - AddOrRemove(values, logical_commit_time, /*is_deleted=*/false); -} - -void UInt32ValueSet::Remove(absl::Span values, - int64_t logical_commit_time) { - AddOrRemove(values, logical_commit_time, /*is_deleted=*/true); -} - -void UInt32ValueSet::Cleanup(int64_t cutoff_logical_commit_time) { - for (const auto& [logical_commit_time, values] : deleted_values_) { - if (logical_commit_time > cutoff_logical_commit_time) { - break; - } - for (auto value : values) { - values_metadata_.erase(value); - } - } - deleted_values_.erase( - deleted_values_.begin(), - deleted_values_.upper_bound(cutoff_logical_commit_time)); -} - -absl::flat_hash_set BitSetToUint32Set( - const roaring::Roaring& bitset) { - auto num_values = bitset.cardinality(); - auto data = std::make_unique(num_values); - bitset.toUint32Array(data.get()); - return absl::flat_hash_set(data.get(), data.get() + num_values); -} - -} // namespace kv_server diff --git a/components/data_server/cache/uint32_value_set.h b/components/data_server/cache/uint32_value_set.h deleted file mode 100644 index 49b1844a..00000000 --- a/components/data_server/cache/uint32_value_set.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMPONENTS_DATA_SERVER_CACHE_UINT32_VALUE_SET_H_ -#define COMPONENTS_DATA_SERVER_CACHE_UINT32_VALUE_SET_H_ - -#include "absl/container/btree_map.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" - -#include "roaring.hh" - -namespace kv_server { - -// Stores a set of `uint32_t` values associated with a `logical_commit_time`. -// The `logical_commit_time` is used to support out of order set mutations, -// i.e., calling `Remove({1, 2, 3}, 5)` and then `Add({1, 2, 3}, 3)` will result -// in an empty set. -// -// The values in the set are also projected to `roaring::Roaring` bitset which -// can be used for efficient set operations such as union, intersection, .e.t.c. -class UInt32ValueSet { - public: - // Returns values not marked as removed from the set. - absl::flat_hash_set GetValues() const; - // Returns values not marked as removed from the set as a bitset. - const roaring::Roaring& GetValuesBitSet() const; - // Returns values marked as removed from the set. - absl::flat_hash_set GetRemovedValues() const; - - // Adds values associated with `logical_commit_time` to the set. If a value - // with the same or greater `logical_commit_time` already exists in the set, - // then this is a noop. - void Add(absl::Span values, int64_t logical_commit_time); - // Marks values associated with `logical_commit_time` as removed from the set. - // If a value with the same or greater `logical_commit_time` already exists in - // the set, then this is a noop. - void Remove(absl::Span values, int64_t logical_commit_time); - // Cleans up space occupied by values (including value metadata) matching the - // condition `logical_commit_time` <= `cutoff_logical_commit_time` and are - // marked as removed. - void Cleanup(int64_t cutoff_logical_commit_time); - - private: - struct ValueMetadata { - int64_t logical_commit_time; - bool is_deleted; - }; - - void AddOrRemove(absl::Span values, int64_t logical_commit_time, - bool is_deleted); - - roaring::Roaring values_bitset_; - absl::flat_hash_map values_metadata_; - absl::btree_map> deleted_values_; -}; - -absl::flat_hash_set BitSetToUint32Set(const roaring::Roaring& bitset); - -} // namespace kv_server - -#endif // COMPONENTS_DATA_SERVER_CACHE_UINT32_VALUE_SET_H_ diff --git a/components/data_server/cache/uint_value_set.cc b/components/data_server/cache/uint_value_set.cc new file mode 100644 index 00000000..dd5b5f6d --- /dev/null +++ b/components/data_server/cache/uint_value_set.cc @@ -0,0 +1,31 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "components/data_server/cache/uint_value_set.h" + +namespace kv_server { + +absl::flat_hash_set BitSetToUint32Set( + const roaring::Roaring& bitset) { + return BitSetToUintSet(bitset); +} + +absl::flat_hash_set BitSetToUint64Set( + const roaring::Roaring64Map& bitset) { + return BitSetToUintSet(bitset); +} + +} // namespace kv_server diff --git a/components/data_server/cache/uint_value_set.h b/components/data_server/cache/uint_value_set.h new file mode 100644 index 00000000..b1f13d28 --- /dev/null +++ b/components/data_server/cache/uint_value_set.h @@ -0,0 +1,181 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_SERVER_CACHE_UINT_VALUE_SET_H_ +#define COMPONENTS_DATA_SERVER_CACHE_UINT_VALUE_SET_H_ + +#include + +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" + +#include "roaring.hh" +#include "roaring64map.hh" + +namespace kv_server { + +// Stores a set of unsigned int values associated with a `logical_commit_time`. +// The `logical_commit_time` is used to support out of order set mutations, +// i.e., calling `Remove({1, 2, 3}, 5)` and then `Add({1, 2, 3}, 3)` will result +// in an empty set. +// +// The values in the set are also projected to `roaring::Roaring` bitset which +// can be used for efficient set operations such as union, intersection, .e.t.c. +template +class UIntValueSet { + public: + using value_type = ValueType; + using bitset_type = BitsetType; + + // Returns values not marked as removed from the set. + absl::flat_hash_set GetValues() const; + // Returns values not marked as removed from the set as a bitset. + const BitsetType& GetValuesBitSet() const; + // Returns values marked as removed from the set. + absl::flat_hash_set GetRemovedValues() const; + + // Adds values associated with `logical_commit_time` to the set. If a value + // with the same or greater `logical_commit_time` already exists in the set, + // then this is a noop. + void Add(absl::Span values, int64_t logical_commit_time); + // Marks values associated with `logical_commit_time` as removed from the set. + // If a value with the same or greater `logical_commit_time` already exists in + // the set, then this is a noop. + void Remove(absl::Span values, int64_t logical_commit_time); + // Cleans up space occupied by values (including value metadata) matching the + // condition `logical_commit_time` <= `cutoff_logical_commit_time` and are + // marked as removed. + void Cleanup(int64_t cutoff_logical_commit_time); + + private: + struct ValueMetadata { + int64_t logical_commit_time; + bool is_deleted; + }; + + void AddOrRemove(absl::Span values, int64_t logical_commit_time, + bool is_deleted); + + BitsetType values_bitset_; + absl::flat_hash_map values_metadata_; + absl::btree_map> deleted_values_; +}; + +// Define specialized aliases for 32 and 64 bit unsigned int sets. +using UInt32ValueSet = UIntValueSet; +using UInt64ValueSet = UIntValueSet; + +template +absl::flat_hash_set UIntValueSet::GetValues() + const { + absl::flat_hash_set values; + values.reserve(values_bitset_.cardinality()); + for (const auto& [value, metadata] : values_metadata_) { + if (!metadata.is_deleted) { + values.insert(value); + } + } + return values; +} + +template +const BitsetType& UIntValueSet::GetValuesBitSet() const { + return values_bitset_; +} + +template +absl::flat_hash_set +UIntValueSet::GetRemovedValues() const { + absl::flat_hash_set removed_values; + for (const auto& [_, values] : deleted_values_) { + for (auto value : values) { + removed_values.insert(value); + } + } + return removed_values; +} + +template +void UIntValueSet::AddOrRemove( + absl::Span values, int64_t logical_commit_time, + bool is_deleted) { + for (auto value : values) { + auto* metadata = &values_metadata_[value]; + if (metadata->logical_commit_time >= logical_commit_time) { + continue; + } + metadata->logical_commit_time = logical_commit_time; + metadata->is_deleted = is_deleted; + if (is_deleted) { + values_bitset_.remove(value); + deleted_values_[logical_commit_time].insert(value); + } else { + values_bitset_.add(value); + deleted_values_[logical_commit_time].erase(value); + } + } + values_bitset_.runOptimize(); +} + +template +void UIntValueSet::Add(absl::Span values, + int64_t logical_commit_time) { + AddOrRemove(values, logical_commit_time, /*is_deleted=*/false); +} + +template +void UIntValueSet::Remove(absl::Span values, + int64_t logical_commit_time) { + AddOrRemove(values, logical_commit_time, /*is_deleted=*/true); +} + +template +void UIntValueSet::Cleanup( + int64_t cutoff_logical_commit_time) { + for (const auto& [logical_commit_time, values] : deleted_values_) { + if (logical_commit_time > cutoff_logical_commit_time) { + break; + } + for (auto value : values) { + values_metadata_.erase(value); + } + } + deleted_values_.erase( + deleted_values_.begin(), + deleted_values_.upper_bound(cutoff_logical_commit_time)); +} + +template +absl::flat_hash_set BitSetToUintSet(const BitsetType& bitset) { + auto num_values = bitset.cardinality(); + auto data = std::make_unique(num_values); + if constexpr (std::is_same_v) { + bitset.toUint32Array(data.get()); + } + if constexpr (std::is_same_v) { + bitset.toUint64Array(data.get()); + } + return absl::flat_hash_set(data.get(), data.get() + num_values); +} + +absl::flat_hash_set BitSetToUint32Set(const roaring::Roaring& bitset); +absl::flat_hash_set BitSetToUint64Set( + const roaring::Roaring64Map& bitset); + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_CACHE_UINT_VALUE_SET_H_ diff --git a/components/data_server/cache/uint_value_set_cache.h b/components/data_server/cache/uint_value_set_cache.h new file mode 100644 index 00000000..4902bfda --- /dev/null +++ b/components/data_server/cache/uint_value_set_cache.h @@ -0,0 +1,209 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_SERVER_CACHE_UINT_VALUE_SET_CACHE_H_ +#define COMPONENTS_DATA_SERVER_CACHE_UINT_VALUE_SET_CACHE_H_ + +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "components/container/thread_safe_hash_map.h" +#include "components/data_server/cache/get_key_value_set_result.h" +#include "components/util/request_context.h" +#include "src/logger/request_context_logger.h" + +namespace kv_server { + +template +class UIntValueSetCache { + public: + // Returns "uint" value set result for given set keys. + std::unique_ptr GetValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const; + + // Inserts or updates set values for a given key and prefix. If a value + // exists, updates its timestamp to the latest logical commit time. + void UpdateSetValues( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = ""); + + // Deletes set values for a given key and prefix. After the deletion, + // the values still exist and is marked "deleted", in case there are + // late-arriving updates to this value. + void DeleteSetValues( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = ""); + + // Removes the set values that were deleted before the specified + // logical_commit_time for a given prefix, i.e., actually reclaims the space + // used by deleted values. + void CleanUpValueSets( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix = ""); + + private: + // Maps set key to unsigned int value set per prefix. + ThreadSafeHashMap sets_map_; + // Maps prefix to maximum clean up commit time. Set updates for this prefix + // before maximum clean up commit time are ignored. + ThreadSafeHashMap sets_max_cleanup_commit_time_map_; + // Maps prefix to list of set keys with deleted elements grouped by deletion + // timestamp. + ThreadSafeHashMap>> + deleted_sets_map_; + + // Allow a unit test class to access private members to verify correct + // deletion and clean up. + friend class UIntValueSetCacheTest; +}; + +template +std::unique_ptr UIntValueSetCache::GetValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const { + auto result = GetKeyValueSetResult::Create(); + for (auto key : key_set) { + PS_VLOG(8, request_context.GetPSLogContext()) << "Getting key: " << key; + result->AddUIntValueSet(key, sets_map_.CGet(key)); + } + return result; +} + +template +void UIntValueSetCache::UpdateSetValues( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix) { + if (value_set.empty()) { + PS_VLOG(8, log_context) + << "Skipping the update as it has no value in the input set."; + return; + } + if (auto prefix_max_time_node = + sets_max_cleanup_commit_time_map_.CGet(prefix); + prefix_max_time_node.is_present() && + logical_commit_time <= *prefix_max_time_node.value()) { + PS_VLOG(8, log_context) + << "Skipping the update as its logical_commit_time: " + << logical_commit_time << " is older than the current cutoff time:" + << *prefix_max_time_node.value(); + return; // Skip old updates. + } + auto cached_set_node = sets_map_.Get(key); + if (!cached_set_node.is_present()) { + auto result = sets_map_.PutIfAbsent(key, SetType()); + if (result.second) { + PS_VLOG(8, log_context) << "Added new key: [" << key << "] is a new key."; + } + cached_set_node = std::move(result.first); + } + cached_set_node.value()->Add(value_set, logical_commit_time); +} + +template +void UIntValueSetCache::DeleteSetValues( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix) { + if (value_set.empty()) { + PS_VLOG(8, log_context) + << "Skipping the delete as it has no value in the input set."; + return; + } + if (auto prefix_max_time_node = + sets_max_cleanup_commit_time_map_.CGet(prefix); + prefix_max_time_node.is_present() && + logical_commit_time <= *prefix_max_time_node.value()) { + PS_VLOG(1, log_context) + << "Skipping the delete as its logical_commit_time: " + << logical_commit_time << " is older than the current cutoff time:" + << *prefix_max_time_node.value(); + return; // Skip old deletes. + } + { + auto cached_set_node = sets_map_.Get(key); + if (!cached_set_node.is_present()) { + auto result = sets_map_.PutIfAbsent(key, SetType()); + cached_set_node = std::move(result.first); + } + cached_set_node.value()->Remove(value_set, logical_commit_time); + } + { + // Mark set as having deleted elements. + auto prefix_deleted_sets_node = deleted_sets_map_.Get(prefix); + if (!prefix_deleted_sets_node.is_present()) { + auto result = deleted_sets_map_.PutIfAbsent( + prefix, absl::btree_map>()); + prefix_deleted_sets_node = std::move(result.first); + } + auto* commit_time_sets = + &(*prefix_deleted_sets_node.value())[logical_commit_time]; + commit_time_sets->insert(std::string(key)); + } +} + +template +void UIntValueSetCache::CleanUpValueSets( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix) { + { + if (auto max_cleanup_time_node = + sets_max_cleanup_commit_time_map_.PutIfAbsent(prefix, + logical_commit_time); + *max_cleanup_time_node.first.value() < logical_commit_time) { + *max_cleanup_time_node.first.value() = logical_commit_time; + } else if (logical_commit_time < *max_cleanup_time_node.first.value()) { + return; + } + } + absl::flat_hash_set cleanup_sets; + { + auto prefix_deleted_sets_node = deleted_sets_map_.Get(prefix); + if (!prefix_deleted_sets_node.is_present()) { + return; // nothing to cleanup for this prefix. + } + absl::flat_hash_set cleanup_commit_times; + for (const auto& [commit_time, deleted_sets] : + *prefix_deleted_sets_node.value()) { + if (commit_time > logical_commit_time) { + break; + } + cleanup_commit_times.insert(commit_time); + cleanup_sets.insert(deleted_sets.begin(), deleted_sets.end()); + } + for (auto commit_time : cleanup_commit_times) { + prefix_deleted_sets_node.value()->erase( + prefix_deleted_sets_node.value()->find(commit_time)); + } + } + { + for (const auto& set : cleanup_sets) { + if (auto set_node = sets_map_.Get(set); set_node.is_present()) { + set_node.value()->Cleanup(logical_commit_time); + } + } + } +} + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_CACHE_UINT_VALUE_SET_CACHE_H_ diff --git a/components/data_server/cache/uint_value_set_cache_test.cc b/components/data_server/cache/uint_value_set_cache_test.cc new file mode 100644 index 00000000..c0ace4b4 --- /dev/null +++ b/components/data_server/cache/uint_value_set_cache_test.cc @@ -0,0 +1,210 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "components/data_server/cache/uint_value_set_cache.h" + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { + +class SafePathTestLogContext + : public privacy_sandbox::server_common::log::SafePathContext { + public: + SafePathTestLogContext() = default; +}; + +class UIntValueSetCacheTest : public ::testing::Test { + public: + UIntValueSetCacheTest() { + InitMetricsContextMap(); + request_context_ = std::make_shared(); + } + + template + static const auto& ReadDeletedNodes(UIntValueSetCache& cache, + std::string_view prefix = "") { + return cache.deleted_sets_map_; + } + + std::shared_ptr request_context_; + SafePathTestLogContext safe_path_log_context_; +}; + +namespace { + +using testing::UnorderedElementsAre; +using testing::UnorderedElementsAreArray; + +template +const SetType* GetValueSet(std::string_view key, + GetKeyValueSetResult& key_value_result) { + if constexpr (std::is_same_v) { + return key_value_result.GetUInt32ValueSet(key); + } + if constexpr (std::is_same_v) { + return key_value_result.GetUInt64ValueSet(key); + } +} + +template +void VerifyUpdatingSets(std::shared_ptr request_context, + SafePathTestLogContext& safe_path_log_context) { + UIntValueSetCache cache; + const auto keys = absl::flat_hash_set({"set1", "set2"}); + { + auto result = cache.GetValueSet(*request_context, keys); + for (const auto& key : keys) { + const auto* set = GetValueSet(key, *result); + EXPECT_EQ(set, nullptr); + } + } + auto max = std::numeric_limits::max(); + auto set1_values = std::vector( + {max - 1, max - 2, max - 3, max - 4, max - 5}); + auto logical_commit_time = 1; + { + // For uint64 sets, if we errorneously store values in uint32 sets, then + // this test would catch the overflow. + cache.UpdateSetValues(safe_path_log_context, "set1", + absl::MakeSpan(set1_values), logical_commit_time); + auto result = cache.GetValueSet(*request_context, keys); + const auto* set = GetValueSet("set1", *result); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), UnorderedElementsAreArray(set1_values)); + } + auto set2_values = std::vector( + {max - 6, max - 7, max - 8, max - 9, max - 10}); + { + cache.UpdateSetValues(safe_path_log_context, "set1", + absl::MakeSpan(set2_values), logical_commit_time); + auto result = cache.GetValueSet(*request_context, keys); + const auto* set = GetValueSet("set1", *result); + set1_values.insert(set1_values.end(), set2_values.begin(), + set2_values.end()); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), UnorderedElementsAreArray(set1_values)); + } +} + +template +void VerifyDeletingSets(std::shared_ptr request_context, + SafePathTestLogContext& safe_path_log_context) { + UIntValueSetCache cache; + const auto keys = absl::flat_hash_set({"set1", "set2"}); + const auto max = std::numeric_limits::max(); + auto delete_values = std::vector( + {max - 1, max - 2, max - 6, max - 7}); + { + auto set1_values = std::vector( + {max - 1, max - 2, max - 3, max - 4, max - 5}); + cache.UpdateSetValues(safe_path_log_context, "set1", + absl::MakeSpan(set1_values), 1); + cache.DeleteSetValues(safe_path_log_context, "set1", + absl::MakeSpan(delete_values), 2); + auto result = cache.GetValueSet(*request_context, keys); + const auto* set = GetValueSet("set1", *result); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), + UnorderedElementsAre(max - 3, max - 4, max - 5)); + } + { + auto set2_values = std::vector( + {max - 6, max - 7, max - 8, max - 9, max - 10}); + cache.UpdateSetValues(safe_path_log_context, "set2", + absl::MakeSpan(set2_values), 1); + cache.DeleteSetValues(safe_path_log_context, "set2", + absl::MakeSpan(delete_values), 2); + auto result = cache.GetValueSet(*request_context, keys); + const auto* set = GetValueSet("set2", *result); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), + UnorderedElementsAre(max - 8, max - 9, max - 10)); + } +} + +template +void VerifyCleaningUpSets(std::shared_ptr request_context, + SafePathTestLogContext& safe_path_log_context) { + UIntValueSetCache cache; + const auto keys = absl::flat_hash_set({"set1"}); + const auto max = std::numeric_limits::max(); + auto set1_values = std::vector( + {max - 1, max - 2, max - 3, max - 4, max - 5}); + auto delete_values = + std::vector({max - 1, max - 2}); + { + cache.UpdateSetValues(safe_path_log_context, "set1", + absl::MakeSpan(set1_values), 1); + cache.DeleteSetValues(safe_path_log_context, "set1", + absl::MakeSpan(delete_values), 2); + auto result = cache.GetValueSet(*request_context, keys); + const auto* set = GetValueSet("set1", *result); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), + UnorderedElementsAre(max - 3, max - 4, max - 5)); + EXPECT_THAT(set->GetRemovedValues(), + UnorderedElementsAreArray(delete_values)); + } + const auto& deleted_nodes = + UIntValueSetCacheTest::ReadDeletedNodes(cache); + { + auto prefix_nodes = deleted_nodes.CGet(""); + ASSERT_TRUE(prefix_nodes.is_present()); + auto iter = prefix_nodes.value()->find(2); + ASSERT_NE(iter, prefix_nodes.value()->end()); + EXPECT_THAT(iter->first, 2); + EXPECT_TRUE(iter->second.contains("set1")); + } + { + cache.CleanUpValueSets(safe_path_log_context, 3); + auto prefix_nodes = deleted_nodes.CGet(""); + ASSERT_TRUE(prefix_nodes.is_present()); + auto iter = prefix_nodes.value()->find(2); + ASSERT_EQ(iter, prefix_nodes.value()->end()); + } + { + auto result = cache.GetValueSet(*request_context, keys); + const auto* set = GetValueSet("set1", *result); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), + UnorderedElementsAre(max - 3, max - 4, max - 5)); + EXPECT_TRUE(set->GetRemovedValues().empty()); + } +} + +TEST_F(UIntValueSetCacheTest, VerifyUpdatingSets) { + VerifyUpdatingSets(request_context_, safe_path_log_context_); + VerifyUpdatingSets(request_context_, safe_path_log_context_); +} + +TEST_F(UIntValueSetCacheTest, VerifyDeletingSets) { + VerifyDeletingSets(request_context_, safe_path_log_context_); + VerifyDeletingSets(request_context_, safe_path_log_context_); +} + +TEST_F(UIntValueSetCacheTest, VerifyCleaningUpSets) { + VerifyCleaningUpSets(request_context_, + safe_path_log_context_); + VerifyCleaningUpSets(request_context_, + safe_path_log_context_); +} + +} // namespace +} // namespace kv_server diff --git a/components/data_server/cache/uint32_value_set_test.cc b/components/data_server/cache/uint_value_set_test.cc similarity index 97% rename from components/data_server/cache/uint32_value_set_test.cc rename to components/data_server/cache/uint_value_set_test.cc index 7af32601..3d769992 100644 --- a/components/data_server/cache/uint32_value_set_test.cc +++ b/components/data_server/cache/uint_value_set_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "components/data_server/cache/uint32_value_set.h" +#include "components/data_server/cache/uint_value_set.h" #include diff --git a/components/data_server/data_loading/data_orchestrator.cc b/components/data_server/data_loading/data_orchestrator.cc index 1ca9d43b..d54e09cc 100644 --- a/components/data_server/data_loading/data_orchestrator.cc +++ b/components/data_server/data_loading/data_orchestrator.cc @@ -99,6 +99,13 @@ absl::Status ApplyUpdateMutation( record.logical_commit_time(), prefix); return absl::OkStatus(); } + if (record.value_type() == Value::UInt64Set) { + auto values = GetRecordValue>(record); + cache.UpdateKeyValueSet(log_context, record.key()->string_view(), + absl::MakeSpan(values), + record.logical_commit_time(), prefix); + return absl::OkStatus(); + } return absl::InvalidArgumentError( absl::StrCat("Record with key: ", record.key()->string_view(), " has unsupported value type: ", record.value_type())); @@ -126,6 +133,13 @@ absl::Status ApplyDeleteMutation( record.logical_commit_time(), prefix); return absl::OkStatus(); } + if (record.value_type() == Value::UInt64Set) { + auto values = GetRecordValue>(record); + cache.DeleteValuesInSet(log_context, record.key()->string_view(), + absl::MakeSpan(values), + record.logical_commit_time(), prefix); + return absl::OkStatus(); + } return absl::InvalidArgumentError( absl::StrCat("Record with key: ", record.key()->string_view(), " has unsupported value type: ", record.value_type())); diff --git a/components/data_server/data_loading/data_orchestrator_test.cc b/components/data_server/data_loading/data_orchestrator_test.cc index 7864a91b..74502e94 100644 --- a/components/data_server/data_loading/data_orchestrator_test.cc +++ b/components/data_server/data_loading/data_orchestrator_test.cc @@ -87,14 +87,7 @@ BlobStorageClient::DataLocation GetTestLocation( class DataOrchestratorTest : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } DataOrchestratorTest() : options_(DataOrchestrator::Options{ .data_bucket = GetTestLocation().bucket, diff --git a/components/data_server/request_handler/BUILD.bazel b/components/data_server/request_handler/BUILD.bazel index 66a2f1cb..e25db06f 100644 --- a/components/data_server/request_handler/BUILD.bazel +++ b/components/data_server/request_handler/BUILD.bazel @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") -load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") package(default_visibility = [ + "//components/data:__subpackages__", "//components/data_server:__subpackages__", "//components/internal_server:__subpackages__", ]) @@ -70,17 +70,25 @@ cc_library( hdrs = [ "get_values_v2_handler.h", ], + visibility = [ + "//components/data_server:__subpackages__", + "//components/internal_server:__subpackages__", + "//components/tools:__subpackages__", + "//tools/request_simulation:__subpackages__", + ], deps = [ - ":compression", - ":framing_utils", ":get_values_v2_status", - ":ohttp_server_encryptor", + "//components/data/converters:cbor_converter", "//components/data_server/cache", + "//components/data_server/request_handler/compression", + "//components/data_server/request_handler/content_type:encoder", + "//components/data_server/request_handler/encryption:ohttp_server_encryptor", "//components/telemetry:server_definition", "//components/udf:udf_client", "//components/util:request_context", "//public:api_schema_cc_proto", "//public:base_types_cc_proto", + "//public/applications/pa:response_utils", "//public/query/v2:get_values_v2_cc_grpc", "@com_github_google_quiche//quiche:binary_http_unstable_api", "@com_github_grpc_grpc//:grpc++", @@ -90,73 +98,10 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", "@google_privacysandbox_servers_common//src/communication:encoding_utils", + "@google_privacysandbox_servers_common//src/communication:framing_utils", "@google_privacysandbox_servers_common//src/telemetry", "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", - ], -) - -cc_library( - name = "framing_utils", - srcs = [ - "framing_utils.cc", - ], - hdrs = [ - "framing_utils.h", - ], - deps = [ - "@com_google_absl//absl/numeric:bits", - ], -) - -cc_test( - name = "framing_utils_test", - size = "small", - srcs = [ - "framing_utils_test.cc", - ], - deps = [ - ":framing_utils", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "compression", - srcs = [ - "compression.cc", - "compression_brotli.cc", - "uncompressed.cc", - ], - hdrs = [ - "compression.h", - "compression_brotli.h", - "uncompressed.h", - ], - deps = [ - "@brotli//:brotlidec", - "@brotli//:brotlienc", - "@com_github_google_quiche//quiche:quiche_unstable_api", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - ], -) - -cc_test( - name = "uncompressed_test", - srcs = ["uncompressed_test.cc"], - deps = [ - ":compression", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "compression_brotli_test", - srcs = ["compression_brotli_test.cc"], - deps = [ - ":compression", - "@com_google_absl//absl/log", - "@com_google_googletest//:gtest_main", + "@nlohmann_json//:lib", ], ) @@ -197,7 +142,8 @@ cc_library( ], deps = [ ":get_values_v2_handler", - ":v2_response_data_cc_proto", + "//components/data_server/request_handler/content_type:encoder", + "//components/errors:error_tag", "//public:api_schema_cc_proto", "//public/applications/pa:api_overlay_cc_proto", "//public/applications/pa:response_utils", @@ -231,34 +177,6 @@ cc_test( ], ) -cc_test( - name = "v2_response_data_proto_test", - size = "small", - srcs = [ - "v2_response_data_proto_test.cc", - ], - deps = [ - ":v2_response_data_cc_proto", - "//public/test_util:proto_matcher", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", - "@com_google_protobuf//:protobuf", - ], -) - -proto_library( - name = "v2_response_data_proto", - srcs = ["v2_response_data.proto"], - deps = [ - "@com_google_protobuf//:struct_proto", - ], -) - -cc_proto_library( - name = "v2_response_data_cc_proto", - deps = [":v2_response_data_proto"], -) - cc_library( name = "mocks", testonly = 1, @@ -269,59 +187,6 @@ cc_library( ], ) -cc_library( - name = "ohttp_client_encryptor", - srcs = [ - "ohttp_client_encryptor.cc", - ], - hdrs = [ - "ohttp_client_encryptor.h", - ], - visibility = [ - "//components/data_server:__subpackages__", - "//components/internal_server:__subpackages__", - "//components/tools:__subpackages__", - ], - deps = [ - "//public:constants", - "@com_github_google_quiche//quiche:oblivious_http_unstable_api", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@google_privacysandbox_servers_common//src/encryption/key_fetcher:key_fetcher_manager", - ], -) - -cc_library( - name = "ohttp_server_encryptor", - srcs = [ - "ohttp_server_encryptor.cc", - ], - hdrs = [ - "ohttp_server_encryptor.h", - ], - deps = [ - "//public:constants", - "@com_github_google_quiche//quiche:oblivious_http_unstable_api", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@google_privacysandbox_servers_common//src/encryption/key_fetcher:key_fetcher_manager", - ], -) - -cc_test( - name = "ohttp_encryptor_test", - size = "small", - srcs = [ - "ohttp_encryptor_test.cc", - ], - deps = [ - ":ohttp_client_encryptor", - ":ohttp_server_encryptor", - "@com_google_googletest//:gtest_main", - "@google_privacysandbox_servers_common//src/encryption/key_fetcher:fake_key_fetcher_manager", - ], -) - cc_library( name = "get_values_v2_status", srcs = select({ diff --git a/components/data_server/request_handler/compression/BUILD.bazel b/components/data_server/request_handler/compression/BUILD.bazel new file mode 100644 index 00000000..479a77f5 --- /dev/null +++ b/components/data_server/request_handler/compression/BUILD.bazel @@ -0,0 +1,62 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +package(default_visibility = [ + "//components/data:__subpackages__", + "//components/data_server:__subpackages__", + "//components/internal_server:__subpackages__", +]) + +cc_library( + name = "compression", + srcs = [ + "compression.cc", + "compression_brotli.cc", + "uncompressed.cc", + ], + hdrs = [ + "compression.h", + "compression_brotli.h", + "uncompressed.h", + ], + deps = [ + "@brotli//:brotlidec", + "@brotli//:brotlienc", + "@com_github_google_quiche//quiche:quiche_unstable_api", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "uncompressed_test", + srcs = ["uncompressed_test.cc"], + deps = [ + ":compression", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "compression_brotli_test", + srcs = ["compression_brotli_test.cc"], + deps = [ + ":compression", + "@com_google_absl//absl/log", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/components/data_server/request_handler/compression.cc b/components/data_server/request_handler/compression/compression.cc similarity index 87% rename from components/data_server/request_handler/compression.cc rename to components/data_server/request_handler/compression/compression.cc index 7defc60e..1d127fda 100644 --- a/components/data_server/request_handler/compression.cc +++ b/components/data_server/request_handler/compression/compression.cc @@ -11,11 +11,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "components/data_server/request_handler/compression.h" +#include "components/data_server/request_handler/compression/compression.h" #include "absl/log/log.h" -#include "components/data_server/request_handler/compression_brotli.h" -#include "components/data_server/request_handler/uncompressed.h" +#include "components/data_server/request_handler/compression/compression_brotli.h" +#include "components/data_server/request_handler/compression/uncompressed.h" #include "quiche/common/quiche_data_writer.h" namespace kv_server { diff --git a/components/data_server/request_handler/compression.h b/components/data_server/request_handler/compression/compression.h similarity index 93% rename from components/data_server/request_handler/compression.h rename to components/data_server/request_handler/compression/compression.h index b0b8be2c..88031edf 100644 --- a/components/data_server/request_handler/compression.h +++ b/components/data_server/request_handler/compression/compression.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_COMPRESSION_H_ -#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_COMPRESSION_H_ +#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_COMPRESSION_COMPRESSION_H_ +#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_COMPRESSION_COMPRESSION_H_ #include #include @@ -97,4 +97,4 @@ class CompressedBlobReader { } // namespace kv_server -#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_COMPRESSION_H_ +#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_COMPRESSION_COMPRESSION_H_ diff --git a/components/data_server/request_handler/compression_brotli.cc b/components/data_server/request_handler/compression/compression_brotli.cc similarity index 98% rename from components/data_server/request_handler/compression_brotli.cc rename to components/data_server/request_handler/compression/compression_brotli.cc index ba902ca4..fc505cc8 100644 --- a/components/data_server/request_handler/compression_brotli.cc +++ b/components/data_server/request_handler/compression/compression_brotli.cc @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "components/data_server/request_handler/compression_brotli.h" +#include "components/data_server/request_handler/compression/compression_brotli.h" #include #include diff --git a/components/data_server/request_handler/compression_brotli.h b/components/data_server/request_handler/compression/compression_brotli.h similarity index 94% rename from components/data_server/request_handler/compression_brotli.h rename to components/data_server/request_handler/compression/compression_brotli.h index f1e2ea4c..a24d95c7 100644 --- a/components/data_server/request_handler/compression_brotli.h +++ b/components/data_server/request_handler/compression/compression_brotli.h @@ -14,7 +14,7 @@ #include -#include "components/data_server/request_handler/compression.h" +#include "components/data_server/request_handler/compression/compression.h" namespace kv_server { diff --git a/components/data_server/request_handler/compression_brotli_test.cc b/components/data_server/request_handler/compression/compression_brotli_test.cc similarity index 95% rename from components/data_server/request_handler/compression_brotli_test.cc rename to components/data_server/request_handler/compression/compression_brotli_test.cc index 3c5daf5b..576b4b22 100644 --- a/components/data_server/request_handler/compression_brotli_test.cc +++ b/components/data_server/request_handler/compression/compression_brotli_test.cc @@ -12,13 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "components/data_server/request_handler/compression_brotli.h" +#include "components/data_server/request_handler/compression/compression_brotli.h" #include #include #include "absl/log/log.h" -#include "components/data_server/request_handler/uncompressed.h" +#include "components/data_server/request_handler/compression/uncompressed.h" #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/components/data_server/request_handler/uncompressed.cc b/components/data_server/request_handler/compression/uncompressed.cc similarity index 95% rename from components/data_server/request_handler/uncompressed.cc rename to components/data_server/request_handler/compression/uncompressed.cc index cbbe32d5..b24044e5 100644 --- a/components/data_server/request_handler/uncompressed.cc +++ b/components/data_server/request_handler/compression/uncompressed.cc @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "components/data_server/request_handler/uncompressed.h" +#include "components/data_server/request_handler/compression/uncompressed.h" #include diff --git a/components/data_server/request_handler/uncompressed.h b/components/data_server/request_handler/compression/uncompressed.h similarity index 94% rename from components/data_server/request_handler/uncompressed.h rename to components/data_server/request_handler/compression/uncompressed.h index da985bb0..b499a950 100644 --- a/components/data_server/request_handler/uncompressed.h +++ b/components/data_server/request_handler/compression/uncompressed.h @@ -15,7 +15,7 @@ #include #include "absl/status/statusor.h" -#include "components/data_server/request_handler/compression.h" +#include "components/data_server/request_handler/compression/compression.h" namespace kv_server { diff --git a/components/data_server/request_handler/uncompressed_test.cc b/components/data_server/request_handler/compression/uncompressed_test.cc similarity index 97% rename from components/data_server/request_handler/uncompressed_test.cc rename to components/data_server/request_handler/compression/uncompressed_test.cc index dbf883e7..204bbfde 100644 --- a/components/data_server/request_handler/uncompressed_test.cc +++ b/components/data_server/request_handler/compression/uncompressed_test.cc @@ -14,7 +14,7 @@ #include -#include "components/data_server/request_handler/compression.h" +#include "components/data_server/request_handler/compression/compression.h" #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/components/data_server/request_handler/content_type/BUILD.bazel b/components/data_server/request_handler/content_type/BUILD.bazel new file mode 100644 index 00000000..13bd0b73 --- /dev/null +++ b/components/data_server/request_handler/content_type/BUILD.bazel @@ -0,0 +1,74 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +package(default_visibility = [ + "//components/data_server:__subpackages__", +]) + +cc_library( + name = "encoder", + srcs = [ + "cbor_encoder.cc", + "encoder.cc", + "json_encoder.cc", + "proto_encoder.cc", + ], + hdrs = [ + "cbor_encoder.h", + "encoder.h", + "json_encoder.h", + "proto_encoder.h", + ], + deps = [ + "//components/data/converters:cbor_converter", + "//components/util:request_context", + "//public/applications/pa:response_utils", + "//public/query/v2:get_values_v2_cc_grpc", + "@com_google_absl//absl/status:statusor", + "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", + ], +) + +cc_test( + name = "cbor_encoder_test", + srcs = ["cbor_encoder_test.cc"], + deps = [ + ":encoder", + "//public/test_util:proto_matcher", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "json_encoder_test", + srcs = ["json_encoder_test.cc"], + deps = [ + ":encoder", + "//public/test_util:proto_matcher", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "proto_encoder_test", + srcs = ["proto_encoder_test.cc"], + deps = [ + ":encoder", + "//public/test_util:proto_matcher", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/components/data_server/request_handler/content_type/cbor_encoder.cc b/components/data_server/request_handler/content_type/cbor_encoder.cc new file mode 100644 index 00000000..8650f5af --- /dev/null +++ b/components/data_server/request_handler/content_type/cbor_encoder.cc @@ -0,0 +1,74 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/content_type/cbor_encoder.h" + +#include +#include +#include +#include + +#include "components/data/converters/cbor_converter.h" +#include "public/applications/pa/response_utils.h" +#include "src/util/status_macro/status_macros.h" + +namespace kv_server { + +absl::StatusOr CborV2EncoderDecoder::EncodeV2GetValuesResponse( + v2::GetValuesResponse& response_proto) const { + PS_ASSIGN_OR_RETURN(std::string response, + V2GetValuesResponseCborEncode(response_proto)); + return response; +} + +absl::StatusOr CborV2EncoderDecoder::EncodePartitionOutputs( + std::vector>& partition_output_pairs, + const RequestContextFactory& request_context_factory) const { + google::protobuf::RepeatedPtrField + partition_outputs; + for (auto& partition_output_pair : partition_output_pairs) { + auto partition_output = + application_pa::PartitionOutputFromJson(partition_output_pair.second); + if (partition_output.ok()) { + partition_output.value().set_id(partition_output_pair.first); + *partition_outputs.Add() = partition_output.value(); + } else { + PS_VLOG(2, request_context_factory.Get().GetPSLogContext()) + << partition_output.status(); + } + } + + if (partition_outputs.empty()) { + return absl::InternalError( + "Parsing partition output proto from json failed for all outputs"); + } + + const auto cbor_string = PartitionOutputsCborEncode(partition_outputs); + if (!cbor_string.ok()) { + PS_VLOG(2, request_context_factory.Get().GetPSLogContext()) + << "CBOR encode failed for partition outputs"; + return cbor_string.status(); + } + return cbor_string.value(); +} + +absl::StatusOr +CborV2EncoderDecoder::DecodeToV2GetValuesRequestProto( + std::string_view request) const { + v2::GetValuesRequest request_proto; + PS_RETURN_IF_ERROR(CborDecodeToNonBytesProto(request, request_proto)); + return request_proto; +} + +} // namespace kv_server diff --git a/components/data_server/request_handler/content_type/cbor_encoder.h b/components/data_server/request_handler/content_type/cbor_encoder.h new file mode 100644 index 00000000..03d1a499 --- /dev/null +++ b/components/data_server/request_handler/content_type/cbor_encoder.h @@ -0,0 +1,44 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "components/data_server/request_handler/content_type/encoder.h" + +namespace kv_server { + +// Handles CBOR encoding/decoding for V2 API requests/responses +class CborV2EncoderDecoder : public V2EncoderDecoder { + public: + CborV2EncoderDecoder() = default; + + absl::StatusOr EncodeV2GetValuesResponse( + v2::GetValuesResponse& response_proto) const override; + + // Returns a serialized CBOR list of serialized CBOR partition outputs, + // as per + // https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md + absl::StatusOr EncodePartitionOutputs( + std::vector>& partition_output_pairs, + const RequestContextFactory& request_context_factory) const override; + + absl::StatusOr DecodeToV2GetValuesRequestProto( + std::string_view request) const override; +}; + +} // namespace kv_server diff --git a/components/data_server/request_handler/content_type/cbor_encoder_test.cc b/components/data_server/request_handler/content_type/cbor_encoder_test.cc new file mode 100644 index 00000000..76a29fdb --- /dev/null +++ b/components/data_server/request_handler/content_type/cbor_encoder_test.cc @@ -0,0 +1,301 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/content_type/cbor_encoder.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "gmock/gmock.h" +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "nlohmann/json.hpp" +#include "public/test_util/proto_matcher.h" + +namespace kv_server { +namespace { + +using json = nlohmann::json; +using google::protobuf::TextFormat; + +TEST(CborEncoderTest, EncodeV2GetValuesResponseCompressionGroupSuccess) { + // "abc" -> [97,98,99] as byte array + json json_v2_response = R"({ + "compressionGroups": [ + { + "compressionGroupId": 1, + "content": {"bytes":[97,98,99],"subtype":null}, + "ttlMs": 2 + } + ] + })"_json; + + v2::GetValuesResponse response_proto; + TextFormat::ParseFromString( + R"pb( + compression_groups { compression_group_id: 1 ttl_ms: 2 content: "abc" } + )pb", + &response_proto); + + CborV2EncoderDecoder encoder; + const auto maybe_cbor_response = + encoder.EncodeV2GetValuesResponse(response_proto); + ASSERT_TRUE(maybe_cbor_response.ok()) << maybe_cbor_response.status(); + EXPECT_EQ(json_v2_response.dump(), + json::from_cbor(*maybe_cbor_response).dump()); +} + +TEST(CborEncoderTest, EncodeV2GetValuesResponseSinglePartitionFailure) { + v2::GetValuesResponse response_proto; + TextFormat::ParseFromString( + R"pb( + single_partition { string_output: "abc" } + )pb", + &response_proto); + + CborV2EncoderDecoder encoder; + const auto maybe_cbor_response = + encoder.EncodeV2GetValuesResponse(response_proto); + ASSERT_FALSE(maybe_cbor_response.ok()) << maybe_cbor_response.status(); + EXPECT_EQ(maybe_cbor_response.status().message(), + "single_partition is not supported for cbor content type"); +} + +TEST(CborEncoderTest, EncodePartitionOutputsSuccess) { + InitMetricsContextMap(); + json json_partition_output1 = R"( + { + "keyGroupOutputs": [ + { + "keyValues": { + "hello": { + "value": "world" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] + })"_json; + json json_partition_output2 = R"( + { + "keyGroupOutputs": [ + { + "keyValues": { + "hello2": { + "value": "world2" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] + } + )"_json; + std::vector> partition_output_pairs = { + {1, json_partition_output1.dump()}, {2, json_partition_output2.dump()}}; + + auto request_context_factory = std::make_unique(); + CborV2EncoderDecoder encoder; + const auto maybe_cbor_content = encoder.EncodePartitionOutputs( + partition_output_pairs, *request_context_factory); + + json expected_partition_output1 = {{"id", 1}}; + expected_partition_output1.update(json_partition_output1); + json expected_partition_output2 = {{"id", 2}}; + expected_partition_output2.update(json_partition_output2); + json expected_partition_outputs = {expected_partition_output1, + expected_partition_output2}; + ASSERT_TRUE(maybe_cbor_content.ok()) << maybe_cbor_content.status(); + EXPECT_EQ(expected_partition_outputs, json::from_cbor(*maybe_cbor_content)); +} + +TEST(CborEncoderTest, EncodePartitionOutputsEmptyKeyGroupOutputSuccess) { + InitMetricsContextMap(); + json json_partition_output = R"( + { + "keyGroupOutputs": [] + })"_json; + std::vector> partition_output_pairs = { + {1, json_partition_output.dump()}}; + + std::string content; + auto request_context_factory = std::make_unique(); + CborV2EncoderDecoder encoder; + const auto maybe_cbor_content = encoder.EncodePartitionOutputs( + partition_output_pairs, *request_context_factory); + + json expected_json_outputs = R"([ + { + "id": 1, + "keyGroupOutputs": [] + }])"_json; + ASSERT_TRUE(maybe_cbor_content.ok()) << maybe_cbor_content.status(); + EXPECT_EQ(expected_json_outputs, json::from_cbor(*maybe_cbor_content)); +} + +TEST(CborEncoderTest, EncodePartitionOutputs_OverwritesId) { + InitMetricsContextMap(); + json json_partition_output = R"( + { + "id": 100, + "keyGroupOutputs": [] + })"_json; + std::vector> partition_output_pairs = { + {1, json_partition_output.dump()}}; + + std::string content; + auto request_context_factory = std::make_unique(); + CborV2EncoderDecoder encoder; + const auto maybe_cbor_content = encoder.EncodePartitionOutputs( + partition_output_pairs, *request_context_factory); + + json expected_json_outputs = R"([ + { + "id": 1, + "keyGroupOutputs": [] + }])"_json; + ASSERT_TRUE(maybe_cbor_content.ok()) << maybe_cbor_content.status(); + EXPECT_EQ(expected_json_outputs, json::from_cbor(*maybe_cbor_content)); +} + +TEST(CborEncoderTest, EncodePartitionOutputsInvalidPartitionOutputIgnored) { + InitMetricsContextMap(); + json json_partition_output_invalid = R"( + { + "keyGroupOtputs": [] + } + )"_json; + + json json_partition_output_valid = R"( + { + "keyGroupOutputs": [] + } + )"_json; + std::vector> partition_output_pairs = { + {1, json_partition_output_invalid.dump()}, + {2, json_partition_output_valid.dump()}}; + + auto request_context_factory = std::make_unique(); + CborV2EncoderDecoder encoder; + const auto maybe_cbor_content = encoder.EncodePartitionOutputs( + partition_output_pairs, *request_context_factory); + + ASSERT_TRUE(maybe_cbor_content.ok()) << maybe_cbor_content.status(); + json partition_outputs_json = json::array(); + json expected_partition_output_valid = {{"id", 2}}; + expected_partition_output_valid.update(json_partition_output_valid); + partition_outputs_json.emplace_back(expected_partition_output_valid); + EXPECT_EQ(partition_outputs_json, json::from_cbor(*maybe_cbor_content)); +} + +TEST(CborEncoderTest, EncodePartitionOutputsAllInvalidPartitionOutputFails) { + InitMetricsContextMap(); + json json_partition_output_invalid = R"( + { + "keyGroupOtputs": [] + } + )"_json; + std::vector> partition_output_pairs = { + {1, json_partition_output_invalid.dump()}}; + + std::string content; + auto request_context_factory = std::make_unique(); + CborV2EncoderDecoder encoder; + const auto maybe_cbor_content = encoder.EncodePartitionOutputs( + partition_output_pairs, *request_context_factory); + + ASSERT_FALSE(maybe_cbor_content.ok()) << maybe_cbor_content.status(); +} + +TEST(CborEncoderTest, DecodeToV2GetValuesRequestProtoEmptyStringSuccess) { + std::string request = ""; + CborV2EncoderDecoder encoder; + const auto maybe_request = encoder.DecodeToV2GetValuesRequestProto(request); + ASSERT_FALSE(maybe_request.ok()) << maybe_request.status(); +} + +TEST(CborEncoderTest, DecodeToV2GetValuesRequestSuccess) { + v2::GetValuesRequest expected; + TextFormat::ParseFromString(R"pb( + client_version: "version1" + metadata { + fields { + key: "foo" + value { string_value: "bar1" } + } + } + partitions { + id: 1 + compression_group_id: 1 + metadata { + fields { + key: "partition_metadata" + value { string_value: "bar2" } + } + } + arguments { + tags { + values { string_value: "tag1" } + values { string_value: "tag2" } + } + + data { string_value: "bar4" } + } + } + )pb", + &expected); + + nlohmann::json json_message = R"( + { + "clientVersion": "version1", + "metadata": { + "foo": "bar1" + }, + "partitions": [ + { + "id": 1, + "compressionGroupId": 1, + "metadata": { + "partition_metadata": "bar2" + }, + "arguments": { + "tags": [ + "tag1", + "tag2" + ], + "data": "bar4" + } + } + ] +} +)"_json; + std::vector v = json::to_cbor(json_message); + std::string cbor_raw(v.begin(), v.end()); + CborV2EncoderDecoder encoder; + const auto maybe_request = encoder.DecodeToV2GetValuesRequestProto(cbor_raw); + ASSERT_TRUE(maybe_request.ok()) << maybe_request.status(); + EXPECT_THAT(expected, EqualsProto(*maybe_request)); +} + +} // namespace +} // namespace kv_server diff --git a/components/data_server/request_handler/content_type/encoder.cc b/components/data_server/request_handler/content_type/encoder.cc new file mode 100644 index 00000000..0e303545 --- /dev/null +++ b/components/data_server/request_handler/content_type/encoder.cc @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/content_type/encoder.h" + +#include + +#include "components/data_server/request_handler/content_type/cbor_encoder.h" +#include "components/data_server/request_handler/content_type/json_encoder.h" +#include "components/data_server/request_handler/content_type/proto_encoder.h" + +namespace kv_server { + +std::unique_ptr V2EncoderDecoder::Create( + const V2EncoderDecoder::ContentType& content_type) { + switch (content_type) { + case V2EncoderDecoder::ContentType::kCbor: { + return std::make_unique(); + } + case V2EncoderDecoder::ContentType::kJson: { + return std::make_unique(); + } + case V2EncoderDecoder::ContentType::kProto: { + return std::make_unique(); + } + } +} + +V2EncoderDecoder::ContentType V2EncoderDecoder::GetContentType( + const std::multimap& headers, + V2EncoderDecoder::ContentType default_content_type) { + for (const auto& [header_name, header_value] : headers) { + if (absl::AsciiStrToLower(std::string_view( + header_name.data(), header_name.size())) == kKVContentTypeHeader) { + if (absl::AsciiStrToLower( + std::string_view(header_value.data(), header_value.size())) == + kContentEncodingProtoHeaderValue) { + return ContentType::kProto; + } else if (absl::AsciiStrToLower(std::string_view(header_value.data(), + header_value.size())) == + kContentEncodingJsonHeaderValue) { + return ContentType::kJson; + } else if (absl::AsciiStrToLower(std::string_view(header_value.data(), + header_value.size())) == + kContentEncodingCborHeaderValue) { + return ContentType::kCbor; + } + } + } + return default_content_type; +} + +} // namespace kv_server diff --git a/components/data_server/request_handler/content_type/encoder.h b/components/data_server/request_handler/content_type/encoder.h new file mode 100644 index 00000000..65d98bc9 --- /dev/null +++ b/components/data_server/request_handler/content_type/encoder.h @@ -0,0 +1,79 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_CONTENT_TYPE_ENCODER_H_ +#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_CONTENT_TYPE_ENCODER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "components/util/request_context.h" +#include "public/query/v2/get_values_v2.grpc.pb.h" + +namespace kv_server { + +// Header in clear text http request/response that indicates which format is +// used by the payload. The more common "Content-Type" header is not used +// because most importantly that has CORS implications, and in addition, may not +// be forwarded by Envoy to gRPC. +inline constexpr std::string_view kKVContentTypeHeader = "kv-content-type"; + +// Protobuf Content Type Header Value. +inline constexpr std::string_view kContentEncodingProtoHeaderValue = + "message/ad-auction-trusted-signals-request+proto"; +// Json Content Type Header Value. +inline constexpr std::string_view kContentEncodingJsonHeaderValue = + "message/ad-auction-trusted-signals-request+json"; +inline constexpr std::string_view kContentEncodingCborHeaderValue = + "message/ad-auction-trusted-signals-request"; + +// Encodes and decodes V2 requests and responses. +class V2EncoderDecoder { + public: + enum class ContentType { kCbor = 0, kJson = 1, kProto = 2 }; + + static ContentType GetContentType( + const std::multimap& headers, + ContentType default_content_type); + + static std::unique_ptr Create(const ContentType& type); + + virtual ~V2EncoderDecoder() = default; + + // Encodes a V2 GetValuesResponse + virtual absl::StatusOr EncodeV2GetValuesResponse( + v2::GetValuesResponse& response_proto) const = 0; + + // Encodes a list of and serializes it as a string + // If UDF partition output has an "id" field, it will be overwritten by the + // given id in the pair. + virtual absl::StatusOr EncodePartitionOutputs( + std::vector>& partition_output_pairs, + const RequestContextFactory& request_context_factory) const = 0; + + // Decodes the string to a V2 GetValuesRequest proto + virtual absl::StatusOr DecodeToV2GetValuesRequestProto( + std::string_view request) const = 0; +}; + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_CONTENT_TYPE_ENCODER_H_ diff --git a/components/data_server/request_handler/content_type/json_encoder.cc b/components/data_server/request_handler/content_type/json_encoder.cc new file mode 100644 index 00000000..345b1a0f --- /dev/null +++ b/components/data_server/request_handler/content_type/json_encoder.cc @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/content_type/json_encoder.h" + +#include +#include +#include +#include + +#include "nlohmann/json.hpp" + +namespace kv_server { + +using google::protobuf::util::MessageToJsonString; + +absl::StatusOr JsonV2EncoderDecoder::EncodeV2GetValuesResponse( + v2::GetValuesResponse& response_proto) const { + std::string response; + PS_RETURN_IF_ERROR(MessageToJsonString(response_proto, &response)); + return response; +} + +absl::StatusOr JsonV2EncoderDecoder::EncodePartitionOutputs( + std::vector>& partition_output_pairs, + const RequestContextFactory& request_context_factory) const { + nlohmann::json json_partition_output_list = nlohmann::json::array(); + for (auto&& partition_output_pair : partition_output_pairs) { + auto partition_output_json = + nlohmann::json::parse(partition_output_pair.second, nullptr, + /*allow_exceptions=*/false, + /*ignore_comments=*/true); + if (partition_output_json.is_discarded()) { + PS_VLOG(2, request_context_factory.Get().GetPSLogContext()) + << "json parse failed for " << partition_output_pair.second; + continue; + } + partition_output_json["id"] = partition_output_pair.first; + json_partition_output_list.emplace_back(partition_output_json); + } + if (json_partition_output_list.size() == 0) { + return absl::InvalidArgumentError( + "No partition outputs were added to compression group content"); + } + return json_partition_output_list.dump(); +} + +absl::StatusOr +JsonV2EncoderDecoder::DecodeToV2GetValuesRequestProto( + std::string_view request) const { + v2::GetValuesRequest request_proto; + PS_RETURN_IF_ERROR( + google::protobuf::util::JsonStringToMessage(request, &request_proto)); + return request_proto; +} + +} // namespace kv_server diff --git a/components/data_server/request_handler/content_type/json_encoder.h b/components/data_server/request_handler/content_type/json_encoder.h new file mode 100644 index 00000000..0d9e506a --- /dev/null +++ b/components/data_server/request_handler/content_type/json_encoder.h @@ -0,0 +1,43 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "components/data_server/request_handler/content_type/encoder.h" + +namespace kv_server { + +// Handles JSON encoding/decoding for V2 API requests/responses +class JsonV2EncoderDecoder : public V2EncoderDecoder { + public: + JsonV2EncoderDecoder() = default; + + absl::StatusOr EncodeV2GetValuesResponse( + v2::GetValuesResponse& response_proto) const override; + + // Returns a serialized JSON array of partition outputs. + // A partition output is simply the return value of a UDF execution. + absl::StatusOr EncodePartitionOutputs( + std::vector>& partition_output_pairs, + const RequestContextFactory& request_context_factory) const override; + + absl::StatusOr DecodeToV2GetValuesRequestProto( + std::string_view request) const override; +}; + +} // namespace kv_server diff --git a/components/data_server/request_handler/content_type/json_encoder_test.cc b/components/data_server/request_handler/content_type/json_encoder_test.cc new file mode 100644 index 00000000..525efb97 --- /dev/null +++ b/components/data_server/request_handler/content_type/json_encoder_test.cc @@ -0,0 +1,218 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/content_type/json_encoder.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "gmock/gmock.h" +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "nlohmann/json.hpp" +#include "public/test_util/proto_matcher.h" + +namespace kv_server { +namespace { + +using json = nlohmann::json; +using google::protobuf::TextFormat; + +TEST(JsonEncoderTest, EncodeV2GetValuesResponseCompressionGroupSuccess) { + // "abc" -> base64 encode => YWJj + json expected = R"({ + "compressionGroups": [ + { + "compressionGroupId": 1, + "content": "YWJj", + "ttlMs": 3 + } + ] + })"_json; + + v2::GetValuesResponse response_proto; + TextFormat::ParseFromString( + R"pb( + compression_groups { compression_group_id: 1 content: "abc" ttl_ms: 3 } + )pb", + &response_proto); + + JsonV2EncoderDecoder encoder; + const auto maybe_json_response = + encoder.EncodeV2GetValuesResponse(response_proto); + ASSERT_TRUE(maybe_json_response.ok()) << maybe_json_response.status(); + nlohmann::json json_response = nlohmann::json::parse(*maybe_json_response); + EXPECT_EQ(expected, json_response); +} + +TEST(JsonEncoderTest, EncodeV2GetValuesResponseSinglePartitionSuccess) { + json expected = R"({ + "singlePartition": { "stringOutput": "abc" } + })"_json; + + v2::GetValuesResponse response_proto; + TextFormat::ParseFromString( + R"pb( + single_partition { string_output: "abc" } + )pb", + &response_proto); + + JsonV2EncoderDecoder encoder; + const auto maybe_json_response = + encoder.EncodeV2GetValuesResponse(response_proto); + ASSERT_TRUE(maybe_json_response.ok()) << maybe_json_response.status(); + EXPECT_EQ(expected.dump(), *maybe_json_response); +} + +TEST(JsonEncoderTest, EncodePartitionOutputsSuccess) { + InitMetricsContextMap(); + json json_partition_output1 = R"( + { + "keyGroupOutputs": [ + { + "keyValues": { + "hello": { + "value": "world" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] + } + )"_json; + json json_partition_output2 = R"( + { + "keyGroupOutputs": [ + { + "keyValues": { + "hello2": { + "value": "world2" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] + } + )"_json; + std::vector> partition_output_pairs = { + {1, json_partition_output1.dump()}, {2, json_partition_output2.dump()}}; + + auto request_context_factory = std::make_unique(); + JsonV2EncoderDecoder encoder; + const auto maybe_json_content = encoder.EncodePartitionOutputs( + partition_output_pairs, *request_context_factory); + + json expected_output1 = {{"id", 1}}; + expected_output1.update(json_partition_output1); + json expected_output2 = {{"id", 2}}; + expected_output2.update(json_partition_output2); + json expected = {expected_output1, expected_output2}; + ASSERT_TRUE(maybe_json_content.ok()) << maybe_json_content.status(); + EXPECT_EQ(expected.dump(), *maybe_json_content); +} + +TEST(JsonEncoderTest, EncodePartitionOutputsEmptyFails) { + InitMetricsContextMap(); + std::vector> partition_output_pairs = {}; + + std::string content; + auto request_context_factory = std::make_unique(); + JsonV2EncoderDecoder encoder; + const auto maybe_json_content = encoder.EncodePartitionOutputs( + partition_output_pairs, *request_context_factory); + + ASSERT_FALSE(maybe_json_content.ok()) << maybe_json_content.status(); +} + +TEST(JsonEncoderTest, DecodeToV2GetValuesRequestProtoEmptyStringFailure) { + std::string request = ""; + JsonV2EncoderDecoder encoder; + const auto maybe_request = encoder.DecodeToV2GetValuesRequestProto(request); + ASSERT_FALSE(maybe_request.ok()) << maybe_request.status(); +} + +TEST(JsonEncoderTest, DecodeToV2GetValuesRequestSuccess) { + v2::GetValuesRequest expected; + TextFormat::ParseFromString(R"pb( + client_version: "version1" + metadata { + fields { + key: "foo" + value { string_value: "bar1" } + } + } + partitions { + id: 1 + compression_group_id: 1 + metadata { + fields { + key: "partition_metadata" + value { string_value: "bar2" } + } + } + arguments { + tags { + values { string_value: "tag1" } + values { string_value: "tag2" } + } + + data { string_value: "bar4" } + } + } + )pb", + &expected); + + nlohmann::json json_message = R"( + { + "clientVersion": "version1", + "metadata": { + "foo": "bar1" + }, + "partitions": [ + { + "id": 1, + "compressionGroupId": 1, + "metadata": { + "partition_metadata": "bar2" + }, + "arguments": { + "tags": [ + "tag1", + "tag2" + ], + "data": "bar4" + } + } + ] +} +)"_json; + JsonV2EncoderDecoder encoder; + const auto maybe_request = + encoder.DecodeToV2GetValuesRequestProto(json_message.dump()); + ASSERT_TRUE(maybe_request.ok()) << maybe_request.status(); + EXPECT_THAT(expected, EqualsProto(*maybe_request)); +} + +} // namespace +} // namespace kv_server diff --git a/components/data_server/request_handler/content_type/proto_encoder.cc b/components/data_server/request_handler/content_type/proto_encoder.cc new file mode 100644 index 00000000..e0657879 --- /dev/null +++ b/components/data_server/request_handler/content_type/proto_encoder.cc @@ -0,0 +1,78 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/content_type/proto_encoder.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "nlohmann/json.hpp" +#include "public/applications/pa/api_overlay.pb.h" + +namespace kv_server { + +absl::StatusOr ProtoV2EncoderDecoder::EncodeV2GetValuesResponse( + v2::GetValuesResponse& response_proto) const { + std::string response; + if (!response_proto.SerializeToString(&response)) { + auto error_message = "Cannot serialize the response as a proto."; + return absl::InvalidArgumentError(error_message); + } + return response; +} + +absl::StatusOr ProtoV2EncoderDecoder::EncodePartitionOutputs( + std::vector>& partition_output_pairs, + const RequestContextFactory& request_context_factory) const { + nlohmann::json json_partition_output_list = nlohmann::json::array(); + for (auto&& partition_output_pair : partition_output_pairs) { + auto partition_output_json = + nlohmann::json::parse(partition_output_pair.second, nullptr, + /*allow_exceptions=*/false, + /*ignore_comments=*/true); + if (partition_output_json.is_discarded()) { + PS_VLOG(2, request_context_factory.Get().GetPSLogContext()) + << "json parse failed for " << partition_output_pair.second; + continue; + } + partition_output_json["id"] = partition_output_pair.first; + json_partition_output_list.emplace_back(partition_output_json); + } + if (json_partition_output_list.size() == 0) { + return absl::InvalidArgumentError( + "No partition outputs were added to compression group content"); + } + return json_partition_output_list.dump(); +} + +absl::StatusOr +ProtoV2EncoderDecoder::DecodeToV2GetValuesRequestProto( + std::string_view request) const { + v2::GetValuesRequest request_proto; + if (request.empty()) { + return absl::InvalidArgumentError( + "Received empty request, not converting to v2::GetValuesRequest proto"); + } + if (!request_proto.ParseFromString(request)) { + auto error_message = absl::StrCat( + "Cannot parse request as a valid serialized proto object: ", request); + return absl::InvalidArgumentError(error_message); + } + return request_proto; +} + +} // namespace kv_server diff --git a/components/data_server/request_handler/content_type/proto_encoder.h b/components/data_server/request_handler/content_type/proto_encoder.h new file mode 100644 index 00000000..26700e8b --- /dev/null +++ b/components/data_server/request_handler/content_type/proto_encoder.h @@ -0,0 +1,43 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "components/data_server/request_handler/content_type/encoder.h" + +namespace kv_server { + +// Handles proto encoding/decoding for V2 API requests/responses +class ProtoV2EncoderDecoder : public V2EncoderDecoder { + public: + ProtoV2EncoderDecoder() = default; + + absl::StatusOr EncodeV2GetValuesResponse( + v2::GetValuesResponse& response_proto) const override; + + // Returns a serialized JSON array of partition outputs. + // A partition output is simply the return value of a UDF execution. + absl::StatusOr EncodePartitionOutputs( + std::vector>& partition_output_pairs, + const RequestContextFactory& request_context_factory) const override; + + absl::StatusOr DecodeToV2GetValuesRequestProto( + std::string_view request) const override; +}; + +} // namespace kv_server diff --git a/components/data_server/request_handler/content_type/proto_encoder_test.cc b/components/data_server/request_handler/content_type/proto_encoder_test.cc new file mode 100644 index 00000000..eae69655 --- /dev/null +++ b/components/data_server/request_handler/content_type/proto_encoder_test.cc @@ -0,0 +1,164 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/content_type/proto_encoder.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "gmock/gmock.h" +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "nlohmann/json.hpp" +#include "public/test_util/proto_matcher.h" + +namespace kv_server { +namespace { + +using google::protobuf::TextFormat; +using nlohmann::json; + +TEST(ProtoEncoderTest, EncodeV2GetValuesResponseSuccess) { + v2::GetValuesResponse response_proto; + TextFormat::ParseFromString( + R"pb( + compression_groups { compression_group_id: 1 content: "abc" ttl_ms: 3 } + single_partition { string_output: "abc" } + )pb", + &response_proto); + ProtoV2EncoderDecoder encoder; + const auto maybe_proto_response = + encoder.EncodeV2GetValuesResponse(response_proto); + ASSERT_TRUE(maybe_proto_response.ok()) << maybe_proto_response.status(); + std::string expected; + response_proto.SerializeToString(&expected); + EXPECT_EQ(expected, *maybe_proto_response); +} + +TEST(ProtoEncoderTest, EncodePartitionOutputsSuccess) { + InitMetricsContextMap(); + json json_partition_output1 = R"( + { + "keyGroupOutputs": [ + { + "keyValues": { + "hello": { + "value": "world" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] + })"_json; + json json_partition_output2 = R"( + { + "keyGroupOutputs": [ + { + "keyValues": { + "hello2": { + "value": "world2" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] + } + )"_json; + std::vector> partition_output_pairs = { + {1, json_partition_output1.dump()}, {2, json_partition_output2.dump()}}; + + auto request_context_factory = std::make_unique(); + ProtoV2EncoderDecoder encoder; + const auto maybe_proto_content = encoder.EncodePartitionOutputs( + partition_output_pairs, *request_context_factory); + + json expected_output1 = {{"id", 1}}; + expected_output1.update(json_partition_output1); + json expected_output2 = {{"id", 2}}; + expected_output2.update(json_partition_output2); + json expected = {expected_output1, expected_output2}; + ASSERT_TRUE(maybe_proto_content.ok()) << maybe_proto_content.status(); + EXPECT_EQ(expected.dump(), *maybe_proto_content); +} + +TEST(JsonEncoderTest, EncodePartitionOutputsEmptyFails) { + InitMetricsContextMap(); + std::vector> partition_output_pairs = {}; + std::string content; + auto request_context_factory = std::make_unique(); + ProtoV2EncoderDecoder encoder; + const auto maybe_proto_content = encoder.EncodePartitionOutputs( + partition_output_pairs, *request_context_factory); + + ASSERT_FALSE(maybe_proto_content.ok()) << maybe_proto_content.status(); +} + +TEST(ProtoEncoderTest, DecodeToV2GetValuesRequestProtoEmptyStringFailure) { + std::string request = ""; + ProtoV2EncoderDecoder encoder; + const auto maybe_request = encoder.DecodeToV2GetValuesRequestProto(request); + ASSERT_FALSE(maybe_request.ok()) << maybe_request.status(); +} + +TEST(ProtoEncoderTest, DecodeToV2GetValuesRequestSuccess) { + v2::GetValuesRequest expected; + TextFormat::ParseFromString(R"pb( + client_version: "version1" + metadata { + fields { + key: "foo" + value { string_value: "bar1" } + } + } + partitions { + id: 1 + compression_group_id: 1 + metadata { + fields { + key: "partition_metadata" + value { string_value: "bar2" } + } + } + arguments { + tags { + values { string_value: "tag1" } + values { string_value: "tag2" } + } + + data { string_value: "bar4" } + } + } + )pb", + &expected); + ProtoV2EncoderDecoder encoder; + std::string serialized_request; + expected.SerializeToString(&serialized_request); + const auto maybe_request = + encoder.DecodeToV2GetValuesRequestProto(serialized_request); + ASSERT_TRUE(maybe_request.ok()) << maybe_request.status(); + EXPECT_THAT(expected, EqualsProto(*maybe_request)); +} + +} // namespace +} // namespace kv_server diff --git a/components/data_server/request_handler/encryption/BUILD.bazel b/components/data_server/request_handler/encryption/BUILD.bazel new file mode 100644 index 00000000..15703c12 --- /dev/null +++ b/components/data_server/request_handler/encryption/BUILD.bazel @@ -0,0 +1,70 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +package(default_visibility = [ + "//components/data_server:__subpackages__", + "//components/internal_server:__subpackages__", + "//components/tools:__subpackages__", +]) + +cc_library( + name = "ohttp_client_encryptor", + srcs = [ + "ohttp_client_encryptor.cc", + ], + hdrs = [ + "ohttp_client_encryptor.h", + ], + deps = [ + "//public:constants", + "@com_github_google_quiche//quiche:oblivious_http_unstable_api", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/encryption/key_fetcher:key_fetcher_manager", + ], +) + +cc_library( + name = "ohttp_server_encryptor", + srcs = [ + "ohttp_server_encryptor.cc", + ], + hdrs = [ + "ohttp_server_encryptor.h", + ], + deps = [ + "//public:constants", + "@com_github_google_quiche//quiche:oblivious_http_unstable_api", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/encryption/key_fetcher:key_fetcher_manager", + ], +) + +cc_test( + name = "ohttp_encryptor_test", + size = "small", + srcs = [ + "ohttp_encryptor_test.cc", + ], + deps = [ + ":ohttp_client_encryptor", + ":ohttp_server_encryptor", + "@com_google_googletest//:gtest_main", + "@google_privacysandbox_servers_common//src/encryption/key_fetcher:fake_key_fetcher_manager", + ], +) diff --git a/components/data_server/request_handler/ohttp_client_encryptor.cc b/components/data_server/request_handler/encryption/ohttp_client_encryptor.cc similarity index 79% rename from components/data_server/request_handler/ohttp_client_encryptor.cc rename to components/data_server/request_handler/encryption/ohttp_client_encryptor.cc index 54cf955e..fe6d8134 100644 --- a/components/data_server/request_handler/ohttp_client_encryptor.cc +++ b/components/data_server/request_handler/encryption/ohttp_client_encryptor.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "components/data_server/request_handler/ohttp_client_encryptor.h" +#include "components/data_server/request_handler/encryption/ohttp_client_encryptor.h" #include @@ -53,15 +53,10 @@ absl::StatusOr OhttpClientEncryptor::EncryptRequest( << public_key_.key_id() << " uint8 key id " << *key_id << "public key " << public_key_.public_key(); absl::Base64Unescape(public_key_.public_key(), &public_key_string); - auto http_client_maybe = - quiche::ObliviousHttpClient::Create(public_key_string, *maybe_config); - if (!http_client_maybe.ok()) { - return absl::InternalError( - std::string(http_client_maybe.status().message())); - } - http_client_ = std::move(*http_client_maybe); auto encrypted_req = - http_client_->CreateObliviousHttpRequest(std::move(payload)); + quiche::ObliviousHttpRequest::CreateClientObliviousRequest( + std::move(payload), public_key_string, *std::move(maybe_config), + kKVOhttpRequestLabel); if (!encrypted_req.ok()) { return absl::InternalError(std::string(encrypted_req.status().message())); } @@ -74,13 +69,15 @@ absl::StatusOr OhttpClientEncryptor::EncryptRequest( absl::StatusOr OhttpClientEncryptor::DecryptResponse( std::string encrypted_payload, privacy_sandbox::server_common::log::PSLogContext& log_context) { - if (!http_client_.has_value() || !http_request_context_.has_value()) { + if (!http_request_context_.has_value()) { return absl::InternalError( - "Emtpy `http_client_` or `http_request_context_`. You should call " + "Emtpy `http_request_context_`. You should call " "`ClientEncryptRequest` first"); } - auto decrypted_response = http_client_->DecryptObliviousHttpResponse( - std::move(encrypted_payload), *http_request_context_); + auto decrypted_response = + quiche::ObliviousHttpResponse::CreateClientObliviousResponse( + std::move(encrypted_payload), *http_request_context_, + kKVOhttpResponseLabel); if (!decrypted_response.ok()) { return decrypted_response.status(); } diff --git a/components/data_server/request_handler/ohttp_client_encryptor.h b/components/data_server/request_handler/encryption/ohttp_client_encryptor.h similarity index 88% rename from components/data_server/request_handler/ohttp_client_encryptor.h rename to components/data_server/request_handler/encryption/ohttp_client_encryptor.h index 8e9fcf81..92b65e54 100644 --- a/components/data_server/request_handler/ohttp_client_encryptor.h +++ b/components/data_server/request_handler/encryption/ohttp_client_encryptor.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_CLIENT_ENCRYPTOR_H_ -#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_CLIENT_ENCRYPTOR_H_ +#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_ENCRYPTION_OHTTP_CLIENT_ENCRYPTOR_H_ +#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_ENCRYPTION_OHTTP_CLIENT_ENCRYPTOR_H_ #include #include @@ -48,11 +48,10 @@ class OhttpClientEncryptor { privacy_sandbox::server_common::log::kNoOpContext)); private: - std::optional http_client_; std::optional http_request_context_; google::cmrt::sdk::public_key_service::v1::PublicKey& public_key_; }; } // namespace kv_server -#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_CLIENT_ENCRYPTOR_H_ +#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_ENCRYPTION_OHTTP_CLIENT_ENCRYPTOR_H_ diff --git a/components/data_server/request_handler/ohttp_encryptor_test.cc b/components/data_server/request_handler/encryption/ohttp_encryptor_test.cc similarity index 94% rename from components/data_server/request_handler/ohttp_encryptor_test.cc rename to components/data_server/request_handler/encryption/ohttp_encryptor_test.cc index 6cad5850..17a9aace 100644 --- a/components/data_server/request_handler/ohttp_encryptor_test.cc +++ b/components/data_server/request_handler/encryption/ohttp_encryptor_test.cc @@ -14,8 +14,8 @@ #include -#include "components/data_server/request_handler/ohttp_client_encryptor.h" -#include "components/data_server/request_handler/ohttp_server_encryptor.h" +#include "components/data_server/request_handler/encryption/ohttp_client_encryptor.h" +#include "components/data_server/request_handler/encryption/ohttp_server_encryptor.h" #include "gtest/gtest.h" #include "src/encryption/key_fetcher/fake_key_fetcher_manager.h" #include "src/encryption/key_fetcher/interface/key_fetcher_manager_interface.h" @@ -93,7 +93,7 @@ TEST(OhttpEncryptorTest, ClientDecryptResponseFails) { client_encryptor.DecryptResponse(kTestRequest); ASSERT_FALSE(request_encrypted_status.ok()); EXPECT_EQ( - "Emtpy `http_client_` or `http_request_context_`. You should call " + "Emtpy `http_request_context_`. You should call " "`ClientEncryptRequest` first", request_encrypted_status.status().message()); } diff --git a/components/data_server/request_handler/ohttp_server_encryptor.cc b/components/data_server/request_handler/encryption/ohttp_server_encryptor.cc similarity index 91% rename from components/data_server/request_handler/ohttp_server_encryptor.cc rename to components/data_server/request_handler/encryption/ohttp_server_encryptor.cc index a8893a4c..c2eb05fd 100644 --- a/components/data_server/request_handler/ohttp_server_encryptor.cc +++ b/components/data_server/request_handler/encryption/ohttp_server_encryptor.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "components/data_server/request_handler/ohttp_server_encryptor.h" +#include "components/data_server/request_handler/encryption/ohttp_server_encryptor.h" #include @@ -54,8 +54,8 @@ absl::StatusOr OhttpServerEncryptor::DecryptRequest( return maybe_ohttp_gateway.status(); } ohttp_gateway_ = std::move(*maybe_ohttp_gateway); - auto decrypted_request_maybe = - ohttp_gateway_->DecryptObliviousHttpRequest(encrypted_payload); + auto decrypted_request_maybe = ohttp_gateway_->DecryptObliviousHttpRequest( + encrypted_payload, kKVOhttpRequestLabel); if (!decrypted_request_maybe.ok()) { return decrypted_request_maybe.status(); } @@ -73,7 +73,7 @@ absl::StatusOr OhttpServerEncryptor::EncryptResponse( } auto server_request_context = std::move(*decrypted_request_).ReleaseContext(); const auto encapsulate_resp = ohttp_gateway_->CreateObliviousHttpResponse( - std::move(payload), server_request_context); + std::move(payload), server_request_context, kKVOhttpResponseLabel); if (!encapsulate_resp.ok()) { return absl::InternalError( std::string(encapsulate_resp.status().message())); diff --git a/components/data_server/request_handler/ohttp_server_encryptor.h b/components/data_server/request_handler/encryption/ohttp_server_encryptor.h similarity index 90% rename from components/data_server/request_handler/ohttp_server_encryptor.h rename to components/data_server/request_handler/encryption/ohttp_server_encryptor.h index 644e89c3..a06c526b 100644 --- a/components/data_server/request_handler/ohttp_server_encryptor.h +++ b/components/data_server/request_handler/encryption/ohttp_server_encryptor.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_SERVER_ENCRYPTOR_H_ -#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_SERVER_ENCRYPTOR_H_ +#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_ENCRYPTION_OHTTP_SERVER_ENCRYPTOR_H_ +#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_ENCRYPTION_OHTTP_SERVER_ENCRYPTOR_H_ #include #include @@ -60,4 +60,4 @@ class OhttpServerEncryptor { } // namespace kv_server -#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_SERVER_ENCRYPTOR_H_ +#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_ENCRYPTION_OHTTP_SERVER_ENCRYPTOR_H_ diff --git a/components/data_server/request_handler/framing_utils.cc b/components/data_server/request_handler/framing_utils.cc deleted file mode 100644 index 7db740bc..00000000 --- a/components/data_server/request_handler/framing_utils.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "components/data_server/request_handler/framing_utils.h" - -#include - -#include "absl/numeric/bits.h" - -namespace kv_server { - -// 1 byte for version + compression details. -constexpr int kVersionCompressionSize = 1; - -// 4-bytes specifying the size of the actual payload. -constexpr int kPayloadLength = 4; - -// Minimum size of the returned response in bytes. -// TODO: b/348613920 - Move framing utils to the common repo, and as part of -// that figure out if this needs to be inline with B&A. -inline constexpr size_t kMinResultBytes = 0; - -// Gets size of the complete payload including the preamble expected by -// android, which is: 1 byte (containing version, compression details), 4 bytes -// indicating the length of the actual encoded response and any other padding -// required to make the complete payload a power of 2. -size_t GetEncodedDataSize(size_t encapsulated_payload_size) { - size_t total_payload_size = - kVersionCompressionSize + kPayloadLength + encapsulated_payload_size; - // Ensure that the payload size is a power of 2. - return std::max(absl::bit_ceil(total_payload_size), kMinResultBytes); -} - -} // namespace kv_server diff --git a/components/data_server/request_handler/get_values_adapter.cc b/components/data_server/request_handler/get_values_adapter.cc index 72972d0d..219ff94d 100644 --- a/components/data_server/request_handler/get_values_adapter.cc +++ b/components/data_server/request_handler/get_values_adapter.cc @@ -24,7 +24,8 @@ #include "absl/log/log.h" #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" -#include "components/data_server/request_handler/v2_response_data.pb.h" +#include "components/data_server/request_handler/content_type/encoder.h" +#include "components/errors/error_tag.h" #include "google/protobuf/util/json_util.h" #include "public/api_schema.pb.h" #include "public/applications/pa/api_overlay.pb.h" @@ -34,6 +35,13 @@ namespace kv_server { namespace { + +enum class ErrorTag : int { + kInvalidNumberOfTagsError = 1, + kNoNamespaceTagsFoundError = 2, + kNoSinglePartitionInResponseError = 3 +}; + using google::protobuf::RepeatedPtrField; using google::protobuf::Struct; using google::protobuf::Value; @@ -120,8 +128,9 @@ void ProcessKeyValues( // Find the namespace tag that is paired with the "custom" tag. absl::StatusOr FindNamespace(RepeatedPtrField tags) { if (tags.size() != 2) { - return absl::InvalidArgumentError( - absl::StrCat("Expected 2 tags, found ", tags.size())); + return StatusWithErrorTag(absl::InvalidArgumentError(absl::StrCat( + "Expected 2 tags, found ", tags.size())), + __FILE__, ErrorTag::kInvalidNumberOfTagsError); } bool has_custom_tag = false; @@ -137,7 +146,9 @@ absl::StatusOr FindNamespace(RepeatedPtrField tags) { if (has_custom_tag) { return maybe_namespace_tag; } - return absl::InvalidArgumentError("No namespace tags found"); + return StatusWithErrorTag( + absl::InvalidArgumentError("No namespace tags found"), __FILE__, + ErrorTag::kNoNamespaceTagsFoundError); } void ProcessKeyGroupOutput(application_pa::KeyGroupOutput key_group_output, @@ -170,14 +181,17 @@ void ProcessKeyGroupOutput(application_pa::KeyGroupOutput key_group_output, } // Converts a v2 response into v1 response. -absl::Status ConvertToV1Response(const v2::GetValuesResponse& v2_response, +absl::Status ConvertToV1Response(RequestContextFactory& request_context_factory, + const v2::GetValuesResponse& v2_response, v1::GetValuesResponse& v1_response) { if (!v2_response.has_single_partition()) { // This should not happen. V1 request always maps to 1 partition so the // output should always have 1 partition. - return absl::InternalError( - "Bug in KV server! response does not have single_partition set for V1 " - "response."); + return StatusWithErrorTag( + absl::InternalError("Bug in KV server! response does not have " + "single_partition set for V1 " + "response."), + __FILE__, ErrorTag::kNoSinglePartitionInResponseError); } if (v2_response.single_partition().has_status()) { return absl::Status(static_cast( @@ -187,9 +201,15 @@ absl::Status ConvertToV1Response(const v2::GetValuesResponse& v2_response, const std::string& string_output = v2_response.single_partition().string_output(); // string_output should be a JSON object - PS_ASSIGN_OR_RETURN(application_pa::KeyGroupOutputs outputs, - application_pa::KeyGroupOutputsFromJson(string_output)); - for (const auto& key_group_output : outputs.key_group_outputs()) { + PS_VLOG(7, request_context_factory.Get().GetPSLogContext()) + << "Received v2 response: " << v2_response.DebugString(); + const auto outputs = application_pa::PartitionOutputFromJson(string_output); + if (!outputs.ok()) { + PS_LOG(ERROR, request_context_factory.Get().GetPSLogContext()) + << outputs.status(); + return outputs.status(); + } + for (const auto& key_group_output : outputs->key_group_outputs()) { ProcessKeyGroupOutput(key_group_output, v1_response); } @@ -213,9 +233,11 @@ class GetValuesAdapterImpl : public GetValuesAdapter { << " to v2 request " << v2_request.DebugString(); v2::GetValuesResponse v2_response; ExecutionMetadata execution_metadata; - if (auto status = - v2_handler_->GetValues(request_context_factory, v2_request, - &v2_response, execution_metadata); + auto v2_codec = + V2EncoderDecoder::Create(V2EncoderDecoder::ContentType::kJson); + if (auto status = v2_handler_->GetValues( + request_context_factory, v2_request, &v2_response, + execution_metadata, /*single_partition_use_case=*/true, *v2_codec); !status.ok()) { return status; } @@ -227,7 +249,7 @@ class GetValuesAdapterImpl : public GetValuesAdapter { PS_VLOG(7, request_context_factory.Get().GetPSLogContext()) << "Received v2 response: " << v2_response.DebugString(); return privacy_sandbox::server_common::FromAbslStatus( - ConvertToV1Response(v2_response, v1_response)); + ConvertToV1Response(request_context_factory, v2_response, v1_response)); } private: diff --git a/components/data_server/request_handler/get_values_adapter_test.cc b/components/data_server/request_handler/get_values_adapter_test.cc index 71022f5e..36467e4c 100644 --- a/components/data_server/request_handler/get_values_adapter_test.cc +++ b/components/data_server/request_handler/get_values_adapter_test.cc @@ -79,7 +79,7 @@ TEST_F(GetValuesAdapterTest, EmptyRequestReturnsEmptyResponse) { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); EXPECT_THAT(v1_response, EqualsProto(v1_expected)); @@ -109,7 +109,7 @@ data { } })", &arg); - application_pa::KeyGroupOutputs key_group_outputs; + application_pa::PartitionOutput partition_output; TextFormat::ParseFromString(R"( key_group_outputs: { tags: "custom" @@ -132,12 +132,12 @@ data { } } )", - &key_group_outputs); + &partition_output); EXPECT_CALL(mock_udf_client_, ExecuteCode(testing::_, EqualsProto(udf_metadata), testing::ElementsAre(EqualsProto(arg)), testing::_)) .WillOnce(Return( - application_pa::KeyGroupOutputsToJson(key_group_outputs).value())); + application_pa::PartitionOutputToJson(partition_output).value())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1"); @@ -145,7 +145,7 @@ data { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString( R"pb( @@ -194,7 +194,7 @@ data { } })", &arg); - application_pa::KeyGroupOutputs key_group_outputs; + application_pa::PartitionOutput partition_output; TextFormat::ParseFromString(R"( key_group_outputs: { tags: "custom" @@ -217,19 +217,19 @@ data { } } )", - &key_group_outputs); + &partition_output); EXPECT_CALL(mock_udf_client_, ExecuteCode(testing::_, EqualsProto(udf_metadata), testing::ElementsAre(EqualsProto(arg)), testing::_)) .WillOnce(Return( - application_pa::KeyGroupOutputsToJson(key_group_outputs).value())); + application_pa::PartitionOutputToJson(partition_output).value())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1,key2"); v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString( R"pb( @@ -292,7 +292,7 @@ data { } })", &arg2); - application_pa::KeyGroupOutputs key_group_outputs; + application_pa::PartitionOutput partition_output; TextFormat::ParseFromString(R"( key_group_outputs: { tags: "custom" @@ -319,14 +319,14 @@ data { } } )", - &key_group_outputs); + &partition_output); EXPECT_CALL( mock_udf_client_, ExecuteCode(testing::_, EqualsProto(udf_metadata), testing::ElementsAre(EqualsProto(arg1), EqualsProto(arg2)), testing::_)) .WillOnce(Return( - application_pa::KeyGroupOutputsToJson(key_group_outputs).value())); + application_pa::PartitionOutputToJson(partition_output).value())); v1::GetValuesRequest v1_request; v1_request.add_render_urls("key1"); @@ -334,7 +334,7 @@ data { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb( render_urls { @@ -383,7 +383,7 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputWithEmptyKVsReturnsOk) { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()) << status.error_message(); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); EXPECT_THAT(v1_response, EqualsProto(v1_expected)); @@ -405,7 +405,7 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputWithInvalidNamespaceTagIsIgnored) { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()) << status.error_message(); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); EXPECT_THAT(v1_response, EqualsProto(v1_expected)); @@ -427,7 +427,7 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputWithNoCustomTagIsIgnored) { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); EXPECT_THAT(v1_response, EqualsProto(v1_expected)); @@ -449,7 +449,7 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputWithNoNamespaceTagIsIgnored) { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); EXPECT_THAT(v1_response, EqualsProto(v1_expected)); @@ -476,7 +476,7 @@ TEST_F(GetValuesAdapterTest, v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()) << status.error_message(); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb( keys { @@ -511,7 +511,7 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputHasDifferentValueTypesReturnsOk) { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString( R"pb( @@ -592,7 +592,7 @@ TEST_F(GetValuesAdapterTest, ValueWithStatusSuccess) { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString( R"pb( @@ -648,7 +648,7 @@ data { } })", &arg); - application_pa::KeyGroupOutputs key_group_outputs; + application_pa::PartitionOutput partition_output; TextFormat::ParseFromString(R"( key_group_outputs: { tags: "custom" @@ -671,12 +671,12 @@ data { } } )", - &key_group_outputs); + &partition_output); EXPECT_CALL(mock_udf_client_, ExecuteCode(testing::_, EqualsProto(udf_metadata), testing::ElementsAre(EqualsProto(arg)), testing::_)) .WillOnce(Return( - application_pa::KeyGroupOutputsToJson(key_group_outputs).value())); + application_pa::PartitionOutputToJson(partition_output).value())); v1::GetValuesRequest v1_request; v1_request.add_interest_group_names("interestGroup1"); @@ -684,7 +684,7 @@ data { v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, v1_request, v1_response); - EXPECT_TRUE(status.ok()); + ASSERT_TRUE(status.ok()) << status.error_message(); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString( R"pb( diff --git a/components/data_server/request_handler/get_values_v2_handler.cc b/components/data_server/request_handler/get_values_v2_handler.cc index 2268a038..d05266c2 100644 --- a/components/data_server/request_handler/get_values_v2_handler.cc +++ b/components/data_server/request_handler/get_values_v2_handler.cc @@ -24,24 +24,27 @@ #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" -#include "components/data_server/request_handler/framing_utils.h" +#include "components/data/converters/cbor_converter.h" +#include "components/data_server/request_handler/encryption/ohttp_server_encryptor.h" #include "components/data_server/request_handler/get_values_v2_status.h" -#include "components/data_server/request_handler/ohttp_server_encryptor.h" #include "components/telemetry/server_definition.h" #include "google/protobuf/util/json_util.h" #include "grpcpp/grpcpp.h" +#include "nlohmann/json.hpp" +#include "public/applications/pa/response_utils.h" #include "public/base_types.pb.h" #include "public/constants.h" #include "public/query/v2/get_values_v2.grpc.pb.h" -#include "quiche/binary_http/binary_http_message.h" #include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" #include "quiche/oblivious_http/oblivious_http_gateway.h" #include "src/communication/encoding_utils.h" +#include "src/communication/framing_utils.h" #include "src/telemetry/telemetry.h" #include "src/util/status_macro/status_macros.h" namespace kv_server { namespace { +using google::protobuf::RepeatedPtrField; using google::protobuf::util::JsonStringToMessage; using google::protobuf::util::MessageToJsonString; using grpc::StatusCode; @@ -50,23 +53,60 @@ using v2::GetValuesHttpRequest; using v2::KeyValueService; using v2::ObliviousGetValuesRequest; -const std::string_view kOHTTPResponseContentType = "message/ohttp-res"; constexpr std::string_view kAcceptEncodingHeader = "accept-encoding"; constexpr std::string_view kContentEncodingHeader = "content-encoding"; constexpr std::string_view kBrotliAlgorithmHeader = "br"; +constexpr std::string_view kIsPas = "is_pas"; -CompressionGroupConcatenator::CompressionType GetResponseCompressionType( - const std::vector& headers) { - for (const quiche::BinaryHttpMessage::Field& header : headers) { - if (absl::AsciiStrToLower(header.name) != kAcceptEncodingHeader) continue; - // TODO(b/278271389): Right now for simplicity we support Accept-Encoding: - // br - if (absl::AsciiStrToLower(header.value) == kBrotliAlgorithmHeader) { - return CompressionGroupConcatenator::CompressionType::kBrotli; +absl::Status GetCompressionGroupContentAsJsonList( + const std::vector& partition_output_strings, + std::string& content, + const RequestContextFactory& request_context_factory) { + nlohmann::json json_partition_output_list = nlohmann::json::array(); + for (auto&& partition_output_string : partition_output_strings) { + auto partition_output_json = + nlohmann::json::parse(partition_output_string, nullptr, + /*allow_exceptions=*/false, + /*ignore_comments=*/true); + if (partition_output_json.is_discarded()) { + PS_VLOG(2, request_context_factory.Get().GetPSLogContext()) + << "json parse failed for " << partition_output_string; + continue; } + json_partition_output_list.emplace_back(partition_output_json); } - return CompressionGroupConcatenator::CompressionType::kUncompressed; + if (json_partition_output_list.size() == 0) { + return absl::InvalidArgumentError( + "Converting partition outputs to JSON returned empty list"); + } + content = json_partition_output_list.dump(); + return absl::OkStatus(); } + +absl::Status GetCompressionGroupContentAsCborList( + std::vector& partition_output_strings, std::string& content, + const RequestContextFactory& request_context_factory) { + RepeatedPtrField partition_outputs; + for (auto& partition_output_string : partition_output_strings) { + auto partition_output = + application_pa::PartitionOutputFromJson(partition_output_string); + if (partition_output.ok()) { + *partition_outputs.Add() = partition_output.value(); + } else { + PS_VLOG(2, request_context_factory.Get().GetPSLogContext()) + << partition_output.status(); + } + } + + const auto cbor_string = PartitionOutputsCborEncode(partition_outputs); + if (!cbor_string.ok()) { + PS_VLOG(2, request_context_factory.Get().GetPSLogContext()) + << "CBOR encode failed for partition outputs"; + } + content = cbor_string.value(); + return absl::OkStatus(); +} + } // namespace grpc::Status GetValuesV2Handler::GetValuesHttp( @@ -74,141 +114,35 @@ grpc::Status GetValuesV2Handler::GetValuesHttp( const std::multimap& headers, const GetValuesHttpRequest& request, google::api::HttpBody* response, ExecutionMetadata& execution_metadata) const { + auto v2_codec = V2EncoderDecoder::Create(V2EncoderDecoder::GetContentType( + headers, V2EncoderDecoder::ContentType::kJson)); return FromAbslStatus( GetValuesHttp(request_context_factory, request.raw_body().data(), - *response->mutable_data(), execution_metadata, - GetContentType(headers, ContentType::kJson))); + *response->mutable_data(), execution_metadata, *v2_codec)); } absl::Status GetValuesV2Handler::GetValuesHttp( RequestContextFactory& request_context_factory, std::string_view request, std::string& response, ExecutionMetadata& execution_metadata, - ContentType content_type) const { - v2::GetValuesRequest request_proto; - if (content_type == ContentType::kJson) { - PS_RETURN_IF_ERROR( - google::protobuf::util::JsonStringToMessage(request, &request_proto)); - } else { // proto - if (!request_proto.ParseFromString(request)) { - auto error_message = - "Cannot parse request as a valid serilized proto object."; - PS_VLOG(4, request_context_factory.Get().GetPSLogContext()) - << error_message; - return absl::InvalidArgumentError(error_message); - } - } + const V2EncoderDecoder& v2_codec) const { + PS_ASSIGN_OR_RETURN(v2::GetValuesRequest request_proto, + v2_codec.DecodeToV2GetValuesRequestProto(request)); PS_VLOG(9) << "Converted the http request to proto: " << request_proto.DebugString(); v2::GetValuesResponse response_proto; - PS_RETURN_IF_ERROR( - GetValues(request_context_factory, request_proto, &response_proto, - execution_metadata)); - if (content_type == ContentType::kJson) { - return MessageToJsonString(response_proto, &response); - } - // content_type == proto - if (!response_proto.SerializeToString(&response)) { - auto error_message = "Cannot serialize the response as a proto."; - PS_VLOG(4, request_context_factory.Get().GetPSLogContext()) - << error_message; - return absl::InvalidArgumentError(error_message); - } + PS_RETURN_IF_ERROR(GetValues( + request_context_factory, request_proto, &response_proto, + execution_metadata, IsSinglePartitionUseCase(request_proto), v2_codec)); + PS_ASSIGN_OR_RETURN(response, + v2_codec.EncodeV2GetValuesResponse(response_proto)); return absl::OkStatus(); } -grpc::Status GetValuesV2Handler::BinaryHttpGetValues( - RequestContextFactory& request_context_factory, - const v2::BinaryHttpGetValuesRequest& bhttp_request, - google::api::HttpBody* response, - ExecutionMetadata& execution_metadata) const { - return FromAbslStatus(BinaryHttpGetValues( - request_context_factory, bhttp_request.raw_body().data(), - *response->mutable_data(), execution_metadata)); -} - -GetValuesV2Handler::ContentType GetValuesV2Handler::GetContentType( - const quiche::BinaryHttpRequest& deserialized_req) const { - for (const auto& header : deserialized_req.GetHeaderFields()) { - if (absl::AsciiStrToLower(header.name) == kContentTypeHeader && - absl::AsciiStrToLower(header.value) == - kContentEncodingProtoHeaderValue) { - return ContentType::kProto; - } - } - return ContentType::kJson; -} - -GetValuesV2Handler::ContentType GetValuesV2Handler::GetContentType( - const std::multimap& headers, - ContentType default_content_type) const { - for (const auto& [header_name, header_value] : headers) { - if (absl::AsciiStrToLower(std::string_view( - header_name.data(), header_name.size())) == kKVContentTypeHeader) { - if (absl::AsciiStrToLower( - std::string_view(header_value.data(), header_value.size())) == - kContentEncodingBhttpHeaderValue) { - return ContentType::kBhttp; - } else if (absl::AsciiStrToLower(std::string_view(header_value.data(), - header_value.size())) == - kContentEncodingProtoHeaderValue) { - return ContentType::kProto; - } else if (absl::AsciiStrToLower(std::string_view(header_value.data(), - header_value.size())) == - kContentEncodingJsonHeaderValue) { - return ContentType::kJson; - } - } - } - return default_content_type; -} - -absl::StatusOr -GetValuesV2Handler::BuildSuccessfulGetValuesBhttpResponse( - RequestContextFactory& request_context_factory, - std::string_view bhttp_request_body, - ExecutionMetadata& execution_metadata) const { - PS_VLOG(9) << "Handling the binary http layer"; - PS_ASSIGN_OR_RETURN(quiche::BinaryHttpRequest deserialized_req, - quiche::BinaryHttpRequest::Create(bhttp_request_body), - _ << "Failed to deserialize binary http request"); - PS_VLOG(3) << "BinaryHttpGetValues request: " - << deserialized_req.DebugString(); - std::string response; - auto content_type = GetContentType(deserialized_req); - PS_RETURN_IF_ERROR(GetValuesHttp(request_context_factory, - deserialized_req.body(), response, - execution_metadata, content_type)); - quiche::BinaryHttpResponse bhttp_response(200); - if (content_type == ContentType::kProto) { - bhttp_response.AddHeaderField({ - .name = std::string(kContentTypeHeader), - .value = std::string(kContentEncodingProtoHeaderValue), - }); - } - bhttp_response.set_body(std::move(response)); - return bhttp_response; -} - -absl::Status GetValuesV2Handler::BinaryHttpGetValues( - RequestContextFactory& request_context_factory, - std::string_view bhttp_request_body, std::string& response, - ExecutionMetadata& execution_metadata) const { - static quiche::BinaryHttpResponse const* kDefaultBhttpResponse = - new quiche::BinaryHttpResponse(500); - const quiche::BinaryHttpResponse* bhttp_response = kDefaultBhttpResponse; - absl::StatusOr maybe_successful_bhttp_response = - BuildSuccessfulGetValuesBhttpResponse( - request_context_factory, bhttp_request_body, execution_metadata); - if (maybe_successful_bhttp_response.ok()) { - bhttp_response = &(maybe_successful_bhttp_response.value()); - } - PS_ASSIGN_OR_RETURN(auto serialized_bhttp_response, - bhttp_response->Serialize()); - - response = std::move(serialized_bhttp_response); - PS_VLOG(9) << "BinaryHttpGetValues finished successfully"; - return absl::OkStatus(); +bool IsSinglePartitionUseCase(const v2::GetValuesRequest& request) { + const auto is_pas_field = request.metadata().fields().find(kIsPas); + return (is_pas_field != request.metadata().fields().end() && + is_pas_field->second.string_value() == "true"); } grpc::Status GetValuesV2Handler::ObliviousGetValues( @@ -219,31 +153,29 @@ grpc::Status GetValuesV2Handler::ObliviousGetValues( ExecutionMetadata& execution_metadata) const { PS_VLOG(9) << "Received ObliviousGetValues request. "; OhttpServerEncryptor encryptor(key_fetcher_manager_); - auto maybe_plain_text = + auto maybe_padded_plain_text = encryptor.DecryptRequest(oblivious_request.raw_body().data(), request_context_factory.Get().GetPSLogContext()); - if (!maybe_plain_text.ok()) { - return FromAbslStatus(maybe_plain_text.status()); + if (!maybe_padded_plain_text.ok()) { + return FromAbslStatus(maybe_padded_plain_text.status()); } std::string response; - auto content_type = GetContentType(headers, ContentType::kBhttp); - if (content_type == ContentType::kBhttp) { - // Now process the binary http request - if (const auto s = - BinaryHttpGetValues(request_context_factory, *maybe_plain_text, - response, execution_metadata); - !s.ok()) { - return FromAbslStatus(s); - } - } else { - if (const auto s = - GetValuesHttp(request_context_factory, *maybe_plain_text, response, - execution_metadata, content_type); - !s.ok()) { - return FromAbslStatus(s); - } + absl::StatusOr + decoded_request = privacy_sandbox::server_common::DecodeRequestPayload( + *maybe_padded_plain_text); + if (!decoded_request.ok()) { + return FromAbslStatus(decoded_request.status()); } - auto encoded_data_size = GetEncodedDataSize(response.size()); + auto v2_codec = V2EncoderDecoder::Create(V2EncoderDecoder::GetContentType( + headers, V2EncoderDecoder::ContentType::kCbor)); + if (const auto s = GetValuesHttp(request_context_factory, + std::move(decoded_request->compressed_data), + response, execution_metadata, *v2_codec); + !s.ok()) { + return FromAbslStatus(s); + } + auto encoded_data_size = privacy_sandbox::server_common::GetEncodedDataSize( + response.size(), kMinResponsePaddingBytes); auto maybe_padded_response = privacy_sandbox::server_common::EncodeResponsePayload( privacy_sandbox::server_common::CompressionType::kUncompressed, @@ -259,7 +191,7 @@ grpc::Status GetValuesV2Handler::ObliviousGetValues( absl::StrCat(encrypted_response.status().code(), " : ", encrypted_response.status().message())); } - oblivious_response->set_content_type(std::string(kOHTTPResponseContentType)); + oblivious_response->set_content_type(std::string(kKVOhttpResponseLabel)); oblivious_response->set_data(*encrypted_response); return grpc::Status::OK; } @@ -273,10 +205,13 @@ absl::Status GetValuesV2Handler::ProcessOnePartition( resp_partition.set_id(req_partition.id()); UDFExecutionMetadata udf_metadata; *udf_metadata.mutable_request_metadata() = req_metadata; + if (!req_partition.metadata().fields().empty()) { + *udf_metadata.mutable_partition_metadata() = req_partition.metadata(); + } - const auto maybe_output_string = udf_client_.ExecuteCode( - std::move(request_context_factory), std::move(udf_metadata), - req_partition.arguments(), execution_metadata); + const auto maybe_output_string = + udf_client_.ExecuteCode(request_context_factory, std::move(udf_metadata), + req_partition.arguments(), execution_metadata); if (!maybe_output_string.ok()) { resp_partition.mutable_status()->set_code( static_cast(maybe_output_string.status().code())); @@ -290,26 +225,78 @@ absl::Status GetValuesV2Handler::ProcessOnePartition( return absl::OkStatus(); } +absl::Status GetValuesV2Handler::ProcessMultiplePartitions( + const RequestContextFactory& request_context_factory, + const v2::GetValuesRequest& request, v2::GetValuesResponse& response, + ExecutionMetadata& execution_metadata, + const V2EncoderDecoder& v2_codec) const { + absl::flat_hash_map>> + compression_group_map; + for (const auto& partition : request.partitions()) { + int32_t compression_group_id = partition.compression_group_id(); + v2::ResponsePartition resp_partition; + if (auto single_partition_status = + ProcessOnePartition(request_context_factory, request.metadata(), + partition, resp_partition, execution_metadata); + single_partition_status.ok()) { + compression_group_map[compression_group_id].emplace_back( + partition.id(), std::move(resp_partition.string_output())); + } else { + PS_VLOG(3, request_context_factory.Get().GetPSLogContext()) + << "Failed to process partition: " << single_partition_status; + } + } + + // The content of each compressed blob is a CBOR/JSON list of partition + // outputs or a V2CompressionGroup protobuf message. + for (auto& [group_id, partition_output_pairs] : compression_group_map) { + const auto maybe_content = v2_codec.EncodePartitionOutputs( + partition_output_pairs, request_context_factory); + if (!maybe_content.ok()) { + PS_VLOG(3, request_context_factory.Get().GetPSLogContext()) + << maybe_content.status(); + continue; + } + // TODO(b/355464083): Compress the compression_group content + auto* compression_group = response.add_compression_groups(); + compression_group->set_content(std::move(*maybe_content)); + compression_group->set_compression_group_id(group_id); + } + if (response.compression_groups().empty()) { + return absl::InvalidArgumentError("All partitions failed."); + } + return absl::OkStatus(); +} + grpc::Status GetValuesV2Handler::GetValues( RequestContextFactory& request_context_factory, const v2::GetValuesRequest& request, v2::GetValuesResponse* response, - ExecutionMetadata& execution_metadata) const { + ExecutionMetadata& execution_metadata, bool single_partition_use_case, + const V2EncoderDecoder& v2_codec) const { PS_VLOG(9) << "Update log context " << request.log_context() << ";" << request.consented_debug_config(); - request_context_factory.UpdateLogContext(request.log_context(), - request.consented_debug_config()); - if (request.partitions().size() == 1) { - const auto partition_status = ProcessOnePartition( - request_context_factory, request.metadata(), request.partitions(0), - *response->mutable_single_partition(), execution_metadata); - return GetExternalStatusForV2(partition_status); - } + request_context_factory.UpdateLogContext( + request.log_context(), request.consented_debug_config(), + [response]() { return response->mutable_debug_info(); }); if (request.partitions().empty()) { return grpc::Status(StatusCode::INTERNAL, "At least 1 partition is required"); } - return grpc::Status(StatusCode::UNIMPLEMENTED, - "Multiple partition support is not implemented"); + if (single_partition_use_case) { + if (request.partitions().size() > 1) { + return grpc::Status(StatusCode::UNIMPLEMENTED, + "This use case only accepts single partitions, but " + "multiple partitions were found."); + } + const auto response_status = ProcessOnePartition( + request_context_factory, request.metadata(), request.partitions(0), + *response->mutable_single_partition(), execution_metadata); + return GetExternalStatusForV2(response_status); + } + const auto response_status = + ProcessMultiplePartitions(request_context_factory, request, *response, + execution_metadata, v2_codec); + return GetExternalStatusForV2(response_status); } } // namespace kv_server diff --git a/components/data_server/request_handler/get_values_v2_handler.h b/components/data_server/request_handler/get_values_v2_handler.h index 51f198d7..963ef2be 100644 --- a/components/data_server/request_handler/get_values_v2_handler.h +++ b/components/data_server/request_handler/get_values_v2_handler.h @@ -27,7 +27,8 @@ #include "absl/status/statusor.h" #include "absl/strings/escaping.h" #include "components/data_server/cache/cache.h" -#include "components/data_server/request_handler/compression.h" +#include "components/data_server/request_handler/compression/compression.h" +#include "components/data_server/request_handler/content_type/encoder.h" #include "components/telemetry/server_definition.h" #include "components/udf/udf_client.h" #include "components/util/request_context.h" @@ -38,22 +39,11 @@ namespace kv_server { -// Content Type Header Name. Can be set for bhttp request to proto or json +// Content Type Header Name. Can be set for ohttp request to proto or json // values below. inline constexpr std::string_view kContentTypeHeader = "content-type"; -// Header in clear text http request/response that indicates which format is -// used by the payload. The more common "Content-Type" header is not used -// because most importantly that has CORS implications, and in addition, may not -// be forwarded by Envoy to gRPC. -inline constexpr std::string_view kKVContentTypeHeader = "kv-content-type"; -// Protobuf Content Type Header Value. -inline constexpr std::string_view kContentEncodingProtoHeaderValue = - "application/protobuf"; -// Json Content Type Header Value. -inline constexpr std::string_view kContentEncodingJsonHeaderValue = - "application/json"; -inline constexpr std::string_view kContentEncodingBhttpHeaderValue = - "message/bhttp"; + +bool IsSinglePartitionUseCase(const v2::GetValuesRequest& request); // Handles the request family of *GetValues. // See the Service proto definition for details. @@ -81,13 +71,9 @@ class GetValuesV2Handler { grpc::Status GetValues(RequestContextFactory& request_context_factory, const v2::GetValuesRequest& request, v2::GetValuesResponse* response, - ExecutionMetadata& execution_metadata) const; - - grpc::Status BinaryHttpGetValues( - RequestContextFactory& request_context_factory, - const v2::BinaryHttpGetValuesRequest& request, - google::api::HttpBody* response, - ExecutionMetadata& execution_metadata) const; + ExecutionMetadata& execution_metadata, + bool single_partition_use_case, + const V2EncoderDecoder& v2_codec) const; // Supports requests encrypted with a fixed key for debugging/demoing. // X25519 Secret key (priv key). @@ -102,6 +88,8 @@ class GetValuesV2Handler { // KDF: HKDF-SHA256 0x0001 // AEAD: AES-256-GCM 0X0002 // (https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md#encryption) + // + // The default content type for OHTTP is cbor. grpc::Status ObliviousGetValues( RequestContextFactory& request_context_factory, const std::multimap& headers, @@ -110,40 +98,11 @@ class GetValuesV2Handler { ExecutionMetadata& execution_metadata) const; private: - enum class ContentType { - kJson = 0, - kProto, - kBhttp, - }; - ContentType GetContentType( - const quiche::BinaryHttpRequest& deserialized_req) const; - - ContentType GetContentType( - const std::multimap& headers, - ContentType default_content_type) const; - - absl::Status GetValuesHttp( - RequestContextFactory& request_context_factory, std::string_view request, - std::string& json_response, ExecutionMetadata& execution_metadata, - ContentType content_type = ContentType::kJson) const; - - // On success, returns a BinaryHttpResponse with a successful response. The - // reason that this is a separate function is so that the error status - // returned from here can be encoded as a BinaryHTTP response code. So even if - // this function fails, the final grpc code may still be ok. - absl::StatusOr - BuildSuccessfulGetValuesBhttpResponse( - RequestContextFactory& request_context_factory, - std::string_view bhttp_request_body, - ExecutionMetadata& execution_metadata) const; - - // Returns error only if the response cannot be serialized into Binary HTTP - // response. For all other failures, the error status will be inside the - // Binary HTTP message. - absl::Status BinaryHttpGetValues( - RequestContextFactory& request_context_factory, - std::string_view bhttp_request_body, std::string& response, - ExecutionMetadata& execution_metadata) const; + absl::Status GetValuesHttp(RequestContextFactory& request_context_factory, + std::string_view request, + std::string& json_response, + ExecutionMetadata& execution_metadata, + const V2EncoderDecoder& v2_codec) const; // Invokes UDF to process one partition. absl::Status ProcessOnePartition( @@ -153,6 +112,13 @@ class GetValuesV2Handler { v2::ResponsePartition& resp_partition, ExecutionMetadata& execution_metadata) const; + // Invokes UDF to process multiple partitions. + absl::Status ProcessMultiplePartitions( + const RequestContextFactory& request_context_factory, + const v2::GetValuesRequest& request, v2::GetValuesResponse& response, + ExecutionMetadata& execution_metadata, + const V2EncoderDecoder& v2_codec) const; + const UdfClient& udf_client_; std::function create_compression_group_concatenator_; diff --git a/components/data_server/request_handler/get_values_v2_handler_test.cc b/components/data_server/request_handler/get_values_v2_handler_test.cc index f9fd95a3..003b705b 100644 --- a/components/data_server/request_handler/get_values_v2_handler_test.cc +++ b/components/data_server/request_handler/get_values_v2_handler_test.cc @@ -20,8 +20,10 @@ #include #include "absl/log/log.h" +#include "components/data/converters/cbor_converter.h" #include "components/data_server/cache/cache.h" #include "components/data_server/cache/mocks.h" +#include "components/data_server/request_handler/content_type/json_encoder.h" #include "components/udf/mocks.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" @@ -31,12 +33,14 @@ #include "public/constants.h" #include "public/test_util/proto_matcher.h" #include "public/test_util/request_example.h" -#include "quiche/binary_http/binary_http_message.h" #include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" #include "quiche/oblivious_http/oblivious_http_client.h" #include "src/communication/encoding_utils.h" +#include "src/communication/framing_utils.h" #include "src/encryption/key_fetcher/fake_key_fetcher_manager.h" +#include "cbor.h" + namespace kv_server { namespace { @@ -46,16 +50,15 @@ using testing::_; using testing::Return; using testing::ReturnRef; using testing::UnorderedElementsAre; -using v2::BinaryHttpGetValuesRequest; using v2::GetValuesHttpRequest; using v2::ObliviousGetValuesRequest; enum class ProtocolType { kPlain = 0, - kBinaryHttp, kObliviousHttp, }; +// TODO(b/355434272): Refactor struct TestingParameters { ProtocolType protocol_type; const std::string_view content_type; @@ -63,9 +66,22 @@ struct TestingParameters { const bool is_consented; }; -class GetValuesHandlerTest - : public ::testing::Test, - public ::testing::WithParamInterface { +nlohmann::json GetPartitionOutputsInJson(const nlohmann::json& content_json) { + std::vector content_cbor = nlohmann::json::to_cbor(content_json); + std::string content_cbor_string = + std::string(content_cbor.begin(), content_cbor.end()); + struct cbor_load_result result; + cbor_item_t* cbor_bytestring = cbor_load( + reinterpret_cast(content_cbor_string.data()), + content_cbor_string.size(), &result); + auto partition_output_cbor = cbor_bytestring_handle(cbor_bytestring); + auto cbor_bytestring_len = cbor_bytestring_length(cbor_bytestring); + return nlohmann::json::from_cbor(std::vector( + partition_output_cbor, partition_output_cbor + cbor_bytestring_len)); +} + +class BaseTest : public ::testing::Test, + public ::testing::WithParamInterface { protected: void SetUp() override { privacy_sandbox::server_common::log::ServerToken( @@ -83,6 +99,16 @@ class GetValuesHandlerTest return param.content_type == kContentEncodingProtoHeaderValue; } + bool IsJsonContent() { + auto param = GetParam(); + return param.content_type == kContentEncodingJsonHeaderValue; + } + + bool IsCborContent() { + auto param = GetParam(); + return param.content_type == kContentEncodingCborHeaderValue; + } + bool IsRequestExpectConsented() { auto param = GetParam(); return param.is_consented; @@ -110,118 +136,23 @@ class GetValuesHandlerTest std::string plain_request_body_; }; - class BHTTPRequest { - public: - explicit BHTTPRequest(PlainRequest plain_request, - bool is_protobuf_content) { - quiche::BinaryHttpRequest req_bhttp_layer({}); - if (is_protobuf_content) { - req_bhttp_layer.AddHeaderField({ - .name = std::string(kContentTypeHeader), - .value = std::string(kContentEncodingProtoHeaderValue), - }); - } - req_bhttp_layer.set_body(plain_request.RequestBody()); - auto maybe_serialized = req_bhttp_layer.Serialize(); - EXPECT_TRUE(maybe_serialized.ok()); - serialized_bhttp_request_ = *maybe_serialized; - } - - BinaryHttpGetValuesRequest Build() const { - BinaryHttpGetValuesRequest brequest; - brequest.mutable_raw_body()->set_data(serialized_bhttp_request_); - return brequest; - } - - const std::string& SerializedBHTTPRequest() const { - return serialized_bhttp_request_; - } - - private: - std::string serialized_bhttp_request_; - }; - - class BHTTPResponse { - public: - google::api::HttpBody& RawResponse() { return response_; } - int16_t ResponseCode(bool is_using_bhttp) const { - std::string response; - if (is_using_bhttp) { - response = response_.data(); - } else { - auto deframed_req = - privacy_sandbox::server_common::DecodeRequestPayload( - response_.data()); - EXPECT_TRUE(deframed_req.ok()) << deframed_req.status(); - response = deframed_req->compressed_data; - } - const absl::StatusOr maybe_res_bhttp_layer = - quiche::BinaryHttpResponse::Create(response); - EXPECT_TRUE(maybe_res_bhttp_layer.ok()) - << "quiche::BinaryHttpResponse::Create failed: " - << maybe_res_bhttp_layer.status(); - return maybe_res_bhttp_layer->status_code(); - } - - std::string Unwrap(bool is_protobuf_content, bool is_using_bhttp) const { - std::string response; - if (is_using_bhttp) { - response = response_.data(); - } else { - auto deframed_req = - privacy_sandbox::server_common::DecodeRequestPayload( - response_.data()); - EXPECT_TRUE(deframed_req.ok()) << deframed_req.status(); - response = deframed_req->compressed_data; - } - const absl::StatusOr maybe_res_bhttp_layer = - quiche::BinaryHttpResponse::Create(response); - EXPECT_TRUE(maybe_res_bhttp_layer.ok()) - << "quiche::BinaryHttpResponse::Create failed: " - << maybe_res_bhttp_layer.status(); - if (maybe_res_bhttp_layer->status_code() == 200 & is_protobuf_content) { - EXPECT_TRUE(HasHeader(*maybe_res_bhttp_layer, kContentTypeHeader, - kContentEncodingProtoHeaderValue)); - } - return std::string(maybe_res_bhttp_layer->body()); - } - - private: - bool HasHeader(const quiche::BinaryHttpResponse& response, - const std::string_view header_key, - const std::string_view header_value) const { - for (const auto& header : response.GetHeaderFields()) { - if (absl::AsciiStrToLower(header.name) == header_key && - absl::AsciiStrToLower(header.value) == header_value) { - return true; - } - } - return false; - } - - google::api::HttpBody response_; - }; - class OHTTPRequest; class OHTTPResponseUnwrapper { public: google::api::HttpBody& RawResponse() { return response_; } - BHTTPResponse Unwrap() { + std::string Unwrap() { uint8_t key_id = 64; auto maybe_config = quiche::ObliviousHttpHeaderKeyConfig::Create( key_id, kKEMParameter, kKDFParameter, kAEADParameter); EXPECT_TRUE(maybe_config.ok()); - - auto client = - quiche::ObliviousHttpClient::Create(public_key_, *maybe_config); - EXPECT_TRUE(client.ok()); auto decrypted_response = - client->DecryptObliviousHttpResponse(response_.data(), context_); - BHTTPResponse bhttp_response; - bhttp_response.RawResponse().set_data( + quiche::ObliviousHttpResponse::CreateClientObliviousResponse( + response_.data(), context_, kKVOhttpResponseLabel); + auto deframed_req = privacy_sandbox::server_common::DecodeRequestPayload( decrypted_response->GetPlaintextData()); - return bhttp_response; + EXPECT_TRUE(deframed_req.ok()) << deframed_req.status(); + return deframed_req->compressed_data; } private: @@ -238,8 +169,8 @@ class GetValuesHandlerTest class OHTTPRequest { public: - explicit OHTTPRequest(BHTTPRequest bhttp_request) - : bhttp_request_(std::move(bhttp_request)) {} + explicit OHTTPRequest(std::string raw_request) + : raw_request_(std::move(raw_request)) {} std::pair Build() const { // matches the test key pair, see common repo: @@ -248,12 +179,10 @@ class GetValuesHandlerTest auto maybe_config = quiche::ObliviousHttpHeaderKeyConfig::Create( key_id, kKEMParameter, kKDFParameter, kAEADParameter); EXPECT_TRUE(maybe_config.ok()); - - auto client = - quiche::ObliviousHttpClient::Create(public_key_, *maybe_config); - EXPECT_TRUE(client.ok()); - auto encrypted_req = client->CreateObliviousHttpRequest( - bhttp_request_.SerializedBHTTPRequest()); + auto encrypted_req = + quiche::ObliviousHttpRequest::CreateClientObliviousRequest( + raw_request_, public_key_, *std::move(maybe_config), + kKVOhttpRequestLabel); EXPECT_TRUE(encrypted_req.ok()); auto serialized_encrypted_req = encrypted_req->EncapsulateAndSerialize(); ObliviousGetValuesRequest ohttp_req; @@ -266,56 +195,85 @@ class GetValuesHandlerTest private: const std::string public_key_ = absl::HexStringToBytes(kTestPublicKey); - BHTTPRequest bhttp_request_; + std::string raw_request_; }; // For Non-plain protocols, test request and response data are converted // to/from the corresponding request/responses. grpc::Status GetValuesBasedOnProtocol( RequestContextFactory& request_context_factory, std::string request_body, - google::api::HttpBody* response, int16_t* bhttp_response_code, + google::api::HttpBody* response, int16_t* http_response_code, GetValuesV2Handler* handler) { PlainRequest plain_request(std::move(request_body)); - ExecutionMetadata execution_metadata; - std::multimap headers = { - {"kv-content-type", "application/json"}}; + ExecutionMetadata execution_metadata; + auto contentTypeHeader = std::string(kKVContentTypeHeader); + auto contentEncodingProtoHeaderValue = + std::string(kContentEncodingProtoHeaderValue); + auto contentEncodingJsonHeaderValue = + std::string(kContentEncodingJsonHeaderValue); + auto contentEncodingCborHeaderValue = + std::string(kContentEncodingCborHeaderValue); + std::multimap headers; if (IsUsing()) { - *bhttp_response_code = 200; - return handler->GetValuesHttp(request_context_factory, headers, - plain_request.Build(), response, - execution_metadata); - } - - BHTTPRequest bhttp_request(std::move(plain_request), IsProtobufContent()); - BHTTPResponse bresponse; - - if (IsUsing()) { - if (const auto s = handler->BinaryHttpGetValues( - request_context_factory, bhttp_request.Build(), - &bresponse.RawResponse(), execution_metadata); - !s.ok()) { - LOG(ERROR) << "BinaryHttpGetValues failed: " << s.error_message(); - return s; - } - *bhttp_response_code = bresponse.ResponseCode(true); - } else if (IsUsing()) { - OHTTPRequest ohttp_request(std::move(bhttp_request)); - // get ObliviousGetValuesRequest, OHTTPResponseUnwrapper - auto [request, response_unwrapper] = ohttp_request.Build(); - if (const auto s = handler->ObliviousGetValues( - request_context_factory, {{"kv-content-type", "message/bhttp"}}, - request, &response_unwrapper.RawResponse(), execution_metadata); - !s.ok()) { - LOG(ERROR) << "ObliviousGetValues failed: " << s.error_message(); - return s; + *http_response_code = 200; + headers.insert({ + contentTypeHeader, + contentEncodingJsonHeaderValue, + }); + const auto s = handler->GetValuesHttp(request_context_factory, headers, + plain_request.Build(), response, + execution_metadata); + if (!s.ok()) { + *http_response_code = s.error_code(); + LOG(ERROR) << "GetValuesHttp failed: " << s.error_message(); } - bresponse = response_unwrapper.Unwrap(); - *bhttp_response_code = bresponse.ResponseCode(false); + return s; + } + auto encoded_data_size = privacy_sandbox::server_common::GetEncodedDataSize( + plain_request.RequestBody().size(), kMinResponsePaddingBytes); + auto maybe_padded_request = + privacy_sandbox::server_common::EncodeResponsePayload( + privacy_sandbox::server_common::CompressionType::kUncompressed, + std::move(plain_request.RequestBody()), encoded_data_size); + if (!maybe_padded_request.ok()) { + LOG(ERROR) << "Padding failed: " + << maybe_padded_request.status().message(); + return privacy_sandbox::server_common::FromAbslStatus( + maybe_padded_request.status()); } - response->set_data(bresponse.Unwrap(IsProtobufContent(), - IsUsing())); + OHTTPRequest ohttp_request(*maybe_padded_request); + // get ObliviousGetValuesRequest, OHTTPResponseUnwrapper + auto [request, response_unwrapper] = ohttp_request.Build(); + if (IsProtobufContent()) { + headers.insert({ + contentTypeHeader, + contentEncodingProtoHeaderValue, + }); + } + if (IsJsonContent()) { + headers.insert({ + contentTypeHeader, + contentEncodingJsonHeaderValue, + }); + } + if (IsCborContent()) { + headers.insert({ + contentTypeHeader, + contentEncodingCborHeaderValue, + }); + } + if (const auto s = handler->ObliviousGetValues( + request_context_factory, headers, request, + &response_unwrapper.RawResponse(), execution_metadata); + !s.ok()) { + *http_response_code = s.error_code(); + LOG(ERROR) << "ObliviousGetValues failed: " << s.error_message(); + return s; + } + response->set_data(response_unwrapper.Unwrap()); + *http_response_code = 200; return grpc::Status::OK; } @@ -324,6 +282,9 @@ class GetValuesHandlerTest fake_key_fetcher_manager_; }; +class GetValuesHandlerTest : public BaseTest {}; +class GetValuesHandlerMultiplePartitionsTest : public BaseTest {}; + INSTANTIATE_TEST_SUITE_P( GetValuesHandlerTest, GetValuesHandlerTest, testing::Values( @@ -347,19 +308,19 @@ INSTANTIATE_TEST_SUITE_P( .is_consented = true, }, TestingParameters{ - .protocol_type = ProtocolType::kBinaryHttp, + .protocol_type = ProtocolType::kObliviousHttp, .content_type = kContentEncodingJsonHeaderValue, .core_request_body = kv_server::kExampleV2RequestInJson, .is_consented = false, }, TestingParameters{ - .protocol_type = ProtocolType::kBinaryHttp, + .protocol_type = ProtocolType::kObliviousHttp, .content_type = kContentEncodingJsonHeaderValue, .core_request_body = kv_server::kExampleConsentedV2RequestInJson, .is_consented = true, }, TestingParameters{ - .protocol_type = ProtocolType::kBinaryHttp, + .protocol_type = ProtocolType::kObliviousHttp, .content_type = kContentEncodingJsonHeaderValue, .core_request_body = kv_server::kExampleConsentedV2RequestWithLogContextInJson, @@ -367,59 +328,93 @@ INSTANTIATE_TEST_SUITE_P( }, TestingParameters{ .protocol_type = ProtocolType::kObliviousHttp, - .content_type = kContentEncodingJsonHeaderValue, + .content_type = kContentEncodingProtoHeaderValue, .core_request_body = kv_server::kExampleV2RequestInJson, .is_consented = false, }, TestingParameters{ .protocol_type = ProtocolType::kObliviousHttp, - .content_type = kContentEncodingJsonHeaderValue, + .content_type = kContentEncodingProtoHeaderValue, .core_request_body = kv_server::kExampleConsentedV2RequestInJson, .is_consented = true, }, TestingParameters{ .protocol_type = ProtocolType::kObliviousHttp, - .content_type = kContentEncodingJsonHeaderValue, + .content_type = kContentEncodingProtoHeaderValue, .core_request_body = kv_server::kExampleConsentedV2RequestWithLogContextInJson, .is_consented = true, - }, + })); + +INSTANTIATE_TEST_SUITE_P( + GetValuesHandlerMultiplePartitionsTest, + GetValuesHandlerMultiplePartitionsTest, + testing::Values( TestingParameters{ - .protocol_type = ProtocolType::kBinaryHttp, - .content_type = kContentEncodingProtoHeaderValue, - .core_request_body = kv_server::kExampleV2RequestInJson, + .protocol_type = ProtocolType::kPlain, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server::kV2RequestMultiplePartitionsInJson, .is_consented = false, }, TestingParameters{ - .protocol_type = ProtocolType::kBinaryHttp, - .content_type = kContentEncodingProtoHeaderValue, - .core_request_body = kv_server::kExampleConsentedV2RequestInJson, + .protocol_type = ProtocolType::kPlain, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = + kv_server::kConsentedV2RequestMultiplePartitionsInJson, .is_consented = true, }, TestingParameters{ - .protocol_type = ProtocolType::kBinaryHttp, - .content_type = kContentEncodingProtoHeaderValue, - .core_request_body = - kv_server::kExampleConsentedV2RequestWithLogContextInJson, + .protocol_type = ProtocolType::kPlain, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server:: + kConsentedV2RequestMultiplePartitionsWithLogContextInJson, .is_consented = true, }, TestingParameters{ .protocol_type = ProtocolType::kObliviousHttp, - .content_type = kContentEncodingProtoHeaderValue, - .core_request_body = kv_server::kExampleV2RequestInJson, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server::kV2RequestMultiplePartitionsInJson, .is_consented = false, }, TestingParameters{ .protocol_type = ProtocolType::kObliviousHttp, - .content_type = kContentEncodingProtoHeaderValue, - .core_request_body = kv_server::kExampleConsentedV2RequestInJson, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = + kv_server::kConsentedV2RequestMultiplePartitionsInJson, .is_consented = true, }, TestingParameters{ .protocol_type = ProtocolType::kObliviousHttp, - .content_type = kContentEncodingProtoHeaderValue, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server:: + kConsentedV2RequestMultiplePartitionsWithLogContextInJson, + .is_consented = true, + }, + TestingParameters{ + .protocol_type = ProtocolType::kObliviousHttp, + .content_type = kContentEncodingCborHeaderValue, + .core_request_body = kv_server::kV2RequestMultiplePartitionsInJson, + .is_consented = false, + }, + TestingParameters{ + .protocol_type = ProtocolType::kObliviousHttp, + .content_type = kContentEncodingCborHeaderValue, .core_request_body = - kv_server::kExampleConsentedV2RequestWithLogContextInJson, + kv_server::kConsentedV2RequestMultiplePartitionsInJson, + .is_consented = true, + }, + TestingParameters{ + .protocol_type = ProtocolType::kObliviousHttp, + .content_type = kContentEncodingCborHeaderValue, + .core_request_body = kv_server:: + kConsentedV2RequestMultiplePartitionsWithLogContextInJson, + .is_consented = true, + }, + TestingParameters{ + .protocol_type = ProtocolType::kObliviousHttp, + .content_type = kContentEncodingCborHeaderValue, + .core_request_body = kv_server:: + kConsentedV2RequestMultiPartWithDebugInfoResponseInJson, .is_consented = true, })); @@ -433,6 +428,12 @@ request_metadata { string_value: "example.com" } } + fields { + key: "is_pas" + value { + string_value: "true" + } + } } )", &udf_metadata); @@ -505,7 +506,7 @@ data { std::string core_request_body = GetTestRequestBody(); google::api::HttpBody response; GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); - int16_t bhttp_response_code = 0; + int16_t http_response_code = 0; if (IsProtobufContent()) { v2::GetValuesRequest request_proto; ASSERT_TRUE(google::protobuf::util::JsonStringToMessage(core_request_body, @@ -518,8 +519,8 @@ data { auto request_context_factory = std::make_unique(); const auto result = GetValuesBasedOnProtocol(*request_context_factory, core_request_body, - &response, &bhttp_response_code, &handler); - ASSERT_EQ(bhttp_response_code, 200); + &response, &http_response_code, &handler); + ASSERT_EQ(http_response_code, 200); ASSERT_TRUE(result.ok()) << "code: " << result.error_code() << ", msg: " << result.error_message(); @@ -545,7 +546,7 @@ TEST_P(GetValuesHandlerTest, NoPartition) { })"; google::api::HttpBody response; GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); - int16_t bhttp_response_code = 0; + int16_t http_response_code = 0; if (IsProtobufContent()) { v2::GetValuesRequest request_proto; @@ -557,14 +558,9 @@ TEST_P(GetValuesHandlerTest, NoPartition) { auto request_context_factory = std::make_unique(); const auto result = GetValuesBasedOnProtocol(*request_context_factory, core_request_body, - &response, &bhttp_response_code, &handler); - if (IsUsing()) { - ASSERT_FALSE(result.ok()); - EXPECT_EQ(result.error_code(), grpc::StatusCode::INTERNAL); - } else { - ASSERT_TRUE(result.ok()); - EXPECT_EQ(bhttp_response_code, 500); - } + &response, &http_response_code, &handler); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.error_code(), grpc::StatusCode::INTERNAL); } TEST_P(GetValuesHandlerTest, UdfFailureForOnePartition) { @@ -577,13 +573,16 @@ TEST_P(GetValuesHandlerTest, UdfFailureForOnePartition) { { "id": 0, } - ] + ], + "metadata": { + "is_pas": "true" + } } )"; google::api::HttpBody response; GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); - int16_t bhttp_response_code = 0; + int16_t http_response_code = 0; if (IsProtobufContent()) { v2::GetValuesRequest request_proto; @@ -595,8 +594,8 @@ TEST_P(GetValuesHandlerTest, UdfFailureForOnePartition) { auto request_context_factory = std::make_unique(); const auto result = GetValuesBasedOnProtocol(*request_context_factory, core_request_body, - &response, &bhttp_response_code, &handler); - ASSERT_EQ(bhttp_response_code, 200); + &response, &http_response_code, &handler); + ASSERT_EQ(http_response_code, 200); ASSERT_TRUE(result.ok()) << "code: " << result.error_code() << ", msg: " << result.error_message(); @@ -616,67 +615,859 @@ TEST_P(GetValuesHandlerTest, UdfFailureForOnePartition) { EXPECT_THAT(actual_response, EqualsProto(expected_response)); } -TEST_F(GetValuesHandlerTest, PureGRPCTest) { - v2::GetValuesRequest req; - ExecutionMetadata execution_metadata; - TextFormat::ParseFromString( - R"pb(partitions { - id: 9 - arguments { data { string_value: "ECHO" } } - })pb", - &req); - GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); - EXPECT_CALL( - mock_udf_client_, - ExecuteCode( - _, _, - testing::ElementsAre(EqualsProto(req.partitions(0).arguments(0))), _)) - .WillOnce(Return("ECHO")); - v2::GetValuesResponse resp; - auto request_context_factory = std::make_unique(); - const auto result = handler.GetValues(*request_context_factory, req, &resp, - execution_metadata); - ASSERT_TRUE(result.ok()) << "code: " << result.error_code() - << ", msg: " << result.error_message(); - - v2::GetValuesResponse res; - TextFormat::ParseFromString( - R"pb(single_partition { id: 9 string_output: "ECHO" })pb", &res); - EXPECT_THAT(resp, EqualsProto(res)); +TEST_P(GetValuesHandlerMultiplePartitionsTest, Success) { + UDFExecutionMetadata udf_metadata; + TextFormat::ParseFromString(R"( +request_metadata { + fields { + key: "hostname" + value { + string_value: "example.com" + } + } +} + )", + &udf_metadata); + UDFArgument arg1, arg2, arg3; + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "structured" + } + values { + string_value: "groupNames" + } +} +data { + list_value { + values { + string_value: "hello" + } + } +})", + &arg1); + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "keys" + } +} +data { + list_value { + values { + string_value: "key1" + } + } +})", + &arg2); + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "keys" + } +} +data { + list_value { + values { + string_value: "key2" + } + } +})", + &arg3); + nlohmann::json output1 = nlohmann::json::parse(R"( +{ + "keyGroupOutputs": [ + { + "keyValues": { + "hello": { + "value": "world" + } + }, + "tags": [ + "structured", + "groupNames" + ] + } + ] +} + )"); + nlohmann::json output2 = nlohmann::json::parse(R"( +{ + "keyGroupOutputs": [ + { + "keyValues": { + "key1": { + "value": "value1" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] +} + )"); + nlohmann::json output3 = nlohmann::json::parse(R"( +{ + "keyGroupOutputs": [ + { + "keyValues": { + "key2": { + "value": "value2" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] } + )"); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg1)), _)) + .WillOnce(Return(output1.dump())); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg2)), _)) + .WillOnce(Return(output2.dump())); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg3)), _)) + .WillOnce(Return(output3.dump())); -TEST_F(GetValuesHandlerTest, PureGRPCTestFailure) { - v2::GetValuesRequest req; - ExecutionMetadata execution_metadata; - TextFormat::ParseFromString( - R"pb(partitions { - id: 9 - arguments { data { string_value: "ECHO" } } - })pb", - &req); + std::string core_request_body = GetTestRequestBody(); + google::api::HttpBody response; GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); - EXPECT_CALL( - mock_udf_client_, - ExecuteCode( - _, _, - testing::ElementsAre(EqualsProto(req.partitions(0).arguments(0))), _)) - .WillOnce(Return(absl::InternalError("UDF execution error"))); - v2::GetValuesResponse resp; + int16_t http_response_code = 0; + if (IsCborContent()) { + nlohmann::json request_body_json = + nlohmann::json::parse(core_request_body, nullptr, + /*allow_exceptions=*/false, + /*ignore_comments=*/true); + ASSERT_FALSE(request_body_json.is_discarded()); + std::vector cbor_vector = + nlohmann::json::to_cbor(request_body_json); + core_request_body = std::string(cbor_vector.begin(), cbor_vector.end()); + } auto request_context_factory = std::make_unique(); - const auto result = handler.GetValues(*request_context_factory, req, &resp, - execution_metadata); + const auto result = + GetValuesBasedOnProtocol(*request_context_factory, core_request_body, + &response, &http_response_code, &handler); + ASSERT_EQ(http_response_code, 200); ASSERT_TRUE(result.ok()) << "code: " << result.error_code() << ", msg: " << result.error_message(); - v2::GetValuesResponse res; - TextFormat::ParseFromString( - R"pb(single_partition { - id: 9 - status: { code: 13 message: "UDF execution error" } + nlohmann::json partition_output1 = {{"id", 0}}; + partition_output1.update(output1); + nlohmann::json partition_output2 = {{"id", 1}}; + partition_output2.update(output2); + nlohmann::json partition_output3 = {{"id", 2}}; + partition_output3.update(output3); + nlohmann::json compressed_partition_group0 = {partition_output1, + partition_output3}; + nlohmann::json compressed_partition_group1 = + nlohmann::json::array({partition_output2}); + if (IsCborContent()) { + nlohmann::json expected_json = { + {"compressionGroups", + {{{"compressionGroupId", 0}, {"content", compressed_partition_group0}}, + {{"compressionGroupId", 1}, + {"content", compressed_partition_group1}}}}}; + + // Convert CBOR to json to check content + nlohmann::json actual_response_from_cbor = nlohmann::json::from_cbor( + response.data(), /*strict=*/true, /*allow_exceptions=*/false); + ASSERT_FALSE(actual_response_from_cbor.is_discarded()); + auto a = GetPartitionOutputsInJson( + actual_response_from_cbor["compressionGroups"][0]["content"]); + actual_response_from_cbor["compressionGroups"][0]["content"] = a; + auto b = GetPartitionOutputsInJson( + actual_response_from_cbor["compressionGroups"][1]["content"]); + actual_response_from_cbor["compressionGroups"][1]["content"] = b; + // Compare compression groups lists, since ordering might throw off equality + EXPECT_THAT(std::vector(expected_json["compressionGroups"].begin(), + expected_json["compressionGroups"].end()), + testing::UnorderedElementsAreArray( + actual_response_from_cbor["compressionGroups"].begin(), + actual_response_from_cbor["compressionGroups"].end())); + return; + } + v2::GetValuesResponse actual_response; + if (IsProtobufContent()) { + ASSERT_TRUE(actual_response.ParseFromString(response.data())); + } else { + ASSERT_TRUE(google::protobuf::util::JsonStringToMessage(response.data(), + &actual_response) + .ok()); + } + EXPECT_EQ(actual_response.compression_groups().size(), 2); + std::vector contents; + for (auto&& compression_group : actual_response.compression_groups()) { + contents.emplace_back(compression_group.content()); + } + EXPECT_THAT(contents, testing::UnorderedElementsAre( + compressed_partition_group0.dump(), + compressed_partition_group1.dump())); +} + +TEST_P(GetValuesHandlerMultiplePartitionsTest, + SinglePartitionUDFFails_IgnorePartition) { + UDFExecutionMetadata udf_metadata; + TextFormat::ParseFromString(R"( +request_metadata { + fields { + key: "hostname" + value { + string_value: "example.com" + } + } +} + )", + &udf_metadata); + UDFArgument arg1, arg2, arg3; + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "structured" + } + values { + string_value: "groupNames" + } +} +data { + list_value { + values { + string_value: "hello" + } + } +})", + &arg1); + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "keys" + } +} +data { + list_value { + values { + string_value: "key1" + } + } +})", + &arg2); + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "keys" + } +} +data { + list_value { + values { + string_value: "key2" + } + } +})", + &arg3); + nlohmann::json output1 = nlohmann::json::parse(R"( +{ + "keyGroupOutputs": [ + { + "keyValues": { + "hello": { + "value": "world" + } + }, + "tags": [ + "structured", + "groupNames" + ] + } + ] +} + )"); + nlohmann::json output2 = nlohmann::json::parse(R"( +{ + "keyGroupOutputs": [ + { + "keyValues": { + "key1": { + "value": "value1" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] +} + )"); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg1)), _)) + .WillOnce(Return(output1.dump())); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg2)), _)) + .WillOnce(Return(output2.dump())); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg3)), _)) + .WillOnce(Return(absl::InternalError("UDF execution error"))); + + std::string core_request_body = GetTestRequestBody(); + google::api::HttpBody response; + GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); + int16_t http_response_code = 0; + if (IsCborContent()) { + nlohmann::json request_body_json = + nlohmann::json::parse(core_request_body, nullptr, + /*allow_exceptions=*/false, + /*ignore_comments=*/true); + ASSERT_FALSE(request_body_json.is_discarded()); + std::vector cbor_vector = + nlohmann::json::to_cbor(request_body_json); + core_request_body = std::string(cbor_vector.begin(), cbor_vector.end()); + } + + auto request_context_factory = std::make_unique(); + + const auto result = + GetValuesBasedOnProtocol(*request_context_factory, core_request_body, + &response, &http_response_code, &handler); + ASSERT_EQ(http_response_code, 200); + ASSERT_TRUE(result.ok()) << "code: " << result.error_code() + << ", msg: " << result.error_message(); + + nlohmann::json partition_output1 = {{"id", 0}}; + partition_output1.update(output1); + nlohmann::json partition_output2 = {{"id", 1}}; + partition_output2.update(output2); + nlohmann::json compressed_partition_group0 = + nlohmann::json::array({partition_output1}); + nlohmann::json compressed_partition_group1 = + nlohmann::json::array({partition_output2}); + if (IsCborContent()) { + nlohmann::json expected_json = { + {"compressionGroups", + {{{"compressionGroupId", 0}, {"content", compressed_partition_group0}}, + {{"compressionGroupId", 1}, + {"content", compressed_partition_group1}}}}}; + + // Convert CBOR to json to check content + nlohmann::json actual_response_from_cbor = nlohmann::json::from_cbor( + response.data(), /*strict=*/true, /*allow_exceptions=*/false); + ASSERT_FALSE(actual_response_from_cbor.is_discarded()); + auto a = GetPartitionOutputsInJson( + actual_response_from_cbor["compressionGroups"][0]["content"]); + actual_response_from_cbor["compressionGroups"][0]["content"] = a; + auto b = GetPartitionOutputsInJson( + actual_response_from_cbor["compressionGroups"][1]["content"]); + actual_response_from_cbor["compressionGroups"][1]["content"] = b; + // Compare compression groups lists, since ordering might throw off equality + EXPECT_THAT(std::vector(expected_json["compressionGroups"].begin(), + expected_json["compressionGroups"].end()), + testing::UnorderedElementsAreArray( + actual_response_from_cbor["compressionGroups"].begin(), + actual_response_from_cbor["compressionGroups"].end())); + return; + } + v2::GetValuesResponse actual_response; + + ASSERT_TRUE(google::protobuf::util::JsonStringToMessage(response.data(), + &actual_response) + .ok()); + + EXPECT_EQ(actual_response.compression_groups().size(), 2); + std::vector contents; + for (auto&& compression_group : actual_response.compression_groups()) { + contents.emplace_back(compression_group.content()); + } + EXPECT_THAT(contents, testing::UnorderedElementsAre( + compressed_partition_group0.dump(), + compressed_partition_group1.dump())); +} + +TEST_P(GetValuesHandlerMultiplePartitionsTest, + AllPartitionsInSingleCompressionGroupUDFFails_IgnoreCompressionGroup) { + UDFExecutionMetadata udf_metadata; + TextFormat::ParseFromString(R"( +request_metadata { + fields { + key: "hostname" + value { + string_value: "example.com" + } + } +} + )", + &udf_metadata); + UDFArgument arg1, arg2, arg3; + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "structured" + } + values { + string_value: "groupNames" + } +} +data { + list_value { + values { + string_value: "hello" + } + } +})", + &arg1); + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "keys" + } +} +data { + list_value { + values { + string_value: "key1" + } + } +})", + &arg2); + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "keys" + } +} +data { + list_value { + values { + string_value: "key2" + } + } +})", + &arg3); + nlohmann::json output = nlohmann::json::parse(R"( +{ + "keyGroupOutputs": [ + { + "keyValues": { + "key1": { + "value": "value1" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] +} + )"); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg1)), _)) + .WillOnce(Return(absl::InternalError("UDF execution error"))); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg2)), _)) + .WillOnce(Return(output.dump())); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg3)), _)) + .WillOnce(Return(absl::InternalError("UDF execution error"))); + + std::string core_request_body = GetTestRequestBody(); + google::api::HttpBody response; + GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); + int16_t http_response_code = 0; + if (IsCborContent()) { + nlohmann::json request_body_json = + nlohmann::json::parse(core_request_body, nullptr, + /*allow_exceptions=*/false, + /*ignore_comments=*/true); + ASSERT_FALSE(request_body_json.is_discarded()); + std::vector cbor_vector = + nlohmann::json::to_cbor(request_body_json); + core_request_body = std::string(cbor_vector.begin(), cbor_vector.end()); + } + auto request_context_factory = std::make_unique(); + const auto result = + GetValuesBasedOnProtocol(*request_context_factory, core_request_body, + &response, &http_response_code, &handler); + ASSERT_EQ(http_response_code, 200); + ASSERT_TRUE(result.ok()) << "code: " << result.error_code() + << ", msg: " << result.error_message(); + + nlohmann::json expected_partition_output = {{"id", 1}}; + expected_partition_output.update(output); + nlohmann::json compressed_partition_group = + nlohmann::json::array({expected_partition_output}); + if (IsCborContent()) { + nlohmann::json expected_json = { + {"compressionGroups", + {{{"compressionGroupId", 1}, + {"content", compressed_partition_group}}}}}; + + // Convert CBOR to json to check content + nlohmann::json actual_response_from_cbor = nlohmann::json::from_cbor( + response.data(), /*strict=*/true, /*allow_exceptions=*/false); + ASSERT_FALSE(actual_response_from_cbor.is_discarded()); + auto a = GetPartitionOutputsInJson( + actual_response_from_cbor["compressionGroups"][0]["content"]); + actual_response_from_cbor["compressionGroups"][0]["content"] = a; + EXPECT_EQ(expected_json, actual_response_from_cbor); + return; + } + + v2::GetValuesResponse actual_response, expected_response; + auto* compression_group = expected_response.add_compression_groups(); + compression_group->set_content(compressed_partition_group.dump()); + compression_group->set_compression_group_id(1); + ASSERT_TRUE(google::protobuf::util::JsonStringToMessage(response.data(), + &actual_response) + .ok()); + EXPECT_THAT(actual_response, EqualsProto(expected_response)); +} + +TEST_P(GetValuesHandlerMultiplePartitionsTest, + AllPartitionsFail_ReturnSuccess) { + UDFExecutionMetadata udf_metadata; + TextFormat::ParseFromString(R"( +request_metadata { + fields { + key: "hostname" + value { + string_value: "example.com" + } + } +} + )", + &udf_metadata); + UDFArgument arg1, arg2, arg3; + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "structured" + } + values { + string_value: "groupNames" + } +} +data { + list_value { + values { + string_value: "hello" + } + } +})", + &arg1); + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "keys" + } +} +data { + list_value { + values { + string_value: "key1" + } + } +})", + &arg2); + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "keys" + } +} +data { + list_value { + values { + string_value: "key2" + } + } +})", + &arg3); + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, EqualsProto(udf_metadata), + testing::ElementsAre(_), _)) + .WillRepeatedly(Return(absl::InternalError("UDF execution error"))); + std::string core_request_body = GetTestRequestBody(); + google::api::HttpBody response; + GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); + int16_t http_response_code = 0; + if (IsCborContent()) { + nlohmann::json request_body_json = + nlohmann::json::parse(core_request_body, nullptr, + /*allow_exceptions=*/false, + /*ignore_comments=*/true); + ASSERT_FALSE(request_body_json.is_discarded()); + std::vector cbor_vector = + nlohmann::json::to_cbor(request_body_json); + core_request_body = std::string(cbor_vector.begin(), cbor_vector.end()); + } + + auto request_context_factory = std::make_unique(); + const auto result = + GetValuesBasedOnProtocol(*request_context_factory, core_request_body, + &response, &http_response_code, &handler); + ASSERT_EQ(http_response_code, 200); + ASSERT_TRUE(result.ok()) << "code: " << result.error_code() + << ", msg: " << result.error_message(); + v2::GetValuesResponse actual_response, expected_response; + if (IsJsonContent()) { + EXPECT_THAT(actual_response, EqualsProto(expected_response)); + } + if (IsCborContent()) { + nlohmann::json expected_json = { + {"compressionGroups", nlohmann::json::array()}}; + + // Convert CBOR to json to check content + nlohmann::json actual_response_from_cbor = nlohmann::json::from_cbor( + response.data(), /*strict=*/true, /*allow_exceptions=*/false); + ASSERT_FALSE(actual_response_from_cbor.is_discarded()); + EXPECT_EQ(expected_json, actual_response_from_cbor); + } +} + +TEST_F(GetValuesHandlerTest, PureGRPCTest_Success) { + v2::GetValuesRequest req; + ExecutionMetadata execution_metadata; + TextFormat::ParseFromString( + R"pb(partitions { + id: 9 + arguments { data { string_value: "ECHO" } } + } + metadata { + fields { + key: "is_pas" + value { string_value: "true" } + } + })pb", + &req); + GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); + EXPECT_CALL( + mock_udf_client_, + ExecuteCode( + _, _, + testing::ElementsAre(EqualsProto(req.partitions(0).arguments(0))), _)) + .WillOnce(Return("ECHO")); + v2::GetValuesResponse resp; + auto request_context_factory = std::make_unique(); + JsonV2EncoderDecoder v2_codec; + const auto result = handler.GetValues( + *request_context_factory, req, &resp, execution_metadata, + /*single_partition_use_case=*/true, v2_codec); + ASSERT_TRUE(result.ok()) << "code: " << result.error_code() + << ", msg: " << result.error_message(); + + v2::GetValuesResponse res; + TextFormat::ParseFromString( + R"pb(single_partition { id: 9 string_output: "ECHO" })pb", &res); + EXPECT_THAT(resp, EqualsProto(res)); +} + +TEST_F(GetValuesHandlerTest, PureGRPCTestFailure) { + v2::GetValuesRequest req; + ExecutionMetadata execution_metadata; + TextFormat::ParseFromString( + R"pb(partitions { + id: 9 + arguments { data { string_value: "ECHO" } } + } + metadata { + fields { + key: "is_pas" + value { string_value: "true" } + } + })pb", + &req); + GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); + EXPECT_CALL( + mock_udf_client_, + ExecuteCode( + _, _, + testing::ElementsAre(EqualsProto(req.partitions(0).arguments(0))), _)) + .WillOnce(Return(absl::InternalError("UDF execution error"))); + v2::GetValuesResponse resp; + auto request_context_factory = std::make_unique(); + JsonV2EncoderDecoder v2_codec; + const auto result = handler.GetValues( + *request_context_factory, req, &resp, execution_metadata, + /*single_partition_use_case=*/true, v2_codec); + ASSERT_TRUE(result.ok()) << "code: " << result.error_code() + << ", msg: " << result.error_message(); + + v2::GetValuesResponse res; + TextFormat::ParseFromString( + R"pb(single_partition { + id: 9 + status: { code: 13 message: "UDF execution error" } })pb", &res); EXPECT_THAT(resp, EqualsProto(res)); } +TEST_F(GetValuesHandlerTest, + PureGRPCTest_SinglePartitionUseCase_PassesPartitionMetadata) { + v2::GetValuesRequest req; + ExecutionMetadata execution_metadata; + TextFormat::ParseFromString( + R"pb(partitions { + id: 9 + arguments { data { string_value: "ECHO" } } + metadata { + fields { + key: "partition_metadata_key" + value: { string_value: "my_value" } + } + } + } + metadata { + fields { + key: "is_pas" + value { string_value: "true" } + } + })pb", + &req); + UDFExecutionMetadata udf_metadata; + TextFormat::ParseFromString(R"( + request_metadata { + fields { + key: "is_pas" + value { + string_value: "true" + } + } + } + partition_metadata { + fields { + key: "partition_metadata_key" + value { + string_value: "my_value" + } + } + } + )", + &udf_metadata); + + GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); + EXPECT_CALL( + mock_udf_client_, + ExecuteCode( + _, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(req.partitions(0).arguments(0))), _)) + .WillOnce(Return("ECHO")); + v2::GetValuesResponse resp; + auto request_context_factory = std::make_unique(); + JsonV2EncoderDecoder v2_codec; + const auto result = handler.GetValues( + *request_context_factory, req, &resp, execution_metadata, + /*single_partition_use_case=*/true, v2_codec); + ASSERT_TRUE(result.ok()) << "code: " << result.error_code() + << ", msg: " << result.error_message(); + + v2::GetValuesResponse res; + TextFormat::ParseFromString( + R"pb(single_partition { id: 9 string_output: "ECHO" })pb", &res); + EXPECT_THAT(resp, EqualsProto(res)); +} + +TEST_F(GetValuesHandlerTest, PureGRPCTest_CBOR_Success) { + v2::GetValuesRequest req; + ExecutionMetadata execution_metadata; + TextFormat::ParseFromString( + R"pb(partitions { + id: 9 + arguments { data { string_value: "ECHO" } } + } + metadata { + fields { + key: "is_pas" + value { string_value: "true" } + } + })pb", + &req); + GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); + EXPECT_CALL( + mock_udf_client_, + ExecuteCode( + _, _, + testing::ElementsAre(EqualsProto(req.partitions(0).arguments(0))), _)) + .WillOnce(Return("ECHO")); + v2::GetValuesResponse resp; + auto request_context_factory = std::make_unique(); + JsonV2EncoderDecoder v2_codec; + const auto result = handler.GetValues( + *request_context_factory, req, &resp, execution_metadata, + /*single_partition_use_case=*/true, v2_codec); + ASSERT_TRUE(result.ok()) << "code: " << result.error_code() + << ", msg: " << result.error_message(); + + v2::GetValuesResponse res; + TextFormat::ParseFromString( + R"pb(single_partition { id: 9 string_output: "ECHO" })pb", &res); + EXPECT_THAT(resp, EqualsProto(res)); +} + +TEST_F(GetValuesHandlerTest, IsSinglePartitionUseCaseIsPasReturnsTrue) { + v2::GetValuesRequest req; + TextFormat::ParseFromString( + R"pb( + metadata { + fields { + key: "is_pas" + value { string_value: "true" } + } + })pb", + &req); + EXPECT_TRUE(IsSinglePartitionUseCase(req)); +} + +TEST_F(GetValuesHandlerTest, IsSinglePartitonUseCaseNotIsPasReturnsFalse) { + v2::GetValuesRequest req; + TextFormat::ParseFromString( + R"pb( + metadata { + fields { + key: "some" + value { string_value: "other data" } + } + })pb", + &req); + EXPECT_FALSE(IsSinglePartitionUseCase(req)); +} + } // namespace } // namespace kv_server diff --git a/components/data_server/request_handler/v2_response_data.proto b/components/data_server/request_handler/v2_response_data.proto deleted file mode 100644 index 309248c7..00000000 --- a/components/data_server/request_handler/v2_response_data.proto +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package kv_server; - -import "google/protobuf/struct.proto"; - -// Proto equivalent of a compression group in the KV V2 API: -// https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md#schema-of-the-request -message V2CompressionGroup { - repeated Partition partitions = 1; -} - -message Partition { - int64 id = 1; - repeated KeyGroupOutput key_group_outputs = 2; -} - -message KeyGroupOutput { - repeated string tags = 1; - map key_values = 2; -} - -message ValueObject { - google.protobuf.Value value = 1; - int64 global_ttl_sec = 2; - int64 dedicated_ttl_sec = 3; -} diff --git a/components/data_server/request_handler/v2_response_data_proto_test.cc b/components/data_server/request_handler/v2_response_data_proto_test.cc deleted file mode 100644 index e37183b7..00000000 --- a/components/data_server/request_handler/v2_response_data_proto_test.cc +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include - -#include "components/data_server/request_handler/v2_response_data.pb.h" -#include "gmock/gmock.h" -#include "google/protobuf/text_format.h" -#include "google/protobuf/util/json_util.h" -#include "gtest/gtest.h" -#include "public/test_util/proto_matcher.h" - -using google::protobuf::TextFormat; -using testing::_; -using testing::Return; - -using google::protobuf::util::JsonStringToMessage; -using google::protobuf::util::MessageToJsonString; - -namespace kv_server { -namespace { - -TEST(V2CompressionGroupProtoTest, - SuccessfullyParsesV2ResponseCompressionGroup) { - V2CompressionGroup v2_response_data_proto; - std::string v2_response_data_json = R"( - { - "partitions": [ - { - "id": 0, - "keyGroupOutputs": [ - { - "keyValues": { - "hello": { - "value": "world" - } - }, - "tags": [ - "custom", - "keys" - ] - }, - { - "keyValues": { - "hello": { - "value": "world" - } - }, - "tags": [ - "structured", - "groupNames" - ] - } - ] - }, - { - "id": 1, - "keyGroupOutputs": [ - { - "keyValues": { - "hello2": { - "value": "world2" - } - }, - "tags": [ - "custom", - "keys" - ] - } - ] - } - ] - } -)"; - auto json_to_proto_status = - JsonStringToMessage(v2_response_data_json, &v2_response_data_proto); - EXPECT_TRUE(json_to_proto_status.ok()); - EXPECT_EQ(json_to_proto_status.message(), ""); - V2CompressionGroup expected; - TextFormat::ParseFromString( - R"pb(partitions { - key_group_outputs { - tags: "custom" - tags: "keys" - key_values { - key: "hello" - value { value { string_value: "world" } } - } - } - key_group_outputs { - tags: "structured" - tags: "groupNames" - key_values { - key: "hello" - value { value { string_value: "world" } } - } - } - } - partitions { - id: 1 - key_group_outputs { - tags: "custom" - tags: "keys" - key_values { - key: "hello2" - value { value { string_value: "world2" } } - } - } - })pb", - &expected); - EXPECT_THAT(v2_response_data_proto, EqualsProto(expected)); -} - -} // namespace -} // namespace kv_server diff --git a/components/data_server/server/BUILD.bazel b/components/data_server/server/BUILD.bazel index fc2a103a..4da82cf3 100644 --- a/components/data_server/server/BUILD.bazel +++ b/components/data_server/server/BUILD.bazel @@ -37,6 +37,9 @@ cc_library( "//:aws_platform": ["parameter_fetcher_aws.cc"], "//:gcp_platform": ["parameter_fetcher_gcp.cc"], "//:local_platform": ["parameter_fetcher_local.cc"], + }) + select({ + "//:nonprod_mode": ["nonprod_parameter_fetcher.cc"], + "//:prod_mode": ["prod_parameter_fetcher.cc"], }) + ["parameter_fetcher.cc"], hdrs = ["parameter_fetcher.h"], visibility = [ @@ -135,6 +138,7 @@ cc_library( ":server_log_init", "//components/cloud_config:instance_client", "//components/cloud_config:parameter_client", + "//components/cloud_config/parameter_update:parameter_notifier", "//components/data/blob_storage:blob_storage_client", "//components/data/blob_storage:delta_file_notifier", "//components/data/realtime:realtime_thread_pool_manager", diff --git a/components/data_server/server/key_value_service_v2_impl.cc b/components/data_server/server/key_value_service_v2_impl.cc index e426a917..a596779e 100644 --- a/components/data_server/server/key_value_service_v2_impl.cc +++ b/components/data_server/server/key_value_service_v2_impl.cc @@ -30,7 +30,8 @@ using v2::KeyValueService; template using HandlerFunctionT = grpc::Status (GetValuesV2Handler::*)( RequestContextFactory&, const RequestT&, ResponseT*, - ExecutionMetadata& execution_metadata) const; + ExecutionMetadata& execution_metadata, bool single_partition_use_case, + const V2EncoderDecoder& v2_codec) const; inline void LogTotalExecutionWithoutCustomCodeMetric( const privacy_sandbox::server_common::Stopwatch& stopwatch, @@ -50,12 +51,17 @@ template grpc::ServerUnaryReactor* HandleRequest( RequestContextFactory& request_context_factory, CallbackServerContext* context, const RequestT* request, - ResponseT* response, const GetValuesV2Handler& handler, + ResponseT* response, bool is_single_partition_use_case, + const GetValuesV2Handler& handler, HandlerFunctionT handler_function) { privacy_sandbox::server_common::Stopwatch stopwatch; ExecutionMetadata execution_metadata; + auto v2_codec = V2EncoderDecoder::Create(V2EncoderDecoder::GetContentType( + context->client_metadata(), + /*default_content_type=*/V2EncoderDecoder::ContentType::kProto)); grpc::Status status = (handler.*handler_function)( - request_context_factory, *request, response, execution_metadata); + request_context_factory, *request, response, execution_metadata, + is_single_partition_use_case, *v2_codec); auto* reactor = context->DefaultReactor(); reactor->Finish(status); LogRequestCommonSafeMetrics(request, response, status, stopwatch); @@ -89,16 +95,8 @@ grpc::ServerUnaryReactor* KeyValueServiceV2Impl::GetValues( v2::GetValuesResponse* response) { auto request_context_factory = std::make_unique(); return HandleRequest(*request_context_factory, context, request, response, - handler_, &GetValuesV2Handler::GetValues); -} - -grpc::ServerUnaryReactor* KeyValueServiceV2Impl::BinaryHttpGetValues( - CallbackServerContext* context, - const v2::BinaryHttpGetValuesRequest* request, - google::api::HttpBody* response) { - auto request_context_factory = std::make_unique(); - return HandleRequest(*request_context_factory, context, request, response, - handler_, &GetValuesV2Handler::BinaryHttpGetValues); + IsSinglePartitionUseCase(*request), handler_, + &GetValuesV2Handler::GetValues); } grpc::ServerUnaryReactor* KeyValueServiceV2Impl::ObliviousGetValues( diff --git a/components/data_server/server/key_value_service_v2_impl.h b/components/data_server/server/key_value_service_v2_impl.h index c95d7b3b..89780f89 100644 --- a/components/data_server/server/key_value_service_v2_impl.h +++ b/components/data_server/server/key_value_service_v2_impl.h @@ -43,11 +43,6 @@ class KeyValueServiceV2Impl final const v2::GetValuesRequest* request, v2::GetValuesResponse* response) override; - grpc::ServerUnaryReactor* BinaryHttpGetValues( - grpc::CallbackServerContext* context, - const v2::BinaryHttpGetValuesRequest* request, - google::api::HttpBody* response) override; - grpc::ServerUnaryReactor* ObliviousGetValues( grpc::CallbackServerContext* context, const v2::ObliviousGetValuesRequest* request, diff --git a/components/data_server/server/lifecycle_heartbeat_test.cc b/components/data_server/server/lifecycle_heartbeat_test.cc index 936a1f7b..b7eae8b9 100644 --- a/components/data_server/server/lifecycle_heartbeat_test.cc +++ b/components/data_server/server/lifecycle_heartbeat_test.cc @@ -54,14 +54,7 @@ class FakePeriodicClosure : public PeriodicClosure { class LifecycleHeartbeatTest : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } }; TEST_F(LifecycleHeartbeatTest, CantRunTwice) { diff --git a/components/data_server/request_handler/framing_utils_test.cc b/components/data_server/server/nonprod_parameter_fetcher.cc similarity index 51% rename from components/data_server/request_handler/framing_utils_test.cc rename to components/data_server/server/nonprod_parameter_fetcher.cc index 254b6c1f..88dbc6dc 100644 --- a/components/data_server/request_handler/framing_utils_test.cc +++ b/components/data_server/server/nonprod_parameter_fetcher.cc @@ -12,21 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "components/data_server/request_handler/framing_utils.h" +#include -#include "gtest/gtest.h" +#include "absl/log/log.h" +#include "absl/strings/str_format.h" +#include "components/data_server/server/parameter_fetcher.h" namespace kv_server { -namespace { +constexpr std::string_view kAddChaffShardingClustersParameterSuffix = + "add-chaff-sharding-clusters"; -TEST(FramingUtilsTest, EncodedDataSizeMatchesTheSpec) { - EXPECT_EQ(GetEncodedDataSize(0), 8); - EXPECT_EQ(GetEncodedDataSize(1), 8); - EXPECT_EQ(GetEncodedDataSize(3), 8); - EXPECT_EQ(GetEncodedDataSize(4), 16); - EXPECT_EQ(GetEncodedDataSize(100), 128); - EXPECT_EQ(GetEncodedDataSize(1000), 1024); +bool ParameterFetcher::ShouldAddChaffCalloutsToShardCluster() const { + auto add_chaff_sharding_clusters = + GetBoolParameter(kAddChaffShardingClustersParameterSuffix); + + PS_LOG(INFO, log_context_) + << "Retrieved " << kAddChaffShardingClustersParameterSuffix + << " parameter: " << add_chaff_sharding_clusters; + + return add_chaff_sharding_clusters; } -} // namespace } // namespace kv_server diff --git a/components/data_server/server/parameter_fetcher.h b/components/data_server/server/parameter_fetcher.h index 0b368a23..9932dc97 100644 --- a/components/data_server/server/parameter_fetcher.h +++ b/components/data_server/server/parameter_fetcher.h @@ -62,13 +62,15 @@ class ParameterFetcher { virtual NotifierMetadata GetRealtimeNotifierMetadata(int32_t num_shards, int32_t shard_num) const; + virtual NotifierMetadata GetLoggingVerbosityParameterNotifierMetadata() const; + virtual bool ShouldAddChaffCalloutsToShardCluster() const; + + std::string GetParamName(std::string_view parameter_suffix) const; protected: privacy_sandbox::server_common::log::PSLogContext& log_context_; private: - std::string GetParamName(std::string_view parameter_suffix) const; - const std::string environment_; const ParameterClient& parameter_client_; absl::AnyInvocable diff --git a/components/data_server/server/parameter_fetcher_aws.cc b/components/data_server/server/parameter_fetcher_aws.cc index 5c3b17e5..fb8baa4c 100644 --- a/components/data_server/server/parameter_fetcher_aws.cc +++ b/components/data_server/server/parameter_fetcher_aws.cc @@ -28,6 +28,10 @@ constexpr std::string_view kDataLoadingFileChannelBucketSNSParameterSuffix = constexpr std::string_view kDataLoadingRealtimeChannelSNSParameterSuffix = "data-loading-realtime-channel-sns-arn"; +// SNS ARN for logging verbosity parameter value updates +constexpr std::string_view kLoggingVerbositySNSParameterSuffix = + "logging-verbosity-update-sns-arn"; + // Max connections for AWS's blob storage client constexpr std::string_view kS3ClientMaxConnectionsParameterSuffix = "s3client-max-connections"; @@ -75,4 +79,15 @@ NotifierMetadata ParameterFetcher::GetRealtimeNotifierMetadata( .environment = environment_}; } +NotifierMetadata +ParameterFetcher::GetLoggingVerbosityParameterNotifierMetadata() const { + std::string sns_arn = GetParameter(kLoggingVerbositySNSParameterSuffix); + PS_LOG(INFO, log_context_) + << "Retrieved " << kLoggingVerbositySNSParameterSuffix + << " parameter: " << sns_arn; + return AwsNotifierMetadata{.queue_prefix = "ParameterNotifier_", + .sns_arn = std::move(sns_arn), + .environment = environment_}; +} + } // namespace kv_server diff --git a/components/data_server/server/parameter_fetcher_gcp.cc b/components/data_server/server/parameter_fetcher_gcp.cc index 159d7856..f2f5a4f2 100644 --- a/components/data_server/server/parameter_fetcher_gcp.cc +++ b/components/data_server/server/parameter_fetcher_gcp.cc @@ -62,4 +62,11 @@ NotifierMetadata ParameterFetcher::GetRealtimeNotifierMetadata( }; } +NotifierMetadata +ParameterFetcher::GetLoggingVerbosityParameterNotifierMetadata() const { + // TODO(b/301118821): set to proper values. Waiting on the change notifier + // implementation. + return GcpNotifierMetadata{}; +} + } // namespace kv_server diff --git a/components/data_server/server/parameter_fetcher_local.cc b/components/data_server/server/parameter_fetcher_local.cc index 083ade02..d74f13f4 100644 --- a/components/data_server/server/parameter_fetcher_local.cc +++ b/components/data_server/server/parameter_fetcher_local.cc @@ -42,4 +42,10 @@ NotifierMetadata ParameterFetcher::GetRealtimeNotifierMetadata( return LocalNotifierMetadata{.local_directory = std::move(directory)}; } +NotifierMetadata +ParameterFetcher::GetLoggingVerbosityParameterNotifierMetadata() const { + // returns dummy notifier data + return LocalNotifierMetadata{}; +} + } // namespace kv_server diff --git a/tools/server_diagnostic/diagnostic.go b/components/data_server/server/prod_parameter_fetcher.cc similarity index 65% rename from tools/server_diagnostic/diagnostic.go rename to components/data_server/server/prod_parameter_fetcher.cc index cec749a3..cb1639b0 100644 --- a/tools/server_diagnostic/diagnostic.go +++ b/components/data_server/server/prod_parameter_fetcher.cc @@ -1,10 +1,10 @@ -// Copyright 2023 Google LLC +// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,13 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main - -import ( - "fmt" -) - -func main() { - fmt.Println("Hello, this is the diagnostic tool") - // TODO(b/311377010): Add code +#include "components/data_server/server/parameter_fetcher.h" +namespace kv_server { +bool ParameterFetcher::ShouldAddChaffCalloutsToShardCluster() const { + return false; } + +} // namespace kv_server diff --git a/components/data_server/server/server.cc b/components/data_server/server/server.cc index 69aaeee9..2a03fe0d 100644 --- a/components/data_server/server/server.cc +++ b/components/data_server/server/server.cc @@ -70,11 +70,14 @@ using privacy_sandbox::server_common::InitTelemetry; using privacy_sandbox::server_common::TelemetryProvider; using privacy_sandbox::server_common::log::PSLogContext; using privacy_sandbox::server_common::telemetry::BuildDependentConfig; +using privacy_sandbox::server_common::telemetry::TelemetryConfig; // TODO: Use config cpio client to get this from the environment constexpr absl::string_view kDataBucketParameterSuffix = "data-bucket-id"; constexpr absl::string_view kBackupPollFrequencySecsParameterSuffix = "backup-poll-frequency-secs"; +constexpr absl::string_view kLoggingVerbosityBackupPollFreqSecsParameterSuffix = + "logging-verbosity-backup-poll-frequency-secs"; constexpr absl::string_view kMetricsExportIntervalMillisParameterSuffix = "metrics-export-interval-millis"; constexpr absl::string_view kMetricsExportTimeoutMillisParameterSuffix = @@ -151,9 +154,8 @@ void CheckMetricsCollectorEndPointConnection( "Checking connection to metrics collector", LogMetricsNoOpCallback()); } -privacy_sandbox::server_common::telemetry::TelemetryConfig -GetServerTelemetryConfig(const ParameterClient& parameter_client, - const std::string& environment) { +TelemetryConfig GetServerTelemetryConfig( + const ParameterClient& parameter_client, const std::string& environment) { ParameterFetcher parameter_fetcher(environment, parameter_client); auto config_string = parameter_fetcher.GetParameter(kTelemetryConfigSuffix); privacy_sandbox::server_common::telemetry::TelemetryConfig config; @@ -183,7 +185,8 @@ Server::Server() GetValuesHook::Create(GetValuesHook::OutputType::kString)), binary_get_values_hook_( GetValuesHook::Create(GetValuesHook::OutputType::kBinary)), - run_set_query_int_hook_(RunSetQueryIntHook::Create()), + run_set_query_uint32_hook_(RunSetQueryUInt32Hook::Create()), + run_set_query_uint64_hook_(RunSetQueryUInt64Hook::Create()), run_set_query_string_hook_(RunSetQueryStringHook::Create()) {} // Because the cache relies on telemetry, this function needs to be @@ -205,14 +208,13 @@ void Server::InitLogger(::opentelemetry::sdk::resource::Resource server_info, // downstream. auto verbosity_level = parameter_fetcher.GetInt32Parameter( kLoggingVerbosityLevelParameterSuffix); - absl::SetGlobalVLogLevel(verbosity_level); - privacy_sandbox::server_common::log::PS_VLOG_IS_ON(0, verbosity_level); static auto* log_provider = privacy_sandbox::server_common::ConfigurePrivateLogger(server_info, collector_endpoint) .release(); privacy_sandbox::server_common::log::logger_private = log_provider->GetLogger(kServiceName.data()).get(); + UpdateLoggingVerbosity(verbosity_level); parameter_client_->UpdateLogContext(server_safe_log_context_); instance_client_->UpdateLogContext(server_safe_log_context_); if (const bool enable_consented_log = @@ -221,6 +223,46 @@ void Server::InitLogger(::opentelemetry::sdk::resource::Resource server_info, privacy_sandbox::server_common::log::ServerToken( parameter_fetcher.GetParameter(kConsentedDebugTokenSuffix, "")); } + // Start verbosity parameter notifier to listen to the updates + uint32_t backup_poll_frequency_secs = parameter_fetcher.GetInt32Parameter( + kLoggingVerbosityBackupPollFreqSecsParameterSuffix); + auto metadata = + parameter_fetcher.GetLoggingVerbosityParameterNotifierMetadata(); + auto message_service_status = + MessageService::Create(metadata, server_safe_log_context_); + if (!message_service_status.ok()) { + PS_LOG(ERROR, server_safe_log_context_) + << "Failed to setup message service for logging verbosity update"; + return; + } + message_service_verbosity_param_update_ = std::move(*message_service_status); + SetQueueManager(metadata, message_service_verbosity_param_update_.get()); + auto logging_verbosity_notifier_status = ParameterNotifier::Create( + metadata, + std::move(parameter_fetcher.GetParamName( + kLoggingVerbosityLevelParameterSuffix)), + absl::Seconds(backup_poll_frequency_secs), server_safe_log_context_); + if (logging_verbosity_notifier_status.ok()) { + logging_verbosity_param_notifier_ = + std::move(*logging_verbosity_notifier_status); + if (auto status = logging_verbosity_param_notifier_->Start( + [this](std::string_view param_name) { + return parameter_client_->GetInt32Parameter(param_name); + }, + absl::bind_front(&Server::UpdateLoggingVerbosity, this)); + !status.ok()) { + PS_LOG(ERROR, server_safe_log_context_) + << "Failed to start the parameter notifier for logging verbosity " + "update " + << status; + return; + } + } else { + PS_LOG(ERROR, server_safe_log_context_) + << "Failed to setup the parameter notifier for logging verbosity " + "update " + << logging_verbosity_notifier_status.status(); + } } void Server::InitializeTelemetry(const ParameterClient& parameter_client, @@ -233,8 +275,9 @@ void Server::InitializeTelemetry(const ParameterClient& parameter_client, parameter_fetcher.GetBoolParameter(kEnableOtelLoggerParameterSuffix); LOG(INFO) << "Retrieved " << kEnableOtelLoggerParameterSuffix << " parameter: " << enable_otel_logger; - BuildDependentConfig telemetry_config( - GetServerTelemetryConfig(parameter_client, environment_)); + TelemetryConfig telemetry_config_proto = + GetServerTelemetryConfig(parameter_client, environment_); + BuildDependentConfig telemetry_config(telemetry_config_proto); InitTelemetry(std::string(kServiceName), std::string(BuildVersion()), telemetry_config.TraceAllowed(), telemetry_config.MetricAllowed(), enable_otel_logger); @@ -246,7 +289,7 @@ void Server::InitializeTelemetry(const ParameterClient& parameter_client, } LOG(INFO) << "Done retrieving metrics collector endpoint"; auto* context_map = KVServerContextMap( - telemetry_config, + std::make_unique(telemetry_config_proto), ConfigurePrivateMetrics( CreateKVAttributes(instance_id, std::to_string(shard_num_), environment_), @@ -254,7 +297,7 @@ void Server::InitializeTelemetry(const ParameterClient& parameter_client, AddSystemMetric(context_map); auto* internal_lookup_context_map = InternalLookupServerContextMap( - telemetry_config, + std::make_unique(telemetry_config_proto), ConfigurePrivateMetrics( CreateKVAttributes(instance_id, std::to_string(shard_num_), environment_), @@ -322,9 +365,10 @@ absl::Status Server::CreateDefaultInstancesIfNecessaryAndGetEnvironment( config_builder .RegisterStringGetValuesHook(*string_get_values_hook_) .RegisterBinaryGetValuesHook(*binary_get_values_hook_) - .RegisterRunSetQueryIntHook(*run_set_query_int_hook_) + .RegisterRunSetQueryUInt32Hook(*run_set_query_uint32_hook_) + .RegisterRunSetQueryUInt64Hook(*run_set_query_uint64_hook_) .RegisterRunSetQueryStringHook(*run_set_query_string_hook_) - .RegisterLoggingFunction() + .RegisterLoggingHook() .SetNumberOfWorkers(number_of_workers) .Config()), absl::Milliseconds(udf_timeout_ms), @@ -477,7 +521,8 @@ absl::Status Server::InitOnceInstancesAreCreated() { } auto maybe_shard_state = server_initializer->InitializeUdfHooks( *string_get_values_hook_, *binary_get_values_hook_, - *run_set_query_string_hook_, *run_set_query_int_hook_); + *run_set_query_string_hook_, *run_set_query_uint32_hook_, + *run_set_query_uint64_hook_); if (!maybe_shard_state.ok()) { return maybe_shard_state.status(); } @@ -502,6 +547,10 @@ absl::Status Server::MaybeShutdownNotifiers() { if (realtime_thread_pool_manager_) { status.Update(realtime_thread_pool_manager_->Stop()); } + if (logging_verbosity_param_notifier_ && + logging_verbosity_param_notifier_->IsRunning()) { + status.Update(logging_verbosity_param_notifier_->Stop()); + } return status; } @@ -718,4 +767,16 @@ std::unique_ptr Server::CreateDeltaFileNotifier( server_safe_log_context_); } +void Server::UpdateLoggingVerbosity(int32_t verbosity_value) { + if (verbosity_value >= 0) { + // absl and ps log internally will check if new verbosity value + // equals existing verbosity value and apply the update if values are + // different. + absl::SetGlobalVLogLevel(verbosity_value); + privacy_sandbox::server_common::log::SetGlobalPSVLogLevel(verbosity_value); + PS_VLOG(1, server_safe_log_context_) + << "Updated logging verbosity level to " << verbosity_value; + } +} + } // namespace kv_server diff --git a/components/data_server/server/server.h b/components/data_server/server/server.h index 68e1dda5..77bc1ea5 100644 --- a/components/data_server/server/server.h +++ b/components/data_server/server/server.h @@ -24,6 +24,7 @@ #include "absl/time/time.h" #include "components/cloud_config/instance_client.h" #include "components/cloud_config/parameter_client.h" +#include "components/cloud_config/parameter_update/parameter_notifier.h" #include "components/data/blob_storage/blob_storage_client.h" #include "components/data/blob_storage/delta_file_notifier.h" #include "components/data/realtime/realtime_thread_pool_manager.h" @@ -107,6 +108,9 @@ class Server { absl::optional collector_endpoint, const ParameterFetcher& parameter_fetcher); + // Updates max logging verbosity level for global absl and ps vlog + void UpdateLoggingVerbosity(int32_t verbosity_string_value); + // This must be first, otherwise the AWS SDK will crash when it's called: PlatformInitializer platform_initializer_; @@ -119,7 +123,8 @@ class Server { std::unique_ptr get_values_adapter_; std::unique_ptr string_get_values_hook_; std::unique_ptr binary_get_values_hook_; - std::unique_ptr run_set_query_int_hook_; + std::unique_ptr run_set_query_uint32_hook_; + std::unique_ptr run_set_query_uint64_hook_; std::unique_ptr run_set_query_string_hook_; // BlobStorageClient must outlive DeltaFileNotifier @@ -153,6 +158,9 @@ class Server { int32_t shard_num_; int32_t num_shards_; + std::unique_ptr message_service_verbosity_param_update_; + std::unique_ptr logging_verbosity_param_notifier_; + std::unique_ptr key_fetcher_manager_; std::unique_ptr open_telemetry_sink_; diff --git a/components/data_server/server/server_initializer.cc b/components/data_server/server/server_initializer.cc index ce8ae4a0..78768c8e 100644 --- a/components/data_server/server/server_initializer.cc +++ b/components/data_server/server/server_initializer.cc @@ -34,7 +34,8 @@ absl::Status InitializeUdfHooksInternal( GetValuesHook& string_get_values_hook, GetValuesHook& binary_get_values_hook, RunSetQueryStringHook& run_query_hook, - RunSetQueryIntHook& run_set_query_int_hook, + RunSetQueryUInt32Hook& run_set_query_uint32_hook, + RunSetQueryUInt64Hook& run_set_query_uint64_hook, privacy_sandbox::server_common::log::PSLogContext& log_context) { PS_VLOG(9, log_context) << "Finishing getValues init"; string_get_values_hook.FinishInit(get_lookup()); @@ -42,8 +43,10 @@ absl::Status InitializeUdfHooksInternal( binary_get_values_hook.FinishInit(get_lookup()); PS_VLOG(9, log_context) << "Finishing runQuery init"; run_query_hook.FinishInit(get_lookup()); - PS_VLOG(9, log_context) << "Finishing runSetQueryInt init"; - run_set_query_int_hook.FinishInit(get_lookup()); + PS_VLOG(9, log_context) << "Finishing runSetQueryUInt32 init"; + run_set_query_uint32_hook.FinishInit(get_lookup()); + PS_VLOG(9, log_context) << "Finishing runSetQueryUInt64 init"; + run_set_query_uint64_hook.FinishInit(get_lookup()); return absl::OkStatus(); } @@ -63,15 +66,16 @@ class NonshardedServerInitializer : public ServerInitializer { GetValuesHook& string_get_values_hook, GetValuesHook& binary_get_values_hook, RunSetQueryStringHook& run_query_hook, - RunSetQueryIntHook& run_set_query_int_hook) override { + RunSetQueryUInt32Hook& run_set_query_uint32_hook, + RunSetQueryUInt64Hook& run_set_query_uint64_hook) override { ShardManagerState shard_manager_state; auto lookup_supplier = [&cache = cache_]() { return CreateLocalLookup(cache); }; InitializeUdfHooksInternal(std::move(lookup_supplier), string_get_values_hook, binary_get_values_hook, - run_query_hook, run_set_query_int_hook, - log_context_); + run_query_hook, run_set_query_uint32_hook, + run_set_query_uint64_hook, log_context_); return shard_manager_state; } @@ -120,11 +124,14 @@ class ShardedServerInitializer : public ServerInitializer { GetValuesHook& string_get_values_hook, GetValuesHook& binary_get_values_hook, RunSetQueryStringHook& run_set_query_string_hook, - RunSetQueryIntHook& run_set_query_int_hook) override { + RunSetQueryUInt32Hook& run_set_query_uint32_hook, + RunSetQueryUInt64Hook& run_set_query_uint64_hook) override { auto maybe_shard_state = CreateShardManager(); if (!maybe_shard_state.ok()) { return maybe_shard_state.status(); } + const bool add_chaff = + parameter_fetcher_.ShouldAddChaffCalloutsToShardCluster(); auto lookup_supplier = [&local_lookup = local_lookup_, num_shards = num_shards_, current_shard_num = current_shard_num_, @@ -133,10 +140,10 @@ class ShardedServerInitializer : public ServerInitializer { return CreateShardedLookup(local_lookup, num_shards, current_shard_num, shard_manager, key_sharder); }; - InitializeUdfHooksInternal(std::move(lookup_supplier), - string_get_values_hook, binary_get_values_hook, - run_set_query_string_hook, - run_set_query_int_hook, log_context_); + InitializeUdfHooksInternal( + std::move(lookup_supplier), string_get_values_hook, + binary_get_values_hook, run_set_query_string_hook, + run_set_query_uint32_hook, run_set_query_uint64_hook, log_context_); return std::move(*maybe_shard_state); } diff --git a/components/data_server/server/server_initializer.h b/components/data_server/server/server_initializer.h index 5e30bfa7..30ff2909 100644 --- a/components/data_server/server/server_initializer.h +++ b/components/data_server/server/server_initializer.h @@ -56,7 +56,8 @@ class ServerInitializer { GetValuesHook& string_get_values_hook, GetValuesHook& binary_get_values_hook, RunSetQueryStringHook& run_set_query_string_hook, - RunSetQueryIntHook& run_set_query_int_hook) = 0; + RunSetQueryUInt32Hook& run_set_query_uint32_hook, + RunSetQueryUInt64Hook& run_set_query_uint64_hook) = 0; }; std::unique_ptr GetServerInitializer( diff --git a/components/data_server/server/server_local_test.cc b/components/data_server/server/server_local_test.cc index 9fa33f5b..3d6f4d82 100644 --- a/components/data_server/server/server_local_test.cc +++ b/components/data_server/server/server_local_test.cc @@ -53,6 +53,11 @@ void RegisterRequiredTelemetryExpectations(MockParameterClient& client) { EXPECT_CALL(client, GetInt32Parameter( "kv-server-environment-backup-poll-frequency-secs")) .WillOnce(::testing::Return(123)); + EXPECT_CALL( + client, + GetInt32Parameter( + "kv-server-environment-logging-verbosity-backup-poll-frequency-secs")) + .WillOnce(::testing::Return(300)); EXPECT_CALL(client, GetBoolParameter("kv-server-environment-enable-otel-logger")) .WillOnce(::testing::Return(false)); @@ -146,7 +151,8 @@ TEST_F(ServerLocalTest, InitFailsWithNoDeltaDirectory) { EXPECT_CALL( *parameter_client, GetInt32Parameter("kv-server-environment-logging-verbosity-level")) - .WillOnce(::testing::Return(0)); + .Times(3) + .WillRepeatedly(::testing::Return(0)); EXPECT_CALL(*parameter_client, GetBoolParameter("kv-server-environment-use-sharding-key-regex")) .WillOnce(::testing::Return(false)); @@ -159,6 +165,7 @@ TEST_F(ServerLocalTest, InitFailsWithNoDeltaDirectory) { absl::Status status = server.Init(std::move(parameter_client), std::move(instance_client), std::move(mock_udf_client)); + server.ForceShutdown(); EXPECT_FALSE(status.ok()); } @@ -219,7 +226,8 @@ TEST_F(ServerLocalTest, InitPassesWithDeltaDirectoryAndRealtimeDirectory) { EXPECT_CALL( *parameter_client, GetInt32Parameter("kv-server-environment-logging-verbosity-level")) - .WillOnce(::testing::Return(0)); + .Times(3) + .WillRepeatedly(::testing::Return(0)); EXPECT_CALL(*parameter_client, GetBoolParameter("kv-server-environment-use-sharding-key-regex")) .WillOnce(::testing::Return(false)); @@ -235,6 +243,7 @@ TEST_F(ServerLocalTest, InitPassesWithDeltaDirectoryAndRealtimeDirectory) { absl::Status status = server.Init(std::move(parameter_client), std::move(instance_client), std::move(mock_udf_client)); + server.ForceShutdown(); EXPECT_TRUE(status.ok()); } @@ -253,7 +262,8 @@ TEST_F(ServerLocalTest, GracefulServerShutdown) { EXPECT_CALL( *parameter_client, GetInt32Parameter("kv-server-environment-logging-verbosity-level")) - .WillOnce(::testing::Return(0)); + .Times(3) + .WillRepeatedly(::testing::Return(0)); EXPECT_CALL(*parameter_client, GetParameter("kv-server-environment-directory", testing::Eq(std::nullopt))) .WillOnce(::testing::Return(::testing::TempDir())); @@ -373,7 +383,8 @@ TEST_F(ServerLocalTest, ForceServerShutdown) { EXPECT_CALL( *parameter_client, GetInt32Parameter("kv-server-environment-logging-verbosity-level")) - .WillOnce(::testing::Return(0)); + .Times(3) + .WillRepeatedly(::testing::Return(0)); EXPECT_CALL(*parameter_client, GetBoolParameter("kv-server-environment-use-sharding-key-regex")) .WillOnce(::testing::Return(false)); diff --git a/components/envoy_proxy/envoy.yaml b/components/envoy_proxy/envoy.yaml index 8b41d3c8..ab898ffd 100644 --- a/components/envoy_proxy/envoy.yaml +++ b/components/envoy_proxy/envoy.yaml @@ -34,7 +34,16 @@ static_resources: virtual_hosts: - name: local_service domains: [ "*" ] + cors: + allow_origin_string_match: + - prefix: "*" + allow_methods: GET, POST, PUT, OPTIONS + allow_headers: Origin, Content, Accept, Content-Type, Authorization, X-Requested-With + allow_credentials: true routes: + - match: { prefix: "/", headers: [ { name: ":method", exact_match: "OPTIONS" } ] } + # fake cluster route ... some issue in envoy. + route: { cluster: "grpc_cluster", timeout: 60s } - match: { prefix: "/kv_server.v1.KeyValueService" } route: { cluster: grpc_cluster, timeout: 60s } - match: { prefix: "/kv_server.v2.KeyValueService" } @@ -59,7 +68,13 @@ static_resources: key: 'x-fledge-bidding-signals-format-version' value: '2' append: false + - header: + key: 'Access-Control-Allow-Origin' + value: '*' http_filters: + - name: envoy.filters.http.cors + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.cors.v3.Cors # Pass the content-type to the grpc server. By default the header is overwritten to "application/grpc" # This extra filter must be used because the content-type must be copied to a metadata prior to grpc transcoding # By the time the Router filter is invoked, the original content-type is already changed. diff --git a/components/errors/BUILD.bazel b/components/errors/BUILD.bazel index a44be19a..3f88029f 100644 --- a/components/errors/BUILD.bazel +++ b/components/errors/BUILD.bazel @@ -17,7 +17,7 @@ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") package(default_visibility = [ "//components:__subpackages__", "//production/packaging:__subpackages__", - "//public/data_loading:__subpackages__", + "//public:__subpackages__", ]) cc_library( diff --git a/components/internal_server/BUILD.bazel b/components/internal_server/BUILD.bazel index 44eb3fd2..0b7458e5 100644 --- a/components/internal_server/BUILD.bazel +++ b/components/internal_server/BUILD.bazel @@ -34,7 +34,7 @@ cc_library( ":internal_lookup_cc_grpc", ":lookup", ":string_padder", - "//components/data_server/request_handler:ohttp_server_encryptor", + "//components/data_server/request_handler/encryption:ohttp_server_encryptor", "//components/query:driver", "//components/query:scanner", "@com_github_grpc_grpc//:grpc++", @@ -140,7 +140,7 @@ cc_library( ":internal_lookup_cc_proto", ":lookup", "//components/data_server/cache", - "//components/data_server/cache:uint32_value_set", + "//components/data_server/cache:uint_value_set", "//components/errors:error_tag", "//components/query:driver", "//components/query:scanner", @@ -160,7 +160,7 @@ cc_library( ":internal_lookup_cc_proto", ":local_lookup", ":remote_lookup_client_impl", - "//components/data_server/cache:uint32_value_set", + "//components/data_server/cache:uint_value_set", "//components/query:driver", "//components/query:scanner", "//components/sharding:shard_manager", @@ -208,7 +208,7 @@ cc_library( ":constants", ":internal_lookup_cc_grpc", ":string_padder", - "//components/data_server/request_handler:ohttp_client_encryptor", + "//components/data_server/request_handler/encryption:ohttp_client_encryptor", "//components/util:request_context", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/log", diff --git a/components/internal_server/local_lookup.cc b/components/internal_server/local_lookup.cc index b9e535a1..fe61d49a 100644 --- a/components/internal_server/local_lookup.cc +++ b/components/internal_server/local_lookup.cc @@ -21,8 +21,7 @@ #include "absl/functional/any_invocable.h" #include "components/data_server/cache/cache.h" -#include "components/data_server/cache/uint32_value_set.h" -#include "components/errors/error_tag.h" +#include "components/data_server/cache/uint_value_set.h" #include "components/internal_server/lookup.h" #include "components/internal_server/lookup.pb.h" #include "components/query/driver.h" @@ -46,53 +45,169 @@ class LocalLookup : public Lookup { absl::StatusOr GetKeyValueSet( const RequestContext& request_context, const absl::flat_hash_set& key_set) const override { - return ProcessValueSetKeys(request_context, key_set, - SingleLookupResult::kKeysetValues); + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetInternalLookupMetricsContext()); + InternalLookupResponse response; + if (key_set.empty()) { + return response; + } + auto key_value_set_result = cache_.GetKeyValueSet(request_context, key_set); + for (const auto& key : key_set) { + SingleLookupResult result; + if (const auto value_set = key_value_set_result->GetValueSet(key); + !value_set.empty()) { + auto* keyset_values = result.mutable_keyset_values(); + keyset_values->mutable_values()->Reserve(value_set.size()); + keyset_values->mutable_values()->Add(value_set.begin(), + value_set.end()); + } else { + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + status->set_message(absl::StrCat("Key not found: ", key)); + } + (*response.mutable_kv_pairs())[key] = std::move(result); + } + return response; } absl::StatusOr GetUInt32ValueSet( const RequestContext& request_context, const absl::flat_hash_set& key_set) const override { - return ProcessValueSetKeys(request_context, key_set, - SingleLookupResult::kUintsetValues); + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetInternalLookupMetricsContext()); + InternalLookupResponse response; + if (key_set.empty()) { + return response; + } + auto key_value_set_result = + cache_.GetUInt32ValueSet(request_context, key_set); + for (const auto& key : key_set) { + SingleLookupResult result; + if (const auto value_set = key_value_set_result->GetUInt32ValueSet(key); + value_set != nullptr && !value_set->GetValues().empty()) { + auto uint32_values = value_set->GetValues(); + auto* result_values = result.mutable_uint32set_values(); + result_values->mutable_values()->Reserve(uint32_values.size()); + result_values->mutable_values()->Add(uint32_values.begin(), + uint32_values.end()); + } else { + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + status->set_message(absl::StrCat("Key not found: ", key)); + } + (*response.mutable_kv_pairs())[key] = std::move(result); + } + return response; + } + + absl::StatusOr GetUInt64ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const override { + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetInternalLookupMetricsContext()); + InternalLookupResponse response; + if (key_set.empty()) { + return response; + } + auto key_value_set_result = + cache_.GetUInt64ValueSet(request_context, key_set); + for (const auto& key : key_set) { + SingleLookupResult result; + if (const auto value_set = key_value_set_result->GetUInt64ValueSet(key); + value_set != nullptr && !value_set->GetValues().empty()) { + auto uint64_values = value_set->GetValues(); + auto* result_values = result.mutable_uint64set_values(); + result_values->mutable_values()->Reserve(uint64_values.size()); + result_values->mutable_values()->Add(uint64_values.begin(), + uint64_values.end()); + } else { + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + status->set_message(absl::StrCat("Key not found: ", key)); + } + (*response.mutable_kv_pairs())[key] = std::move(result); + } + return response; } absl::StatusOr RunQuery( const RequestContext& request_context, std::string query) const override { - return ProcessQuery>>( + return ProcessQuery( request_context, std::move(query), [](const RequestContext& request_context, const Driver& driver, - const Cache& cache) { + const Cache& cache) -> absl::StatusOr { auto get_key_value_set_result = cache.GetKeyValueSet( request_context, driver.GetRootNode()->Keys()); - return driver.EvaluateQuery>( - [&get_key_value_set_result](std::string_view key) { - return get_key_value_set_result->GetValueSet(key); + auto eval_result = + driver.EvaluateQuery>( + [&get_key_value_set_result](std::string_view key) { + return get_key_value_set_result->GetValueSet(key); + }); + if (!eval_result.ok()) { + return eval_result.status(); + } + InternalRunQueryResponse response; + response.mutable_elements()->Reserve(eval_result->size()); + response.mutable_elements()->Assign(eval_result->begin(), + eval_result->end()); + return response; + }); + } + + absl::StatusOr RunSetQueryUInt32( + const RequestContext& request_context, std::string query) const override { + return ProcessQuery( + request_context, std::move(query), + [](const RequestContext& request_context, const Driver& driver, + const Cache& cache) + -> absl::StatusOr { + auto cache_result = cache.GetUInt32ValueSet( + request_context, driver.GetRootNode()->Keys()); + auto eval_result = driver.EvaluateQuery( + [&cache_result](std::string_view key) { + auto set = cache_result->GetUInt32ValueSet(key); + return set == nullptr ? UInt32ValueSet::bitset_type() + : set->GetValuesBitSet(); }); + if (!eval_result.ok()) { + return eval_result.status(); + } + auto uint32_set = BitSetToUint32Set(*eval_result); + InternalRunSetQueryUInt32Response response; + response.mutable_elements()->Reserve(uint32_set.size()); + response.mutable_elements()->Assign(uint32_set.begin(), + uint32_set.end()); + return response; }); } - absl::StatusOr RunSetQueryInt( + absl::StatusOr RunSetQueryUInt64( const RequestContext& request_context, std::string query) const override { - return ProcessQuery>>( + return ProcessQuery( request_context, std::move(query), [](const RequestContext& request_context, const Driver& driver, const Cache& cache) - -> absl::StatusOr> { - auto get_key_value_set_result = cache.GetUInt32ValueSet( + -> absl::StatusOr { + auto cache_result = cache.GetUInt64ValueSet( request_context, driver.GetRootNode()->Keys()); - auto query_eval_result = driver.EvaluateQuery( - [&get_key_value_set_result](std::string_view key) { - auto set = get_key_value_set_result->GetUInt32ValueSet(key); - return set == nullptr ? roaring::Roaring() + auto eval_result = driver.EvaluateQuery( + [&cache_result](std::string_view key) { + auto set = cache_result->GetUInt64ValueSet(key); + return set == nullptr ? UInt64ValueSet::bitset_type() : set->GetValuesBitSet(); }); - if (!query_eval_result.ok()) { - return query_eval_result.status(); + if (!eval_result.ok()) { + return eval_result.status(); } - return BitSetToUint32Set(*query_eval_result); + auto uint64_set = BitSetToUint64Set(*eval_result); + InternalRunSetQueryUInt64Response response; + response.mutable_elements()->Reserve(uint64_set.size()); + response.mutable_elements()->Assign(uint64_set.begin(), + uint64_set.end()); + return response; }); } @@ -124,68 +239,11 @@ class LocalLookup : public Lookup { return response; } - absl::StatusOr ProcessValueSetKeys( - const RequestContext& request_context, - const absl::flat_hash_set& key_set, - SingleLookupResult::SingleLookupResultCase set_type) const { - ScopeLatencyMetricsRecorder - latency_recorder(request_context.GetInternalLookupMetricsContext()); - InternalLookupResponse response; - if (key_set.empty()) { - return response; - } - std::unique_ptr key_value_set_result; - if (set_type == SingleLookupResult::kKeysetValues) { - key_value_set_result = cache_.GetKeyValueSet(request_context, key_set); - } else if (set_type == SingleLookupResult::kUintsetValues) { - key_value_set_result = cache_.GetUInt32ValueSet(request_context, key_set); - } else { - return StatusWithErrorTag(absl::InvalidArgumentError(absl::StrCat( - "Unsupported set type: ", set_type)), - __FILE__, ErrorTag::kProcessValueSetKeys); - } - for (const auto& key : key_set) { - SingleLookupResult result; - bool is_empty_value_set = false; - if (set_type == SingleLookupResult::kKeysetValues) { - if (const auto value_set = key_value_set_result->GetValueSet(key); - !value_set.empty()) { - auto* keyset_values = result.mutable_keyset_values(); - keyset_values->mutable_values()->Reserve(value_set.size()); - keyset_values->mutable_values()->Add(value_set.begin(), - value_set.end()); - } else { - is_empty_value_set = true; - } - } - if (set_type == SingleLookupResult::kUintsetValues) { - if (const auto value_set = key_value_set_result->GetUInt32ValueSet(key); - value_set != nullptr && !value_set->GetValues().empty()) { - auto uint32_values = value_set->GetValues(); - auto* result_values = result.mutable_uintset_values(); - result_values->mutable_values()->Reserve(uint32_values.size()); - result_values->mutable_values()->Add(uint32_values.begin(), - uint32_values.end()); - } else { - is_empty_value_set = true; - } - } - if (is_empty_value_set) { - auto status = result.mutable_status(); - status->set_code(static_cast(absl::StatusCode::kNotFound)); - status->set_message(absl::StrCat("Key not found: ", key)); - } - (*response.mutable_kv_pairs())[key] = std::move(result); - } - return response; - } - - template + template absl::StatusOr ProcessQuery( const RequestContext& request_context, std::string query, - absl::AnyInvocable + absl::AnyInvocable( + const RequestContext&, const Driver&, const Cache&)> query_eval_fn) const { ScopeLatencyMetricsRecorder @@ -208,9 +266,7 @@ class LocalLookup : public Lookup { kLocalRunQueryFailure); return result.status(); } - ResponseType response; - response.mutable_elements()->Assign(result->begin(), result->end()); - return response; + return result; } const Cache& cache_; diff --git a/components/internal_server/local_lookup_test.cc b/components/internal_server/local_lookup_test.cc index 0f8c8e97..638f0f3f 100644 --- a/components/internal_server/local_lookup_test.cc +++ b/components/internal_server/local_lookup_test.cc @@ -14,6 +14,7 @@ #include "components/internal_server/local_lookup.h" +#include #include #include #include @@ -157,8 +158,9 @@ TEST_F(LocalLookupTest, GetUInt32ValueSets_KeysFound_Success) { auto response = local_lookup->GetUInt32ValueSet(GetRequestContext(), {"key1"}); ASSERT_TRUE(response.ok()); - EXPECT_THAT(response.value().kv_pairs().at("key1").uintset_values().values(), - testing::UnorderedElementsAreArray(values)); + EXPECT_THAT( + response.value().kv_pairs().at("key1").uint32set_values().values(), + testing::UnorderedElementsAreArray(values)); } TEST_F(LocalLookupTest, GetUInt32ValueSets_SetEmpty_Success) { @@ -246,7 +248,7 @@ TEST_F(LocalLookupTest, RunQuery_ParsingError_Error) { EXPECT_EQ(response.status().code(), absl::StatusCode::kInvalidArgument); } -TEST_F(LocalLookupTest, Verify_RunSetQueryInt_Success) { +TEST_F(LocalLookupTest, Verify_RunSetQueryUInt32_Success) { std::string query = "A"; UInt32ValueSet value_set; auto values = std::vector({10, 20, 30, 40, 50}); @@ -259,20 +261,91 @@ TEST_F(LocalLookupTest, Verify_RunSetQueryInt_Success) { GetUInt32ValueSet(_, absl::flat_hash_set{"A"})) .WillOnce(Return(std::move(mock_get_key_value_set_result))); auto local_lookup = CreateLocalLookup(mock_cache_); - auto response = local_lookup->RunSetQueryInt(GetRequestContext(), query); + auto response = local_lookup->RunSetQueryUInt32(GetRequestContext(), query); ASSERT_TRUE(response.ok()) << response.status(); EXPECT_THAT(response.value().elements(), testing::UnorderedElementsAreArray(values.begin(), values.end())); } -TEST_F(LocalLookupTest, Verify_RunSetQueryInt_ParsingError_Error) { +TEST_F(LocalLookupTest, Verify_RunSetQueryUInt32_ParsingError_Error) { std::string query = "someset|("; auto local_lookup = CreateLocalLookup(mock_cache_); - auto response = local_lookup->RunSetQueryInt(GetRequestContext(), query); + auto response = local_lookup->RunSetQueryUInt32(GetRequestContext(), query); EXPECT_FALSE(response.ok()); EXPECT_EQ(response.status().code(), absl::StatusCode::kInvalidArgument); } +TEST_F(LocalLookupTest, Verify_RunSetQueryUInt64_Success) { + std::string query = "A"; + UInt64ValueSet value_set; + auto uint64_max = std::numeric_limits::max(); + auto values = + std::vector({uint64_max - 10, uint64_max - 20, uint64_max - 30, + uint64_max - 40, uint64_max - 50}); + value_set.Add(absl::MakeSpan(values), 1); + auto mock_result = std::make_unique(); + EXPECT_CALL(*mock_result, GetUInt64ValueSet("A")) + .WillOnce(Return(&value_set)); + EXPECT_CALL(mock_cache_, + GetUInt64ValueSet(_, absl::flat_hash_set{"A"})) + .WillOnce(Return(std::move(mock_result))); + auto local_lookup = CreateLocalLookup(mock_cache_); + auto response = local_lookup->RunSetQueryUInt64(GetRequestContext(), query); + ASSERT_TRUE(response.ok()) << response.status(); + EXPECT_THAT(response.value().elements(), + testing::UnorderedElementsAreArray(values.begin(), values.end())); +} + +TEST_F(LocalLookupTest, Verify_RunSetQueryUInt64_ParsingError_Error) { + std::string query = "someset|("; + auto local_lookup = CreateLocalLookup(mock_cache_); + auto response = local_lookup->RunSetQueryUInt64(GetRequestContext(), query); + EXPECT_FALSE(response.ok()); + EXPECT_EQ(response.status().code(), absl::StatusCode::kInvalidArgument); +} + +TEST_F(LocalLookupTest, GetUInt64ValueSets_KeysFound_Success) { + auto uint64_max = std::numeric_limits::max(); + auto values = std::vector({uint64_max - 1000, uint64_max - 1001}); + UInt64ValueSet value_set; + value_set.Add(absl::MakeSpan(values), 1); + auto mock_get_key_value_set_result = + std::make_unique(); + EXPECT_CALL(*mock_get_key_value_set_result, GetUInt64ValueSet("key1")) + .WillOnce(Return(&value_set)); + EXPECT_CALL(mock_cache_, GetUInt64ValueSet(_, _)) + .WillOnce(Return(std::move(mock_get_key_value_set_result))); + auto local_lookup = CreateLocalLookup(mock_cache_); + auto response = + local_lookup->GetUInt64ValueSet(GetRequestContext(), {"key1"}); + ASSERT_TRUE(response.ok()); + EXPECT_THAT( + response.value().kv_pairs().at("key1").uint64set_values().values(), + testing::UnorderedElementsAreArray(values)); +} + +TEST_F(LocalLookupTest, GetUInt64ValueSets_SetEmpty_Success) { + auto mock_get_key_value_set_result = + std::make_unique(); + EXPECT_CALL(*mock_get_key_value_set_result, GetUInt64ValueSet("key1")) + .WillOnce(Return(nullptr)); + EXPECT_CALL(mock_cache_, GetUInt64ValueSet(_, _)) + .WillOnce(Return(std::move(mock_get_key_value_set_result))); + auto local_lookup = CreateLocalLookup(mock_cache_); + auto response = + local_lookup->GetUInt64ValueSet(GetRequestContext(), {"key1"}); + ASSERT_TRUE(response.ok()); + InternalLookupResponse expected; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key1" + value { status { code: 5 message: "Key not found: key1" } } + } + )pb", + &expected); + EXPECT_THAT(response.value(), EqualsProto(expected)); +} + } // namespace } // namespace kv_server diff --git a/components/internal_server/lookup.h b/components/internal_server/lookup.h index 363e1b7b..49072ed4 100644 --- a/components/internal_server/lookup.h +++ b/components/internal_server/lookup.h @@ -45,10 +45,17 @@ class Lookup { const RequestContext& request_context, const absl::flat_hash_set& key_set) const = 0; + virtual absl::StatusOr GetUInt64ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const = 0; + virtual absl::StatusOr RunQuery( const RequestContext& request_context, std::string query) const = 0; - virtual absl::StatusOr RunSetQueryInt( + virtual absl::StatusOr RunSetQueryUInt32( + const RequestContext& request_context, std::string query) const = 0; + + virtual absl::StatusOr RunSetQueryUInt64( const RequestContext& request_context, std::string query) const = 0; }; diff --git a/components/internal_server/lookup.proto b/components/internal_server/lookup.proto index a31b26a5..1a509f72 100644 --- a/components/internal_server/lookup.proto +++ b/components/internal_server/lookup.proto @@ -21,20 +21,8 @@ import "src/logger/logger.proto"; // Internal Lookup Service API. service InternalLookupService { - // Endpoint for querying the server's internal datastore. Should only be used - // within TEEs. - rpc InternalLookup(InternalLookupRequest) returns (InternalLookupResponse) {} - // Endpoint for querying the datastore over the network. rpc SecureLookup(SecureLookupRequest) returns (SecureLookupResponse) {} - - // Endpoint for running a query on the server's internal datastore. Should - // only be used within TEEs. - rpc InternalRunQuery(InternalRunQueryRequest) returns (InternalRunQueryResponse) {} - - // Endpoint for running a set query over unsigned int sets in the server's - // internal datastore. Should only be used within TEEs. - rpc InternalRunSetQueryInt(InternalRunSetQueryIntRequest) returns (InternalRunSetQueryIntResponse) {} } // Lookup request for internal datastore. @@ -89,7 +77,8 @@ message SingleLookupResult { string value = 1; google.rpc.Status status = 2; KeysetValues keyset_values = 3; - UInt32SetValues uintset_values = 4; + UInt32SetValues uint32set_values = 4; + UInt64SetValues uint64set_values = 5; } } @@ -103,6 +92,11 @@ message UInt32SetValues { repeated uint32 values = 1; } +// UInt64 set values +message UInt64SetValues { + repeated uint64 values = 1; +} + // Run Query request. message InternalRunQueryRequest { // Query to run. @@ -120,7 +114,7 @@ message InternalRunQueryResponse { } // Run Query request. -message InternalRunSetQueryIntRequest { +message InternalRunSetQueryUInt32Request { // Query to run. optional string query = 1; // Context useful for logging and tracing requests @@ -130,6 +124,21 @@ message InternalRunSetQueryIntRequest { } // Response for running a set query using sets of unsigned ints as input. -message InternalRunSetQueryIntResponse { +message InternalRunSetQueryUInt32Response { repeated uint32 elements = 1; } + +// Run Query request. +message InternalRunSetQueryUInt64Request { + // Query to run. + optional string query = 1; + // Context useful for logging and tracing requests + privacy_sandbox.server_common.LogContext log_context = 2; + // Consented debugging configuration + privacy_sandbox.server_common.ConsentedDebugConfiguration consented_debug_config = 3; +} + +// Response for running a set query using sets of unsigned ints as input. +message InternalRunSetQueryUInt64Response { + repeated uint64 elements = 1; +} diff --git a/components/internal_server/lookup_server_impl.cc b/components/internal_server/lookup_server_impl.cc index f4f6575c..bfaf8381 100644 --- a/components/internal_server/lookup_server_impl.cc +++ b/components/internal_server/lookup_server_impl.cc @@ -20,7 +20,7 @@ #include "absl/functional/any_invocable.h" #include "absl/log/log.h" #include "absl/status/status.h" -#include "components/data_server/request_handler/ohttp_server_encryptor.h" +#include "components/data_server/request_handler/encryption/ohttp_server_encryptor.h" #include "components/internal_server/lookup.h" #include "components/internal_server/string_padder.h" #include "google/protobuf/message.h" @@ -68,18 +68,6 @@ void LookupServiceImpl::ProcessKeysetKeys( } } -grpc::Status LookupServiceImpl::InternalLookup( - grpc::ServerContext* context, const InternalLookupRequest* request, - InternalLookupResponse* response) { - RequestContext request_context; - if (context->IsCancelled()) { - return grpc::Status(grpc::StatusCode::CANCELLED, - "Deadline exceeded or client cancelled, abandoning."); - } - ProcessKeys(request_context, request->keys(), *response); - return grpc::Status::OK; -} - grpc::Status LookupServiceImpl::SecureLookup( grpc::ServerContext* context, const SecureLookupRequest* secure_lookup_request, @@ -153,26 +141,4 @@ std::string LookupServiceImpl::GetPayload( return response.SerializeAsString(); } -grpc::Status LookupServiceImpl::InternalRunQuery( - grpc::ServerContext* context, const InternalRunQueryRequest* request, - InternalRunQueryResponse* response) { - return RunSetQuery( - context, request, response, - [this](const RequestContext& request_context, std::string query) { - return lookup_.RunQuery(request_context, query); - }); -} - -grpc::Status LookupServiceImpl::InternalRunSetQueryInt( - grpc::ServerContext* context, - const kv_server::InternalRunSetQueryIntRequest* request, - kv_server::InternalRunSetQueryIntResponse* response) { - return RunSetQuery( - context, request, response, - [this](const RequestContext& request_context, std::string query) { - return lookup_.RunSetQueryInt(request_context, query); - }); -} - } // namespace kv_server diff --git a/components/internal_server/lookup_server_impl.h b/components/internal_server/lookup_server_impl.h index ac3040d0..fbf2573c 100644 --- a/components/internal_server/lookup_server_impl.h +++ b/components/internal_server/lookup_server_impl.h @@ -38,25 +38,10 @@ class LookupServiceImpl final ~LookupServiceImpl() override = default; - grpc::Status InternalLookup( - grpc::ServerContext* context, - const kv_server::InternalLookupRequest* request, - kv_server::InternalLookupResponse* response) override; - grpc::Status SecureLookup(grpc::ServerContext* context, const kv_server::SecureLookupRequest* request, kv_server::SecureLookupResponse* response) override; - grpc::Status InternalRunQuery( - grpc::ServerContext* context, - const kv_server::InternalRunQueryRequest* request, - kv_server::InternalRunQueryResponse* response) override; - - grpc::Status InternalRunSetQueryInt( - grpc::ServerContext* context, - const kv_server::InternalRunSetQueryIntRequest* request, - kv_server::InternalRunSetQueryIntResponse* response) override; - private: std::string GetPayload( const RequestContext& request_context, const bool lookup_sets, @@ -72,28 +57,6 @@ class LookupServiceImpl final InternalLookupMetricsContext& metrics_context, const absl::Status& status, std::string_view error_code) const; - template - grpc::Status RunSetQuery(grpc::ServerContext* context, - const RequestType* request, ResponseType* response, - absl::AnyInvocable( - const RequestContext&, std::string)> - run_set_query_fn) { - RequestContext request_context; - if (context->IsCancelled()) { - return grpc::Status(grpc::StatusCode::CANCELLED, - "Deadline exceeded or client cancelled, abandoning."); - } - const auto process_result = - run_set_query_fn(request_context, request->query()); - if (!process_result.ok()) { - return ToInternalGrpcStatus( - request_context.GetInternalLookupMetricsContext(), - process_result.status(), kInternalRunQueryRequestFailure); - } - *response = std::move(*process_result); - return grpc::Status::OK; - } - const Lookup& lookup_; privacy_sandbox::server_common::KeyFetcherManagerInterface& key_fetcher_manager_; diff --git a/components/internal_server/lookup_server_impl_test.cc b/components/internal_server/lookup_server_impl_test.cc index 277d13e5..62ec43f5 100644 --- a/components/internal_server/lookup_server_impl_test.cc +++ b/components/internal_server/lookup_server_impl_test.cc @@ -15,6 +15,7 @@ #include "components/internal_server/lookup_server_impl.h" +#include #include #include "components/internal_server/mocks.h" @@ -57,74 +58,6 @@ class LookupServiceImplTest : public ::testing::Test { std::unique_ptr stub_; }; -TEST_F(LookupServiceImplTest, InternalLookup_Success) { - InternalLookupRequest request; - request.add_keys("key1"); - request.add_keys("key2"); - InternalLookupResponse expected; - TextFormat::ParseFromString(R"pb(kv_pairs { - key: "key1" - value { value: "value1" } - } - kv_pairs { - key: "key2" - value { value: "value2" } - } - )pb", - &expected); - EXPECT_CALL(mock_lookup_, GetKeyValues(_, _)).WillOnce(Return(expected)); - - InternalLookupResponse response; - grpc::ClientContext context; - - grpc::Status status = stub_->InternalLookup(&context, request, &response); - EXPECT_THAT(response, EqualsProto(expected)); -} - -TEST_F(LookupServiceImplTest, - InternalLookup_LookupReturnsStatus_EmptyResponse) { - InternalLookupRequest request; - request.add_keys("key1"); - request.add_keys("key2"); - EXPECT_CALL(mock_lookup_, GetKeyValues(_, _)) - .WillOnce(Return(absl::UnknownError("Some error"))); - - InternalLookupResponse response; - grpc::ClientContext context; - - grpc::Status status = stub_->InternalLookup(&context, request, &response); - InternalLookupResponse expected; - EXPECT_THAT(response, EqualsProto(expected)); -} - -TEST_F(LookupServiceImplTest, InternalRunQuery_Success) { - InternalRunQueryRequest request; - request.set_query("someset"); - - InternalRunQueryResponse expected; - expected.add_elements("value1"); - expected.add_elements("value2"); - EXPECT_CALL(mock_lookup_, RunQuery(_, _)).WillOnce(Return(expected)); - InternalRunQueryResponse response; - grpc::ClientContext context; - grpc::Status status = stub_->InternalRunQuery(&context, request, &response); - auto results = response.elements(); - EXPECT_THAT(results, - testing::UnorderedElementsAreArray({"value1", "value2"})); -} - -TEST_F(LookupServiceImplTest, InternalRunQuery_LookupError_Failure) { - InternalRunQueryRequest request; - request.set_query("fail|||||now"); - EXPECT_CALL(mock_lookup_, RunQuery(_, _)) - .WillOnce(Return(absl::UnknownError("Some error"))); - InternalRunQueryResponse response; - grpc::ClientContext context; - grpc::Status status = stub_->InternalRunQuery(&context, request, &response); - - EXPECT_EQ(status.error_code(), grpc::StatusCode::INTERNAL); -} - TEST_F(LookupServiceImplTest, SecureLookupFailure) { SecureLookupRequest secure_lookup_request; secure_lookup_request.set_ohttp_request("garbage"); @@ -135,34 +68,6 @@ TEST_F(LookupServiceImplTest, SecureLookupFailure) { EXPECT_EQ(status.error_code(), grpc::StatusCode::INTERNAL); } -TEST_F(LookupServiceImplTest, InternalRunSetQueryInt_Success) { - InternalRunSetQueryIntRequest request; - request.set_query("someset"); - InternalRunSetQueryIntResponse expected; - expected.add_elements(1000); - expected.add_elements(1001); - expected.add_elements(1002); - EXPECT_CALL(mock_lookup_, RunSetQueryInt(_, _)).WillOnce(Return(expected)); - InternalRunSetQueryIntResponse response; - grpc::ClientContext context; - grpc::Status status = - stub_->InternalRunSetQueryInt(&context, request, &response); - auto results = response.elements(); - EXPECT_THAT(results, testing::UnorderedElementsAreArray({1000, 1001, 1002})); -} - -TEST_F(LookupServiceImplTest, InternalRunSetQueryInt_LookupError_Failure) { - InternalRunSetQueryIntRequest request; - request.set_query("fail|||||now"); - EXPECT_CALL(mock_lookup_, RunSetQueryInt(_, _)) - .WillOnce(Return(absl::UnknownError("Some error"))); - InternalRunSetQueryIntResponse response; - grpc::ClientContext context; - grpc::Status status = - stub_->InternalRunSetQueryInt(&context, request, &response); - EXPECT_EQ(status.error_code(), grpc::StatusCode::INTERNAL); -} - } // namespace } // namespace kv_server diff --git a/components/internal_server/mocks.h b/components/internal_server/mocks.h index c72f5487..104d24aa 100644 --- a/components/internal_server/mocks.h +++ b/components/internal_server/mocks.h @@ -53,10 +53,18 @@ class MockLookup : public Lookup { (const RequestContext&, const absl::flat_hash_set&), (const, override)); + MOCK_METHOD(absl::StatusOr, GetUInt64ValueSet, + (const RequestContext&, + const absl::flat_hash_set&), + (const, override)); MOCK_METHOD(absl::StatusOr, RunQuery, (const RequestContext&, std::string query), (const, override)); - MOCK_METHOD(absl::StatusOr, RunSetQueryInt, - (const RequestContext&, std::string query), (const, override)); + MOCK_METHOD(absl::StatusOr, + RunSetQueryUInt32, (const RequestContext&, std::string query), + (const, override)); + MOCK_METHOD(absl::StatusOr, + RunSetQueryUInt64, (const RequestContext&, std::string query), + (const, override)); }; } // namespace kv_server diff --git a/components/internal_server/remote_lookup_client_impl.cc b/components/internal_server/remote_lookup_client_impl.cc index a599daff..87b9b967 100644 --- a/components/internal_server/remote_lookup_client_impl.cc +++ b/components/internal_server/remote_lookup_client_impl.cc @@ -18,7 +18,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "components/data_server/request_handler/ohttp_client_encryptor.h" +#include "components/data_server/request_handler/encryption/ohttp_client_encryptor.h" #include "components/internal_server/constants.h" #include "components/internal_server/lookup.grpc.pb.h" #include "components/internal_server/remote_lookup_client.h" diff --git a/components/internal_server/sharded_lookup.cc b/components/internal_server/sharded_lookup.cc index 2e98bf7a..b26c692d 100644 --- a/components/internal_server/sharded_lookup.cc +++ b/components/internal_server/sharded_lookup.cc @@ -22,7 +22,7 @@ #include #include "absl/log/check.h" -#include "components/data_server/cache/uint32_value_set.h" +#include "components/data_server/cache/uint_value_set.h" #include "components/internal_server/lookup.h" #include "components/internal_server/lookup.pb.h" #include "components/internal_server/remote_lookup_client.h" @@ -71,12 +71,13 @@ class ShardedLookup : public Lookup { explicit ShardedLookup(const Lookup& local_lookup, const int32_t num_shards, const int32_t current_shard_num, const ShardManager& shard_manager, - KeySharder key_sharder) + KeySharder key_sharder, bool add_chaff = true) : local_lookup_(local_lookup), num_shards_(num_shards), current_shard_num_(current_shard_num), shard_manager_(shard_manager), - key_sharder_(std::move(key_sharder)) { + key_sharder_(std::move(key_sharder)), + add_chaff_(add_chaff) { CHECK_GT(num_shards, 1) << "num_shards for ShardedLookup must be > 1"; } @@ -100,13 +101,109 @@ class ShardedLookup : public Lookup { absl::StatusOr GetKeyValueSet( const RequestContext& request_context, const absl::flat_hash_set& keys) const override { - return GetKeyValueSets(request_context, keys); + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetUdfRequestMetricsContext()); + InternalLookupResponse response; + if (keys.empty()) { + return response; + } + auto maybe_result = + GetShardedKeyValueSet(request_context, keys); + if (!maybe_result.ok()) { + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedGetKeyValueSetKeySetRetrievalFailure); + return maybe_result.status(); + } + for (const auto& key : keys) { + SingleLookupResult result; + if (const auto key_iter = maybe_result->find(key); + key_iter == maybe_result->end()) { + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedGetKeyValueSetKeySetNotFound); + } else { + auto* keyset_values = result.mutable_keyset_values(); + keyset_values->mutable_values()->Reserve(key_iter->second.size()); + keyset_values->mutable_values()->Add(key_iter->second.begin(), + key_iter->second.end()); + } + (*response.mutable_kv_pairs())[key] = std::move(result); + } + return response; } absl::StatusOr GetUInt32ValueSet( const RequestContext& request_context, const absl::flat_hash_set& key_set) const override { - return GetKeyValueSets(request_context, key_set); + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetUdfRequestMetricsContext()); + InternalLookupResponse response; + if (key_set.empty()) { + return response; + } + auto maybe_result = + GetShardedKeyValueSet(request_context, key_set); + if (!maybe_result.ok()) { + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedGetUInt32ValueSetKeySetRetrievalFailure); + return maybe_result.status(); + } + for (const auto& key : key_set) { + SingleLookupResult result; + if (const auto key_iter = maybe_result->find(key); + key_iter == maybe_result->end()) { + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedGetUInt32ValueSetKeySetNotFound); + } else { + auto* uint32set_values = result.mutable_uint32set_values(); + uint32set_values->mutable_values()->Reserve(key_iter->second.size()); + uint32set_values->mutable_values()->Add(key_iter->second.begin(), + key_iter->second.end()); + } + (*response.mutable_kv_pairs())[key] = std::move(result); + } + return response; + } + + absl::StatusOr GetUInt64ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const override { + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetUdfRequestMetricsContext()); + InternalLookupResponse response; + if (key_set.empty()) { + return response; + } + auto maybe_result = + GetShardedKeyValueSet(request_context, key_set); + if (!maybe_result.ok()) { + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedGetUInt64ValueSetKeySetRetrievalFailure); + return maybe_result.status(); + } + for (const auto& key : key_set) { + SingleLookupResult result; + if (const auto key_iter = maybe_result->find(key); + key_iter == maybe_result->end()) { + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedGetUInt64ValueSetKeySetNotFound); + } else { + auto* uint64set_values = result.mutable_uint64set_values(); + uint64set_values->mutable_values()->Reserve(key_iter->second.size()); + uint64set_values->mutable_values()->Add(key_iter->second.begin(), + key_iter->second.end()); + } + (*response.mutable_kv_pairs())[key] = std::move(result); + } + return response; } absl::StatusOr RunQuery( @@ -114,15 +211,20 @@ class ShardedLookup : public Lookup { ScopeLatencyMetricsRecorder latency_recorder(request_context.GetUdfRequestMetricsContext()); - InternalRunQueryResponse response; if (query.empty()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), kShardedRunQueryEmptyQuery); - return response; + return InternalRunQueryResponse(); } - auto result = - RunSetQuery, std::string>( - request_context, query); + auto result = RunSetQuery, + std::string, InternalRunQueryResponse>( + request_context, query, [](const auto& result_set) { + InternalRunQueryResponse response; + response.mutable_elements()->Reserve(result_set.size()); + response.mutable_elements()->Assign(result_set.begin(), + result_set.end()); + return response; + }); if (!result.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), kShardedRunQueryFailure); @@ -130,42 +232,80 @@ class ShardedLookup : public Lookup { } PS_VLOG(8, request_context.GetPSLogContext()) << "Driver results for query " << query; - for (const auto& value : *result) { + for (const auto& value : result->elements()) { PS_VLOG(8, request_context.GetPSLogContext()) << "Value: " << value << "\n"; } - response.mutable_elements()->Assign(result->begin(), result->end()); - return response; + return result; } - absl::StatusOr RunSetQueryInt( + absl::StatusOr RunSetQueryUInt32( const RequestContext& request_context, std::string query) const override { ScopeLatencyMetricsRecorder + kShardedLookupRunSetQueryUInt32LatencyInMicros> latency_recorder(request_context.GetUdfRequestMetricsContext()); - InternalRunSetQueryIntResponse response; if (query.empty()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedRunQueryEmptyQuery); + kShardedRunSetQueryUInt32EmptyQuery); + return InternalRunSetQueryUInt32Response(); + } + auto result = RunSetQuery( + request_context, query, [](const auto& result_set) { + InternalRunSetQueryUInt32Response response; + auto uint32_set = BitSetToUint32Set(result_set); + response.mutable_elements()->Reserve(uint32_set.size()); + response.mutable_elements()->Assign(uint32_set.begin(), + uint32_set.end()); + return response; + }); + if (!result.ok()) { + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedRunSetQueryUInt32Failure); + return result.status(); + } + PS_VLOG(8, request_context.GetPSLogContext()) + << "Driver results for query " << query; + for (const auto& value : result->elements()) { + PS_VLOG(8, request_context.GetPSLogContext()) + << "Value: " << value << "\n"; + } + return result; + } + + absl::StatusOr RunSetQueryUInt64( + const RequestContext& request_context, std::string query) const override { + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetUdfRequestMetricsContext()); + InternalRunSetQueryUInt64Response response; + if (query.empty()) { + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedRunSetQueryUInt64EmptyQuery); return response; } - auto result = - RunSetQuery(request_context, query); + auto result = RunSetQuery( + request_context, query, [](const auto& result_set) { + InternalRunSetQueryUInt64Response response; + auto uint64_set = BitSetToUint64Set(result_set); + response.mutable_elements()->Reserve(uint64_set.size()); + response.mutable_elements()->Assign(uint64_set.begin(), + uint64_set.end()); + return response; + }); if (!result.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedRunQueryFailure); + kShardedRunSetQueryUInt64Failure); return result.status(); } PS_VLOG(8, request_context.GetPSLogContext()) << "Driver results for query " << query; - for (const auto& value : *result) { + for (const auto& value : result->elements()) { PS_VLOG(8, request_context.GetPSLogContext()) << "Value: " << value << "\n"; } - auto uint32_set = BitSetToUint32Set(*result); - response.mutable_elements()->Reserve(uint32_set.size()); - response.mutable_elements()->Assign(uint32_set.begin(), uint32_set.end()); - return response; + return result; } private: @@ -264,10 +404,18 @@ class ShardedLookup : public Lookup { } responses.push_back(std::async( std::launch::async, - [client, &request_context](std::string_view serialized_request, - int32_t padding) { - return client->GetValues(request_context, serialized_request, - padding); + [client, &request_context, add_chaff = add_chaff_, + keys = shard_lookup_input.keys]( + std::string_view serialized_request, int32_t padding) { + if (!add_chaff && keys.empty()) { + InternalLookupResponse response; + absl::StatusOr maybe_response = + response; + return maybe_response; + } else { + return client->GetValues(request_context, serialized_request, + padding); + } }, shard_lookup_input.serialized_request, shard_lookup_input.padding)); } @@ -292,9 +440,12 @@ class ShardedLookup : public Lookup { if constexpr (result_type == SingleLookupResult::kKeysetValues) { return local_lookup_.GetKeyValueSet(request_context, keys); } - if constexpr (result_type == SingleLookupResult::kUintsetValues) { + if constexpr (result_type == SingleLookupResult::kUint32SetValues) { return local_lookup_.GetUInt32ValueSet(request_context, keys); } + if constexpr (result_type == SingleLookupResult::kUint64SetValues) { + return local_lookup_.GetUInt64ValueSet(request_context, keys); + } } absl::StatusOr ProcessShardedKeys( @@ -353,8 +504,18 @@ class ShardedLookup : public Lookup { } if constexpr (std::is_same_v) { if (keyset_lookup_result.single_lookup_result_case() == - SingleLookupResult::kUintsetValues) { - for (auto& v : keyset_lookup_result.uintset_values().values()) { + SingleLookupResult::kUint32SetValues) { + for (auto& v : keyset_lookup_result.uint32set_values().values()) { + PS_VLOG(8, request_context.GetPSLogContext()) + << "keyset name: " << key << " value: " << v; + value_set.emplace(std::move(v)); + } + } + } + if constexpr (std::is_same_v) { + if (keyset_lookup_result.single_lookup_result_case() == + SingleLookupResult::kUint64SetValues) { + for (auto& v : keyset_lookup_result.uint64set_values().values()) { PS_VLOG(8, request_context.GetPSLogContext()) << "keyset name: " << key << " value: " << v; value_set.emplace(std::move(v)); @@ -391,7 +552,11 @@ class ShardedLookup : public Lookup { request_context, key_list); } if constexpr (std::is_same_v) { - return GetLocalLookupResponse( + return GetLocalLookupResponse( + request_context, key_list); + } + if constexpr (std::is_same_v) { + return GetLocalLookupResponse( request_context, key_list); } }); @@ -415,57 +580,20 @@ class ShardedLookup : public Lookup { return key_sets; } - template - absl::StatusOr GetKeyValueSets( - const RequestContext& request_context, - const absl::flat_hash_set& keys) const { - ScopeLatencyMetricsRecorder - latency_recorder(request_context.GetUdfRequestMetricsContext()); - InternalLookupResponse response; - if (keys.empty()) { - return response; - } - absl::flat_hash_map> - key_sets; - auto get_key_value_set_result_maybe = - GetShardedKeyValueSet(request_context, keys); - if (!get_key_value_set_result_maybe.ok()) { - LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedGetKeyValueSetKeySetRetrievalFailure); - return get_key_value_set_result_maybe.status(); - } - key_sets = *std::move(get_key_value_set_result_maybe); - for (const auto& key : keys) { - SingleLookupResult result; - if (const auto key_iter = key_sets.find(key); - key_iter == key_sets.end()) { - auto status = result.mutable_status(); - status->set_code(static_cast(absl::StatusCode::kNotFound)); - LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedGetKeyValueSetKeySetNotFound); - } else { - if constexpr (std::is_same_v) { - auto* keyset_values = result.mutable_keyset_values(); - keyset_values->mutable_values()->Reserve(key_iter->second.size()); - keyset_values->mutable_values()->Add(key_iter->second.begin(), - key_iter->second.end()); - } - if constexpr (std::is_same_v) { - auto* uint32set_values = result.mutable_uintset_values(); - uint32set_values->mutable_values()->Reserve(key_iter->second.size()); - uint32set_values->mutable_values()->Add(key_iter->second.begin(), - key_iter->second.end()); - } - } - (*response.mutable_kv_pairs())[key] = std::move(result); + template + static BitsetType ToBitset(const absl::flat_hash_set& set) { + BitsetType bitset; + for (const auto& element : set) { + bitset.add(element); } - return response; + bitset.runOptimize(); + return bitset; } - template - absl::StatusOr RunSetQuery(const RequestContext& request_context, - std::string query) const { + template + absl::StatusOr RunSetQuery( + const RequestContext& request_context, std::string query, + absl::AnyInvocable to_response_fn) const { kv_server::Driver driver; std::istringstream stream(query); kv_server::Scanner scanner(stream); @@ -476,38 +604,38 @@ class ShardedLookup : public Lookup { kShardedRunQueryParsingFailure); return absl::InvalidArgumentError("Parsing failure."); } - auto get_key_value_set_result_maybe = GetShardedKeyValueSet( + auto key_value_result = GetShardedKeyValueSet( request_context, driver.GetRootNode()->Keys()); - if (!get_key_value_set_result_maybe.ok()) { + if (!key_value_result.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), kShardedRunQueryKeySetRetrievalFailure); - return get_key_value_set_result_maybe.status(); - } - auto keysets = std::move(*get_key_value_set_result_maybe); - return driver.EvaluateQuery([&keysets, &request_context]( - std::string_view key) { - const auto key_iter = keysets.find(key); - if (key_iter == keysets.end()) { - PS_VLOG(8, request_context.GetPSLogContext()) - << "Driver can't find " << key << "key_set. Returning empty."; - LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedRunQueryMissingKeySet); - return SetType(); - } - if constexpr (std::is_same_v>) { - return absl::flat_hash_set(key_iter->second.begin(), - key_iter->second.end()); - } - if constexpr (std::is_same_v) { - roaring::Roaring bitset; - for (const auto& element : key_iter->second) { - bitset.add(element); - } - bitset.runOptimize(); - return bitset; - } - }); + return key_value_result.status(); + } + auto query_result = driver.EvaluateQuery( + [&key_value_result, &request_context](std::string_view key) { + const auto key_iter = key_value_result->find(key); + if (key_iter == key_value_result->end()) { + PS_VLOG(8, request_context.GetPSLogContext()) + << "Driver can't find " << key << "key_set. Returning empty."; + LogUdfRequestErrorMetric( + request_context.GetUdfRequestMetricsContext(), + kShardedRunQueryMissingKeySet); + return SetType(); + } + if constexpr (std::is_same_v>) { + return absl::flat_hash_set( + key_iter->second.begin(), key_iter->second.end()); + } + if constexpr (std::is_same_v || + std::is_same_v) { + return ToBitset(key_iter->second); + } + }); + if (!query_result.ok()) { + return query_result.status(); + } + return to_response_fn(*query_result); } const Lookup& local_lookup_; @@ -516,6 +644,10 @@ class ShardedLookup : public Lookup { const std::string hashing_seed_; const ShardManager& shard_manager_; KeySharder key_sharder_; + // For prod this flag is always true. + // When this flag is on we always query all shards. This is done for + // privacy reasons. + const bool add_chaff_; }; } // namespace @@ -524,10 +656,11 @@ std::unique_ptr CreateShardedLookup(const Lookup& local_lookup, const int32_t num_shards, const int32_t current_shard_num, const ShardManager& shard_manager, - KeySharder key_sharder) { + KeySharder key_sharder, + bool add_chaff) { return std::make_unique(local_lookup, num_shards, current_shard_num, shard_manager, - std::move(key_sharder)); + std::move(key_sharder), add_chaff); } } // namespace kv_server diff --git a/components/internal_server/sharded_lookup.h b/components/internal_server/sharded_lookup.h index 2619af0a..4201bc79 100644 --- a/components/internal_server/sharded_lookup.h +++ b/components/internal_server/sharded_lookup.h @@ -30,7 +30,8 @@ std::unique_ptr CreateShardedLookup(const Lookup& local_lookup, const int32_t num_shards, const int32_t current_shard_num, const ShardManager& shard_manager, - KeySharder key_sharder); + KeySharder key_sharder, + bool add_chaff = true); } // namespace kv_server diff --git a/components/internal_server/sharded_lookup_test.cc b/components/internal_server/sharded_lookup_test.cc index 81021255..7a647448 100644 --- a/components/internal_server/sharded_lookup_test.cc +++ b/components/internal_server/sharded_lookup_test.cc @@ -32,6 +32,7 @@ namespace { using google::protobuf::TextFormat; using testing::_; using testing::Return; +using ::testing::StrictMock; class ShardedLookupTest : public ::testing::Test { protected: @@ -178,6 +179,87 @@ TEST_F(ShardedLookupTest, GetKeyValues_Success) { EXPECT_THAT(response.value(), EqualsProto(expected)); } +TEST_F(ShardedLookupTest, GetKeyValues_Nochaff_Success) { + const int num_shards_three = 3; + InternalLookupResponse local_lookup_response; + TextFormat::ParseFromString(R"pb(kv_pairs { + key: "key1" + value { value: "value1" } + } + )pb", + &local_lookup_response); + EXPECT_CALL(mock_local_lookup_, GetKeyValues(_, _)) + .WillOnce(Return(local_lookup_response)); + std::vector> cluster_mappings; + std::vector>> + remote_lookup_client; + for (int i = 0; i < num_shards_three; i++) { + cluster_mappings.push_back({std::to_string(i)}); + + // The key piece here is the `strict` mode. If `add_chaff` was set to true, + // this test would fail because `GetValues` would be called for a non local + // lookup. In strict mode instead of returning defaults, we fail the test + // for such an unexpected call. + if (i == 1) { + auto mock_remote_lookup_client_1 = + std::make_unique>(); + const std::vector key_list_remote = {"key4"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, + GetValues(_, serialized_request, 0)) + .WillOnce([&]() { + InternalLookupResponse resp; + SingleLookupResult result; + result.set_value("value4"); + (*resp.mutable_kv_pairs())["key4"] = result; + return resp; + }); + remote_lookup_client.push_back(std::move(mock_remote_lookup_client_1)); + } else { + remote_lookup_client.push_back( + std::make_unique>()); + } + } + auto shard_manager = ShardManager::Create( + num_shards_three, std::move(cluster_mappings), + std::make_unique(), + [this, &remote_lookup_client](const std::string& ip) { + return std::move(remote_lookup_client[stoi(ip)]); + }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_three, shard_num_, + *(*shard_manager), key_sharder_, /*add_chaff=*/false); + + // as part of this call for non-chaff we're making two requests. + // one local and one remote. With chaff it would have been three calls: + // one local and two remote (one chaff and one non-chaff). + auto response = + sharded_lookup->GetKeyValues(GetRequestContext(), {"key1", "key4"}); + EXPECT_TRUE(response.ok()); + InternalLookupResponse expected; + TextFormat::ParseFromString(R"pb( + kv_pairs { + key: "key1" + value { value: "value1" } + } + kv_pairs { + key: "key4" + value { value: "value4" } + } + )pb", + &expected); + EXPECT_THAT(response.value(), EqualsProto(expected)); +} + TEST_F(ShardedLookupTest, GetKeyValues_KeyMissing_ReturnsStatus) { InternalLookupResponse local_lookup_response; TextFormat::ParseFromString( @@ -720,7 +802,7 @@ TEST_F(ShardedLookupTest, GetUInt32ValueSets_KeysFound_Success) { TextFormat::ParseFromString( R"pb(kv_pairs { key: "key4" - value { uintset_values { values: 1000 } } + value { uint32set_values { values: 1000 } } } )pb", &local_lookup_response); @@ -756,7 +838,7 @@ TEST_F(ShardedLookupTest, GetUInt32ValueSets_KeysFound_Success) { TextFormat::ParseFromString( R"pb(kv_pairs { key: "key1" - value { uintset_values { values: 2000 } } + value { uint32set_values { values: 2000 } } } )pb", &resp); @@ -774,11 +856,11 @@ TEST_F(ShardedLookupTest, GetUInt32ValueSets_KeysFound_Success) { TextFormat::ParseFromString( R"pb(kv_pairs { key: "key1" - value { uintset_values { values: 2000 } } + value { uint32set_values { values: 2000 } } } kv_pairs { key: "key4" - value { uintset_values { values: 1000 } } + value { uint32set_values { values: 1000 } } } )pb", &expected); @@ -790,7 +872,7 @@ TEST_F(ShardedLookupTest, GetUInt32ValueSets_KeysMissing_ReturnsStatus) { TextFormat::ParseFromString( R"pb(kv_pairs { key: "key4" - value { uintset_values { values: 1000 } } + value { uint32set_values { values: 1000 } } } )pb", &local_lookup_response); @@ -853,7 +935,7 @@ TEST_F(ShardedLookupTest, GetUInt32ValueSets_KeysMissing_ReturnsStatus) { } kv_pairs { key: "key4" - value { uintset_values { values: 1000 } } + value { uint32set_values { values: 1000 } } } kv_pairs { key: "key5" @@ -1153,12 +1235,12 @@ TEST_F(ShardedLookupTest, RunQuery_EmptyRequest_EmptyResponse) { EXPECT_TRUE(response.value().elements().empty()); } -TEST_F(ShardedLookupTest, RunSetQueryInt_Success) { +TEST_F(ShardedLookupTest, RunSetQueryUInt32_Success) { InternalLookupResponse local_lookup_response; TextFormat::ParseFromString( R"pb(kv_pairs { key: "key4" - value { uintset_values { values: 1000 } } + value { uint32set_values { values: 1000 } } } )pb", &local_lookup_response); @@ -1195,7 +1277,7 @@ TEST_F(ShardedLookupTest, RunSetQueryInt_Success) { TextFormat::ParseFromString( R"pb(kv_pairs { key: "key1" - value { uintset_values { values: 2000 } } + value { uint32set_values { values: 2000 } } } )pb", &resp); @@ -1207,18 +1289,18 @@ TEST_F(ShardedLookupTest, RunSetQueryInt_Success) { CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, *(*shard_manager), key_sharder_); auto response = - sharded_lookup->RunSetQueryInt(GetRequestContext(), "key1|key4"); + sharded_lookup->RunSetQueryUInt32(GetRequestContext(), "key1|key4"); EXPECT_TRUE(response.ok()); EXPECT_THAT(response.value().elements(), testing::UnorderedElementsAreArray({1000, 2000})); } -TEST_F(ShardedLookupTest, RunSetQueryInt_ShardedLookupFails_Error) { +TEST_F(ShardedLookupTest, RunSetQueryUInt32_ShardedLookupFails_Error) { InternalLookupResponse local_lookup_response; TextFormat::ParseFromString( R"pb(kv_pairs { key: "key4" - value { uintset_values { values: 1000 } } + value { uint32set_values { values: 1000 } } } )pb", &local_lookup_response); @@ -1236,12 +1318,12 @@ TEST_F(ShardedLookupTest, RunSetQueryInt_ShardedLookupFails_Error) { CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, *(*shard_manager), key_sharder_); auto response = - sharded_lookup->RunSetQueryInt(GetRequestContext(), "key1|key4"); + sharded_lookup->RunSetQueryUInt32(GetRequestContext(), "key1|key4"); EXPECT_FALSE(response.ok()); EXPECT_THAT(response.status().code(), absl::StatusCode::kInternal); } -TEST_F(ShardedLookupTest, RunSetQueryInt_EmptyRequest_EmptyResponse) { +TEST_F(ShardedLookupTest, RunSetQueryUInt32_EmptyRequest_EmptyResponse) { std::vector> cluster_mappings; for (int i = 0; i < 2; i++) { cluster_mappings.push_back({std::to_string(i)}); @@ -1255,11 +1337,219 @@ TEST_F(ShardedLookupTest, RunSetQueryInt_EmptyRequest_EmptyResponse) { auto sharded_lookup = CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, *(*shard_manager), key_sharder_); - auto response = sharded_lookup->RunSetQueryInt(GetRequestContext(), ""); + auto response = sharded_lookup->RunSetQueryUInt32(GetRequestContext(), ""); EXPECT_TRUE(response.ok()); EXPECT_TRUE(response.value().elements().empty()); } +TEST_F(ShardedLookupTest, RunSetQueryUInt64_Success) { + InternalLookupResponse local_lookup_response; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key4" + value { uint64set_values { values: 18446744073709551 } } + } + )pb", + &local_lookup_response); + EXPECT_CALL(mock_local_lookup_, GetUInt64ValueSet(_, _)) + .WillOnce(Return(local_lookup_response)); + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [this](const std::string& ip) { + if (ip != "1") { + return std::make_unique(); + } + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + request.set_lookup_sets(true); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, + GetValues(_, serialized_request, 0)) + .WillOnce([&]() { + InternalLookupResponse resp; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key1" + value { + uint64set_values { values: 18446744073709552 } + } + } + )pb", + &resp); + return resp; + }); + return mock_remote_lookup_client_1; + }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = + sharded_lookup->RunSetQueryUInt64(GetRequestContext(), "key1|key4"); + EXPECT_TRUE(response.ok()); + EXPECT_THAT(response.value().elements(), + testing::UnorderedElementsAreArray( + {18446744073709551, 18446744073709552})); +} + +TEST_F(ShardedLookupTest, RunSetQueryUInt64_ShardedLookupFails_Error) { + InternalLookupResponse local_lookup_response; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key4" + value { uint64set_values { values: 18446744073709551 } } + } + )pb", + &local_lookup_response); + EXPECT_CALL(mock_local_lookup_, GetUInt64ValueSet(_, _)) + .WillOnce(Return(local_lookup_response)); + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = + ShardManager::Create(num_shards_, std::move(cluster_mappings), + std::make_unique(), + [](const std::string& ip) { return nullptr; }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = + sharded_lookup->RunSetQueryUInt64(GetRequestContext(), "key1|key4"); + EXPECT_FALSE(response.ok()); + EXPECT_THAT(response.status().code(), absl::StatusCode::kInternal); +} + +TEST_F(ShardedLookupTest, RunSetQueryUInt64_EmptyRequest_EmptyResponse) { + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [](const std::string& ip) { + return std::make_unique(); + }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = sharded_lookup->RunSetQueryUInt64(GetRequestContext(), ""); + EXPECT_TRUE(response.ok()); + EXPECT_TRUE(response.value().elements().empty()); +} + +TEST_F(ShardedLookupTest, GetUInt64ValueSets_KeysMissing_ReturnsStatus) { + InternalLookupResponse local_lookup_response; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key4" + value { uint64set_values { values: 18446744073709551 } } + } + )pb", + &local_lookup_response); + EXPECT_CALL(mock_local_lookup_, GetUInt64ValueSet(_, _)) + .WillOnce(Return(local_lookup_response)); + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [this](const std::string& ip) { + if (ip != "1") { + return std::make_unique(); + } + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1", "key5"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + request.set_lookup_sets(true); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, _, 0)) + .WillOnce([=](const RequestContext& request_context, + const std::string_view serialized_message, + const int32_t padding_length) { + InternalLookupRequest request; + EXPECT_TRUE(request.ParseFromString(serialized_message)); + auto request_keys = std::vector( + request.keys().begin(), request.keys().end()); + EXPECT_THAT(request.keys(), + testing::UnorderedElementsAreArray(key_list_remote)); + InternalLookupResponse resp; + SingleLookupResult result; + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + (*resp.mutable_kv_pairs())["key1"] = result; + return resp; + }); + return mock_remote_lookup_client_1; + }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = sharded_lookup->GetUInt64ValueSet(GetRequestContext(), + {"key1", "key4", "key5"}); + ASSERT_TRUE(response.ok()); + InternalLookupResponse expected; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key1" + value { status: { code: 5, message: "" } } + } + kv_pairs { + key: "key4" + value { uint64set_values { values: 18446744073709551 } } + } + kv_pairs { + key: "key5" + value { status: { code: 5, message: "" } } + } + )pb", + &expected); + EXPECT_THAT(response.value(), EqualsProto(expected)); +} + +TEST_F(ShardedLookupTest, GetUInt64ValueSet_EmptyRequest_ReturnsEmptyResponse) { + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [](const std::string& ip) { + return std::make_unique(); + }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = sharded_lookup->GetUInt64ValueSet(GetRequestContext(), {}); + EXPECT_TRUE(response.ok()); + + InternalLookupResponse expected; + EXPECT_THAT(response.value(), EqualsProto(expected)); +} + } // namespace } // namespace kv_server diff --git a/components/query/BUILD.bazel b/components/query/BUILD.bazel index 211ea2bc..b44f2097 100644 --- a/components/query/BUILD.bazel +++ b/components/query/BUILD.bazel @@ -30,7 +30,6 @@ cc_library( ], deps = [ "@com_google_absl//absl/container:flat_hash_set", - "@roaring_bitmap//:c_roaring", ], ) @@ -42,8 +41,8 @@ cc_test( ], deps = [ ":sets", + "//components/data_server/cache:uint_value_set", "@com_google_googletest//:gtest_main", - "@roaring_bitmap//:c_roaring", ], ) @@ -60,6 +59,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", + "@roaring_bitmap//:c_roaring", ], ) @@ -68,12 +71,12 @@ cc_test( size = "small", srcs = [ "ast_test.cc", + "template_test_utils.h", ], deps = [ ":ast", "@com_google_absl//absl/container:flat_hash_map", "@com_google_googletest//:gtest_main", - "@roaring_bitmap//:c_roaring", ], ) @@ -100,6 +103,7 @@ cc_test( size = "small", srcs = [ "driver_test.cc", + "template_test_utils.h", ], deps = [ ":driver", @@ -107,7 +111,9 @@ cc_test( ":scanner", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", + "@roaring_bitmap//:c_roaring", ], ) diff --git a/components/query/ast.cc b/components/query/ast.cc index a1d6e55d..dc188cb6 100644 --- a/components/query/ast.cc +++ b/components/query/ast.cc @@ -42,25 +42,46 @@ std::vector ComputePostfixOrder(const Node* root) { return result; } -std::string ValueNode::Accept(ASTStringVisitor& visitor) const { +absl::StatusOr ValueNode::Accept(ASTStringVisitor& visitor) const { return visitor.Visit(*this); } -std::string UnionNode::Accept(ASTStringVisitor& visitor) const { +absl::StatusOr UnionNode::Accept(ASTStringVisitor& visitor) const { return visitor.Visit(*this); } -std::string DifferenceNode::Accept(ASTStringVisitor& visitor) const { +absl::StatusOr DifferenceNode::Accept( + ASTStringVisitor& visitor) const { return visitor.Visit(*this); } -std::string IntersectionNode::Accept(ASTStringVisitor& visitor) const { +absl::StatusOr IntersectionNode::Accept( + ASTStringVisitor& visitor) const { return visitor.Visit(*this); } - -void ValueNode::Accept(ASTVisitor& visitor) const { visitor.Visit(*this); } -void UnionNode::Accept(ASTVisitor& visitor) const { visitor.Visit(*this); } -void IntersectionNode::Accept(ASTVisitor& visitor) const { - visitor.Visit(*this); +absl::StatusOr NumberSetNode::Accept( + ASTStringVisitor& visitor) const { + return visitor.Visit(*this); +} +absl::StatusOr StringViewSetNode::Accept( + ASTStringVisitor& visitor) const { + return visitor.Visit(*this); +} +absl::Status ValueNode::Accept(ASTVisitor& visitor) const { + return visitor.Visit(*this); +} +absl::Status UnionNode::Accept(ASTVisitor& visitor) const { + return visitor.Visit(*this); +} +absl::Status IntersectionNode::Accept(ASTVisitor& visitor) const { + return visitor.Visit(*this); +} +absl::Status DifferenceNode::Accept(ASTVisitor& visitor) const { + return visitor.Visit(*this); +} +absl::Status NumberSetNode::Accept(ASTVisitor& visitor) const { + return visitor.Visit(*this); +} +absl::Status StringViewSetNode::Accept(ASTVisitor& visitor) const { + return visitor.Visit(*this); } -void DifferenceNode::Accept(ASTVisitor& visitor) const { visitor.Visit(*this); } absl::flat_hash_set OpNode::Keys() const { std::vector nodes; diff --git a/components/query/ast.h b/components/query/ast.h index b01c0f00..9c5a8657 100644 --- a/components/query/ast.h +++ b/components/query/ast.h @@ -25,7 +25,13 @@ #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "components/query/sets.h" +#include "src/util/status_macro/status_macros.h" + +#include "roaring.hh" +#include "roaring64map.hh" namespace kv_server { // All set operations using `KVStringSetView` operate on a reference to the data @@ -36,6 +42,17 @@ using KVStringSetView = absl::flat_hash_set; class ASTVisitor; class ASTStringVisitor; +// SFINAE with std::void_t +template > +struct has_value_type : std::false_type {}; + +template +struct has_value_type> : std::true_type { +}; + +template +inline constexpr bool has_value_type_v = has_value_type::value; + class Node { public: virtual ~Node() = default; @@ -43,8 +60,9 @@ class Node { virtual Node* Right() const { return nullptr; } // Return all Keys associated with ValueNodes in the tree. virtual absl::flat_hash_set Keys() const = 0; - virtual void Accept(ASTVisitor& visitor) const = 0; - virtual std::string Accept(ASTStringVisitor& visitor) const = 0; + virtual absl::Status Accept(ASTVisitor& visitor) const = 0; + virtual absl::StatusOr Accept( + ASTStringVisitor& visitor) const = 0; }; // The value associated with a `ValueNode` is the set with its associated `key`. @@ -53,13 +71,45 @@ class ValueNode : public Node { explicit ValueNode(std::string key) : key_(std::move(key)) {} std::string_view Key() const { return key_; } absl::flat_hash_set Keys() const override; - void Accept(ASTVisitor& visitor) const override; - std::string Accept(ASTStringVisitor& visitor) const override; + absl::Status Accept(ASTVisitor& visitor) const override; + absl::StatusOr Accept(ASTStringVisitor& visitor) const override; private: std::string key_; }; +template +class SetNode : public Node { + public: + using value_type = T; + + explicit SetNode(std::vector values) + : values_(values.begin(), values.end()) {} + absl::flat_hash_set Keys() const override { return {}; }; + // TODO(b/371977043): Consider changing Vistor argument + // from const-ref to value. Then we can return an r-value and avoid copy. + const absl::flat_hash_set& GetValues() const { return values_; } + + private: + absl::flat_hash_set values_; +}; + +class NumberSetNode : public SetNode { + public: + using SetNode::SetNode; + absl::Status Accept(ASTVisitor& visitor) const override; + absl::StatusOr Accept(ASTStringVisitor& visitor) const override; +}; + +// View to strings who's lifetime is managed externally, +// typically the `Driver`. +class StringViewSetNode : public SetNode { + public: + using SetNode::SetNode; + absl::Status Accept(ASTVisitor& visitor) const override; + absl::StatusOr Accept(ASTStringVisitor& visitor) const override; +}; + class OpNode : public Node { public: OpNode(std::unique_ptr left, std::unique_ptr right) @@ -76,22 +126,22 @@ class OpNode : public Node { class UnionNode : public OpNode { public: using OpNode::OpNode; - void Accept(ASTVisitor& visitor) const override; - std::string Accept(ASTStringVisitor& visitor) const override; + absl::Status Accept(ASTVisitor& visitor) const override; + absl::StatusOr Accept(ASTStringVisitor& visitor) const override; }; class IntersectionNode : public OpNode { public: using OpNode::OpNode; - void Accept(ASTVisitor& visitor) const override; - std::string Accept(ASTStringVisitor& visitor) const override; + absl::Status Accept(ASTVisitor& visitor) const override; + absl::StatusOr Accept(ASTStringVisitor& visitor) const override; }; class DifferenceNode : public OpNode { public: using OpNode::OpNode; - void Accept(ASTVisitor& visitor) const override; - std::string Accept(ASTStringVisitor& visitor) const override; + absl::Status Accept(ASTVisitor& visitor) const override; + absl::StatusOr Accept(ASTStringVisitor& visitor) const override; }; // Traverses the binary tree starting at root and returns a vector of `Node`s in @@ -103,10 +153,12 @@ std::vector ComputePostfixOrder(const Node* root); // upon inspection. class ASTStringVisitor { public: - virtual std::string Visit(const UnionNode&) = 0; - virtual std::string Visit(const DifferenceNode&) = 0; - virtual std::string Visit(const IntersectionNode&) = 0; - virtual std::string Visit(const ValueNode&) = 0; + virtual absl::StatusOr Visit(const UnionNode&) = 0; + virtual absl::StatusOr Visit(const DifferenceNode&) = 0; + virtual absl::StatusOr Visit(const IntersectionNode&) = 0; + virtual absl::StatusOr Visit(const ValueNode&) = 0; + virtual absl::StatusOr Visit(const NumberSetNode&) = 0; + virtual absl::StatusOr Visit(const StringViewSetNode&) = 0; }; // Defines a general AST visitor interface which can be extended to implement @@ -114,11 +166,14 @@ class ASTStringVisitor { class ASTVisitor { public: // Entrypoint for running the visitor algorithm on a given AST tree, `root`. - virtual void ConductVisit(const Node& root) = 0; - virtual void Visit(const ValueNode& node) = 0; - virtual void Visit(const UnionNode& node) = 0; - virtual void Visit(const DifferenceNode& node) = 0; - virtual void Visit(const IntersectionNode& node) = 0; + virtual ~ASTVisitor() = default; + virtual absl::Status ConductVisit(const Node& root) = 0; + virtual absl::Status Visit(const ValueNode& node) = 0; + virtual absl::Status Visit(const UnionNode& node) = 0; + virtual absl::Status Visit(const DifferenceNode& node) = 0; + virtual absl::Status Visit(const IntersectionNode& node) = 0; + virtual absl::Status Visit(const NumberSetNode& node) = 0; + virtual absl::Status Visit(const StringViewSetNode& node) = 0; }; // Implements AST tree evaluation using iterative post order processing. @@ -129,24 +184,52 @@ class ASTPostOrderEvalVisitor final : public ASTVisitor { absl::AnyInvocable lookup_fn) : lookup_fn_(std::move(lookup_fn)) {} - void ConductVisit(const Node& root) override { + absl::Status ConductVisit(const Node& root) override { stack_.clear(); for (const auto* node : ComputePostfixOrder(&root)) { - node->Accept(*this); + PS_RETURN_IF_ERROR(node->Accept(*this)); } + return absl::OkStatus(); } - void Visit(const ValueNode& node) override { + absl::Status Visit(const ValueNode& node) override { stack_.push_back(std::move(lookup_fn_(node.Key()))); + return absl::OkStatus(); + } + absl::Status Visit(const UnionNode& node) override { + return Visit(node, Union); } - void Visit(const UnionNode& node) override { Visit(node, Union); } - void Visit(const DifferenceNode& node) override { - Visit(node, Difference); + absl::Status Visit(const DifferenceNode& node) override { + return Visit(node, Difference); } - void Visit(const IntersectionNode& node) override { - Visit(node, Intersection); + absl::Status Visit(const IntersectionNode& node) override { + return Visit(node, Intersection); + } + + absl::Status Visit(const NumberSetNode& node) override { + if constexpr (std::is_same_v || + std::is_same_v) { + ValueT r; + for (const auto v : node.GetValues()) { + r.add(v); + } + stack_.push_back(std::move(r)); + return absl::OkStatus(); + } + return absl::InvalidArgumentError("Unexpected set type"); + } + + absl::Status Visit(const StringViewSetNode& node) override { + if constexpr (has_value_type_v) { + if constexpr (std::is_same_v) { + stack_.push_back(node.GetValues()); + return absl::OkStatus(); + } + } + return absl::InvalidArgumentError("Unexpected set type"); } ValueT GetResult() { @@ -157,13 +240,14 @@ class ASTPostOrderEvalVisitor final : public ASTVisitor { } private: - void Visit(const OpNode& node, - absl::AnyInvocable op_fn) { + absl::Status Visit(const OpNode& node, + absl::AnyInvocable op_fn) { auto right = std::move(stack_.back()); stack_.pop_back(); auto left = std::move(stack_.back()); stack_.pop_back(); stack_.push_back(op_fn(std::move(left), std::move(right))); + return absl::OkStatus(); } absl::AnyInvocable lookup_fn_; @@ -172,10 +256,11 @@ class ASTPostOrderEvalVisitor final : public ASTVisitor { // Accepts an AST representing a set query, creates execution plan and runs it. template -ValueT Eval(const Node& node, - absl::AnyInvocable lookup_fn) { +absl::StatusOr Eval( + const Node& node, + absl::AnyInvocable lookup_fn) { auto visitor = ASTPostOrderEvalVisitor(std::move(lookup_fn)); - visitor.ConductVisit(node); + PS_RETURN_IF_ERROR(visitor.ConductVisit(node)); return visitor.GetResult(); } diff --git a/components/query/ast_test.cc b/components/query/ast_test.cc index 1b18b9d4..e464c9ef 100644 --- a/components/query/ast_test.cc +++ b/components/query/ast_test.cc @@ -14,112 +14,280 @@ #include "components/query/ast.h" +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "roaring.hh" +#include "roaring64map.hh" namespace kv_server { namespace { const absl::flat_hash_map> - kDb = { + kStringSetDB = { {"A", {"a", "b", "c"}}, {"B", {"b", "c", "d"}}, {"C", {"c", "d", "e"}}, {"D", {"d", "e", "f"}}, }; -const absl::flat_hash_map kBitsetDb = { + +const absl::flat_hash_map kUInt32SetDb = { {"A", {1, 2, 3}}, {"B", {2, 3, 4}}, {"C", {3, 4, 5}}, {"D", {4, 5, 6}}, }; +const absl::flat_hash_map kUInt64SetDb = { + {"A", + {18446744073709551609UL, 18446744073709551610UL, 18446744073709551611UL}}, + {"B", + {18446744073709551610UL, 18446744073709551611UL, 18446744073709551612UL}}, + {"C", + {18446744073709551611UL, 18446744073709551612UL, 18446744073709551613UL}}, + {"D", + {18446744073709551612UL, 18446744073709551613UL, 18446744073709551614UL}}, +}; + +// Copied from driver_test +// TODO: move to a shared header. +template +struct SetTypeConverter; + +template <> +struct SetTypeConverter> { + using type = std::string_view; +}; + +template <> +struct SetTypeConverter { + using type = uint32_t; +}; + +template <> +struct SetTypeConverter { + using type = uint64_t; +}; + +template +using ConvertedSetType = typename SetTypeConverter::type; + +template +SetType Lookup(std::string_view key); + +template <> absl::flat_hash_set Lookup(std::string_view key) { - if (const auto& it = kDb.find(key); it != kDb.end()) { + if (const auto& it = kStringSetDB.find(key); it != kStringSetDB.end()) { return it->second; } return {}; } -roaring::Roaring BitsetLookup(std::string_view key) { - if (const auto& it = kBitsetDb.find(key); it != kBitsetDb.end()) { +template <> +roaring::Roaring Lookup(std::string_view key) { + if (const auto& it = kUInt32SetDb.find(key); it != kUInt32SetDb.end()) { return it->second; } return {}; } -TEST(AstTest, Value) { +template <> +roaring::Roaring64Map Lookup(std::string_view key) { + if (const auto& it = kUInt64SetDb.find(key); it != kUInt64SetDb.end()) { + return it->second; + } + return {}; +} + +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v>) { + return "StringSet"; + } + if constexpr (std::is_same_v) { + return "UInt32Set"; + } + if constexpr (std::is_same_v) { + return "UInt64Set"; + } + } +}; + +template +class ASTTest : public ::testing::Test {}; + +using SetTypes = testing::Types, + roaring::Roaring, roaring::Roaring64Map>; +TYPED_TEST_SUITE(ASTTest, SetTypes, NameGenerator); + +TYPED_TEST(ASTTest, Value) { ValueNode value("A"); - EXPECT_EQ(Eval(value, Lookup), Lookup("A")); + auto result = Eval(value, Lookup); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(*result, Lookup("A")); + ValueNode value2("B"); - EXPECT_EQ(Eval(value2, Lookup), Lookup("B")); + result = Eval(value2, Lookup); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(*result, Lookup("B")); + ValueNode value3("C"); - EXPECT_EQ(Eval(value3, Lookup), Lookup("C")); + result = Eval(value3, Lookup); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(*result, Lookup("C")); + ValueNode value4("D"); - EXPECT_EQ(Eval(value4, Lookup), Lookup("D")); + result = Eval(value4, Lookup); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(*result, Lookup("D")); + ValueNode value5("E"); - EXPECT_EQ(Eval(value5, Lookup), Lookup("E")); + result = Eval(value5, Lookup); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(*result, Lookup("E")); } -TEST(AstTest, Union) { +TYPED_TEST(ASTTest, Set) { + StringViewSetNode ssn({"a", "b", "c"}); + auto str_result = Eval(ssn, Lookup); + if constexpr (std::is_same_v>) { + ASSERT_TRUE(str_result.ok()); + EXPECT_THAT(*str_result, testing::UnorderedElementsAre("a", "b", "c")); + } else { + // For number evals, we expect string sets to be invalid. + ASSERT_FALSE(str_result.ok()); + } + + if constexpr (std::is_same_v>) { + // For string evals, we expect strings to be invalid + auto result = Eval(NumberSetNode({1, 2, 3}), Lookup); + ASSERT_FALSE(result.ok()); + } else { + std::vector> vals = { + 0, std::numeric_limits>::max(), + std::numeric_limits>::max() - 1, + std::numeric_limits>::max() - 2}; + // Currently no bounds checking on 64-bit type fits into 32-bit range + // for 32-bit Roaring eval type. + // See TODO for eval to return an error. + NumberSetNode nsn({vals.begin(), vals.end()}); + auto num_result = Eval(nsn, Lookup); + typename decltype(num_result)::value_type expected(vals.size(), + vals.data()); + EXPECT_EQ(*num_result, expected); + } +} + +TYPED_TEST(ASTTest, Union) { std::unique_ptr a = std::make_unique("A"); std::unique_ptr b = std::make_unique("B"); UnionNode op(std::move(a), std::move(b)); - absl::flat_hash_set expected = {"a", "b", "c", "d"}; - EXPECT_EQ(Eval(op, Lookup), expected); + auto result = Eval(op, Lookup); + ASSERT_TRUE(result.ok()); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c", "d")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1, 2, 3, 4})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map( + {18446744073709551609UL, 18446744073709551610UL, + 18446744073709551611UL, 18446744073709551612UL})); + } } -TEST(AstTest, UnionSelf) { +TYPED_TEST(ASTTest, UnionSelf) { std::unique_ptr a = std::make_unique("A"); std::unique_ptr a2 = std::make_unique("A"); UnionNode op(std::move(a), std::move(a2)); - absl::flat_hash_set expected = {"a", "b", "c"}; - EXPECT_EQ(Eval(op, Lookup), expected); + auto result = Eval(op, Lookup); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(*result, Lookup("A")); } -TEST(AstTest, Intersection) { +TYPED_TEST(ASTTest, Intersection) { std::unique_ptr a = std::make_unique("A"); std::unique_ptr b = std::make_unique("B"); IntersectionNode op(std::move(a), std::move(b)); - absl::flat_hash_set expected = {"b", "c"}; - EXPECT_EQ(Eval(op, Lookup), expected); + auto result = Eval(op, Lookup); + ASSERT_TRUE(result.ok()); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({2, 3})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map( + {18446744073709551610UL, 18446744073709551611UL})); + } } -TEST(AstTest, IntersectionSelf) { +TYPED_TEST(ASTTest, IntersectionSelf) { std::unique_ptr a = std::make_unique("A"); std::unique_ptr a2 = std::make_unique("A"); IntersectionNode op(std::move(a), std::move(a2)); - absl::flat_hash_set expected = {"a", "b", "c"}; - EXPECT_EQ(Eval(op, Lookup), expected); + auto result = Eval(op, Lookup); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(*result, Lookup("A")); } -TEST(AstTest, Difference) { +TYPED_TEST(ASTTest, Difference) { std::unique_ptr a = std::make_unique("A"); std::unique_ptr b = std::make_unique("B"); DifferenceNode op(std::move(a), std::move(b)); - absl::flat_hash_set expected = {"a"}; - EXPECT_EQ(Eval(op, Lookup), expected); + auto result = Eval(op, Lookup); + ASSERT_TRUE(result.ok()); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551609UL})); + } std::unique_ptr a2 = std::make_unique("A"); std::unique_ptr b2 = std::make_unique("B"); DifferenceNode op2(std::move(b2), std::move(a2)); - absl::flat_hash_set expected2 = {"d"}; - EXPECT_EQ(Eval(op2, Lookup), expected2); + result = Eval(op2, Lookup); + ASSERT_TRUE(result.ok()); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("d")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({4})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551612UL})); + } } -TEST(AstTest, DifferenceSelf) { +TYPED_TEST(ASTTest, DifferenceSelf) { std::unique_ptr a = std::make_unique("A"); std::unique_ptr a2 = std::make_unique("A"); DifferenceNode op(std::move(a), std::move(a2)); - absl::flat_hash_set expected = {}; - EXPECT_EQ(Eval(op, Lookup), expected); + auto result = Eval(op, Lookup); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(*result, TypeParam()); } -TEST(AstTest, All) { +TYPED_TEST(ASTTest, All) { // (A-B) | (C&D) = // {a} | {d,e} = // {a, d, e} @@ -132,23 +300,35 @@ TEST(AstTest, All) { std::unique_ptr right = std::make_unique(std::move(c), std::move(d)); UnionNode center(std::move(left), std::move(right)); - absl::flat_hash_set expected = {"a", "d", "e"}; - EXPECT_EQ(Eval(center, Lookup), expected); + auto result = Eval(center, Lookup); + ASSERT_TRUE(result.ok()); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "d", "e")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1, 4, 5})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551609UL, + 18446744073709551612UL, + 18446744073709551613UL})); + } } -TEST(AstTest, ValueNodeKeys) { +TEST(ASTKeysTest, ValueNodeKeys) { ValueNode v("A"); EXPECT_THAT(v.Keys(), testing::UnorderedElementsAre("A")); } -TEST(AstTest, OpNodeKeys) { +TEST(ASTKeysTest, OpNodeKeys) { std::unique_ptr a = std::make_unique("A"); std::unique_ptr b = std::make_unique("B"); DifferenceNode op(std::move(b), std::move(a)); EXPECT_THAT(op.Keys(), testing::UnorderedElementsAre("A", "B")); } -TEST(AstTest, DupeNodeKeys) { +TEST(ASTKeysTest, DupeNodeKeys) { std::unique_ptr a = std::make_unique("A"); std::unique_ptr b = std::make_unique("B"); std::unique_ptr c = std::make_unique("C"); @@ -161,63 +341,5 @@ TEST(AstTest, DupeNodeKeys) { EXPECT_THAT(center.Keys(), testing::UnorderedElementsAre("A", "B", "C")); } -TEST(ASTEvalTest, VerifyValueNodeEvaluation) { - { - ValueNode root("DOES_NOT_EXIST"); - EXPECT_TRUE(Eval(root, Lookup).empty()); - } - ValueNode root("A"); - EXPECT_THAT(Eval(root, Lookup), - testing::UnorderedElementsAre("a", "b", "c")); - EXPECT_EQ(Eval(root, BitsetLookup), - roaring::Roaring({1, 2, 3})); -} - -TEST(ASTEvalTest, VerifyUnionNodeEvaluation) { - auto a = std::make_unique("A"); - auto b = std::make_unique("B"); - UnionNode root(std::move(a), std::move(b)); - EXPECT_THAT(Eval(root, Lookup), - testing::UnorderedElementsAre("a", "b", "c", "d")); - EXPECT_EQ(Eval(root, BitsetLookup), - roaring::Roaring({1, 2, 3, 4})); -} - -TEST(ASTEvalTest, VerifyDifferenceNodeEvaluation) { - auto a = std::make_unique("A"); - auto b = std::make_unique("B"); - DifferenceNode root(std::move(a), std::move(b)); - EXPECT_THAT(Eval(root, Lookup), - testing::UnorderedElementsAre("a")); - EXPECT_EQ(Eval(root, BitsetLookup), roaring::Roaring({1})); -} - -TEST(ASTEvalTest, VerifyIntersectionNodeEvaluation) { - auto a = std::make_unique("A"); - auto b = std::make_unique("B"); - IntersectionNode root(std::move(a), std::move(b)); - EXPECT_THAT(Eval(root, Lookup), - testing::UnorderedElementsAre("b", "c")); - EXPECT_EQ(Eval(root, BitsetLookup), - roaring::Roaring({2, 3})); -} - -TEST(ASTEvalTest, VerifyComplexNodeEvaluation) { - // (A-B) | (C&D) = - // {a} | {d,e} = - // {a, d, e} - auto a = std::make_unique("A"); - auto b = std::make_unique("B"); - auto c = std::make_unique("C"); - auto d = std::make_unique("D"); - auto left = std::make_unique(std::move(a), std::move(b)); - auto right = std::make_unique(std::move(c), std::move(d)); - UnionNode root(std::move(left), std::move(right)); - EXPECT_THAT(Eval(root, Lookup), - testing::UnorderedElementsAre("a", "d", "e")); - EXPECT_THAT(Eval(root, BitsetLookup), - roaring::Roaring({1, 4, 5})); -} - } // namespace } // namespace kv_server diff --git a/components/query/driver.h b/components/query/driver.h index 794daa86..bf823339 100644 --- a/components/query/driver.h +++ b/components/query/driver.h @@ -17,10 +17,12 @@ #ifndef COMPONENTS_QUERY_DRIVER_H_ #define COMPONENTS_QUERY_DRIVER_H_ +#include #include #include #include #include +#include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" @@ -56,10 +58,26 @@ class Driver { // Clients should not call these functions, they are called by the parser. void SetAst(std::unique_ptr); void SetError(std::string error); - void ClearError() { status_ = absl::OkStatus(); } + void Clear() { + status_ = absl::OkStatus(); + ast_ = nullptr; + buffer_.clear(); + } + std::vector StoreStrings(std::vector strings) { + std::vector views; + views.reserve(strings.size()); + for (auto& string : strings) { + buffer_.push_back(std::move(string)); + views.push_back(buffer_.back()); + } + return views; + } private: std::unique_ptr ast_; + // using list since we require pointer stabilty on string_view that references + // them. + std::list buffer_; absl::Status status_ = absl::OkStatus(); }; diff --git a/components/query/driver_test.cc b/components/query/driver_test.cc index 4404402a..366572b4 100644 --- a/components/query/driver_test.cc +++ b/components/query/driver_test.cc @@ -19,11 +19,17 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/synchronization/notification.h" #include "components/query/scanner.h" +#include "components/query/template_test_utils.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "roaring.hh" +#include "roaring64map.hh" + namespace kv_server { namespace { @@ -35,6 +41,53 @@ const absl::flat_hash_map> {"D", {"d", "e", "f"}}, }; +const absl::flat_hash_map kUInt32SetDb = { + {"A", {1, 2, 3}}, + {"B", {2, 3, 4}}, + {"C", {3, 4, 5}}, + {"D", {4, 5, 6}}, +}; + +const absl::flat_hash_map kUInt64SetDb = { + {"A", + {18446744073709551609UL, 18446744073709551610UL, 18446744073709551611UL}}, + {"B", + {18446744073709551610UL, 18446744073709551611UL, 18446744073709551612UL}}, + {"C", + {18446744073709551611UL, 18446744073709551612UL, 18446744073709551613UL}}, + {"D", + {18446744073709551612UL, 18446744073709551613UL, 18446744073709551614UL}}, +}; + +template +SetType GetExpectations(std::vector> elems); + +template <> +absl::flat_hash_set GetExpectations( + std::vector expected) { + absl::flat_hash_set views; + for (auto e : expected) { + if (e.front() == '"' && e.back() == '"') { + views.insert(e.substr(1, e.size() - 2)); + } + } + return views; +} + +template <> +roaring::Roaring GetExpectations(std::vector expected) { + return roaring::Roaring(expected.size(), expected.data()); +} + +template <> +roaring::Roaring64Map GetExpectations(std::vector expected) { + return roaring::Roaring64Map(expected.size(), expected.data()); +} + +template +SetType Lookup(std::string_view key); + +template <> absl::flat_hash_set Lookup(std::string_view key) { if (const auto& it = kStringSetDB.find(key); it != kStringSetDB.end()) { return it->second; @@ -42,6 +95,23 @@ absl::flat_hash_set Lookup(std::string_view key) { return {}; } +template <> +roaring::Roaring Lookup(std::string_view key) { + if (const auto& it = kUInt32SetDb.find(key); it != kUInt32SetDb.end()) { + return it->second; + } + return {}; +} + +template <> +roaring::Roaring64Map Lookup(std::string_view key) { + if (const auto& it = kUInt64SetDb.find(key); it != kUInt64SetDb.end()) { + return it->second; + } + return {}; +} + +template class DriverTest : public ::testing::Test { protected: void SetUp() override { @@ -62,139 +132,278 @@ class DriverTest : public ::testing::Test { std::vector drivers_; }; -TEST_F(DriverTest, EmptyQuery) { - Parse(""); - EXPECT_EQ(driver_->GetRootNode(), nullptr); +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v>) { + return "StringSet"; + } + if constexpr (std::is_same_v) { + return "UInt32Set"; + } + if constexpr (std::is_same_v) { + return "UInt64Set"; + } + } +}; +using SetTypes = testing::Types, + roaring::Roaring, roaring::Roaring64Map>; +TYPED_TEST_SUITE(DriverTest, SetTypes, NameGenerator); + +TYPED_TEST(DriverTest, EmptyQuery) { + this->Parse(""); + EXPECT_EQ(this->driver_->GetRootNode(), nullptr); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - absl::flat_hash_set expected; - EXPECT_EQ(*result, expected); + EXPECT_EQ(*result, TypeParam()); } -TEST_F(DriverTest, InvalidTokensQuery) { - Parse("!! hi"); - EXPECT_EQ(driver_->GetRootNode(), nullptr); +TYPED_TEST(DriverTest, InvalidTokensQuery) { + this->Parse("!! hi"); + EXPECT_EQ(this->driver_->GetRootNode(), nullptr); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } -TEST_F(DriverTest, MissingOperatorVar) { - Parse("A A"); - EXPECT_EQ(driver_->GetRootNode(), nullptr); +TYPED_TEST(DriverTest, MissingOperatorVar) { + this->Parse("A A"); + EXPECT_EQ(this->driver_->GetRootNode(), nullptr); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } -TEST_F(DriverTest, MissingOperatorExp) { - Parse("(A) (A)"); - EXPECT_EQ(driver_->GetRootNode(), nullptr); +TYPED_TEST(DriverTest, MissingOperatorExp) { + this->Parse("(A) (A)"); + EXPECT_EQ(this->driver_->GetRootNode(), nullptr); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } -TEST_F(DriverTest, InvalidOp) { - Parse("A UNION "); - EXPECT_EQ(driver_->GetRootNode(), nullptr); +TYPED_TEST(DriverTest, InvalidOp) { + this->Parse("A UNION "); + EXPECT_EQ(this->driver_->GetRootNode(), nullptr); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } -TEST_F(DriverTest, KeyOnly) { - Parse("A"); +TYPED_TEST(DriverTest, KeyOnly) { + this->Parse("A"); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c")); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1, 2, 3})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551609UL, + 18446744073709551610UL, + 18446744073709551611UL})); + } - Parse("B"); - result = - driver_->EvaluateQuery>(Lookup); + this->Parse("B"); + result = this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c", "d")); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c", "d")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({2, 3, 4})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551610UL, + 18446744073709551611UL, + 18446744073709551612UL})); + } } -TEST_F(DriverTest, Union) { - Parse("A UNION B"); +TYPED_TEST(DriverTest, InlineSetOnly) { + std::vector> multi_elem; + ConvertedSetType elem; + if constexpr (std::is_same_v>) { + elem = "\"a\""; + multi_elem = {"\"a\"", "\"b\""}; + } + if constexpr (std::is_same_v) { + elem = 1; + multi_elem = {2, 3}; + } + if constexpr (std::is_same_v) { + elem = 184467440737551610UL; + multi_elem = {18446744073709551609UL, 18446744073709551610UL}; + } + this->Parse(absl::StrCat("Set(", elem, ")")); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c", "d")); + EXPECT_THAT(*result, GetExpectations({elem})); - Parse("A | B"); - result = - driver_->EvaluateQuery>(Lookup); + this->Parse(absl::StrCat("Set(", absl::StrJoin(multi_elem, ","), ")")); + result = this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c", "d")); + EXPECT_THAT(*result, GetExpectations(multi_elem)); } -TEST_F(DriverTest, Difference) { - Parse("A - B"); +TYPED_TEST(DriverTest, InlineIntegerSetTooBig) { + this->Parse("Set(99999999999999999999)"); auto result = - driver_->EvaluateQuery>(Lookup); - ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); - - Parse("A DIFFERENCE B"); - result = - driver_->EvaluateQuery>(Lookup); - ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); + this->driver_->template EvaluateQuery(Lookup); + ASSERT_FALSE(result.ok()); +} - Parse("B - A"); - result = - driver_->EvaluateQuery>(Lookup); - ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("d")); +TYPED_TEST(DriverTest, WrongInlineSetType) { + if constexpr (std::is_same_v) { + // Incorrectly mix a inline number node with a string_view set. + this->Parse("A & Set(1)"); + auto result = + this->driver_->template EvaluateQuery(Lookup); + // TODO(b/353502448): Consider returning an error. + ASSERT_TRUE(result.ok()); + // Wrong type returns empty set. + EXPECT_TRUE(result->empty()); + } +} - Parse("B DIFFERENCE A"); - result = - driver_->EvaluateQuery>(Lookup); - ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("d")); +TYPED_TEST(DriverTest, Union) { + for (std::string_view query : {"A UNION B", "A | B"}) { + this->Parse(std::string(query)); + auto result = + this->driver_->template EvaluateQuery(Lookup); + ASSERT_TRUE(result.ok()); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c", "d")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1, 2, 3, 4})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map( + {18446744073709551609UL, 18446744073709551610UL, + 18446744073709551611UL, 18446744073709551612UL})); + } + } } -TEST_F(DriverTest, Intersection) { - Parse("A INTERSECTION B"); - auto result = - driver_->EvaluateQuery>(Lookup); - ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c")); +TYPED_TEST(DriverTest, Difference) { + for (std::string_view query : {"A - B", "A DIFFERENCE B"}) { + this->Parse(std::string(query)); + auto result = + this->driver_->template EvaluateQuery(Lookup); + ASSERT_TRUE(result.ok()); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551609UL})); + } + } + for (std::string_view query : {"B - A", "B DIFFERENCE A"}) { + this->Parse(std::string(query)); + auto result = + this->driver_->template EvaluateQuery(Lookup); + ASSERT_TRUE(result.ok()); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("d")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({4})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551612UL})); + } + } +} - Parse("A & B"); - result = - driver_->EvaluateQuery>(Lookup); - ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c")); +TYPED_TEST(DriverTest, Intersection) { + for (std::string_view query : {"A & B", "A INTERSECTION B"}) { + this->Parse(std::string(query)); + auto result = + this->driver_->template EvaluateQuery(Lookup); + ASSERT_TRUE(result.ok()); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({2, 3})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map( + {18446744073709551610UL, 18446744073709551611UL})); + } + } } -TEST_F(DriverTest, OrderOfOperations) { - Parse("A - B - C"); +TYPED_TEST(DriverTest, OrderOfOperations) { + this->Parse("A - B - C"); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551609UL})); + } - Parse("A - (B - C)"); - result = - driver_->EvaluateQuery>(Lookup); + this->Parse("A - (B - C)"); + result = this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "c")); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "c")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1, 3})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map( + {18446744073709551609UL, 18446744073709551611UL})); + } } -TEST_F(DriverTest, MultipleOperations) { - Parse("(A-B) | (C&D)"); +TYPED_TEST(DriverTest, MultipleOperations) { + this->Parse("(A-B) | (C&D)"); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "d", "e")); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "d", "e")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1, 4, 5})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551609UL, + 18446744073709551612UL, + 18446744073709551613UL})); + } } -TEST_F(DriverTest, MultipleThreads) { +TYPED_TEST(DriverTest, MultipleThreads) { absl::Notification notification; auto test_func = [¬ification](Driver* driver) { notification.WaitForNotification(); @@ -203,14 +412,24 @@ TEST_F(DriverTest, MultipleThreads) { Scanner scanner(stream); Parser parse(*driver, scanner); parse(); - auto result = - driver->EvaluateQuery>(Lookup); + auto result = driver->EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "d", "e")); + if constexpr (std::is_same_v>) { + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "d", "e")); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring({1, 4, 5})); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map({18446744073709551609UL, + 18446744073709551612UL, + 18446744073709551613UL})); + } }; std::vector threads; - for (Driver& driver : drivers_) { + for (Driver& driver : this->drivers_) { threads.push_back(std::thread(test_func, &driver)); } notification.Notify(); @@ -219,30 +438,45 @@ TEST_F(DriverTest, MultipleThreads) { } } -TEST_F(DriverTest, EmptyResults) { +TYPED_TEST(DriverTest, EmptyResults) { // no overlap - Parse("A & D"); + this->Parse("A & D"); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_EQ(result->size(), 0); - + if constexpr (std::is_same_v>) { + EXPECT_EQ(result->size(), 0); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring()); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map()); + } // missing key - Parse("A & E"); - result = - driver_->EvaluateQuery>(Lookup); + this->Parse("A & E"); + result = this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); - EXPECT_EQ(result->size(), 0); + if constexpr (std::is_same_v>) { + EXPECT_EQ(result->size(), 0); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring()); + } + if constexpr (std::is_same_v) { + EXPECT_EQ(*result, roaring::Roaring64Map()); + } } -TEST_F(DriverTest, DriverErrorsClearedOnParse) { - Parse("A &"); +TYPED_TEST(DriverTest, DriverErrorsClearedOnParse) { + this->Parse("A &"); auto result = - driver_->EvaluateQuery>(Lookup); + this->driver_->template EvaluateQuery(Lookup); ASSERT_FALSE(result.ok()); - Parse("A"); - result = - driver_->EvaluateQuery>(Lookup); + this->Parse("A"); + result = this->driver_->template EvaluateQuery(Lookup); ASSERT_TRUE(result.ok()); } diff --git a/components/query/parser.yy b/components/query/parser.yy index 32f73221..80deb404 100644 --- a/components/query/parser.yy +++ b/components/query/parser.yy @@ -25,6 +25,8 @@ %code requires { #include #include + #include + #include "components/query/ast.h" namespace kv_server { @@ -38,17 +40,32 @@ %code { + #include "absl/strings/numbers.h" + #include "absl/strings/str_cat.h" #include "components/query/parser.h" #include "components/query/driver.h" #include "components/query/scanner.h" #undef yylex #define yylex(x) scanner.yylex(x) + + namespace { + bool PushBackUint64(kv_server::Driver& driver, std::vector& stack, char* str) { + uint64_t val; + if(!absl::SimpleAtoi(str, &val)) { + driver.SetError(absl::StrCat("Unable to parse number: ", str)); + return false; + } + stack.push_back(val); + return true; + } + } } /* declare tokens */ -%token UNION INTERSECTION DIFFERENCE LPAREN RPAREN +%token UNION INTERSECTION DIFFERENCE LPAREN RPAREN SET COMMA %token VAR ERROR +%token NUMBER STRING %token YYEOF 0 /* Allows defining the types returned by `term` and `exp below. */ @@ -56,6 +73,8 @@ %define api.value.type variant %type > term +%type > number_list +%type > string_list %nterm > exp /* Order of operations is left to right */ @@ -65,7 +84,7 @@ %expect 0 %initial-action { - driver.ClearError(); + driver.Clear(); } %% @@ -74,16 +93,32 @@ query: %empty | query exp YYEOF { driver.SetAst(std::move($2)); } -exp: term {$$ = std::move($1);} +exp: + term {$$ = std::move($1);} | exp UNION exp { $$ = std::make_unique(std::move($1), std::move($3)); } | exp INTERSECTION exp { $$ = std::make_unique(std::move($1), std::move($3)); } | exp DIFFERENCE exp { $$ = std::make_unique(std::move($1), std::move($3)); } | LPAREN exp RPAREN { $$ = std::move($2); } + | SET LPAREN number_list RPAREN {$$ = std::make_unique(std::move($3));} + | SET LPAREN string_list RPAREN {$$ = std::make_unique(driver.StoreStrings(std::move($3)));} | ERROR { driver.SetError("Invalid token: " + $1); YYERROR;} ; term: VAR { $$ = std::make_unique(std::move($1)); } - ; + +number_list: + NUMBER { std::vector stack; + if(PushBackUint64(driver, stack, $1)) $$ = stack; + else YYERROR;} + | number_list COMMA NUMBER { + if(PushBackUint64(driver, $1, $3)) $$ = $1; + else YYERROR;} +; + +string_list: + STRING { $$ = std::vector{$1};} + | string_list COMMA STRING { $1.emplace_back(std::move($3)); $$ = $1;} +; %% diff --git a/components/query/scanner.ll b/components/query/scanner.ll index de8d4dbb..274257eb 100644 --- a/components/query/scanner.ll +++ b/components/query/scanner.ll @@ -30,6 +30,9 @@ %option prefix="KV" %option noyywrap nounput noinput debug batch +/* States */ +%x IN_FUNC + /* Valid key name characters, this list can be expanded as needed */ VAR_CHARS [a-zA-Z0-9_:\.] /* @@ -41,6 +44,14 @@ OP_CHARS [|&\-+=/] %% [ \t\r\n]+ {} +(?i:SET) { yy_push_state(IN_FUNC); return kv_server::Parser::make_SET(); } +{ + "(" { return kv_server::Parser::make_LPAREN(); } + ")" { yy_pop_state(); return kv_server::Parser::make_RPAREN();} + "," { return kv_server::Parser::make_COMMA(); } + [0-9]+ { return kv_server::Parser::make_NUMBER(yytext); } + \"[^\"]*\" { yytext[strlen(yytext)-1]='\0'; return kv_server::Parser::make_STRING(yytext + 1); } +} "(" { return kv_server::Parser::make_LPAREN(); } ")" { return kv_server::Parser::make_RPAREN();} (?i:UNION) { return kv_server::Parser::make_UNION(); } diff --git a/components/query/sets.cc b/components/query/sets.cc index 2223e8e2..b01bbbc7 100644 --- a/components/query/sets.cc +++ b/components/query/sets.cc @@ -54,20 +54,4 @@ absl::flat_hash_set Difference( return std::move(left); } -template <> -roaring::Roaring Union(roaring::Roaring&& left, roaring::Roaring&& right) { - return left | right; -} - -template <> -roaring::Roaring Intersection(roaring::Roaring&& left, - roaring::Roaring&& right) { - return left & right; -} - -template <> -roaring::Roaring Difference(roaring::Roaring&& left, roaring::Roaring&& right) { - return left - right; -} - } // namespace kv_server diff --git a/components/query/sets.h b/components/query/sets.h index a6fcecc8..83f9d1ac 100644 --- a/components/query/sets.h +++ b/components/query/sets.h @@ -17,20 +17,26 @@ #ifndef COMPONENTS_QUERY_SETS_H_ #define COMPONENTS_QUERY_SETS_H_ -#include "absl/container/flat_hash_set.h" +#include -#include "roaring.hh" +#include "absl/container/flat_hash_set.h" namespace kv_server { -template -SetT Union(SetT&&, SetT&&); +template +SetType Union(SetType&& left, SetType&& right) { + return std::forward(left) | std::forward(right); +} -template -SetT Intersection(SetT&&, SetT&&); +template +SetType Intersection(SetType&& left, SetType&& right) { + return std::forward(left) & std::forward(right); +} -template -SetT Difference(SetT&&, SetT&&); +template +SetType Difference(SetType&& left, SetType&& right) { + return std::forward(left) - std::forward(right); +} template <> absl::flat_hash_set Union( @@ -47,16 +53,5 @@ absl::flat_hash_set Difference( absl::flat_hash_set&& left, absl::flat_hash_set&& right); -template <> -roaring::Roaring Union(roaring::Roaring&& left, roaring::Roaring&& right); - -template <> -roaring::Roaring Intersection(roaring::Roaring&& left, - roaring::Roaring&& right); - -// Subtracts `right` from `left`. -template <> -roaring::Roaring Difference(roaring::Roaring&& left, roaring::Roaring&& right); - } // namespace kv_server #endif // COMPONENTS_QUERY_SETS_H_ diff --git a/components/query/sets_test.cc b/components/query/sets_test.cc index 6f4ee250..56b00301 100644 --- a/components/query/sets_test.cc +++ b/components/query/sets_test.cc @@ -18,45 +18,46 @@ #include +#include "components/data_server/cache/uint_value_set.h" #include "gtest/gtest.h" namespace kv_server { namespace { TEST(SetsTest, VerifyBitwiseUnion) { - roaring::Roaring left({1, 2, 3, 4, 5}); - roaring::Roaring right({6, 7, 8, 9, 10}); + UInt32ValueSet::bitset_type left({1, 2, 3, 4, 5}); + UInt32ValueSet::bitset_type right({6, 7, 8, 9, 10}); EXPECT_EQ(Union(std::move(left), std::move(right)), - roaring::Roaring({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); + UInt32ValueSet::bitset_type({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); } TEST(SetsTest, VerifyBitwiseIntersection) { { - roaring::Roaring left({1, 2, 3, 4, 5}); - roaring::Roaring right({6, 7, 8, 9, 10}); + UInt32ValueSet::bitset_type left({1, 2, 3, 4, 5}); + UInt32ValueSet::bitset_type right({6, 7, 8, 9, 10}); EXPECT_EQ(Intersection(std::move(left), std::move(right)), - roaring::Roaring()); + UInt32ValueSet::bitset_type()); } { - roaring::Roaring left({1, 2, 3, 4, 5}); - roaring::Roaring right({1, 2, 3, 9, 10}); + UInt32ValueSet::bitset_type left({1, 2, 3, 4, 5}); + UInt32ValueSet::bitset_type right({1, 2, 3, 9, 10}); EXPECT_EQ(Intersection(std::move(left), std::move(right)), - roaring::Roaring({1, 2, 3})); + UInt32ValueSet::bitset_type({1, 2, 3})); } } TEST(SetsTest, VerifyBitwiseDifference) { { - roaring::Roaring left({1, 2, 3, 4, 5}); - roaring::Roaring right({6, 7, 8, 9, 10}); + UInt32ValueSet::bitset_type left({1, 2, 3, 4, 5}); + UInt32ValueSet::bitset_type right({6, 7, 8, 9, 10}); EXPECT_EQ(Difference(std::move(left), std::move(right)), - roaring::Roaring({1, 2, 3, 4, 5})); + UInt32ValueSet::bitset_type({1, 2, 3, 4, 5})); } { - roaring::Roaring left({1, 2, 3, 4, 5}); - roaring::Roaring right({1, 2, 3, 9, 10}); + UInt32ValueSet::bitset_type left({1, 2, 3, 4, 5}); + UInt32ValueSet::bitset_type right({1, 2, 3, 9, 10}); EXPECT_EQ(Difference(std::move(left), std::move(right)), - roaring::Roaring({4, 5})); + UInt32ValueSet::bitset_type({4, 5})); } } diff --git a/components/data_server/request_handler/framing_utils.h b/components/query/template_test_utils.h similarity index 51% rename from components/data_server/request_handler/framing_utils.h rename to components/query/template_test_utils.h index 5b146c2a..c73f32b8 100644 --- a/components/data_server/request_handler/framing_utils.h +++ b/components/query/template_test_utils.h @@ -14,20 +14,34 @@ * limitations under the License. */ -#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_FRAMING_UTILS_H_ -#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_FRAMING_UTILS_H_ +#include -#include +#include "absl/container/flat_hash_set.h" + +#include "roaring.hh" +#include "roaring64map.hh" -// TODO: b/348613920 - Move framing utils to the common repo namespace kv_server { -// Gets size of the complete payload including the preamble expected by -// client, which is: 1 byte (containing version, compression details), 4 bytes -// indicating the length of the actual encoded response and any other padding -// required to make the complete payload a power of 2. -size_t GetEncodedDataSize(size_t encapsulated_payload_size); +template +struct SetTypeConverter; -} // namespace kv_server +template <> +struct SetTypeConverter> { + using type = std::string_view; +}; + +template <> +struct SetTypeConverter { + using type = uint32_t; +}; -#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_FRAMING_UTILS_H_ +template <> +struct SetTypeConverter { + using type = uint64_t; +}; + +template +using ConvertedSetType = typename SetTypeConverter::type; + +} // namespace kv_server diff --git a/components/sharding/cluster_mappings_manager_aws_test.cc b/components/sharding/cluster_mappings_manager_aws_test.cc index 738e609f..7abfccaa 100644 --- a/components/sharding/cluster_mappings_manager_aws_test.cc +++ b/components/sharding/cluster_mappings_manager_aws_test.cc @@ -30,14 +30,7 @@ namespace { class ClusterMappingsAwsTest : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } }; TEST_F(ClusterMappingsAwsTest, RetrieveMappingsSuccessfully) { diff --git a/components/sharding/cluster_mappings_manager_gcp_test.cc b/components/sharding/cluster_mappings_manager_gcp_test.cc index c081af1a..2c27ea33 100644 --- a/components/sharding/cluster_mappings_manager_gcp_test.cc +++ b/components/sharding/cluster_mappings_manager_gcp_test.cc @@ -30,14 +30,7 @@ namespace { class ClusterMappingsGcpTest : public ::testing::Test { protected: - void SetUp() override { - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); - } + void SetUp() override { kv_server::InitMetricsContextMap(); } }; TEST_F(ClusterMappingsGcpTest, RetrieveMappingsSuccessfully) { diff --git a/components/telemetry/error_code.h b/components/telemetry/error_code.h index 4bf282ac..6417aead 100644 --- a/components/telemetry/error_code.h +++ b/components/telemetry/error_code.h @@ -101,9 +101,17 @@ inline constexpr std::string_view kShardedKeyCollisionOnKeySetCollection = // Empty query encountered in the sharded lookup inline constexpr std::string_view kShardedRunQueryEmptyQuery = "ShardedRunQueryEmptyQuery"; +inline constexpr std::string_view kShardedRunSetQueryUInt32EmptyQuery = + "ShardedRunSetQueryUInt32EmptyQuery"; +inline constexpr std::string_view kShardedRunSetQueryUInt64EmptyQuery = + "ShardedRunSetQueryUInt64EmptyQuery"; // Failure in running query in sharded lookup inline constexpr std::string_view kShardedRunQueryFailure = "ShardedRunQueryFailure"; +inline constexpr std::string_view kShardedRunSetQueryUInt32Failure = + "ShardedRunSetQueryUInt32Failure"; +inline constexpr std::string_view kShardedRunSetQueryUInt64Failure = + "ShardedRunSetQueryUInt64Failure"; // Key set not found error in the GetValueKeySet in sharded lookup inline constexpr std::string_view kShardedGetKeyValueSetKeySetNotFound = "ShardedGetKeyValueSetKeySetNotFound"; @@ -119,6 +127,20 @@ inline constexpr std::string_view kShardedRunQueryMissingKeySet = // Query parsing failure in the run query in sharded lookup inline constexpr std::string_view kShardedRunQueryParsingFailure = "ShardedRunQueryParsingFailure"; +// Key set retrieval failure in the GetUInt32ValueSet in sharded lookup +inline constexpr std::string_view + kShardedGetUInt32ValueSetKeySetRetrievalFailure = + "ShardedGetUInt32ValueSetKeySetRetrievalFailure"; +// Key set not found error in the GetUInt32ValueSet in sharded lookup +inline constexpr std::string_view kShardedGetUInt32ValueSetKeySetNotFound = + "ShardedGetUInt32ValueSetKeySetNotFound"; +// Key set retrieval failure in the GetUInt64ValueSet in sharded lookup +inline constexpr std::string_view + kShardedGetUInt64ValueSetKeySetRetrievalFailure = + "ShardedGetUInt64ValueSetKeySetRetrievalFailure"; +// Key set not found error in the GetUInt64ValueSet in sharded lookup +inline constexpr std::string_view kShardedGetUInt64ValueSetKeySetNotFound = + "ShardedGetUInt64ValueSetKeySetNotFound"; // Strings must be sorted, this is required by the API of partitioned metrics inline constexpr absl::string_view kKVUdfRequestErrorCode[] = { @@ -129,6 +151,10 @@ inline constexpr absl::string_view kKVUdfRequestErrorCode[] = { kRemoteSecureLookupFailure, kShardedGetKeyValueSetKeySetNotFound, kShardedGetKeyValueSetKeySetRetrievalFailure, + kShardedGetUInt32ValueSetKeySetNotFound, + kShardedGetUInt32ValueSetKeySetRetrievalFailure, + kShardedGetUInt64ValueSetKeySetNotFound, + kShardedGetUInt64ValueSetKeySetRetrievalFailure, kShardedKeyCollisionOnKeySetCollection, kShardedKeyValueRequestFailure, kShardedKeyValueSetRequestFailure, @@ -137,6 +163,10 @@ inline constexpr absl::string_view kKVUdfRequestErrorCode[] = { kShardedRunQueryKeySetRetrievalFailure, kShardedRunQueryMissingKeySet, kShardedRunQueryParsingFailure, + kShardedRunSetQueryUInt32EmptyQuery, + kShardedRunSetQueryUInt32Failure, + kShardedRunSetQueryUInt64EmptyQuery, + kShardedRunSetQueryUInt64Failure, }; // Non request related server error diff --git a/components/telemetry/server_definition.h b/components/telemetry/server_definition.h index a8e39782..75edd71d 100644 --- a/components/telemetry/server_definition.h +++ b/components/telemetry/server_definition.h @@ -138,6 +138,24 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, kMicroSecondsLowerBound); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kShardedLookupGetUInt32ValueSetLatencyInMicros( + "ShardedLookupGetUInt32ValueSetLatencyInMicros", + "Latency in executing GetUInt32ValueSet in the sharded lookup", + kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kShardedLookupGetUInt64ValueSetLatencyInMicros( + "ShardedLookupGetUInt64ValueSetLatencyInMicros", + "Latency in executing GetUInt64ValueSet in the sharded lookup", + kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, privacy_sandbox::server_common::metrics::Instrument::kHistogram> @@ -150,9 +168,18 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, privacy_sandbox::server_common::metrics::Instrument::kHistogram> - kShardedLookupRunSetQueryIntLatencyInMicros( - "ShardedLookupRunSetQueryIntLatencyInMicros", - "Latency in executing RunQuery in the sharded lookup", + kShardedLookupRunSetQueryUInt32LatencyInMicros( + "ShardedLookupRunSetQueryUInt32LatencyInMicros", + "Latency in executing RunSetQueryUInt32 in the sharded lookup", + kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kShardedLookupRunSetQueryUInt64LatencyInMicros( + "ShardedLookupRunSetQueryUInt64LatencyInMicros", + "Latency in executing RunSetQueryUInt64 in the sharded lookup", kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, kMicroSecondsLowerBound); @@ -192,6 +219,24 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, kMicroSecondsLowerBound); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kInternalGetUInt32ValueSetLatencyInMicros( + "InternalGetUInt32ValueSetLatencyInMicros", + "Latency in internal get uint32 value set call", + kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kInternalGetUInt64ValueSetLatencyInMicros( + "InternalGetUInt64ValueSetLatencyInMicros", + "Latency in internal get uint64 value set call", + kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + inline constexpr privacy_sandbox::server_common::metrics::Definition< int, privacy_sandbox::server_common::metrics::Privacy::kImpacting, privacy_sandbox::server_common::metrics::Instrument::kPartitionedCounter> @@ -239,6 +284,15 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, kMicroSecondsLowerBound); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kGetUInt64ValueSetLatencyInMicros( + "GetUInt64ValueSetLatencyInMicros", + "Latency in executing GetUInt64ValueSet in cache", + kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + inline constexpr privacy_sandbox::server_common::metrics::Definition< int, privacy_sandbox::server_common::metrics::Privacy::kImpacting, privacy_sandbox::server_common::metrics::Instrument::kPartitionedCounter> @@ -350,7 +404,7 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, privacy_sandbox::server_common::metrics::Instrument::kUpDownCounter> kReceivedLowLatencyNotificationsCount( - "kReceivedLowLatencyNotificationsCount", + "ReceivedLowLatencyNotificationsCount", "Count of messages received through pub/sub"); inline constexpr privacy_sandbox::server_common::metrics::Definition< @@ -471,6 +525,13 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< "Latency in uint32 key value set update", kLatencyInMicroSecondsBoundaries); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kUpdateUInt64ValueSetLatency("UpdateUInt64ValueSetLatency", + "Latency in uint64 key value set update", + kLatencyInMicroSecondsBoundaries); + inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, privacy_sandbox::server_common::metrics::Instrument::kHistogram> @@ -491,6 +552,13 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< "Latency in deleting values in an uint32 set", kLatencyInMicroSecondsBoundaries); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kDeleteUInt64ValueSetLatency("DeleteUInt64ValueSetLatency", + "Latency in deleting values in an uint64 set", + kLatencyInMicroSecondsBoundaries); + inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, privacy_sandbox::server_common::metrics::Instrument::kHistogram> @@ -516,10 +584,9 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, privacy_sandbox::server_common::metrics::Instrument::kHistogram> - kCleanUpUInt32SetMapLatency( - "CleanUpUInt32SetMapMapLatency", - "Latency in cleaning up key value uint32 set map", - kLatencyInMicroSecondsBoundaries); + kCleanUpUIntSetMapLatency("CleanUpUIntSetMapMapLatency", + "Latency in cleaning up key value uint set maps", + kLatencyInMicroSecondsBoundaries); inline constexpr privacy_sandbox::server_common::metrics::Definition< int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, @@ -598,9 +665,12 @@ inline constexpr const privacy_sandbox::server_common::metrics::DefinitionName* &kShardedLookupGetKeyValuesLatencyInMicros, &kShardedLookupGetKeyValueSetLatencyInMicros, &kShardedLookupRunQueryLatencyInMicros, - &kShardedLookupRunSetQueryIntLatencyInMicros, + &kShardedLookupRunSetQueryUInt32LatencyInMicros, &kRemoteLookupGetValuesLatencyInMicros, &kTotalV2LatencyWithoutCustomCode, &kUDFExecutionLatencyInMicros, + &kShardedLookupGetUInt32ValueSetLatencyInMicros, + &kShardedLookupGetUInt64ValueSetLatencyInMicros, + &kShardedLookupRunSetQueryUInt64LatencyInMicros, // Safe metrics &kKVServerError, &privacy_sandbox::server_common::metrics::kTotalRequestCount, @@ -630,8 +700,9 @@ inline constexpr const privacy_sandbox::server_common::metrics::DefinitionName* &kUpdateUInt32ValueSetLatency, &kDeleteKeyLatency, &kDeleteValuesInSetLatency, &kDeleteUInt32ValueSetLatency, &kRemoveDeletedKeyLatency, &kCleanUpKeyValueMapLatency, - &kCleanUpKeyValueSetMapLatency, &kCleanUpUInt32SetMapLatency, - &kBlobStorageReadBytes}; + &kCleanUpKeyValueSetMapLatency, &kCleanUpUIntSetMapLatency, + &kBlobStorageReadBytes, &kUpdateUInt64ValueSetLatency, + &kDeleteUInt64ValueSetLatency}; // Internal lookup service metrics list contains metrics collected in the // internal lookup server. This separation from KV metrics list allows all @@ -649,7 +720,9 @@ inline constexpr const privacy_sandbox::server_common::metrics::DefinitionName* &kInternalGetKeyValueSetLatencyInMicros, &kInternalSecureLookupLatencyInMicros, &kGetValuePairsLatencyInMicros, &kGetKeyValueSetLatencyInMicros, &kGetUInt32ValueSetLatencyInMicros, - &kCacheAccessEventCount}; + &kCacheAccessEventCount, &kGetUInt64ValueSetLatencyInMicros, + &kInternalGetUInt32ValueSetLatencyInMicros, + &kInternalGetUInt64ValueSetLatencyInMicros}; inline constexpr absl::Span< const privacy_sandbox::server_common::metrics::DefinitionName* const> @@ -660,9 +733,9 @@ inline constexpr absl::Span< kInternalLookupServiceMetricsSpan = kInternalLookupServiceMetricsList; inline auto* KVServerContextMap( - std::optional< + std::unique_ptr< privacy_sandbox::server_common::telemetry::BuildDependentConfig> - config = std::nullopt, + config = nullptr, std::unique_ptr provider = nullptr, absl::string_view service = kKVServerServiceName, absl::string_view version = "") { @@ -673,9 +746,9 @@ inline auto* KVServerContextMap( } inline auto* InternalLookupServerContextMap( - std::optional< + std::unique_ptr< privacy_sandbox::server_common::telemetry::BuildDependentConfig> - config = std::nullopt, + config = nullptr, std::unique_ptr provider = nullptr, absl::string_view service = kInternalLookupServiceName, absl::string_view version = "") { @@ -736,10 +809,12 @@ inline void InitMetricsContextMap() { config_proto.set_mode( privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( + std::make_unique< + privacy_sandbox::server_common::telemetry::BuildDependentConfig>( config_proto)); kv_server::InternalLookupServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( + std::make_unique< + privacy_sandbox::server_common::telemetry::BuildDependentConfig>( config_proto)); } @@ -842,6 +917,9 @@ inline void LogV1RequestCommonSafeMetrics( // Measures the latency of a block of code. The latency is recorded in // microseconds as histogram metrics when the object of this class goes // out of scope. The metric can be either safe or unsafe metric. +// For unsafe metric, the metric data point will be aggregated for the given +// metric definition and the mean will be logged as final value at the +// destruction of metric context template class ScopeLatencyMetricsRecorder { public: @@ -853,9 +931,16 @@ class ScopeLatencyMetricsRecorder { stopwatch_ = std::move(stopwatch); } ~ScopeLatencyMetricsRecorder() { - LogIfError(metrics_context_.template LogHistogram( - absl::ToDoubleMicroseconds(stopwatch_->GetElapsedTime()))); + if (definition.type_privacy == + privacy_sandbox::server_common::metrics::Privacy::kImpacting) { + LogIfError(metrics_context_.template AggregateMetricToGetMean( + absl::ToDoubleMicroseconds(stopwatch_->GetElapsedTime()))); + } else { + LogIfError(metrics_context_.template LogHistogram( + absl::ToDoubleMicroseconds(stopwatch_->GetElapsedTime()))); + } } + // Returns the latency so far absl::Duration GetLatency() { return stopwatch_->GetElapsedTime(); } diff --git a/components/tools/BUILD.bazel b/components/tools/BUILD.bazel index a645d8f3..90a7cdd3 100644 --- a/components/tools/BUILD.bazel +++ b/components/tools/BUILD.bazel @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@io_bazel_rules_docker//container:container.bzl", - "container_image", - "container_layer", -) load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load") load("@rules_pkg//pkg:tar.bzl", "pkg_tar") pkg_tar( @@ -29,33 +25,37 @@ pkg_tar( package_dir = "/opt/privacysandbox/bin", ) -container_layer( - name = "data_loading_analyzer_layer", - directory = "/", - tars = [ - ":data_loading_analyzer_binaries", - ], -) - # This image target is meant for testing running the server in an enclave using. # # See project README.md on how to run the image. -container_image( +oci_image( name = "data_loading_analyzer_enclave_image", base = select({ - "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64//image", - "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64//image", + "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64", + "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64", }), cmd = [ "/opt/privacysandbox/bin/proxify", "/opt/privacysandbox/bin/data_loading_analyzer", ], entrypoint = ["/bin/bash"], - layers = [ - ":data_loading_analyzer_layer", + tars = [ + ":data_loading_analyzer_binaries", ], ) +oci_load( + name = "data_loading_analyzer_enclave_tarball", + image = ":data_loading_analyzer_enclave_image", + repo_tags = ["bazel/components/tools:data_loading_analyzer_enclave"], +) + +filegroup( + name = "data_loading_analyzer_enclave_image.tar", + srcs = [":data_loading_analyzer_enclave_tarball"], + output_group = "tarball", +) + cc_binary( name = "data_loading_analyzer", srcs = ["data_loading_analyzer.cc"], @@ -262,6 +262,7 @@ cc_binary( "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", ], ) diff --git a/components/tools/benchmarks/BUILD.bazel b/components/tools/benchmarks/BUILD.bazel index 8f1e5705..8c0026ce 100644 --- a/components/tools/benchmarks/BUILD.bazel +++ b/components/tools/benchmarks/BUILD.bazel @@ -105,6 +105,7 @@ cc_binary( ":benchmark_util", "//components/data_server/cache:get_key_value_set_result_impl", "//components/data_server/cache:key_value_cache", + "//components/data_server/cache:uint_value_set", "//components/query:ast", "//components/query:driver", "//components/query:scanner", @@ -121,6 +122,5 @@ cc_binary( "@com_google_absl//absl/log:initialize", "@com_google_absl//absl/strings", "@com_google_benchmark//:benchmark", - "@roaring_bitmap//:c_roaring", ], ) diff --git a/components/tools/benchmarks/query_evaluation_benchmark.cc b/components/tools/benchmarks/query_evaluation_benchmark.cc index 352ee144..d62d4a96 100644 --- a/components/tools/benchmarks/query_evaluation_benchmark.cc +++ b/components/tools/benchmarks/query_evaluation_benchmark.cc @@ -25,6 +25,7 @@ #include "benchmark/benchmark.h" #include "components/data_server/cache/cache.h" #include "components/data_server/cache/key_value_cache.h" +#include "components/data_server/cache/uint_value_set.h" #include "components/query/ast.h" #include "components/query/driver.h" #include "components/query/scanner.h" @@ -32,8 +33,6 @@ #include "components/tools/benchmarks/benchmark_util.h" #include "components/tools/util/configure_telemetry_tools.h" -#include "roaring.hh" - ABSL_FLAG(int64_t, set_size, 1000, "Number of elements in a set."); ABSL_FLAG(std::string, query, "(A - B) | (C & D)", "Query to evaluate"); ABSL_FLAG(uint32_t, range_min, 0, "Minimum element in a set"); @@ -45,11 +44,13 @@ ABSL_FLAG(std::vector, set_names, namespace kv_server { namespace { -using RoaringBitSet = roaring::Roaring; +using UInt32Set = UInt32ValueSet::bitset_type; +using UInt64Set = UInt64ValueSet::bitset_type; using StringSet = absl::flat_hash_set; std::unique_ptr STRING_SET_RESULT = nullptr; std::unique_ptr UINT32_SET_RESULT = nullptr; +std::unique_ptr UINT64_SET_RESULT = nullptr; template ValueT Lookup(std::string_view); @@ -60,10 +61,15 @@ StringSet Lookup(std::string_view key) { } template <> -RoaringBitSet Lookup(std::string_view key) { +UInt32Set Lookup(std::string_view key) { return UINT32_SET_RESULT->GetUInt32ValueSet(key)->GetValuesBitSet(); } +template <> +UInt64Set Lookup(std::string_view key) { + return UINT64_SET_RESULT->GetUInt64ValueSet(key)->GetValuesBitSet(); +} + Driver* GetDriver() { static auto* const driver = std::make_unique().release(); return driver; @@ -87,7 +93,11 @@ void SetUpKeyValueCache(int64_t set_size, uint32_t range_min, } GetKeyValueCache()->UpdateKeyValueSet(log_context, set_name, absl::MakeSpan(nums), 1); - + auto nums64 = std::vector(); + std::transform(nums.begin(), nums.end(), std::back_inserter(nums64), + [](auto elem) { return elem; }); + GetKeyValueCache()->UpdateKeyValueSet(log_context, set_name, + absl::MakeSpan(nums64), 1); auto strings = std::vector(); std::transform(nums.begin(), nums.end(), std::back_inserter(strings), [](uint32_t elem) { return absl::StrCat(elem); }); @@ -150,28 +160,32 @@ void BM_AstTreeEvaluation(::benchmark::State& state) { } // namespace } // namespace kv_server -BENCHMARK(kv_server::BM_SetUnion); +BENCHMARK(kv_server::BM_SetUnion); +BENCHMARK(kv_server::BM_SetUnion); BENCHMARK(kv_server::BM_SetUnion); -BENCHMARK(kv_server::BM_SetDifference); +BENCHMARK(kv_server::BM_SetDifference); +BENCHMARK(kv_server::BM_SetDifference); BENCHMARK(kv_server::BM_SetDifference); -BENCHMARK(kv_server::BM_SetIntersection); +BENCHMARK(kv_server::BM_SetIntersection); +BENCHMARK(kv_server::BM_SetIntersection); BENCHMARK(kv_server::BM_SetIntersection); -BENCHMARK(kv_server::BM_AstTreeEvaluation); +BENCHMARK(kv_server::BM_AstTreeEvaluation); +BENCHMARK(kv_server::BM_AstTreeEvaluation); BENCHMARK(kv_server::BM_AstTreeEvaluation); using kv_server::ConfigureTelemetryForTools; using kv_server::GetKeyValueCache; using kv_server::RequestContext; -using kv_server::RoaringBitSet; using kv_server::SetUpKeyValueCache; using kv_server::StringSet; +using kv_server::UInt32Set; // Sample run: // // bazel run -c opt //components/tools/benchmarks:query_evaluation_benchmark \ // -- --benchmark_counters_tabular=true \ // --benchmark_time_unit=us \ -// --benchmark_filter="*" \ +// --benchmark_filter="." \ // --range_min=1000000 --range_max=2000000 \ // --set_size=10000 \ // --query="A & B - C | D" \ @@ -200,6 +214,9 @@ int main(int argc, char** argv) { kv_server::UINT32_SET_RESULT = GetKeyValueCache()->GetUInt32ValueSet( request_context, absl::flat_hash_set(set_names.begin(), set_names.end())); + kv_server::UINT64_SET_RESULT = GetKeyValueCache()->GetUInt64ValueSet( + request_context, absl::flat_hash_set(set_names.begin(), + set_names.end())); std::istringstream stream(absl::GetFlag(FLAGS_query)); kv_server::Scanner scanner(stream); kv_server::Parser parser(*kv_server::GetDriver(), scanner); diff --git a/components/tools/query_dot.cc b/components/tools/query_dot.cc index c2e5a404..d21231ca 100644 --- a/components/tools/query_dot.cc +++ b/components/tools/query_dot.cc @@ -14,11 +14,13 @@ #include "components/tools/query_dot.h" +#include #include #include #include #include "absl/strings/str_join.h" +#include "src/util/status_macro/status_macros.h" namespace kv_server::query_toy { @@ -28,10 +30,24 @@ namespace { // upon inspection. class ASTNameVisitor : public ASTStringVisitor { public: - virtual std::string Visit(const UnionNode&) { return "Union"; } - virtual std::string Visit(const DifferenceNode&) { return "Difference"; } - virtual std::string Visit(const IntersectionNode&) { return "Intersection"; } - virtual std::string Visit(const ValueNode&) { return "Value"; } + virtual absl::StatusOr Visit(const UnionNode&) { + return "Union"; + } + virtual absl::StatusOr Visit(const DifferenceNode&) { + return "Difference"; + } + virtual absl::StatusOr Visit(const IntersectionNode&) { + return "Intersection"; + } + virtual absl::StatusOr Visit(const ValueNode&) { + return "Value"; + } + virtual absl::StatusOr Visit(const NumberSetNode&) { + return "NumberSet"; + } + virtual absl::StatusOr Visit(const StringViewSetNode&) { + return "StringViewSet"; + } }; class ASTDotGraphLabelVisitor : public ASTStringVisitor { @@ -42,23 +58,36 @@ class ASTDotGraphLabelVisitor : public ASTStringVisitor { lookup_fn) : lookup_fn_(std::move(lookup_fn)) {} - virtual std::string Visit(const UnionNode& node) { + virtual absl::StatusOr Visit(const UnionNode& node) { return name_visitor_.Visit(node); } - virtual std::string Visit(const DifferenceNode& node) { + virtual absl::StatusOr Visit(const DifferenceNode& node) { return name_visitor_.Visit(node); } - virtual std::string Visit(const IntersectionNode& node) { + virtual absl::StatusOr Visit(const IntersectionNode& node) { return name_visitor_.Visit(node); } - virtual std::string Visit(const ValueNode& node) { - return absl::StrCat( - ToString(node.Keys()), "->", - ToString(Eval>( - node, [this](std::string_view key) { return lookup_fn_(key); }))); + virtual absl::StatusOr Visit(const ValueNode& node) { + PS_ASSIGN_OR_RETURN( + auto values, + Eval>( + node, [this](std::string_view key) { return lookup_fn_(key); })); + return absl::StrCat(ToString(node.Keys()), "->", ToString(values)); + } + + virtual absl::StatusOr Visit(const NumberSetNode& node) { + auto numbers = node.GetValues(); + std::vector strings(numbers.size()); + std::transform(numbers.begin(), numbers.end(), strings.begin(), + [](uint64_t num) { return std::to_string(num); }); + return absl::StrCat("NumberSet(", ToString(strings), ")"); + } + + virtual absl::StatusOr Visit(const StringViewSetNode& node) { + return absl::StrCat("StringViewSet(", ToString(node.GetValues()), ")"); } private: @@ -68,48 +97,54 @@ class ASTDotGraphLabelVisitor : public ASTStringVisitor { lookup_fn_; }; -std::string DotNodeName(const Node& node, uint32_t namecnt) { +absl::StatusOr DotNodeName(const Node& node, uint32_t namecnt) { ASTNameVisitor name_visitor; - return absl::StrCat(node.Accept(name_visitor), namecnt); + PS_ASSIGN_OR_RETURN(auto name, node.Accept(name_visitor)); + return absl::StrCat(name, namecnt); } -std::string ToDotGraphBody( +absl::StatusOr ToDotGraphBody( const Node& node, uint32_t* namecnt, std::function(std::string_view)> lookup_fn) { ASTDotGraphLabelVisitor label_visitor(lookup_fn); - const std::string label = node.Accept(label_visitor); - const std::string node_name = DotNodeName(node, *namecnt); + PS_ASSIGN_OR_RETURN(const std::string label, node.Accept(label_visitor)); + PS_ASSIGN_OR_RETURN(const std::string node_name, DotNodeName(node, *namecnt)); std::string dot_str = absl::StrCat(node_name, " [label=\"", label, "\"]\n"); if (node.Left() != nullptr) { *namecnt = *namecnt + 1; - const std::string arrow = - absl::StrCat(node_name, " -- ", DotNodeName(*node.Left(), *namecnt)); - absl::StrAppend(&dot_str, arrow, "\n", - ToDotGraphBody(*node.Left(), namecnt, lookup_fn)); + PS_ASSIGN_OR_RETURN(const auto node_name_left, + DotNodeName(*node.Left(), *namecnt)); + const std::string arrow = absl::StrCat(node_name, " -- ", node_name_left); + PS_ASSIGN_OR_RETURN(const auto dgb, + ToDotGraphBody(*node.Left(), namecnt, lookup_fn)); + absl::StrAppend(&dot_str, arrow, "\n", dgb); } if (node.Right() != nullptr) { *namecnt = *namecnt + 1; - const std::string arrow = - absl::StrCat(node_name, " -- ", DotNodeName(*node.Right(), *namecnt)); - absl::StrAppend(&dot_str, arrow, "\n", - ToDotGraphBody(*node.Right(), namecnt, lookup_fn)); + PS_ASSIGN_OR_RETURN(const auto node_name_right, + DotNodeName(*node.Right(), *namecnt)); + const std::string arrow = absl::StrCat(node_name, " -- ", node_name_right); + PS_ASSIGN_OR_RETURN(const auto dgb, + ToDotGraphBody(*node.Right(), namecnt, lookup_fn)); + absl::StrAppend(&dot_str, arrow, "\n", dgb); } return dot_str; } } // namespace -void QueryDotWriter::WriteAst( +absl::Status QueryDotWriter::WriteAst( std::string_view query, const Node& node, std::function(std::string_view)> lookup_fn) { uint32_t namecnt = 0; const std::string title = absl::StrCat("labelloc=\"t\"\nlabel=\"AST for Query: ", query, "\"\n"); - file_ << absl::StrCat("graph {\n", title, - ToDotGraphBody(node, &namecnt, std::move(lookup_fn)), - "\n}\n"); + PS_ASSIGN_OR_RETURN(const auto dgb, + ToDotGraphBody(node, &namecnt, std::move(lookup_fn))); + file_ << absl::StrCat("graph {\n", title, dgb, "\n}\n"); + return absl::OkStatus(); } void QueryDotWriter::Flush() { file_.flush(); } diff --git a/components/tools/query_dot.h b/components/tools/query_dot.h index 9894d631..cffb7422 100644 --- a/components/tools/query_dot.h +++ b/components/tools/query_dot.h @@ -32,7 +32,7 @@ class QueryDotWriter { explicit QueryDotWriter(std::string_view path) : file_(path.data()) {} ~QueryDotWriter() { file_.close(); } // Outputs the dot representation of the AST node to the output path. - void WriteAst( + absl::Status WriteAst( const std::string_view query, const Node& node, std::function(std::string_view key)> lookup_fn); diff --git a/components/tools/query_toy.cc b/components/tools/query_toy.cc index 76d66b04..51bbf246 100644 --- a/components/tools/query_toy.cc +++ b/components/tools/query_toy.cc @@ -126,7 +126,11 @@ void PromptForQuery( std::getline(std::cin, query); ProcessQuery(driver, query); if (dot_writer && driver.GetRootNode()) { - dot_writer->WriteAst(query, *driver.GetRootNode(), Lookup); + if (const auto status = + dot_writer->WriteAst(query, *driver.GetRootNode(), Lookup); + !status.ok()) { + std::cout << "Failed to write AST with error: " << status << std::endl; + } dot_writer->Flush(); } } @@ -149,7 +153,11 @@ int main(int argc, char* argv[]) { if (!query.empty()) { ProcessQuery(driver, query); if (dot_writer && driver.GetRootNode()) { - dot_writer->WriteAst(query, *driver.GetRootNode(), Lookup); + if (const auto status = + dot_writer->WriteAst(query, *driver.GetRootNode(), Lookup); + !status.ok()) { + std::cout << "Failed to write AST with error: " << status << std::endl; + } } return 0; } diff --git a/components/tools/sharding_correctness_validator/BUILD.bazel b/components/tools/sharding_correctness_validator/BUILD.bazel index 643d25c7..8e62c3d9 100644 --- a/components/tools/sharding_correctness_validator/BUILD.bazel +++ b/components/tools/sharding_correctness_validator/BUILD.bazel @@ -24,14 +24,16 @@ cc_binary( }), deps = [ "//components/cloud_config:parameter_client", - "//components/data_server/request_handler:ohttp_client_encryptor", + "//components/data/converters:cbor_converter", + "//components/data_server/request_handler:get_values_v2_handler", + "//components/data_server/request_handler/encryption:ohttp_client_encryptor", "//components/data_server/server:key_fetcher_factory", "//components/data_server/server:parameter_fetcher", "//components/tools/util:configure_telemetry_tools", "//components/util:platform_initializer", + "//public:constants", "//public/applications/pa:response_utils", "//public/query/cpp:grpc_client", - "@com_github_google_quiche//quiche:binary_http_unstable_api", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/log", @@ -40,5 +42,8 @@ cc_binary( "@com_google_absl//absl/random", "@com_google_absl//absl/strings", "@google_privacysandbox_servers_common//src/communication:encoding_utils", + "@google_privacysandbox_servers_common//src/communication:framing_utils", + "@libcbor//:cbor", + "@nlohmann_json//:lib", ], ) diff --git a/components/tools/sharding_correctness_validator/validator.cc b/components/tools/sharding_correctness_validator/validator.cc index a33af827..4799e077 100644 --- a/components/tools/sharding_correctness_validator/validator.cc +++ b/components/tools/sharding_correctness_validator/validator.cc @@ -24,16 +24,22 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "components/cloud_config/parameter_client.h" -#include "components/data_server/request_handler/ohttp_client_encryptor.h" +#include "components/data/converters/cbor_converter.h" +#include "components/data_server/request_handler/encryption/ohttp_client_encryptor.h" +#include "components/data_server/request_handler/get_values_v2_handler.h" #include "components/data_server/server/key_fetcher_factory.h" #include "components/data_server/server/parameter_fetcher.h" #include "components/tools/util/configure_telemetry_tools.h" #include "components/util/platform_initializer.h" +#include "nlohmann/json.hpp" #include "public/applications/pa/response_utils.h" +#include "public/constants.h" #include "public/query/cpp/grpc_client.h" #include "public/query/v2/get_values_v2.grpc.pb.h" -#include "quiche/binary_http/binary_http_message.h" #include "src/communication/encoding_utils.h" +#include "src/communication/framing_utils.h" + +#include "cbor.h" ABSL_DECLARE_FLAG(std::string, gcp_project_id); @@ -55,9 +61,6 @@ inline constexpr std::string_view kPublicKeyEndpointParameterSuffix = "public-key-endpoint"; inline constexpr std::string_view kUseRealCoordinatorsParameterSuffix = "use-real-coordinators"; -inline constexpr std::string_view kContentTypeHeader = "content-type"; -inline constexpr std::string_view kContentEncodingProtoHeaderValue = - "application/protobuf"; absl::BitGen bitgen; int total_failures = 0; @@ -76,6 +79,53 @@ int64_t Get(int64_t upper_bound) { return absl::Uniform(bitgen, 0, upper_bound); } +// Cannot use existing CborDecodeToNonBytesProto, since that doesn't handle +// bytes (content field). To fix this, convert the content from a +// cbor serialized list to a proto serialized list. +// +// This function is only used for testing purposes and the output might +// not accurately represent what an actual response proto content would look +// like. +absl::Status CborDecodeToGetValuesResponseProto( + std::string_view cbor_raw, v2::GetValuesResponse& response) { + nlohmann::json json_from_cbor = nlohmann::json::from_cbor( + cbor_raw, /*strict=*/true, /*allow_exceptions=*/false); + if (json_from_cbor.is_discarded()) { + return absl::InternalError("Failed to convert raw CBOR buffer to JSON"); + } + for (auto& json_compression_group : json_from_cbor["compressionGroups"]) { + // Convert CBOR serialized list to a JSON list of partition outputs + PS_ASSIGN_OR_RETURN( + auto json_partition_outputs, + GetPartitionOutputsInJson(json_compression_group["content"])); + + // Put the JSON list of partition outputs into a V2CompressionGroup + // This is necessary so that we can properly serialize the proto. + // The proto library only supports serializing a message, not + // a vector/list of protos. + application_pa::V2CompressionGroup v2_compression_group; + for (const auto& json_partition_output : json_partition_outputs) { + PS_ASSIGN_OR_RETURN(*v2_compression_group.add_partition_outputs(), + application_pa::PartitionOutputFromJson( + json_partition_output.dump())); + } + // Serialize the V2CompressionGroup and set it as the compression group's + // content + LOG(INFO) << "V2 compression group: " << v2_compression_group; + std::string serialized_content; + if (!v2_compression_group.SerializeToString(&serialized_content)) { + return absl::InternalError(absl::StrCat( + "Failed to serialize proto to string: ", v2_compression_group)); + } + json_compression_group.erase("content"); + auto* compression_group = response.add_compression_groups(); + PS_RETURN_IF_ERROR(google::protobuf::util::JsonStringToMessage( + json_compression_group.dump(), compression_group)); + compression_group->set_content(std::move(serialized_content)); + } + return absl::OkStatus(); +} + absl::StatusOr GetPublicKey(std::unique_ptr& parameter_fetcher) { if (!parameter_fetcher->GetBoolParameter( @@ -123,32 +173,26 @@ absl::StatusOr GetValuesWithCoordinators( std::unique_ptr& public_key) { std::string serialized_req; - if (!proto_req.SerializeToString(&serialized_req)) { - return absl::Status(absl::StatusCode::kUnknown, - absl::StrCat("Protobuf SerializeToString failed!")); - } - quiche::BinaryHttpRequest req_bhttp_layer({}); - req_bhttp_layer.AddHeaderField({ - .name = std::string(kContentTypeHeader), - .value = std::string(kContentEncodingProtoHeaderValue), - }); - req_bhttp_layer.set_body(serialized_req); - auto maybe_serialized_bhttp = req_bhttp_layer.Serialize(); - if (!maybe_serialized_bhttp.ok()) { - return absl::Status( - absl::StatusCode::kInternal, - absl::StrCat(maybe_serialized_bhttp.status().message())); + PS_ASSIGN_OR_RETURN(serialized_req, + V2GetValuesRequestProtoToCborEncode(proto_req)); + auto encoded_data_size = privacy_sandbox::server_common::GetEncodedDataSize( + serialized_req.size(), kMinResponsePaddingBytes); + auto maybe_padded_request = + privacy_sandbox::server_common::EncodeResponsePayload( + privacy_sandbox::server_common::CompressionType::kUncompressed, + std::move(serialized_req), encoded_data_size); + if (!maybe_padded_request.ok()) { + LOG(ERROR) << "Padding failed: " << maybe_padded_request.status().message(); + return maybe_padded_request.status(); } - if (!public_key) { const std::string error = "public_key==nullptr, cannot proceed."; LOG(ERROR) << error; return absl::InternalError(error); } OhttpClientEncryptor encryptor(*public_key); - auto encrypted_serialized_request_maybe = - encryptor.EncryptRequest(*maybe_serialized_bhttp); + encryptor.EncryptRequest(*maybe_padded_request); if (!encrypted_serialized_request_maybe.ok()) { return encrypted_serialized_request_maybe.status(); } @@ -156,6 +200,8 @@ absl::StatusOr GetValuesWithCoordinators( ohttp_req.mutable_raw_body()->set_data(*encrypted_serialized_request_maybe); google::api::HttpBody ohttp_res; grpc::ClientContext context; + context.AddMetadata(std::string(kKVContentTypeHeader), + std::string(kContentEncodingCborHeaderValue)); grpc::Status status = stub->ObliviousGetValues(&context, ohttp_req, &ohttp_res); if (!status.ok()) { @@ -175,19 +221,11 @@ absl::StatusOr GetValuesWithCoordinators( LOG(ERROR) << "unpadding response failed!"; return deframed_req.status(); } - const absl::StatusOr maybe_res_bhttp_layer = - quiche::BinaryHttpResponse::Create(deframed_req->compressed_data); - if (!maybe_res_bhttp_layer.ok()) { - LOG(ERROR) << "Failed to create bhttp resonse layer!"; - return maybe_res_bhttp_layer.status(); - } - v2::GetValuesResponse get_value_response; - if (!get_value_response.ParseFromString( - std::string(maybe_res_bhttp_layer->body()))) { - return absl::Status(absl::StatusCode::kUnknown, - absl::StrCat("Protobuf ParseFromString failed!")); - } - return get_value_response; + v2::GetValuesResponse get_values_response; + PS_RETURN_IF_ERROR(CborDecodeToGetValuesResponseProto( + deframed_req->compressed_data, get_values_response)); + LOG(INFO) << "response: " << get_values_response; + return get_values_response; } v2::GetValuesRequest GetRequest(const std::vector& input_values) { @@ -209,12 +247,35 @@ absl::StatusOr GetValueFromResponse( if (!maybe_response.ok()) { return maybe_response.status(); } - auto output = maybe_response->single_partition().string_output(); - auto maybe_proto = application_pa::KeyGroupOutputsFromJson(output); - if (!maybe_proto.ok()) { - return maybe_proto.status(); + // We are only sending 1 partition, so should only get 1 partition back. + if (maybe_response->compression_groups().size() != 1) { + return absl::InvalidArgumentError(absl::StrFormat( + "Expected compression group size is 1, but found %s.", + std::to_string(maybe_response->compression_groups().size()))); + } + // TODO(b/355464083): Will need to uncompress once compression is implemented + auto content = maybe_response->compression_groups(0).content(); + application_pa::V2CompressionGroup v2_compression_group; + if (!v2_compression_group.ParseFromString(content)) { + return absl::InvalidArgumentError(absl::StrCat( + "Could not parse V2CompressionGroup from content: ", content)); + } + + // Expecting 1 partition in the key group output + if (v2_compression_group.partition_outputs_size() != 1) { + return absl::InvalidArgumentError(absl::StrCat( + "Expected key group output to be a list of size 1, but received: ", + v2_compression_group)); + } + + auto partition_output = v2_compression_group.partition_outputs(0); + // Expecting 1 kv in the response key group output + if (partition_output.key_group_outputs_size() != 1) { + return absl::InvalidArgumentError(absl::StrCat( + "Expected key group output to be a list of size 1, but received: ", + partition_output)); } - auto key_group_outputs = maybe_proto->key_group_outputs(); + auto key_group_outputs = partition_output.key_group_outputs(); if (key_group_outputs.empty()) { return absl::InvalidArgumentError("key_group_outputs empty"); } @@ -245,6 +306,7 @@ void ValidateResponse(absl::StatusOr maybe_response, std::vector& keys) { const int value_size = absl::GetFlag(FLAGS_value_size); for (const auto& key : keys) { + LOG(INFO) << "Validating key: " << key; auto maybe_response_value = GetValueFromResponse(maybe_response, key); if (!maybe_response_value.ok()) { total_failures++; @@ -321,7 +383,7 @@ void Validate( // _number_of_requests_to_make_. Each request has _batch_size_ number of keys to // lookup. The assumptions for these tests are following. The keys loaded to the // kv server are of format _key_prefix_{0....inclusive_upper_bound} Each value -// is deterministiclly mapped from the key -- const std::string +// is deterministically mapped from the key -- const std::string // expected_value(value_size, key[key.size() - 1]); For each request a random // key from the key space is selected. And the request look up that key and // _batch_size_ of the sequential keys. diff --git a/components/udf/BUILD.bazel b/components/udf/BUILD.bazel index 21fdd0d9..b24d81a3 100644 --- a/components/udf/BUILD.bazel +++ b/components/udf/BUILD.bazel @@ -112,7 +112,6 @@ cc_test( "//public/test_util:proto_matcher", "//public/test_util:request_example", "//public/udf:constants", - "@com_google_absl//absl/log:scoped_mock_log", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", diff --git a/components/udf/hooks/logging_hook.h b/components/udf/hooks/logging_hook.h index 020e557b..e681c46d 100644 --- a/components/udf/hooks/logging_hook.h +++ b/components/udf/hooks/logging_hook.h @@ -26,19 +26,20 @@ namespace kv_server { -// Logging function to register with Roma. -inline void LoggingFunction(absl::LogSeverity severity, - const std::weak_ptr& context, - std::string_view msg) { - std::shared_ptr request_context = context.lock(); +// UDF hook for logging a string. +// TODO(b/285331079): Disable for production builds. +inline void LogMessage( + google::scp::roma::FunctionBindingPayload>& + payload) { + std::shared_ptr request_context = payload.metadata.lock(); if (request_context == nullptr) { PS_VLOG(1) << "Request context is not available, the request might " "have been marked as complete"; return; } - PS_VLOG(9, request_context->GetPSLogContext()) << "Called logging hook"; - privacy_sandbox::server_common::log::LogWithPSLog( - severity, request_context->GetPSLogContext(), msg); + PS_VLOG(10, request_context->GetPSLogContext()) << "Called logging hook"; + PS_LOG(INFO, request_context->GetPSLogContext()) + << payload.io_proto.input_string(); } } // namespace kv_server diff --git a/components/udf/hooks/run_query_hook.h b/components/udf/hooks/run_query_hook.h index db260fe8..7726b631 100644 --- a/components/udf/hooks/run_query_hook.h +++ b/components/udf/hooks/run_query_hook.h @@ -63,8 +63,13 @@ constexpr std::string_view RunSetQueryHook::HookName() { if constexpr (std::is_same_v) { return "runQuery"; } - if constexpr (std::is_same_v) { - return "runSetQueryInt"; + if constexpr (std::is_same_v) { + return "runSetQueryUInt32"; + } + if constexpr (std::is_same_v) { + return "runSetQueryUInt64"; } } @@ -121,8 +126,14 @@ void RunSetQueryHook::operator()( response_or_status = lookup_->RunQuery(*request_context, payload.io_proto.input_string()); } - if constexpr (std::is_same_v) { - response_or_status = lookup_->RunSetQueryInt( + if constexpr (std::is_same_v) { + response_or_status = lookup_->RunSetQueryUInt32( + *request_context, payload.io_proto.input_string()); + } + if constexpr (std::is_same_v) { + response_or_status = lookup_->RunSetQueryUInt64( *request_context, payload.io_proto.input_string()); } if (!response_or_status.ok()) { @@ -144,11 +155,18 @@ void RunSetQueryHook::operator()( *payload.io_proto.mutable_output_list_of_string()->mutable_data() = std::move(*response_or_status.value().mutable_elements()); } - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { const auto& elements = response_or_status->elements(); payload.io_proto.set_output_bytes(elements.data(), elements.size() * sizeof(uint32_t)); } + if constexpr (std::is_same_v) { + const auto& elements = response_or_status->elements(); + payload.io_proto.set_output_bytes(elements.data(), + elements.size() * sizeof(uint64_t)); + } PS_VLOG(9, request_context->GetPSLogContext()) << HookName() << " result: " << payload.io_proto.DebugString(); } @@ -160,7 +178,10 @@ RunSetQueryHook::Create() { } using RunSetQueryStringHook = RunSetQueryHook; -using RunSetQueryIntHook = RunSetQueryHook; +using RunSetQueryUInt32Hook = + RunSetQueryHook; +using RunSetQueryUInt64Hook = + RunSetQueryHook; } // namespace kv_server diff --git a/components/udf/hooks/run_query_hook_test.cc b/components/udf/hooks/run_query_hook_test.cc index 749d5edd..fec25e2b 100644 --- a/components/udf/hooks/run_query_hook_test.cc +++ b/components/udf/hooks/run_query_hook_test.cc @@ -69,22 +69,22 @@ TEST_F(RunQueryHookTest, SuccessfullyProcessesValue) { UnorderedElementsAreArray({"a", "b"})); } -TEST_F(RunQueryHookTest, VerifyProcessingIntSetsSuccessfully) { - InternalRunSetQueryIntResponse run_query_response; +TEST_F(RunQueryHookTest, VerifyProcessingUInt32SetsSuccessfully) { + InternalRunSetQueryUInt32Response run_query_response; TextFormat::ParseFromString(R"pb(elements: 1000 elements: 1001)pb", &run_query_response); auto mock_lookup = std::make_unique(); - EXPECT_CALL(*mock_lookup, RunSetQueryInt(_, "Q")) + EXPECT_CALL(*mock_lookup, RunSetQueryUInt32(_, "Q")) .WillOnce(Return(run_query_response)); FunctionBindingIoProto io; TextFormat::ParseFromString(R"pb(input_string: "Q")pb", &io); - auto run_query_hook = RunSetQueryIntHook::Create(); + auto run_query_hook = RunSetQueryUInt32Hook::Create(); run_query_hook->FinishInit(std::move(mock_lookup)); FunctionBindingPayload> payload{ io, GetRequestContext()}; (*run_query_hook)(payload); ASSERT_TRUE(io.has_output_bytes()); - InternalRunSetQueryIntResponse actual_response; + InternalRunSetQueryUInt32Response actual_response; actual_response.mutable_elements()->Resize( io.output_bytes().size() / sizeof(uint32_t), 0); std::memcpy(actual_response.mutable_elements()->mutable_data(), @@ -111,13 +111,13 @@ TEST_F(RunQueryHookTest, RunQueryClientReturnsError) { {R"({"code":2,"message":"runQuery failed with error: Some error"})"})); } -TEST_F(RunQueryHookTest, RunSetQueryIntClientReturnsError) { +TEST_F(RunQueryHookTest, RunSetQueryUInt32ClientReturnsError) { auto mock_lookup = std::make_unique(); - EXPECT_CALL(*mock_lookup, RunSetQueryInt(_, "Q")) + EXPECT_CALL(*mock_lookup, RunSetQueryUInt32(_, "Q")) .WillOnce(Return(absl::UnknownError("Some error"))); FunctionBindingIoProto io; TextFormat::ParseFromString(R"pb(input_string: "Q")pb", &io); - auto run_query_hook = RunSetQueryIntHook::Create(); + auto run_query_hook = RunSetQueryUInt32Hook::Create(); run_query_hook->FinishInit(std::move(mock_lookup)); FunctionBindingPayload> payload{ io, GetRequestContext()}; @@ -125,7 +125,7 @@ TEST_F(RunQueryHookTest, RunSetQueryIntClientReturnsError) { EXPECT_THAT( io.output_list_of_string().data(), UnorderedElementsAreArray( - {R"({"code":2,"message":"runSetQueryInt failed with error: Some error"})"})); + {R"({"code":2,"message":"runSetQueryUInt32 failed with error: Some error"})"})); } TEST_F(RunQueryHookTest, InputIsNotString) { @@ -145,5 +145,46 @@ TEST_F(RunQueryHookTest, InputIsNotString) { {R"({"code":3,"message":"runQuery input must be a string"})"})); } +TEST_F(RunQueryHookTest, VerifyProcessingUInt64SetsSuccessfully) { + InternalRunSetQueryUInt64Response run_query_response; + TextFormat::ParseFromString(R"pb(elements: 18446744073709551614 + elements: 18446744073709551615)pb", + &run_query_response); + auto mock_lookup = std::make_unique(); + EXPECT_CALL(*mock_lookup, RunSetQueryUInt64(_, "Q")) + .WillOnce(Return(run_query_response)); + FunctionBindingIoProto io; + TextFormat::ParseFromString(R"pb(input_string: "Q")pb", &io); + auto run_query_hook = RunSetQueryUInt64Hook::Create(); + run_query_hook->FinishInit(std::move(mock_lookup)); + FunctionBindingPayload> payload{ + io, GetRequestContext()}; + (*run_query_hook)(payload); + ASSERT_TRUE(io.has_output_bytes()); + InternalRunSetQueryUInt64Response actual_response; + actual_response.mutable_elements()->Resize( + io.output_bytes().size() / sizeof(uint64_t), 0); + std::memcpy(actual_response.mutable_elements()->mutable_data(), + io.output_bytes().data(), io.output_bytes().size()); + EXPECT_THAT(actual_response, EqualsProto(run_query_response)); +} + +TEST_F(RunQueryHookTest, RunSetQueryUInt64ClientReturnsError) { + auto mock_lookup = std::make_unique(); + EXPECT_CALL(*mock_lookup, RunSetQueryUInt64(_, "Q")) + .WillOnce(Return(absl::UnknownError("Some error"))); + FunctionBindingIoProto io; + TextFormat::ParseFromString(R"pb(input_string: "Q")pb", &io); + auto run_query_hook = RunSetQueryUInt64Hook::Create(); + run_query_hook->FinishInit(std::move(mock_lookup)); + FunctionBindingPayload> payload{ + io, GetRequestContext()}; + (*run_query_hook)(payload); + EXPECT_THAT( + io.output_list_of_string().data(), + UnorderedElementsAreArray( + {R"({"code":2,"message":"runSetQueryUInt64 failed with error: Some error"})"})); +} + } // namespace } // namespace kv_server diff --git a/components/udf/udf_client.cc b/components/udf/udf_client.cc index 9ed410e9..d989173d 100644 --- a/components/udf/udf_client.cc +++ b/components/udf/udf_client.cc @@ -119,7 +119,10 @@ class UdfClientImpl : public UdfClient { PS_VLOG(9, request_context_factory.Get().GetPSLogContext()) << "Executing UDF with input arg(s): " << absl::StrJoin(invocation_request.input, ","); - privacy_sandbox::server_common::Stopwatch stopwatch; + ScopeLatencyMetricsRecorder + latency_recorder( + request_context_factory.Get().GetUdfRequestMetricsContext()); const auto status = roma_service_.Execute( std::make_unique>>( std::move(invocation_request)), @@ -132,10 +135,10 @@ class UdfClientImpl : public UdfClient { } notification->Notify(); }); - if (!status.ok()) { + if (!status.status().ok()) { PS_LOG(ERROR, request_context_factory.Get().GetPSLogContext()) - << "Error sending UDF for execution: " << status; - return status; + << "Error sending UDF for execution: " << status.status(); + return status.status(); } notification->WaitForNotificationWithTimeout(udf_timeout_); @@ -152,13 +155,8 @@ class UdfClientImpl : public UdfClient { } // TODO(b/338813575): waiting on the K&B team. Once that's // implemented we should just use that number. - const auto udf_execution_time = stopwatch.GetElapsedTime(); metadata.custom_code_total_execution_time_micros = - absl::ToInt64Milliseconds(udf_execution_time); - LogIfError(request_context_factory.Get() - .GetUdfRequestMetricsContext() - .LogHistogram( - absl::ToDoubleMicroseconds(udf_execution_time))); + absl::ToInt64Milliseconds(latency_recorder.GetLatency()); return *result; } absl::Status Init() { return roma_service_.Init(); } diff --git a/components/udf/udf_client_test.cc b/components/udf/udf_client_test.cc index 10cc5092..d661bf48 100644 --- a/components/udf/udf_client_test.cc +++ b/components/udf/udf_client_test.cc @@ -20,7 +20,6 @@ #include #include -#include "absl/log/scoped_mock_log.h" #include "absl/status/statusor.h" #include "components/internal_server/mocks.h" #include "components/udf/code_config.h" @@ -456,15 +455,16 @@ TEST_F(UdfClientTest, JsJSONObjectInWithRunQueryHookSucceeds) { TEST_F(UdfClientTest, VerifyJsRunSetQueryIntHookSucceeds) { auto mock_lookup = std::make_unique(); - InternalRunSetQueryIntResponse response; + InternalRunSetQueryUInt32Response response; TextFormat::ParseFromString(R"pb(elements: 1000 elements: 1001)pb", &response); - ON_CALL(*mock_lookup, RunSetQueryInt(_, _)).WillByDefault(Return(response)); - auto run_query_hook = RunSetQueryIntHook::Create(); + ON_CALL(*mock_lookup, RunSetQueryUInt32(_, _)) + .WillByDefault(Return(response)); + auto run_query_hook = RunSetQueryUInt32Hook::Create(); run_query_hook->FinishInit(std::move(mock_lookup)); UdfConfigBuilder config_builder; absl::StatusOr> udf_client = UdfClient::Create( - std::move(config_builder.RegisterRunSetQueryIntHook(*run_query_hook) + std::move(config_builder.RegisterRunSetQueryUInt32Hook(*run_query_hook) .SetNumberOfWorkers(1) .Config())); EXPECT_TRUE(udf_client.ok()); @@ -472,7 +472,7 @@ TEST_F(UdfClientTest, VerifyJsRunSetQueryIntHookSucceeds) { .js = R"( function hello(input) { let keys = input.keys; - let bytes = runSetQueryInt(keys[0]); + let bytes = runSetQueryUInt32(keys[0]); if (bytes instanceof Uint8Array) { return Array.from(new Uint32Array(bytes.buffer)); } @@ -492,7 +492,7 @@ TEST_F(UdfClientTest, VerifyJsRunSetQueryIntHookSucceeds) { EXPECT_TRUE(stop.ok()); } -TEST_F(UdfClientTest, JsCallsLoggingFunctionLogForConsentedRequests) { +TEST_F(UdfClientTest, JsCallsLogMessageConsentedSucceeds) { std::stringstream log_ss; auto* logger_provider = opentelemetry::sdk::logs::LoggerProviderFactory::Create( @@ -503,19 +503,17 @@ TEST_F(UdfClientTest, JsCallsLoggingFunctionLogForConsentedRequests) { .release(); privacy_sandbox::server_common::log::logger_private = logger_provider->GetLogger("test").get(); + UdfConfigBuilder config_builder; absl::StatusOr> udf_client = - UdfClient::Create(std::move(config_builder.RegisterLoggingFunction() - .SetNumberOfWorkers(1) - .Config())); + UdfClient::Create(std::move( + config_builder.RegisterLoggingHook().SetNumberOfWorkers(1).Config())); EXPECT_TRUE(udf_client.ok()); absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{ .js = R"( function hello(input) { - const a = console.error("Error message"); - const b = console.warn("Warning message"); - const c = console.log("Info message"); + logMessage("first message"); return ""; } )", @@ -536,9 +534,7 @@ TEST_F(UdfClientTest, JsCallsLoggingFunctionLogForConsentedRequests) { EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("")"); auto output_log = log_ss.str(); - EXPECT_THAT(output_log, ContainsRegex("Error message")); - EXPECT_THAT(output_log, ContainsRegex("Warning message")); - EXPECT_THAT(output_log, ContainsRegex("Info message")); + EXPECT_THAT(output_log, ContainsRegex("first message")); absl::Status stop = udf_client.value()->Stop(); EXPECT_TRUE(stop.ok()); @@ -557,17 +553,14 @@ TEST_F(UdfClientTest, JsCallsLoggingFunctionNoLogForNonConsentedRequests) { logger_provider->GetLogger("test").get(); UdfConfigBuilder config_builder; absl::StatusOr> udf_client = - UdfClient::Create(std::move(config_builder.RegisterLoggingFunction() - .SetNumberOfWorkers(1) - .Config())); + UdfClient::Create(std::move( + config_builder.RegisterLoggingHook().SetNumberOfWorkers(1).Config())); EXPECT_TRUE(udf_client.ok()); absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{ .js = R"( function hello(input) { - const a = console.error("Error message"); - const b = console.warn("Warning message"); - const c = console.log("Info message"); + logMessage("first message"); return ""; } )", @@ -702,6 +695,11 @@ TEST_F(UdfClientTest, MetadataPassedSuccesfully) { { return "true"; } + if(metadata.partitionMetadata && + metadata.partitionMetadata.partition_level_key) + { + return "true"; + } return "false"; } )", @@ -711,8 +709,9 @@ TEST_F(UdfClientTest, MetadataPassedSuccesfully) { }); EXPECT_TRUE(code_obj_status.ok()); v2::GetValuesRequest req; - (*(req.mutable_metadata()->mutable_fields()))["is_pas"].set_string_value( - "true"); + auto* fields = req.mutable_metadata()->mutable_fields(); + (*fields)["is_pas"].set_string_value("true"); + (*fields)["partition_level_key"].set_string_value("true"); UDFExecutionMetadata udf_metadata; *udf_metadata.mutable_request_metadata() = *req.mutable_metadata(); google::protobuf::RepeatedPtrField args; @@ -749,6 +748,7 @@ TEST_F(UdfClientTest, DefaultUdfPASucceeds) { UdfConfigBuilder config_builder; absl::StatusOr> udf_client = UdfClient::Create( std::move(config_builder.RegisterStringGetValuesHook(*get_values_hook) + .RegisterLoggingHook() .SetNumberOfWorkers(1) .Config())); EXPECT_TRUE(udf_client.ok()); @@ -795,6 +795,7 @@ TEST_F(UdfClientTest, DefaultUdfPasKeyLookupFails) { UdfConfigBuilder config_builder; absl::StatusOr> udf_client = UdfClient::Create( std::move(config_builder.RegisterStringGetValuesHook(*get_values_hook) + .RegisterLoggingHook() .SetNumberOfWorkers(1) .Config())); EXPECT_TRUE(udf_client.ok()); @@ -847,6 +848,7 @@ TEST_F(UdfClientTest, DefaultUdfPasSucceeds) { UdfConfigBuilder config_builder; absl::StatusOr> udf_client = UdfClient::Create( std::move(config_builder.RegisterStringGetValuesHook(*get_values_hook) + .RegisterLoggingHook() .SetNumberOfWorkers(1) .Config())); EXPECT_TRUE(udf_client.ok()); @@ -885,6 +887,46 @@ TEST_F(UdfClientTest, DefaultUdfPasSucceeds) { EXPECT_TRUE(stop.ok()); } -} // namespace +TEST_F(UdfClientTest, VerifyJsRunSetQueryUInt64HookSucceeds) { + auto mock_lookup = std::make_unique(); + InternalRunSetQueryUInt64Response response; + TextFormat::ParseFromString(R"pb(elements: 18446744073709551614 + elements: 18446744073709551615)pb", + &response); + ON_CALL(*mock_lookup, RunSetQueryUInt64(_, _)) + .WillByDefault(Return(response)); + auto run_query_hook = RunSetQueryUInt64Hook::Create(); + run_query_hook->FinishInit(std::move(mock_lookup)); + UdfConfigBuilder config_builder; + absl::StatusOr> udf_client = UdfClient::Create( + std::move(config_builder.RegisterRunSetQueryUInt64Hook(*run_query_hook) + .SetNumberOfWorkers(1) + .Config())); + EXPECT_TRUE(udf_client.ok()); + absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{ + .js = R"( + function hello(input) { + const keys = input.keys; + const bytes = runSetQueryUInt64(keys[0]); + if (bytes instanceof Uint8Array) { + const uint64Array = new BigUint64Array(bytes.buffer); + return Array.from(uint64Array, uint64 => uint64.toString()); + } + return "runSetQueryInt failed."; + } + )", + .udf_handler_name = "hello", + .logical_commit_time = 1, + .version = 1, + }); + EXPECT_TRUE(code_obj_status.ok()); + absl::StatusOr result = udf_client.value()->ExecuteCode( + *request_context_factory_, {R"({"keys":["key1"]})"}, execution_metadata_); + ASSERT_TRUE(result.ok()) << result.status(); + EXPECT_EQ(*result, "[\"18446744073709551614\",\"18446744073709551615\"]"); + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} +} // namespace } // namespace kv_server diff --git a/components/udf/udf_config_builder.cc b/components/udf/udf_config_builder.cc index 6ba2e3c8..40925781 100644 --- a/components/udf/udf_config_builder.cc +++ b/components/udf/udf_config_builder.cc @@ -38,7 +38,9 @@ using google::scp::roma::FunctionBindingPayload; constexpr char kStringGetValuesHookJsName[] = "getValues"; constexpr char kBinaryGetValuesHookJsName[] = "getValuesBinary"; constexpr char kRunQueryHookJsName[] = "runQuery"; -constexpr char kRunSetQueryIntHookJsName[] = "runSetQueryInt"; +constexpr char kRunSetQueryUInt32HookJsName[] = "runSetQueryUInt32"; +constexpr char kRunSetQueryUInt64HookJsName[] = "runSetQueryUInt64"; +constexpr char kLoggingHookJsName[] = "logMessage"; std::unique_ptr>> GetValuesFunctionObject(GetValuesHook& get_values_hook, @@ -84,22 +86,40 @@ UdfConfigBuilder& UdfConfigBuilder::RegisterRunSetQueryStringHook( return *this; } -UdfConfigBuilder& UdfConfigBuilder::RegisterRunSetQueryIntHook( - RunSetQueryIntHook& run_set_query_int_hook) { +UdfConfigBuilder& UdfConfigBuilder::RegisterRunSetQueryUInt32Hook( + RunSetQueryUInt32Hook& run_set_query_uint32_hook) { auto run_query_function_object = std::make_unique< FunctionBindingObjectV2>>(); - run_query_function_object->function_name = kRunSetQueryIntHookJsName; + run_query_function_object->function_name = kRunSetQueryUInt32HookJsName; run_query_function_object->function = - [&run_set_query_int_hook]( + [&run_set_query_uint32_hook]( FunctionBindingPayload>& in) { - run_set_query_int_hook(in); + run_set_query_uint32_hook(in); }; config_.RegisterFunctionBinding(std::move(run_query_function_object)); return *this; } -UdfConfigBuilder& UdfConfigBuilder::RegisterLoggingFunction() { - config_.SetLoggingFunction(LoggingFunction); +UdfConfigBuilder& UdfConfigBuilder::RegisterRunSetQueryUInt64Hook( + RunSetQueryUInt64Hook& run_set_query_uint64_hook) { + auto run_query_function_object = std::make_unique< + FunctionBindingObjectV2>>(); + run_query_function_object->function_name = kRunSetQueryUInt64HookJsName; + run_query_function_object->function = + [&run_set_query_uint64_hook]( + FunctionBindingPayload>& in) { + run_set_query_uint64_hook(in); + }; + config_.RegisterFunctionBinding(std::move(run_query_function_object)); + return *this; +} + +UdfConfigBuilder& UdfConfigBuilder::RegisterLoggingHook() { + auto logging_function_object = std::make_unique< + FunctionBindingObjectV2>>(); + logging_function_object->function_name = kLoggingHookJsName; + logging_function_object->function = LogMessage; + config_.RegisterFunctionBinding(std::move(logging_function_object)); return *this; } diff --git a/components/udf/udf_config_builder.h b/components/udf/udf_config_builder.h index e6bcd965..988e7cc0 100644 --- a/components/udf/udf_config_builder.h +++ b/components/udf/udf_config_builder.h @@ -30,10 +30,13 @@ class UdfConfigBuilder { UdfConfigBuilder& RegisterRunSetQueryStringHook( RunSetQueryStringHook& run_query_hook); - UdfConfigBuilder& RegisterRunSetQueryIntHook( - RunSetQueryIntHook& run_set_query_int_hook); + UdfConfigBuilder& RegisterRunSetQueryUInt32Hook( + RunSetQueryUInt32Hook& run_set_query_uint32_hook); - UdfConfigBuilder& RegisterLoggingFunction(); + UdfConfigBuilder& RegisterRunSetQueryUInt64Hook( + RunSetQueryUInt64Hook& run_set_query_uint64_hook); + + UdfConfigBuilder& RegisterLoggingHook(); UdfConfigBuilder& SetNumberOfWorkers(int number_of_workers); diff --git a/components/util/request_context.cc b/components/util/request_context.cc index 6531b4fc..4616b334 100644 --- a/components/util/request_context.cc +++ b/components/util/request_context.cc @@ -38,9 +38,12 @@ RequestLogContext& RequestContext::GetRequestLogContext() const { void RequestContext::UpdateLogContext( const privacy_sandbox::server_common::LogContext& log_context, const privacy_sandbox::server_common::ConsentedDebugConfiguration& - consented_debug_config) { - request_log_context_ = - std::make_unique(log_context, consented_debug_config); + consented_debug_config, + std::optional< + absl::AnyInvocable> + debug_info_opt) { + request_log_context_ = std::make_unique( + log_context, consented_debug_config, std::move(debug_info_opt)); if (request_log_context_->GetRequestLoggingContext().is_consented()) { const std::string generation_id = request_log_context_->GetLogContext().generation_id().empty() @@ -71,11 +74,21 @@ RequestContext::GetPSLogContext() const { RequestLogContext::RequestLogContext( const privacy_sandbox::server_common::LogContext& log_context, const privacy_sandbox::server_common::ConsentedDebugConfiguration& - consented_debug_config) + consented_debug_config, + std::optional< + absl::AnyInvocable> + debug_info_opt) : log_context_(log_context), consented_debug_config_(consented_debug_config), request_logging_context_(GetContextMap(log_context), - consented_debug_config) {} + consented_debug_config) { + if (debug_info_opt.has_value()) { + request_logging_context_ = + privacy_sandbox::server_common::log::ContextImpl<>( + GetContextMap(log_context), consented_debug_config, + std::move(debug_info_opt.value())); + } +} privacy_sandbox::server_common::log::ContextImpl<>& RequestLogContext::GetRequestLoggingContext() { @@ -111,8 +124,12 @@ const RequestContext& RequestContextFactory::Get() const { void RequestContextFactory::UpdateLogContext( const privacy_sandbox::server_common::LogContext& log_context, const privacy_sandbox::server_common::ConsentedDebugConfiguration& - consented_debug_config) { - request_context_->UpdateLogContext(log_context, consented_debug_config); + consented_debug_config, + std::optional< + absl::AnyInvocable> + debug_info_opt) { + request_context_->UpdateLogContext(log_context, consented_debug_config, + std::move(debug_info_opt)); } } // namespace kv_server diff --git a/components/util/request_context.h b/components/util/request_context.h index 07fc8c80..835aaf48 100644 --- a/components/util/request_context.h +++ b/components/util/request_context.h @@ -18,6 +18,7 @@ #define COMPONENTS_UTIL_REQUEST_CONTEXT_H_ #include +#include #include #include @@ -33,7 +34,10 @@ class RequestLogContext { explicit RequestLogContext( const privacy_sandbox::server_common::LogContext& log_context, const privacy_sandbox::server_common::ConsentedDebugConfiguration& - consented_debug_config); + consented_debug_config, + std::optional< + absl::AnyInvocable> + debug_info_opt = std::nullopt); privacy_sandbox::server_common::log::ContextImpl<>& GetRequestLoggingContext(); @@ -79,7 +83,10 @@ class RequestContext { void UpdateLogContext( const privacy_sandbox::server_common::LogContext& log_context, const privacy_sandbox::server_common::ConsentedDebugConfiguration& - consented_debug_config); + consented_debug_config, + std::optional< + absl::AnyInvocable> + debug_info_opt = std::nullopt); UdfRequestMetricsContext& GetUdfRequestMetricsContext() const; InternalLookupMetricsContext& GetInternalLookupMetricsContext() const; RequestLogContext& GetRequestLogContext() const; @@ -128,7 +135,10 @@ class RequestContextFactory { void UpdateLogContext( const privacy_sandbox::server_common::LogContext& log_context, const privacy_sandbox::server_common::ConsentedDebugConfiguration& - consented_debug_config); + consented_debug_config, + std::optional< + absl::AnyInvocable> + debug_info_opt = std::nullopt); // Not movable and copyable to prevent making unnecessary // copies of underlying shared_ptr of request context, and moving of // shared ownership of request context diff --git a/components/util/sleepfor_mock.h b/components/util/sleepfor_mock.h index cfb577f7..d806451c 100644 --- a/components/util/sleepfor_mock.h +++ b/components/util/sleepfor_mock.h @@ -21,13 +21,13 @@ namespace kv_server { class MockSleepFor : public SleepFor { public: - MOCK_METHOD(bool, Duration, (absl::Duration), (const override)); + MOCK_METHOD(bool, Duration, (absl::Duration), (const, override)); MOCK_METHOD(absl::Status, Stop, (), (override)); }; class MockUnstoppableSleepFor : public UnstoppableSleepFor { public: - MOCK_METHOD(bool, Duration, (absl::Duration), (const override)); + MOCK_METHOD(bool, Duration, (absl::Duration), (const, override)); MOCK_METHOD(absl::Status, Stop, (), (override)); }; diff --git a/docs/AWS_Terraform_vars.md b/docs/AWS_Terraform_vars.md index 7892310c..d7401a64 100644 --- a/docs/AWS_Terraform_vars.md +++ b/docs/AWS_Terraform_vars.md @@ -1,5 +1,10 @@ # AWS Key Value Server Terraform vars documentation +- **add_chaff_sharding_clusters** + + Whether to add chaff sharding clusters. Only works for nonprod. For prod mode requests are + always chaffed. + - **add_missing_keys_v1** Add missing keys v1. @@ -123,6 +128,10 @@ supported from the [AWS article](https://docs.aws.amazon.com/enclaves/latest/user/nitro-enclave.html). +- **logging_verbosity_backup_poll_frequency_secs** + + Backup poll frequency in seconds for the logging verbosity parameter. + - **logging_verbosity_level** Logging verbosity level diff --git a/docs/GCP_Terraform_vars.md b/docs/GCP_Terraform_vars.md index c98ac831..c983e461 100644 --- a/docs/GCP_Terraform_vars.md +++ b/docs/GCP_Terraform_vars.md @@ -1,5 +1,10 @@ # GCP Key Value Server Terraform vars documentation +- **add_chaff_sharding_clusters** + + Whether to add chaff sharding clusters. Only works for nonprod. For prod mode requests are + always chaffed. + - **add_missing_keys_v1** Add missing keys v1. @@ -28,6 +33,10 @@ The grpc port that receives traffic destined for the OpenTelemetry collector +- **collector_startup_script_path** + + Relative path from main.tf to collector service startup script. + - **consented_debug_token** Consented debug token to enable the otel collection of consented logs. Empty token means no-op @@ -94,6 +103,10 @@ The grpc port that receives traffic destined for the frontend service. +- **logging_verbosity_backup_poll_frequency_secs** + + Backup poll frequency in seconds for the logging verbosity parameter. + - **logging_verbosity_level** Logging verbosity level diff --git a/docs/ami_structure_aws.md b/docs/ami_structure_aws.md index d6b068a3..62092508 100644 --- a/docs/ami_structure_aws.md +++ b/docs/ami_structure_aws.md @@ -50,5 +50,5 @@ Let: ## How is the server launched on EC2 instance launch? The -[instance launch script](https://github.com/privacysandbox/fledge-key-value-service/blob/main/production/terraform/aws/services/autoscaling/instance_init_script.tftpl) +[instance launch script](https://github.com/privacysandbox/protected-auction-key-value-service/blob/main/production/terraform/aws/services/autoscaling/instance_init_script.tftpl) is used to startup the KV server and other relevant components when the EC2 instance is launched. diff --git a/docs/assets/gcp_instance_prod_logs.png b/docs/assets/gcp_instance_prod_logs.png new file mode 100644 index 00000000..729c2137 Binary files /dev/null and b/docs/assets/gcp_instance_prod_logs.png differ diff --git a/docs/cloud_build/cloud_build_gcp.md b/docs/cloud_build/cloud_build_gcp.md new file mode 100644 index 00000000..5f3d2acc --- /dev/null +++ b/docs/cloud_build/cloud_build_gcp.md @@ -0,0 +1,84 @@ +# GCP Cloud Build for Key/Value Server + +## Overview + +This doc contains instructions on how to setup [GCP Cloud Build](https://cloud.google.com/build) to +build the Key/Value server Docker Images for use in +[Confidential Spaces](https://cloud.google.com/docs/security/confidential-space). These images can +be directly used for Key/Value server deployment on GCP. + +### Why do this? + +The Key/Value server can take around 1.5 ~ 2 hours (with 32 cores) to build. If you create an +automated build pipeline that builds new Key/Value server releases, you can avoid manual labor and +increase operational efficiency. Binaries and docker images will be provided directly in the future. + +## Cloud Build Configuration + +### Prerequisites + +#### Connecting to Github + +First, follow the steps to +[connect a Github repository](https://cloud.google.com/build/docs/automating-builds/github/connect-repo-github?generation=2nd-gen) +and create a host connection. You will need to clone the +[Key/Value server repo](https://github.com/privacysandbox/protected-auction-key-value-service) to +your own Github account before you can connect it to your GCP project's Cloud Build. Make sure that +your fork, if updated automatically, also fetches the tags from the upstream repo -- that way, you +can build directly from the semantically versioned tags. See +[here](/production/packaging/sync_key_value_repo.yaml) for an example Github Action that handles +syncing. + +#### Configuring an Image Repo + +Please create an [Artifact Registry](https://cloud.google.com/artifact-registry) repo to hold all of +the Key/Value server images that will be created. We use a default repo name of +`us-docker.pkg.dev/${PROJECT_ID}/kvs-docker-repo-shared/kv-service`. + +#### Service Account Permissions + +Navigate to the Cloud Build page in the GCP GUI and click on Settings. Make sure the service account +permissions have 'Service Account User' enabled. Then, in IAM, additionally make sure that the +service account has Artifact Registry Writer permissions. The build script will attempt to push +images to the image repo specified using the service account for permissions. + +### Create a Trigger + +#### Source + +You must create a build trigger. Starting with a +[manual](https://cloud.google.com/build/docs/triggers#manual) or +[Github](https://cloud.google.com/build/docs/triggers#github) trigger is recommended. Please make +sure to use a 'REPOSITORIES (RECOMMENDED)' repository source type. + +Recommendation 1: Use a `Push a new tag` Event to build `.*` tags. + +Recommendation 2: Create a separate trigger for each `_BUILD_FLAVOR` (see below). + +#### Configuration + +1. Type: Cloud Build configuration file (yaml or json) +1. Location: Repository + + ```plaintext + production/packaging/gcp/cloud_build/cloudbuild.yaml + ``` + +1. Substitution Variables + + Note: these will override variables in the cloudbuild.yaml. + + ```plaintext + key: _BUILD_FLAVOR + value: prod or non_prod. While 'prod' allows for attestation against production private keys, nonprod has enhanced logging. + + key: _GCP_IMAGE_REPO + value: service images repo URI from prerequisites (default: us-docker.pkg.dev/${PROJECT_ID}/kvs-docker-repo-shared/kv-service) + + key: _GCP_IMAGE_TAG + value: any tag (default: ${BUILD_ID}) + ``` + +1. Service account: Use the account created [previously](#service-account-permissions). + +After configuring your Trigger, click Save. You may manually run it from the Triggers page. diff --git a/docs/cloud_build/codebuild_aws.md b/docs/cloud_build/codebuild_aws.md new file mode 100644 index 00000000..f405b6a6 --- /dev/null +++ b/docs/cloud_build/codebuild_aws.md @@ -0,0 +1,209 @@ +# AWS CodeBuild for Key/Value Server + +## Overview + +This README contains instructions on how to setup [AWS CodeBuild](https://aws.amazon.com/codebuild/) +to build the Key/Value server [AMIs](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/AMIs.html). +These AMIs can then be directly used for Key/Value server deployment on AWS. + +### Why do this? + +The Key/Value server can take around 2 hours (with 72 cores) to build. If you create an automated +build pipeline that builds new Key/Value server releases, you can avoid manual labor and increase +operational efficiency. + +## CodeBuild Configuration + +### Prerequisites + +#### Choose an AWS Region + +Setup will be simplest if you do everything in a single region. Verify that AWS CodeBuild is +available in your chosen region via the +[Global Infrastructure page](https://aws.amazon.com/about-aws/global-infrastructure/regional-product-services/). + +#### Service Role + +You must first create a role via [IAM](https://aws.amazon.com/iam). This service role will need to +be able to build and upload AMIs, so we will need to modify its permission policies. Attach the +following policy (you can add it as an inline policy via the Role Permissions -> Add permissions +feature in IAM): + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "ecr-public:GetAuthorizationToken", + "sts:GetServiceBearerToken", + "ec2:AttachVolume", + "ec2:AuthorizeSecurityGroupIngress", + "ec2:CopyImage", + "ec2:CreateImage", + "ec2:CreateKeyPair", + "ec2:CreateSecurityGroup", + "ec2:CreateSnapshot", + "ec2:CreateTags", + "ec2:CreateVolume", + "ec2:DeleteKeyPair", + "ec2:DeleteSecurityGroup", + "ec2:DeleteSnapshot", + "ec2:DeleteVolume", + "ec2:DeregisterImage", + "ec2:DescribeImageAttribute", + "ec2:DescribeImages", + "ec2:DescribeInstances", + "ec2:DescribeInstanceStatus", + "ec2:DescribeRegions", + "ec2:DescribeSecurityGroups", + "ec2:DescribeSnapshots", + "ec2:DescribeSubnets", + "ec2:DescribeTags", + "ec2:DescribeVolumes", + "ec2:DetachVolume", + "ec2:GetPasswordData", + "ec2:ModifyImageAttribute", + "ec2:ModifyInstanceAttribute", + "ec2:ModifySnapshotAttribute", + "ec2:RegisterImage", + "ec2:RunInstances", + "ec2:StopInstances", + "ec2:TerminateInstances" + ], + "Resource": "*" + } + ] +} +``` + +Additionally, increase the `Maximum session duration` to 4 hours by editing the Role Summary. This +is necessary because the build uses the role's session token and we want to avoid the token expiring +while the build is still running. + +#### Docker Login Credentials + +You will need a [Docker Hub](https://hub.docker.com/) account because the images we use to run the +build are, by default, sourced from Docker Hub. + +Create an AWS Secret (via [Secret Manager](https://aws.amazon.com/secrets-manager/)) to store your +Docker Hub password. When creating the secret, choose secret type 'Other' and make sure to store the +password as plaintext (not as a key/value pair). Add the following permissions to your secret: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam:::role/" + }, + "Action": "secretsmanager:GetSecretValue", + "Resource": "*" + } + ] +} +``` + +The build process will login to the provided Docker Hub account at build time. + +### Project + +Create a build project in your preferred AWS Region. The following sections will guide you through +the build project configuration. + +### Source Repo + +Set Github as your source provider, and use the public repository: +`https://github.com/privacysandbox/protected-auction-key-value-service.git`. + +#### Webhooks + +Webhook events are not available for the 'Public repository' option. If you want to run an automatic +build for every release, you will have to clone the public repo into a Github repo under your +ownership and keep your clone up to date with the public repo. If you prefer to avoid cloning the +public repository, any time you want a new build for Bidding and Auction serivces you must follow +the steps [here](#start-build). + +If you do choose to use webhook events, select 'Repository in my Github account', and in the +'Primary source webhook events' section select 'Rebuild every time a code change is pushed to this +repository' for a 'Single build'. Then, add a 'Start a build' filter group of event type `PUSH`, +type `HEAD_REF`, and pattern `^refs/tags/.*`. This will build on every release tag. + +Make sure that your Github fork, if updated automatically, also fetches the tags from the upstream +repo -- that way, you can build directly from the semantically versioned tags. See +[here](../../production/packaging/sync_key_value_repo.yaml) for an example Github Action that +handles syncing. + +### Environment + +Use the following configuration: + +1. Provisioning model: On-Demand +1. Environment image: Managed image +1. Compute: EC2 +1. Operating system: Ubuntu +1. Runtime: Standard +1. Image: `aws/codebuild/standard:7.0` +1. Image version: Always use latest +1. Service role: see the [service role section](#service-role-permissions). +1. Report auto-discover: disable + +Additional configuration: + +1. Timeout: 4 Hours +1. Queued timeout: 8 Hours +1. Privileged: Enable +1. Certificate: Do not install +1. Compute: `145 GB memory, 72 vCPUs` +1. Environment variables: + + ```plaintext + key: AMI_REGION + value: + + key: DOCKERHUB_USERNAME + value: + + key: DOCKERHUB_PASSWORD_SECRET_NAME + value: + + key: BUILD_FLAVOR + value: + ``` + +### Buildspec + +Select `Use a buildspec file`. Enter the following: +`production/packaging/aws/codebuild/buildspec.yaml`. + +### Batch configuration + +Leave blank. + +### Artifacts + +Type: No artifacts + +Additional configuration: + +Encryption key: use default + +### Logs + +Enable CloudWatch logs. + +### Service role permissions + +Choose the service role you created in the [prerequisites](#service-role) and enable allowing AWS +CodeBuild to modify the role so it can be used with this build project. + +### Start Build + +Save the changes and create/update the project. Then, click `Start build with overrides` from the +Build projects directory and try building the Key/value server with a desired source version (e.g., +`main` or `v0.16.0`). + +If the build succeeds, you can find the new AMI at the end of the build log. diff --git a/docs/data_loading/data_loading_capabilities.md b/docs/data_loading/data_loading_capabilities.md index bcbd37f6..45c17802 100644 --- a/docs/data_loading/data_loading_capabilities.md +++ b/docs/data_loading/data_loading_capabilities.md @@ -10,9 +10,9 @@ different chunks of a single data file concurrently. For AWS S3, the level of co reader and the number of concurrent S3 client connections can be tuned using the following terraform parameters: -- [data_loading_num_threads](https://github.com/privacysandbox/fledge-key-value-service/blob/4d6f691b0d12f9604988c14f534f6e91f4025f29/production/terraform/aws/environments/kv_server_variables.tf) +- [data_loading_num_threads](https://github.com/privacysandbox/protected-auction-key-value-service/blob/4d6f691b0d12f9604988c14f534f6e91f4025f29/production/terraform/aws/environments/kv_server_variables.tf) sets the number of concurrent threads used to read a single data file [default: 16]. -- [s3client_max_connections](https://github.com/privacysandbox/fledge-key-value-service/blob/4d6f691b0d12f9604988c14f534f6e91f4025f29/production/terraform/aws/environments/kv_server_variables.tf) +- [s3client_max_connections](https://github.com/privacysandbox/protected-auction-key-value-service/blob/4d6f691b0d12f9604988c14f534f6e91f4025f29/production/terraform/aws/environments/kv_server_variables.tf) sets the maximum number of concurrent connections used by the S3 client [default: 64]. ## Benchmarking tool diff --git a/docs/data_loading/loading_data.md b/docs/data_loading/loading_data.md index 867cfee6..daec3099 100644 --- a/docs/data_loading/loading_data.md +++ b/docs/data_loading/loading_data.md @@ -268,7 +268,7 @@ delta file to a dedicated broadcast topic. In the case of AWS it is a Simple Notification Service (SNS) topic. That topic is created in terraform -[here](https://github.com/privacysandbox/fledge-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/production/terraform/aws/services/data_storage/main.tf#L107). +[here](https://github.com/privacysandbox/protected-auction-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/production/terraform/aws/services/data_storage/main.tf#L107). Delta files contain multiple rows, which allows you to batch multiple updates together. There is a [limit](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-messages.html) of 256KB for the message size. diff --git a/docs/data_loading/realtime_updates_capabilities.md b/docs/data_loading/realtime_updates_capabilities.md index 85834d3a..9d0f2db6 100644 --- a/docs/data_loading/realtime_updates_capabilities.md +++ b/docs/data_loading/realtime_updates_capabilities.md @@ -3,7 +3,7 @@ ## Overview A parameter -([AWS](https://github.com/privacysandbox/fledge-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/production/terraform/aws/environments/kv_server.tf#L51), +([AWS](https://github.com/privacysandbox/protected-auction-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/production/terraform/aws/environments/kv_server.tf#L51), [GCP](/docs/GCP_Terraform_vars.md#L96)) sets the size of the thread pool that reads off a queue. The bigger that number is, the smaller the batch size can be. It is preferred to use a larger batch size where possible. @@ -45,7 +45,7 @@ While similar logic applies, the GCP SDK has superior performance due to ## Multiple threads To get to a higher QPS we can have multiple threads reading off a queue. This is a parameter -([AWS](https://github.com/privacysandbox/fledge-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/production/terraform/aws/environments/kv_server.tf#L51), +([AWS](https://github.com/privacysandbox/protected-auction-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/production/terraform/aws/environments/kv_server.tf#L51), [GCP](/docs/GCP_Terraform_vars.md#L96)) that our solution exposes. It can be increased to match specific QPS requirements and underlying hardware - based on the number of cores. @@ -170,10 +170,10 @@ the batch size. You can test our service with our tools. -[Data generation script](https://github.com/privacysandbox/fledge-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/tools/serving_data_generator/generate_load_test_data) +[Data generation script](https://github.com/privacysandbox/protected-auction-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/tools/serving_data_generator/generate_load_test_data) Allows to generate N deltas, with B updates per delta. You can configure the N, and B (batch size). -[Publisher](https://github.com/privacysandbox/fledge-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/components/tools/realtime_updates_publisher.cc#L122) +[Publisher](https://github.com/privacysandbox/protected-auction-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/components/tools/realtime_updates_publisher.cc#L122) allows to insert batched updates at the specified rate from a specified folder. ### Getting p values @@ -183,7 +183,7 @@ allows to insert batched updates at the specified rate from a specified folder. You can query Prometheus directly by using the script below. You can update the p value, 0.5 in this case, to the value you're interested in. AMP_QUERY_ENDPOINT can be found in AWS UI. It is a url for the prometheus workspace that's created by -[this](https://github.com/privacysandbox/fledge-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/production/terraform/aws/services/telemetry/main.tf#L17) +[this](https://github.com/privacysandbox/protected-auction-key-value-service/blob/7f3710b1f1c944d7879718a334afd5cb8f80f3d9/production/terraform/aws/services/telemetry/main.tf#L17) ```sh AWS_ACCESS_KEY_ID=... diff --git a/docs/deployment/deploying_locally.md b/docs/deployment/deploying_locally.md index d0e1ce03..be50f982 100644 --- a/docs/deployment/deploying_locally.md +++ b/docs/deployment/deploying_locally.md @@ -29,12 +29,12 @@ follow the instructions for ## Get the source code from GitHub The code for the FLEDGE Key/Value server is released on -[GitHub](https://github.com/privacysandbox/fledge-key-value-service). +[GitHub](https://github.com/privacysandbox/protected-auction-key-value-service). Using Git, clone the repository into a folder: ```sh -git clone https://github.com/privacysandbox/fledge-key-value-service.git +git clone https://github.com/privacysandbox/protected-auction-key-value-service.git ``` ## Build the local binary diff --git a/docs/deployment/deploying_on_aws.md b/docs/deployment/deploying_on_aws.md index c8f9c424..902d0c7e 100644 --- a/docs/deployment/deploying_on_aws.md +++ b/docs/deployment/deploying_on_aws.md @@ -18,9 +18,8 @@ To learn more about FLEDGE and the Key/Value server, take a look at the followin > with the functionality and high level user experience. As more privacy protection mechanisms > are added to the system, this document will be updated accordingly. -For the initial testing of the Key/Value server, you must have or -[create an Amazon Web Services (AWS)](https://portal.aws.amazon.com/billing/signup/iam) account. -You'll need API access, as well as your key ID and secret key. +For the initial testing of the Key/Value server, you must have or create an Amazon Web Services +(AWS) account. You'll need API access, as well as your key ID and secret key. # Set up your AWS account @@ -149,6 +148,15 @@ export AWS_REGION=us-east-1 # For example. Then run `dist/aws/push_sqs` to push the SQS cleanup lambda image to AWS ECR. +If you want to deploy a new version of SQS cleanup lambda image to clean up the expired sqs queues +for KV servers that already had been deployed, after running `dist/aws/push_sqs` command, run the +aws update-function-code command to notify the AWS lambda sqs clean up function to pick up the new +lambda image. + +```shell +aws lambda update-function-code --function-name kv-server--sqs-cleanup --image-uri .dkr.ecr..amazonaws.com/sqs_lambda:latest +``` + ## Set up Terraform The setup scripts require Terraform version 1.2.3. There is a helper script /tools/terraform, which @@ -180,6 +188,21 @@ Update the `[[REGION]].backend.conf`: - `region` - Set the region where Terraform will run. This should be the same as the region in the variables defined. +## Bidding an Auction services integration within the same VPC + +If you're integrating with Bidding and Auction services (B&A), you are likely going to be reusing +the same VPC (virtual private cloud), subnets and AWS AppMesh (internal LB). In this case, you need +the following changes: + +- Make sure you are deploying the Key/Value server in the same region (specified by the `region` + terraform variable) and under the same AWS account as B&A servers. +- Set the terraform variable `use_existing_vpc` to `true`. +- Set the terraform variable `existing_vpc_environment` as the environment from B&A's deployment. +- Set the terraform variable `existing_vpc_operator` as the operator from B&A's deployment (for + example, `buyer1`). +- Optionally, you can set the terraform variable `enable_external_traffic` to `false` if you only + need to handle traffic from B&A servers. + ## Apply Terraform From your `repository/production/terraform/aws/environments` folder, run: @@ -304,6 +327,16 @@ Or gRPC (using [grpcurl](https://github.com/fullstorydev/grpcurl)): grpcurl --protoset dist/query_api_descriptor_set.pb -d '{"raw_body": {"data": "'"$(echo -n $BODY|base64 -w 0)"'"}}' demo.kv-server.your-domain.example:8443 kv_server.v2.KeyValueService/GetValuesHttp ``` +If you deploy the Key/Value server under the same VPC as the B&A servers (terraform variable +`use_existing_vpc` is set to `true`), you can ssh into the target B&A server (must be a server that +is configured to query the Key/Value server), and then use the following command to place a query: + +```sh +grpcurl --plaintext -d '{"kv_internal":"hi"}' kv-server--appmesh-virtual-service.kv-server.privacysandboxdemo.app:50051 kv_server.v1.KeyValueService.GetValues +``` + +where `` should be replaced by the Key/Value server's `environment`. + ## SSH into EC2 ![how a single SSH instance is used to log into multiple server instances](../assets/ssh_instance.png) @@ -357,6 +390,11 @@ instance id is `i-00f54fe22aa47367f`): mssh i-00f54fe22aa47367f --region us-east-1 ``` +### Alternative: Connect via Session Manager + +Navigate to actual EC2 instance and connect via Session Manager by following the instructions on +[Connect to your Amazon EC2 instance using Session Manager](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-connect-methods.html) + Once you have connected to the instance, run `ls` to see the content of the server. The output should look similar to something like this: diff --git a/docs/deployment/deploying_on_gcp.md b/docs/deployment/deploying_on_gcp.md index f3eff5c0..9ae87716 100644 --- a/docs/deployment/deploying_on_gcp.md +++ b/docs/deployment/deploying_on_gcp.md @@ -96,10 +96,10 @@ run into any Docker access errors, follow the instructions for ## Get the source code from GitHub The code for the FLEDGE Key/Value server is released on -[GitHub](https://github.com/privacysandbox/fledge-key-value-service). +[GitHub](https://github.com/privacysandbox/protected-auction-key-value-service). The main branch is under active development. For a more stable experience, please use the -[latest release branch](https://github.com/privacysandbox/fledge-key-value-service/releases). +[latest release branch](https://github.com/privacysandbox/protected-auction-key-value-service/releases). ## Build the Docker image for GCP @@ -245,17 +245,24 @@ listed to the right. Instances associated with your Kv-server have the name star ## Access server logs -In the instance details page, under `Logs`, you can access server logs in both `Logging` and -`Serial port (console)`. The `Logging` option is more powerful with better filtering and query -support on `Logs Explorer`. +### nonprod build -The console log is located under resource type `VM Instance`. When server is running in prod mode, -the console log will not be available. However, if parameter `enable_otel_logger` is set to true, KV -server will export selective server logs to `Logs Explorer` under resource type `Generic Task`. More -details about logging in `prod mode` and `nonprod mode` in -![developing the server](/docs/developing_the_server.md). +In the instance details page, under `Logs`, you can access server console logs in both `Logging` and +`Serial port (console)`. To enable console log, you need to set `use_confidential_space_debug_image` +parameter to `true`. The `Logging` option is more powerful with better filtering and query support +on `Logs Explorer`. In `Logs Explorer`, the server log is located under resource type +`VM Instance`(console log), and `Generic Task`(log exported through otel collector) +![how to access GCP instance logs for nonprod build server](../assets/gcp_instance_logs.png) -![how to access GCP instance logs](../assets/gcp_instance_logs.png) +### prod build + +When server is running in prod build, the console log will not be available. However, if parameter +`enable_otel_logger` is set to true, KV server will export selective server logs to `Logs Explorer` +under resource type `Generic Task`. +![how to access GCP instance logs for prod build server](../assets/gcp_instance_prod_logs.png) + +More details about logging in `prod mode` and `nonprod mode` in +[developing the server](/docs/developing_the_server.md). ## Query the server diff --git a/docs/playbook/index.md b/docs/playbook/index.md index 53dd577c..207f075b 100644 --- a/docs/playbook/index.md +++ b/docs/playbook/index.md @@ -28,10 +28,8 @@ these features on are on our roadmap and will be available when we move to GA. [Monitoring Protected Audience API Services](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/monitoring_protected_audience_api_services.md) has high-level suggestions for Protected Audience monitoring. All metrics available to KV server are -defined in [server_definition.h](../../components/telemetry/server_definition.h). - -[Privacy safe telemetry](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/monitoring_protected_audience_api_services.md#privacy-safe-telemetry) -outlines metrics noising. +defined in [server_definition.h](../../components/telemetry/server_definition.h), see details in +[MonitoringDetails](monitoring_details.md). ### Logs diff --git a/docs/playbook/monitoring_details.md b/docs/playbook/monitoring_details.md new file mode 100644 index 00000000..413b069b --- /dev/null +++ b/docs/playbook/monitoring_details.md @@ -0,0 +1,139 @@ +# MonitoringDetails + +## Otel collector configuration + +All server metrics are collected by the +[OpenTelemetry collector](https://opentelemetry.io/docs/collector/). The otel collector will not run +in the trusted environment. On AWS, the collector runs on the same EC2 host outside the enclave in +which the TEE server is running. On GCP, the collector runs in an individual VM instance which is +deployed with KV server's out of box +[terraform configuration](../../production/terraform/gcp/services/metrics_collector/main.tf), and a +[collector end point](../../production/terraform/gcp/environments/demo/us-east1.tfvars.json#L7) is +set up during deployment for the KV server instances to connect to it. + +We provide default configurations for the otel collector: +[GCP](../../production/terraform/gcp/services/metrics_collector_autoscaling/collector_startup.sh) +and [AWS](../../production/packaging/aws/otel_collector/otel_collector_config.yaml), Adtech can +modify them without affecting the code running inside the TEE. The default setup enable the otel +collector to export the following server telemetry data to AWS and GCP: + +- **Metrics** + - AWS + - AWS Cloudwatch + - [Amazon Managed Prometheus](https://aws.amazon.com/prometheus/) + - GCP + - [Google Cloud Monitoring](https://cloud.google.com/monitoring?hl=en) +- **Logging** (Include + [consented request logging](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/debugging_protected_audience_api_services.md) + and non-request related logging) + - AWS: + [AWS Cloudwatch Logs](https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/WhatIsCloudWatchLogs.html) + - GCP: [Google Cloud Logging](https://cloud.google.com/logging?hl=en) +- **Traces** + - AWS: [AWS Xray](https://aws.amazon.com/xray/) + - GCP: [Google Cloud Trace](https://cloud.google.com/trace) + +## Server telemetry configuration + +The following are metrics related configurations and can be configured by server parameters. +Examples: +[AWS demo terraform setup](../../production/terraform/aws/environments/demo/us-east-1.tfvars.json), +[GCP demo terraform setup](../../production/terraform/gcp/environments/demo/us-east1.tfvars.json). + +#### metrics_export_interval_millis + +Export interval for metrics in milliseconds + +#### metrics_export_timeout_millis + +Export timeout for metrics in milliseconds + +#### telemetry_config + +Telemetry configuration to specify the mode for metrics collection. This mode is to control whether +metrics are collected raw or with noise. More details about telemetry configuration and how it works +with non_prod and prod build flavors are +[here](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/monitoring_protected_audience_api_services.md#configuring-the-metric-collection-mode) + +## Common attributes + +All the metrics tracked in the KV server are tagged with the following attributes as additional +metrics dimensions. + +| Attributes name | Description | +| ---------------------- | ------------------------------------------------------------------------------------------ | +| service.name | The name of the service that the metric is measured on. It is a constant value "kv-server" | +| service.version | The current version number of the server binary in use | +| host.arch | The CPU architecture the server's host system is running on | +| deployment.environment | The environment in which the server is deployed on | +| service.instance.id | The id of the machine instance the server is running on | +| shard_number | The shard number that the server instance belongs to | + +## Metrics definitions + +All metrics available to KV server are defined in +[server_definition.h](../../components/telemetry/server_definition.h). Here is a list of common +properties in the metric definition and their meanings. + +### Privacy definition + +| Property | Description | +| ---------------------- | ------------------------ | +| Privacy::kImpacting | Metric is noised with DP | +| Privacy::kNonImpacting | Metric is not noised | + +### Instrument type + +#### Instrument::kUpdownCounter + +Counter instrument implemented with +[Otel UpDownCounter](https://opentelemetry.io/docs/specs/otel/metrics/api/#updowncounter) + +#### Instrument::kPartitionedCounter + +Partitioned counter instrument implemented with +[Otel UpDownCounter](https://opentelemetry.io/docs/specs/otel/metrics/api/#updowncounter). This +instrument allows the metric to be partitioned by an additional dimension with values defined in the +partition list. + +#### Instrument::kHistogram + +Histogram instrument implemented with +[Otel Histogram](https://opentelemetry.io/docs/specs/otel/metrics/api/#histogram) + +#### Instrument::Gauge + +Gauge Instrument implemented with +[Otel Gauge](https://opentelemetry.io/docs/specs/otel/metrics/api/#gauge). Observable metrics such +as CPU and memory utilization are collected using this instrument + +## Differential privacy and noising + +Metrics in the request path after request is decrypted are considered privacy-sensitive thus will be +noised with differential privacy(DP). The level of noise depends on several factors, notably the +available privacy budget. The higher the budget, the less noise added to the metric. The +[total privacy budget](../../components/telemetry/server_definition.h#L96) for the server is +specified as a constant value to encompass all tracked privacy-sensitive metrics. The total privacy +budget is by default distributed evenly across all tracked privacy-sensitive metrics, we also allow +Adtech to customize the distribution across different metrics. And other supported noise related +configurations are described +[here](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/monitoring_protected_audience_api_services.md#noise-related-configuration). +In addition, Adtech can also +[configure a set of metrics to monitor](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/monitoring_protected_audience_api_services.md#configuring-collected-metrics), +the set of metrics will need to be selected from the defined KV metrics in +[server_definition.h](../../components/telemetry/server_definition.h). + +Metrics for +[consented requests](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/debugging_protected_audience_api_services.md) +will be unnoised and can be filtered by request's generation_id (a unique identifier passed by +client via LogContext proto in [V2 request API](../../public/query/v2/get_values_v2.proto), or a +default constant value "consented" if generation_id is not provided by client). + +More detailed information about differential privacy used in the Protected Audience TEE servers is +in this +[doc](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/monitoring_protected_audience_api_services.md#differential-privacy-and-noising). + +## Error codes + +All the server error codes are defined in [error_code.h](../../components/telemetry/error_code.h). +These error codes are used to report error metrics mentioned [here](total_error_rate_too_high.md) diff --git a/docs/playbook/read_latency_too_high.md b/docs/playbook/read_latency_too_high.md index 44b1cebc..848cf558 100644 --- a/docs/playbook/read_latency_too_high.md +++ b/docs/playbook/read_latency_too_high.md @@ -27,10 +27,26 @@ Check out the and see if it's outside of the expected range. [AWS Request Latency](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/aws/services/dashboard/main.tf#L110) + [GCP Request Latency](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/gcp/services/dashboards/main.tf#L148) ## Troubleshooting and solution +### Look at the subcomponents + +Server overhead without the UDF custom code execution. + +[TotalLatencyWithoutCustomCode](https://github.com/privacysandbox/protected-auction-key-value-service/blob/9e6e38979bb95822a0a4c4004bb455324ddc6c90/components/telemetry/server_definition.h#L494) + +Check out how long key lookups outside of UDF sandbox takes. + +[InternalGetKeyValuesLatencyInMicros](https://github.com/privacysandbox/protected-auction-key-value-service/blob/9a60180f9d6f52a4ca805e5463ecc9e5e80e88f9/components/telemetry/server_definition.h#L162) + +If you're using set intersections, you can check out how long that part of the request takes outside +of the UDF sandbox. + +[InternalRunQueryLatencyInMicros](https://github.com/privacysandbox/protected-auction-key-value-service/blob/9a60180f9d6f52a4ca805e5463ecc9e5e80e88f9/components/telemetry/server_definition.h#L156) + ### Not enough capacity Check out the CPU metrics for your machines. They are available in the dashboard linked diff --git a/docs/playbook/server_is_unhealthy.md b/docs/playbook/server_is_unhealthy.md index 50393925..0fd11d8f 100644 --- a/docs/playbook/server_is_unhealthy.md +++ b/docs/playbook/server_is_unhealthy.md @@ -2,24 +2,89 @@ ## Overview -This alert means that all kv servers (ec2/vm instances) are down and cannot serve traffic. +This alert means that kv server is unhelathy and cannot fully serve traffic. If you see this alert, there is a big chance you'll see most, if not all, other alerts firing too. +## Background + +AdTechs can define their uptime off the load balancer response codes. + +Specifically, they would get a distribution over a period of time K seconds of 2xx's - Y%, which +AdTechs can map to "up" and the rest to "down". + +KV server is a grpc server which returns grpc status codes. Those are mapped to http codes as +explained +[here](https://chromium.googlesource.com/external/github.com/grpc/grpc/+/refs/tags/v1.21.4-pre1/doc/statuscodes.md) + ## Recommended alert level -Fire an alert if over 90% of response did not return OK over 5 mins. Probe interval = 300 ms. +An AdTechs starts the downtime stopwatch once 10% error rate (non-2xx's responses) is hit. An alert +is fired after 1 minute of 10% error rate (non-2xx's responses). Once the error rate goes down below +10% the stopwatch is stopped. Once the error rate is below the 10% for 10 continuous seconds, the +alert auto resolves. ## Alert Severity Critical. -The service is down, and is fully unavailable to serve read requests. +The service is down, and is unavailable to serve read requests. This condition directly affects the uptime SLO. ## Verification +### Load balancer + +#### GCP + +[Load balancers](https://cloud.google.com/load-balancing/docs/https/https-logging-monitoring) come +with all sorts of monitoring for different response codes, request counts, latency, throughput etc +with easy ways of setting up alerts. + +#### AWS + +Similarly, AWS load balancers and target groups provide breakdowns by code and other metrics, on top +of which alerts can be defined. + +### Mesh / Traffic director (internal load balancer) + +#### GCP + +[Less automated](https://cloud.google.com/traffic-director/docs/control-plane-observability) than +the external load balancer. + +Operator needs to enable cloud logging for the backend service and then have a dashboard that +monitors the access logs for 500s, and other metrics. + +#### AWS + +This is being defined atm, but it's probably going to be similar to GCP. + +## Shard cluster uptime definition + +A shard can be down, but KV still serves requests. + +Autoscaling group level alerts can be set up, for example, on the minimum number of available +healthy machines. + +### GCP + +Alerts can be defined for instanceGroups. In addition, this resource has good monitoring dashboards +with different tiles, as well as error logs. + +### AWS + +Alerts can defined for autoscaling groups. A number of userful tiles for the default dashboard is +provided. Logs are available to see the actions that are taken by the autoscaling groups to detect +anomalies. + +#### Other signals for uptime + +AWS and GCP have vm/ec2 level checks. + +#### Checking a specific kv server instance + _Http_ ```sh @@ -119,9 +184,9 @@ verbosity and analyze the last few entries before the machine crashes. A common technique to address this bug is to revert to a previous more stable build. If you believe that this is a KV server issue, you should escalate using the info from -[here](index.md) +[here](index.md). ## Related Links -[Server initialization](../server_initialization.md) -- provides extra details on the server +[Server initialization](../server_initialization.md) provides extra details on the server initialization lifecycle and how it affects the health check. diff --git a/docs/profiling_the_server.md b/docs/profiling_the_server.md index 72f11e90..3d883e13 100644 --- a/docs/profiling_the_server.md +++ b/docs/profiling_the_server.md @@ -33,7 +33,7 @@ After the build completes successfully, the generated docker image tar is locate docker image using the following command: ```bash -docker load -i bazel-bin/production/packaging/local/data_server/server_profiling_docker_image.tar +docker load -i bazel-bin/production/packaging/local/data_server/server_profiling_docker_tarball/tarball.tar ``` # Some things to callout @@ -58,7 +58,7 @@ docker run \ --security-opt=seccomp=unconfined \ --privileged=true \ --cpus=4 \ - --entrypoint=/server \ + --entrypoint=/server/bin/init_server_with_profiler \ bazel/production/packaging/local/data_server:server_profiling_docker_image \ --port 50051 -delta_directory=/data --realtime_directory=/data/realtime \ --stderrthreshold=0 @@ -85,7 +85,7 @@ docker run -it --rm \ --cpus=4 \ --entrypoint=pprof \ bazel/production/packaging/local/data_server:server_profiling_docker_image \ - --http=":" /server /data/profiles/server.cpu.prof.0 + --http=":" /server/bin/server /data/profiles/server.cpu.prof.0 ``` This should open an interactive `pprof` web app showing the CPU profile information. The image below @@ -111,7 +111,7 @@ docker run \ --security-opt=seccomp=unconfined \ --privileged=true \ --cpus=4 \ - --entrypoint=/server \ + --entrypoint=/server/bin/init_server_with_profiler\ bazel/production/packaging/local/data_server:server_profiling_docker_image \ --port 50051 -delta_directory=/data --realtime_directory=/data/realtime \ --stderrthreshold=0 @@ -138,7 +138,7 @@ docker run -it --rm \ --cpus=4 \ --entrypoint=pprof \ bazel/production/packaging/local/data_server:server_profiling_docker_image \ - --http=":" /server /data/profiles/server.heap.hprof.0001.heap + --http=":" /server/bin/server /data/profiles/server.heap.hprof.0001.heap ``` This should open an interactive `pprof` web app showing the heap profile information. The image @@ -190,7 +190,7 @@ docker run \ --security-opt=seccomp=unconfined \ --privileged=true \ --cpus=4 \ - --entrypoint=/server \ + --entrypoint=/server/bin/server \ bazel/production/packaging/local/data_server:server_profiling_docker_image \ --port 50051 -delta_directory=/data --realtime_directory=/data/realtime \ --stderrthreshold=0 diff --git a/docs/protected_app_signals/ad_retrieval_overview.md b/docs/protected_app_signals/ad_retrieval_overview.md index 2dba3935..be61980b 100644 --- a/docs/protected_app_signals/ad_retrieval_overview.md +++ b/docs/protected_app_signals/ad_retrieval_overview.md @@ -90,8 +90,8 @@ This section describes how data is loaded into the server. #### Deployment The ad tech builds the service system by downloading the source code from the -[Github repository](https://github.com/privacysandbox/fledge-key-value-service) and following the -documentation in the repository. +[Github repository](https://github.com/privacysandbox/protected-auction-key-value-service) and +following the documentation in the repository. The ad tech deploys the system to a supported public cloud of their choice. At time of publication, the system will be available on GCP and AWS. @@ -246,7 +246,7 @@ below. - `runQuery(query_string)`: UDF can construct a query to perform set operations, such as union, intersection and difference. The query uses keys to represent the sets. The keys are defined as the sets are loaded into the dataset. See the exact grammar - [here](https://github.com/privacysandbox/fledge-key-value-service/blob/main/components/query/parser.yy). + [here](https://github.com/privacysandbox/protected-auction-key-value-service/blob/main/components/query/parser.yy). For more information, see [the UDF spec](https://github.com/privacysandbox/fledge-docs/blob/main/key_value_service_user_defined_functions.md). @@ -405,7 +405,8 @@ The service codebase has an end-to-end example in the context of information ret Documentation: - High level: -- Lower level: +- Lower level: + Getting started & Example: - + diff --git a/docs/sharding/data_locality.md b/docs/sharding/data_locality.md index 730e527b..13dcd9e2 100644 --- a/docs/sharding/data_locality.md +++ b/docs/sharding/data_locality.md @@ -51,9 +51,9 @@ obfuscate key names. # Data loading and serving flow with the data locality feature A regex can be set by an AdTech on startup as a terraform -[parameter](https://github.com/privacysandbox/fledge-key-value-service/blob/b047d89ebfa6312ec8d1de275da69fd60d24eba3/production/terraform/aws/environments/kv_server_variables.tf#L254) +[parameter](https://github.com/privacysandbox/protected-auction-key-value-service/blob/b047d89ebfa6312ec8d1de275da69fd60d24eba3/production/terraform/aws/environments/kv_server_variables.tf#L254) that will be applied to all keys. This -[parameter](https://github.com/privacysandbox/fledge-key-value-service/blob/b047d89ebfa6312ec8d1de275da69fd60d24eba3/production/terraform/aws/environments/kv_server_variables.tf#L248) +[parameter](https://github.com/privacysandbox/protected-auction-key-value-service/blob/b047d89ebfa6312ec8d1de275da69fd60d24eba3/production/terraform/aws/environments/kv_server_variables.tf#L248) should be set to `true`. The regex is global in the sense that there is only 1 regex and it is known to all shards. diff --git a/docs/sharding/sharding.md b/docs/sharding/sharding.md index 1a73dc63..5881033e 100644 --- a/docs/sharding/sharding.md +++ b/docs/sharding/sharding.md @@ -34,7 +34,7 @@ sharding. Naturally, if the size of your data set is bigger than the biggest mac capacity, your only option is to use sharding. Additionally, sometimes it is cheaper to use multiple smaller machines, as compared to a few high capacity ones. The memory capacity of a single machine depends on what you specify in the terraform variables -[file](https://github.com/privacysandbox/fledge-key-value-service/blob/04d3e75794fadc14c17b960a9cd02088216aa138/production/terraform/aws/environments/demo/us-east-1.tfvars.json#L16). +[file](https://github.com/privacysandbox/protected-auction-key-value-service/blob/04d3e75794fadc14c17b960a9cd02088216aa138/production/terraform/aws/environments/demo/us-east-1.tfvars.json#L16). E.g. the demo value is `m5.xlarge` for AWS, which [is](https://aws.amazon.com/ec2/instance-types/m5/) 16 GB. Cloud providers give an ability to @@ -82,15 +82,15 @@ incurred. ## Sharding function -[Sharding function](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/public/sharding/sharding_function.h#L32) +[Sharding function](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/public/sharding/sharding_function.h#L32) takes a key and maps it to a shard number. It is a -[SHA256](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/public/sharding/sharding_function.h#L35C38-L35C38) +[SHA256](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/public/sharding/sharding_function.h#L35C38-L35C38) mod `number of shards`. ## Write path Data that doesn't belong to a given shard is dropped if it makes it to the server. There is a -[metric](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/data_server/data_loading/data_orchestrator.cc#L101C54-L101C54) +[metric](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/data_server/data_loading/data_orchestrator.cc#L101C54-L101C54) that an AdTech can track. Ideally, this metric should be 0. Sharding helps limit the amount of data each server instance has to process on the write path. This @@ -106,9 +106,9 @@ and load it. ![Realtime sequence](../assets/grouping_records.png) A snapshot/delta file indicates its shard number through this -[field](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/public/data_loading/riegeli_metadata.proto#L40). +[field](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/public/data_loading/riegeli_metadata.proto#L40). This -[tool](https://github.com/privacysandbox/fledge-key-value-service/blob/252d361c7a3b291f50ffbf36d86fc4405af6a147/tools/serving_data_generator/test_serving_data_generator.cc#L36-L37) +[tool](https://github.com/privacysandbox/protected-auction-key-value-service/blob/252d361c7a3b291f50ffbf36d86fc4405af6a147/tools/serving_data_generator/test_serving_data_generator.cc#L36-L37) shows how that field can be set. If it is not set, the whole file will be read by all machines. However, only records that belong to that particular shard will be loaded in memory. If the shard number in the file does not match the server's shard number, the server can skip the file without @@ -118,8 +118,8 @@ reading the records. A message published to SNS, for AWS, or PubSub, for GCP _must_ be tagged with a shard number. SNS/PubSub will fan out such messages only -([AWS](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/data/common/msg_svc_aws.cc#L174), -[GCP](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/data/common/msg_svc_gcp.cc#L86)) +([AWS](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/data/common/msg_svc_aws.cc#L174), +[GCP](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/data/common/msg_svc_gcp.cc#L86)) to the machines that are associated with that shard number. This increases the throughput for any given machine, as it has to process fewer messages and only relevant ones. @@ -147,9 +147,9 @@ from relevant shards and then combines them together and returns the result to t some keys may be looked up in memory from that server. If one of the downstream requests fails, a corresponding per key -[status](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/internal_server/sharded_lookup.cc#L85) +[status](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/internal_server/sharded_lookup.cc#L85) is set, which is different from `Not found` -[status](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/internal_server/sharded_lookup.cc#L72C24-L72C24). +[status](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/internal_server/sharded_lookup.cc#L72C24-L72C24). Similarly, if a set query needs to be run, it will be run after corresponding sets have been collected. However, if one of the downstream requests fails, then the whole query is failed. @@ -158,10 +158,10 @@ Each machine knows its shard number. A machine isn't ready to serve traffic unti at least one active replica for each shard cluster (from 1 to `num_shards`). That mapping from a shard cluster number to internal ip addresses, preserved in memory for performance reasons, is updated every -[`update_interval_millis`](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/sharding/cluster_mappings_manager.h#L48C30-L48C30). +[`update_interval_millis`](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/sharding/cluster_mappings_manager.h#L48C30-L48C30). When a request needs to be made to a shard cluster with K replicas, a machine is chosen -[randomly](https://github.com/privacysandbox/fledge-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/sharding/shard_manager.cc#L91C45-L91C45) +[randomly](https://github.com/privacysandbox/protected-auction-key-value-service/blob/31e6d0e3f173086214c068b62d6b95935063fd6b/components/sharding/shard_manager.cc#L91C45-L91C45) from the pool. ## Privacy @@ -209,5 +209,5 @@ replicas. ## Work in progress -[Logical Sharding Config](https://github.com/privacysandbox/fledge-key-value-service/blob/0e9b454825d641786255f11df4a2b62eee893a98/public/data_loading/riegeli_metadata.proto#L44) +[Logical Sharding Config](https://github.com/privacysandbox/protected-auction-key-value-service/blob/0e9b454825d641786255f11df4a2b62eee893a98/public/data_loading/riegeli_metadata.proto#L44) is a work in progress and you should not be using it at the moment. diff --git a/docs/testing_the_query_protocol.md b/docs/testing_the_query_protocol.md index 6b06b8a7..5e259c1e 100644 --- a/docs/testing_the_query_protocol.md +++ b/docs/testing_the_query_protocol.md @@ -73,7 +73,8 @@ containers locally. In addition, run the helper server alongside: `cd` into the root of the repo. ```sh -bazel run -c opt //infrastructure/testing:protocol_testing_helper_server +builders/tools/bazel-debian build //infrastructure/testing:protocol_testing_helper_server +bazel-bin/infrastructure/testing/protocol_testing_helper_server ``` For more information on how to test the query protocol with the helper server, see also the @@ -99,68 +100,6 @@ grpcurl --protoset dist/query_api_descriptor_set.pb -d '{"raw_body": {"data": "' For gRPC, use base64 --decode to convert the output to plaintext. -## Binary HTTP query ("BinaryHTTPGetValues") - -First convert the request body into Binary HTTP request: - -```sh -echo -n '{"is_request": true, "body": "'$(echo -n $BODY|base64 -w 0)'"}'|grpcurl -plaintext -d @ localhost:50050 kv_server.ProtocolTestingHelper/BHTTPEncapsulate -``` - -The result should be something like: - -```json -{ - "bhttp_message": "AAAAAAAAQZZ7ICJjb250ZXh0IjogeyAic3Via2V5IjogImV4YW1wbGUuY29tIiB9LCAicGFydGl0aW9ucyI6IFsgeyAiaWQiOiAwLCAiY29tcHJlc3Npb25Hcm91cCI6IDAsICJrZXlHcm91cHMiOiBbIHsgInRhZ3MiOiBbICJzdHJ1Y3R1cmVkIiwgImdyb3VwTmFtZXMiIF0sICJrZXlMaXN0IjogWyAiaGkiIF0gfSwgeyAidGFncyI6IFsgImN1c3RvbSIsICJrZXlzIiBdLCAia2V5TGlzdCI6IFsgImhpIiBdIH0gXSB9LCB7ICJpZCI6IDEsICJjb21wcmVzc2lvbkdyb3VwIjogMCwgImtleUdyb3VwcyI6IFsgeyAidGFncyI6IFsgInN0cnVjdHVyZWQiLCAiZ3JvdXBOYW1lcyIgXSwgImtleUxpc3QiOiBbICJoaSIgXSB9LCB7ICJ0YWdzIjogWyAiY3VzdG9tIiwgImtleXMiIF0sICJrZXlMaXN0IjogWyAiaGkiIF0gfSBdIH0gXSB9" -} -``` - -Assign the bhttp_message output to an environment variable: - -```sh -BHTTP_REQ=AAAAAAAAQZZ7ICJjb250ZXh0IjogeyAic3Via2V5IjogImV4YW1wbGUuY29tIiB9LCAicGFydGl0aW9ucyI6IFsgeyAiaWQiOiAwLCAiY29tcHJlc3Npb25Hcm91cCI6IDAsICJrZXlHcm91cHMiOiBbIHsgInRhZ3MiOiBbICJzdHJ1Y3R1cmVkIiwgImdyb3VwTmFtZXMiIF0sICJrZXlMaXN0IjogWyAiaGkiIF0gfSwgeyAidGFncyI6IFsgImN1c3RvbSIsICJrZXlzIiBdLCAia2V5TGlzdCI6IFsgImhpIiBdIH0gXSB9LCB7ICJpZCI6IDEsICJjb21wcmVzc2lvbkdyb3VwIjogMCwgImtleUdyb3VwcyI6IFsgeyAidGFncyI6IFsgInN0cnVjdHVyZWQiLCAiZ3JvdXBOYW1lcyIgXSwgImtleUxpc3QiOiBbICJoaSIgXSB9LCB7ICJ0YWdzIjogWyAiY3VzdG9tIiwgImtleXMiIF0sICJrZXlMaXN0IjogWyAiaGkiIF0gfSBdIH0gXSB9 -``` - -Send the request to the k/v server. - -HTTP: - -```sh -BHTTP_RES=$(curl -svX POST --data-binary @<(echo -n $BHTTP_REQ|base64 --decode) http://localhost:51052/v2/bhttp_getvalues|base64 -w 0);echo $BHTTP_RES -``` - -Or gRPC: - -```sh -grpcurl --protoset dist/query_api_descriptor_set.pb --protoset dist/query_api_descriptor_set.pb -d '{"raw_body": {"data": "'"$BHTTP_REQ"'"}}' -plaintext localhost:50051 kv_server.v2.KeyValueService/BinaryHttpGetValues -``` - -The result should look similar to: - -```json -{ - "data": "AUDIAEFcAAABWHsicGFydGl0aW9ucyI6W3siaWQiOjAsImtleUdyb3VwT3V0cHV0cyI6W3sia2V5VmFsdWVzIjp7ImhpIjp7InZhbHVlIjoiSGVsbG8sIHdvcmxkISBJZiB5b3UgYXJlIHNlZWluZyB0aGlzLCBpdCBtZWFucyB5b3UgY2FuIHF1ZXJ5IG1lIHN1Y2Nlc3NmdWxseSJ9fSwidGFncyI6WyJjdXN0b20iLCJrZXlzIl19XX0seyJpZCI6MSwia2V5R3JvdXBPdXRwdXRzIjpbeyJrZXlWYWx1ZXMiOnsiaGkiOnsidmFsdWUiOiJIZWxsbywgd29ybGQhIElmIHlvdSBhcmUgc2VlaW5nIHRoaXMsIGl0IG1lYW5zIHlvdSBjYW4gcXVlcnkgbWUgc3VjY2Vzc2Z1bGx5In19LCJ0YWdzIjpbImN1c3RvbSIsImtleXMiXX1dfV19" -} -``` - -Assign the data to `BHTTP_RES` and decode the BHTTP layer: - -```sh -echo -n '{"is_request": false, "bhttp_message": "'$BHTTP_RES'"}'|grpcurl -plaintext -d @ localhost:50050 kv_server.ProtocolTestingHelper/BHTTPDecapsulate -``` - -Result: - -```json -{ - "body": "AAABWHsicGFydGl0aW9ucyI6W3siaWQiOjAsImtleUdyb3VwT3V0cHV0cyI6W3sia2V5VmFsdWVzIjp7ImhpIjp7InZhbHVlIjoiSGVsbG8sIHdvcmxkISBJZiB5b3UgYXJlIHNlZWluZyB0aGlzLCBpdCBtZWFucyB5b3UgY2FuIHF1ZXJ5IG1lIHN1Y2Nlc3NmdWxseSJ9fSwidGFncyI6WyJjdXN0b20iLCJrZXlzIl19XX0seyJpZCI6MSwia2V5R3JvdXBPdXRwdXRzIjpbeyJrZXlWYWx1ZXMiOnsiaGkiOnsidmFsdWUiOiJIZWxsbywgd29ybGQhIElmIHlvdSBhcmUgc2VlaW5nIHRoaXMsIGl0IG1lYW5zIHlvdSBjYW4gcXVlcnkgbWUgc3VjY2Vzc2Z1bGx5In19LCJ0YWdzIjpbImN1c3RvbSIsImtleXMiXX1dfV19" -} -``` - -The returned data is base64 encoded. Decode the content with base64 --decode and you should see the -expected response. Note that the first 32 bits of the response stores the size of the response. That -is part of the compression layer, which is not turned on in this exercise. - ## Oblivious HTTP query ("ObliviousGetValues") Oblivious HTTP request is encrypted with a public key as one of the initial input. The testing @@ -173,7 +112,7 @@ grpcurl -plaintext localhost:50050 kv_server.ProtocolTestingHelper/GetTestConfig To build the request: ```sh -echo -n '{"body": "'$BHTTP_REQ'", "key_id": 1, "public_key": "MeHwWnQBAhFSIOmvkY9zhnSuyV9U224E63Baro55gVU="}'|grpcurl -plaintext -d @ localhost:50050 kv_server.ProtocolTestingHelper/OHTTPEncapsulate +echo -n '{"body": "'$(echo -n $BODY|base64 -w 0)'", "key_id": 1, "public_key": "87ey8XZPXAd+/+ytKv2GFUWW5j9zdepSJ2G4gebDwyM="}'|grpcurl -plaintext -d @ localhost:50050 kv_server.ProtocolTestingHelper/OHTTPEncapsulate ``` The output is similar to: @@ -185,9 +124,17 @@ The output is similar to: } ``` +Set the `OTTP_REQ` and `CONTEXT_TOKEN` env vars, for example: + +```sh +OHTTP_REQ="AQAgAAEAATVl8Lz2p4B27AbFoIT+R2H7jRCp+Q/c87qruxKbXLRnNdMHGZjJLCaNSs9caPvgHpo4uYB4g9fdL/a+/mJglyME1B7ngo5mJX7puHHl8aoEWeIugq/pJjvrGI38P4z3gQlb4mBinGPhqOTdH+xvfMss5b44PwqacbjZYJ3eb1hDjXsgmsTGa0ZzlFUymqI/9P7ZsdQAwtD9cxuywZsKF9A1aRhwRuA1Y/9iMCmpJlX9SmGeN8FptL4VnoAo4eJwPSS6Z/OHPsfP/d6CQZH4hGudjGgtbzzPItD/drK8MMiCKq3PPffCgcDXP/0u9SWXOim3/gzMDsU/uh47JhbYhjhOQ4DJAaxcG/DQqRqLKd1Z4sHechv9xdoJJbV7laPoxyEFWMiWwSTHL+kZVRBc0uQSWBRgyDxxknjl71g/3SeLOjz9ovC4DOouLFAWbWMpgxRHJRA4GsevdBq3Od3I7AEvtJ2AfIMpo3tsch7iJzcaORV0Ml/TgASSdliaThYj2e/G38GQYdzHQfHmcB6r+2M0DC/bEN29JEJayWIfl7DUOs1U1GLLh0+y7+mH85zFhu4lb4lX0PtzcN/TrNOtB19d/YQ6Mv2n+Dbea6S9hg==" + +CONTEXT_TOKEN="1675366132" +``` + The context_token is used as one input to decode the response later. -Call the k/v server with the request (stored as OHTTP_REQ env var): +Call the k/v server with the request (stored as `OHTTP_REQ` env var): HTTP: @@ -205,7 +152,7 @@ Result is similar to: ```json { - "contentType": "message/ohttp-res", + "contentType": "message/ad-auction-trusted-signals-response", "data": "TFZDIlvBIBUfq4fzHvWwa58pjRrMmyE8mkfQshA4N9SDD6Ts28KigYIU3OcV30/+ZrCmStdCg/BcgY59Rod6TCLkSfI32Gk25oY+9I+vVxpj7FG67vWoQdbee7FUvn7TxsrdCSd9ulwpixbE7KtSw7MmX6Y0y0I7xHkx9N7zKSu/cmabg9ZgdQFipDUdBaBNPScNOrwh6b6nZhWHbW/oUWCFMHtDa9sLVP5cNi9oMjb7AFdK5NKeq1qiCuhKTi3RZ7bKNbk98JnmyGI6OwAs2631Gl+S0npPR/KDblWQJ2ZCI0maek0zIVPhWLs2/kA+etwOCmRzB7syxDwwT3MRDo6wWJdcKKHC8Y48XgKEv5NvTLC39tsEniSvPdymevNfG2PTLJDKaAocb/WVLj5wm08UNjAv+Pxu8a+wRDxP+kxm+TnKMCPapnRcplU4D3+VH4YdhQbF2V1kwsyfBQxQMr4XX1w6n87ah8qUBucjveKPSa6kqKVSk2w261McQobJW54=" } ``` @@ -213,23 +160,16 @@ Result is similar to: Decode the response: ```sh -echo -n '{"context_token": 1675366132, "ohttp_response": "'$OHTTP_RES'"}'|grpcurl -plaintext -d @ localhost:50050 kv_server.ProtocolTestingHelper/OHTTPDecapsulate +echo -n '{"context_token": "'${CONTEXT_TOKEN}'", "ohttp_response": "'$OHTTP_RES'"}'|grpcurl -plaintext -d @ localhost:50050 kv_server.ProtocolTestingHelper/OHTTPDecapsulate ``` Result example: ```json { - "body": "AUDIAEFcAAABWHsicGFydGl0aW9ucyI6W3siaWQiOjAsImtleUdyb3VwT3V0cHV0cyI6W3sia2V5VmFsdWVzIjp7ImhpIjp7InZhbHVlIjoiSGVsbG8sIHdvcmxkISBJZiB5b3UgYXJlIHNlZWluZyB0aGlzLCBpdCBtZWFucyB5b3UgY2FuIHF1ZXJ5IG1lIHN1Y2Nlc3NmdWxseSJ9fSwidGFncyI6WyJjdXN0b20iLCJrZXlzIl19XX0seyJpZCI6MSwia2V5R3JvdXBPdXRwdXRzIjpbeyJrZXlWYWx1ZXMiOnsiaGkiOnsidmFsdWUiOiJIZWxsbywgd29ybGQhIElmIHlvdSBhcmUgc2VlaW5nIHRoaXMsIGl0IG1lYW5zIHlvdSBjYW4gcXVlcnkgbWUgc3VjY2Vzc2Z1bGx5In19LCJ0YWdzIjpbImN1c3RvbSIsImtleXMiXX1dfV19" + "body": "eyJjb21wcmVzc2lvbkdyb3VwcyI6W3siY29tcHJlc3Npb25Hcm91cElkIjowLCJjb250ZW50IjoiW3tcImlkXCI6MCxcImtleUdyb3VwT3V0cHV0c1wiOlt7XCJrZXlWYWx1ZXNcIjp7XCJoaVwiOntcInZhbHVlXCI6XCJIZWxsbywgd29ybGQhIElmIHlvdSBhcmUgc2VlaW5nIHRoaXMsIGl0IG1lYW5zIHlvdSBjYW4gcXVlcnkgbWUgc3VjY2Vzc2Z1bGx5XCJ9fSxcInRhZ3NcIjpbXCJzdHJ1Y3R1cmVkXCIsXCJncm91cE5hbWVzXCJdfSx7XCJrZXlWYWx1ZXNcIjp7XCJoaVwiOntcInZhbHVlXCI6XCJIZWxsbywgd29ybGQhIElmIHlvdSBhcmUgc2VlaW5nIHRoaXMsIGl0IG1lYW5zIHlvdSBjYW4gcXVlcnkgbWUgc3VjY2Vzc2Z1bGx5XCJ9fSxcInRhZ3NcIjpbXCJjdXN0b21cIixcImtleXNcIl19XX0se1wiaWRcIjowLFwia2V5R3JvdXBPdXRwdXRzXCI6W3tcImtleVZhbHVlc1wiOntcImhpXCI6e1widmFsdWVcIjpcIkhlbGxvLCB3b3JsZCEgSWYgeW91IGFyZSBzZWVpbmcgdGhpcywgaXQgbWVhbnMgeW91IGNhbiBxdWVyeSBtZSBzdWNjZXNzZnVsbHlcIn19LFwidGFnc1wiOltcInN0cnVjdHVyZWRcIixcImdyb3VwTmFtZXNcIl19LHtcImtleVZhbHVlc1wiOntcImhpXCI6e1widmFsdWVcIjpcIkhlbGxvLCB3b3JsZCEgSWYgeW91IGFyZSBzZWVpbmcgdGhpcywgaXQgbWVhbnMgeW91IGNhbiBxdWVyeSBtZSBzdWNjZXNzZnVsbHlcIn19LFwidGFnc1wiOltcImN1c3RvbVwiLFwia2V5c1wiXX1dfV0ifV19" } ``` -The body here is in Binary HTTP format. Decapsulate it the same way as the Binary HTTP response: - -```sh -echo -n '{"is_request": false, "bhttp_message": "'$BHTTP_RES'"}'|grpcurl -plaintext -d @ localhost:50050 kv_server.ProtocolTestingHelper/BHTTPDecapsulate -``` - The returned data is base64 encoded. Decode the content with base64 --decode and you should see the -expected response. Note that the first 32 bits of the response stores the size of the response. That -is part of the compression layer, which is not turned on in this exercise. +expected response. diff --git a/docs/working_with_set_queries.md b/docs/working_with_set_queries.md index efcff5af..79717327 100644 --- a/docs/working_with_set_queries.md +++ b/docs/working_with_set_queries.md @@ -1,9 +1,9 @@ # Set query language -The K/V server supports a simple set query language that can be invoked using two UDF reap APIs (1) -`runQuery("query")` and `runSetQueryInt("query")`. The query language implements threee set -operations `union` denoted as `|`, `difference` denoted as `-` and `intersection` denoted as `&` and -queries use sets of strings and numbers as input. +The K/V server supports a simple set query language that can be invoked using three UDF reap APIs +(1) `runQuery("query")`, `runSetQueryUInt32("query")` and `runSetQueryUInt64("query")`. The query +language implements threee set operations `union` denoted as `|`, `difference` denoted as `-` and +`intersection` denoted as `&` and queries use sets of strings and numbers as input. As an example, suppose that we have indexed two sets of 32 bit integer ad ids targeting `games` and `news` to the K/V server memory store. We can find the sets of ad ids targeting both `games` and @@ -25,10 +25,12 @@ The set query language [grammar](/components/query/parser.yy) supports three bin ## Query operands -Queries support two types of sets as operands: +Queries support three types of sets as operands: -- set of numbers - only 32 bit unsigned numbers are supported. The data loading type is - `UInt32Set` in [data_loading.fbs](/public/data_loading/data_loading.fbs). +- set of `uint32` numbers - For working with uint32 numbers. The data loading type is `UInt32Set` + in [data_loading.fbs](/public/data_loading/data_loading.fbs). +- set of `uint64` numbers - For working with uint64 (supports also uint32) numbers. The data + loading type is `UInt64Set` in [data_loading.fbs](/public/data_loading/data_loading.fbs). - set of strings - arbitrary length UTF-8 or 7-bit ASCII strings are supported. The data loading type is `StringSet` in [data_loading.fbs](/public/data_loading/data_loading.fbs). @@ -66,17 +68,27 @@ The K/V server supports two APIs for running queries inside UDFs. - `runQuery("query")` - Takes a valid `query` as input and returns a set of strings. - See [run_query_udf.js](/tools/udf/sample_udf/run_query_udf.js) for a JavaScript example. -- `runSetQueryInt("query")` +- `runSetQueryUInt32("query")` - Takes a valid `query` as input and returns a byte array of serialized 32 bit usinged integers. - - See [run_set_query_int_udf.js](/tools/udf/sample_udf/run_set_query_int_udf.js) for a + - See [run_set_query_uint32_udf.js](/tools/udf/sample_udf/run_set_query_uint32_udf.js) for a + JavaScript example. +- `runSetQueryUInt64("query")` + - Takes a valid `query` as input and returns a byte array of serialized 64 bit usinged + integers. + - Note that `uint32` numbers can be loaded using `UInt64Set` and then queries can be evaluated + using `runSetQueryUInt64("query")`. However, this approach can be significantly less + efficient than using `UInt32Set` and `runSetQueryUInt32("query")`. + - See [run_set_query_uint64_udf.js](/tools/udf/sample_udf/run_set_query_uint64_udf.js) for a JavaScript example. ## Which API should I use, `runQuery` vs. `runSetQueryInt`? -`runSetQueryInt` implements a much more perfomant version of query evaluation based on bitmaps. So -if your sets can be represented as 32 bit unsigned integer sets and are relatively dense compared to -the range of numbers, then use `runSetQueryInt`. We also provide a benchmarking tool +`runSetQueryUInt32` and `runSetQueryUInt64` implements a much more perfomant version of query +evaluation based on bitmaps. So (1) if your sets can be represented as 32 bit unsigned integer sets +and are relatively dense compared to the range of numbers, then use `runSetQueryUInt32` and (2) if +your sets must be represented as 64 bit unsigned integer sets and are relatively dense compared to +the range of numbers, then use `runSetQueryUInt64`. We also provide a benchmarking tool [query_evaluation_benchmark](/components/tools/benchmarks/query_evaluation_benchmark.cc) that can be used to determine whether using integer sets vs. string sets would be much more performant for a given scenario. For example: @@ -94,8 +106,7 @@ given scenario. For example: benchamrks evaluating set operations and the query `(A - B) | (C & D)` using sets with 500,000 elements randomly selected from the range `[1,000,000 - 2,000,000]`. On my machine, the benchmark -produces the following output which shows superior performance for integer sets (Note that output -for integer sets is denoted using `.*kv_server::RoaringBitSet.*`): +produces the following output which shows superior performance for integer sets: ```bash Run on (128 X 2450 MHz CPU s) @@ -104,21 +115,25 @@ CPU Caches: L1 Instruction 32 KiB (x64) L2 Unified 512 KiB (x64) L3 Unified 32768 KiB (x8) -Load Average: 4.73, 5.96, 6.31 ---------------------------------------------------------------------------------------------------------------- -Benchmark Time CPU Iterations Ops/s ---------------------------------------------------------------------------------------------------------------- -kv_server::BM_SetUnion 15.5 us 15.5 us 45888 64.4784k/s -kv_server::BM_SetUnion 60104 us 60104 us 10 16.6378/s -kv_server::BM_SetDifference 15.3 us 15.3 us 43768 65.4571k/s -kv_server::BM_SetDifference 21194 us 21192 us 32 47.1883/s -kv_server::BM_SetIntersection 16.6 us 16.6 us 43286 60.1952k/s -kv_server::BM_SetIntersection 20601 us 20599 us 32 48.5452/s ------------------------------------------------------------------------------------------------------------------ -Benchmark Time CPU Iterations QueryEvals/s ------------------------------------------------------------------------------------------------------------------ -kv_server::BM_AstTreeEvaluation 43.4 us 43.4 us 15977 23.0672k/s -kv_server::BM_AstTreeEvaluation 60834 us 60826 us 10 16.4404/s +Load Average: 30.38, 15.45, 9.58 +----------------------------------------------------------------------------------------------------------- +Benchmark Time CPU Iterations Ops/s +----------------------------------------------------------------------------------------------------------- +kv_server::BM_SetUnion 15.2 us 15.2 us 45767 65.834k/s +kv_server::BM_SetUnion 22.6 us 22.6 us 30639 44.2866k/s +kv_server::BM_SetUnion 51034 us 51031 us 12 19.596/s +kv_server::BM_SetDifference 16.3 us 16.3 us 44607 61.4837k/s +kv_server::BM_SetDifference 22.9 us 22.9 us 30509 43.6579k/s +kv_server::BM_SetDifference 21059 us 21057 us 33 47.4906/s +kv_server::BM_SetIntersection 16.4 us 16.4 us 42387 60.9809k/s +kv_server::BM_SetIntersection 22.0 us 22.0 us 31571 45.3572k/s +kv_server::BM_SetIntersection 21156 us 21156 us 33 47.2686/s +------------------------------------------------------------------------------------------------------------- +Benchmark Time CPU Iterations QueryEvals/s +------------------------------------------------------------------------------------------------------------- +kv_server::BM_AstTreeEvaluation 43.8 us 43.8 us 15727 22.8436k/s +kv_server::BM_AstTreeEvaluation 63.0 us 63.0 us 11544 15.8808k/s +kv_server::BM_AstTreeEvaluation 62148 us 62135 us 11 16.0939/s ``` # Loading sets into the K/V server diff --git a/getting_started/quick_start.md b/getting_started/quick_start.md index 1347b16d..264a19b1 100644 --- a/getting_started/quick_start.md +++ b/getting_started/quick_start.md @@ -11,13 +11,15 @@ Before starting the build process, install [Docker](https://docs.docker.com/engi [BuildKit](https://docs.docker.com/build/buildkit/). If you run into any Docker access errors, follow the instructions for [setting up sudoless Docker](https://docs.docker.com/engine/install/linux-postinstall/#manage-docker-as-a-non-root-user). +You will also need the +[Docker Compose v2](https://github.com/docker/compose#where-to-get-docker-compose) plugin. ## Clone the github repo Using Git, clone the repository into a folder: ```sh -git clone https://github.com/privacysandbox/fledge-key-value-service.git +git clone https://github.com/privacysandbox/protected-auction-key-value-service.git ``` The default branch contains the latest stable release. To get the latest code, switch to the `main` @@ -58,8 +60,8 @@ cp bazel-bin/components/data_server/server/server dist/server ## Run the server ```sh -docker-compose -f getting_started/quick_start_assets/docker-compose.yaml build kvserver -docker-compose -f getting_started/quick_start_assets/docker-compose.yaml run --rm kvserver +docker compose -f getting_started/quick_start_assets/docker-compose.yaml build kvserver +docker compose -f getting_started/quick_start_assets/docker-compose.yaml run --rm kvserver ``` In a separate terminal, at the repo root, run @@ -102,7 +104,7 @@ chmod 744 dist/query_api_descriptor_set.pb ```sh chmod 444 components/envoy_proxy/envoy.yaml -docker-compose -f getting_started/quick_start_assets/docker-compose.yaml up +docker compose -f getting_started/quick_start_assets/docker-compose.yaml up ``` In a separate terminal, run: @@ -264,7 +266,7 @@ function getKeyGroupOutputs(hostname, udf_arguments) { } function HandleRequest(executionMetadata, ...udf_arguments) { - console.log(JSON.stringify(executionMetadata)); + logMessage(JSON.stringify(executionMetadata)); const keyGroupOutputs = getKeyGroupOutputs(executionMetadata.requestMetadata.hostname, udf_arguments); return {keyGroupOutputs, udfOutputApiVersion: 1}; } diff --git a/infrastructure/testing/BUILD.bazel b/infrastructure/testing/BUILD.bazel index 76ea12e2..8f1d77af 100644 --- a/infrastructure/testing/BUILD.bazel +++ b/infrastructure/testing/BUILD.bazel @@ -21,14 +21,16 @@ cc_binary( srcs = ["protocol_testing_helper_server.cc"], deps = [ ":protocol_testing_helper_server_cc_grpc", + "//components/data/converters:cbor_converter", "//public:constants", - "@com_github_google_quiche//quiche:binary_http_unstable_api", "@com_github_google_quiche//quiche:oblivious_http_unstable_api", "@com_github_grpc_grpc//:grpc++", "@com_github_grpc_grpc//:grpc++_reflection", # for grpc_cli "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/communication:encoding_utils", + "@google_privacysandbox_servers_common//src/communication:framing_utils", ], ) diff --git a/infrastructure/testing/protocol_testing_helper_server.cc b/infrastructure/testing/protocol_testing_helper_server.cc index ec136142..de66025a 100644 --- a/infrastructure/testing/protocol_testing_helper_server.cc +++ b/infrastructure/testing/protocol_testing_helper_server.cc @@ -20,22 +20,48 @@ #include "absl/status/statusor.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" +#include "components/data/converters/cbor_converter.h" #include "grpcpp/ext/proto_server_reflection_plugin.h" #include "grpcpp/grpcpp.h" #include "infrastructure/testing/protocol_testing_helper_server.grpc.pb.h" #include "public/constants.h" -#include "quiche/binary_http/binary_http_message.h" #include "quiche/oblivious_http/oblivious_http_client.h" +#include "src/communication/encoding_utils.h" +#include "src/communication/framing_utils.h" ABSL_FLAG(uint16_t, port, 50050, "Port the server is listening on. Defaults to 50050."); namespace kv_server { +namespace { using grpc::Server; using grpc::ServerBuilder; using grpc::ServerContext; +absl::Status CborDecodeToGetValuesResponseJsonString( + std::string_view cbor_raw, std::string& json_response) { + nlohmann::json json_from_cbor = nlohmann::json::from_cbor( + cbor_raw, /*strict=*/true, /*allow_exceptions=*/false); + if (json_from_cbor.is_discarded()) { + return absl::InternalError("Failed to convert raw CBOR buffer to JSON"); + } + for (auto& json_compression_group : json_from_cbor["compressionGroups"]) { + // Convert CBOR serialized list to a JSON list of partition outputs + auto json_partition_outputs = + GetPartitionOutputsInJson(json_compression_group["content"]); + if (!json_partition_outputs.ok()) { + LOG(ERROR) << json_partition_outputs.status(); + return json_partition_outputs.status(); + } + LOG(INFO) << "json_partition_outputs" << json_partition_outputs->dump(); + json_compression_group["content"] = json_partition_outputs->dump(); + } + json_response = json_from_cbor.dump(); + return absl::OkStatus(); +} +} // namespace + class ProtocolTestingHelperServiceImpl final : public ProtocolTestingHelper::Service { grpc::Status GetTestConfig(ServerContext* context, @@ -45,47 +71,6 @@ class ProtocolTestingHelperServiceImpl final return grpc::Status::OK; } - grpc::Status BHTTPEncapsulate(ServerContext* context, - const BHTTPEncapsulateRequest* request, - BHTTPEncapsulateResponse* response) override { - const auto process = [&request, &response](auto&& bhttp_layer) { - bhttp_layer.set_body(request->body()); - auto maybe_serialized = bhttp_layer.Serialize(); - if (!maybe_serialized.ok()) { - return grpc::Status(grpc::INTERNAL, - std::string(maybe_serialized.status().message())); - } - response->set_bhttp_message(*maybe_serialized); - return grpc::Status::OK; - }; - if (request->is_request()) { - return process(quiche::BinaryHttpRequest({})); - } - return process(quiche::BinaryHttpResponse(200)); - } - - grpc::Status BHTTPDecapsulate(ServerContext* context, - const BHTTPDecapsulateRequest* request, - BHTTPDecapsulateResponse* response) override { - const auto process = [&request, &response](auto&& maybe_bhttp_layer) { - if (!maybe_bhttp_layer.ok()) { - return grpc::Status(grpc::INTERNAL, - std::string(maybe_bhttp_layer.status().message())); - } - std::string body; - maybe_bhttp_layer->swap_body(body); - response->set_body(std::move(body)); - return grpc::Status::OK; - }; - - if (request->is_request()) { - return process( - quiche::BinaryHttpRequest::Create(request->bhttp_message())); - } - return process( - quiche::BinaryHttpResponse::Create(request->bhttp_message())); - } - grpc::Status OHTTPEncapsulate(ServerContext* context, const OHTTPEncapsulateRequest* request, OHTTPEncapsulateResponse* response) override { @@ -96,15 +81,30 @@ class ProtocolTestingHelperServiceImpl final return grpc::Status(grpc::INTERNAL, std::string(maybe_config.status().message())); } - - auto client = quiche::ObliviousHttpClient::Create(request->public_key(), - *maybe_config); - if (!client.ok()) { + const auto maybe_cbor_body = + V2GetValuesRequestJsonStringCborEncode(request->body()); + if (!maybe_cbor_body.ok()) { + LOG(ERROR) << "Converting JSON to CBOR failed: " + << maybe_cbor_body.status().message(); return grpc::Status(grpc::INTERNAL, - std::string(client.status().message())); + std::string(maybe_cbor_body.status().message())); } - - auto encrypted_req = client->CreateObliviousHttpRequest(request->body()); + auto encoded_data_size = privacy_sandbox::server_common::GetEncodedDataSize( + maybe_cbor_body->size(), kMinResponsePaddingBytes); + auto maybe_padded_request = + privacy_sandbox::server_common::EncodeResponsePayload( + privacy_sandbox::server_common::CompressionType::kUncompressed, + std::move(*maybe_cbor_body), encoded_data_size); + if (!maybe_padded_request.ok()) { + LOG(ERROR) << "Padding failed: " + << maybe_padded_request.status().message(); + return grpc::Status(grpc::INTERNAL, + std::string(maybe_padded_request.status().message())); + } + auto encrypted_req = + quiche::ObliviousHttpRequest::CreateClientObliviousRequest( + std::move(*maybe_padded_request), request->public_key(), + *std::move(maybe_config), kKVOhttpRequestLabel); if (!encrypted_req.ok()) { return grpc::Status(grpc::INTERNAL, std::string(encrypted_req.status().message())); @@ -149,9 +149,26 @@ class ProtocolTestingHelperServiceImpl final quiche::ObliviousHttpRequest::Context context = std::move(context_iter->second); context_map_.erase(context_iter); - auto decrypted_response = client->DecryptObliviousHttpResponse( - request->ohttp_response(), context); - response->set_body(decrypted_response->GetPlaintextData()); + auto decrypted_response = + quiche::ObliviousHttpResponse::CreateClientObliviousResponse( + request->ohttp_response(), context, kKVOhttpResponseLabel); + if (!decrypted_response.ok()) { + LOG(ERROR) << decrypted_response.status(); + } + + auto deframed_req = privacy_sandbox::server_common::DecodeRequestPayload( + std::move(*decrypted_response).ConsumePlaintextData()); + if (!deframed_req.ok()) { + LOG(ERROR) << "unpadding response failed!"; + return grpc::Status(grpc::INTERNAL, + std::string(deframed_req.status().message())); + } + + std::string resp_json_string; + CborDecodeToGetValuesResponseJsonString(deframed_req->compressed_data, + resp_json_string); + + response->set_body(resp_json_string); return grpc::Status::OK; } } diff --git a/infrastructure/testing/protocol_testing_helper_server.proto b/infrastructure/testing/protocol_testing_helper_server.proto index f53f5880..7da676c2 100644 --- a/infrastructure/testing/protocol_testing_helper_server.proto +++ b/infrastructure/testing/protocol_testing_helper_server.proto @@ -30,13 +30,6 @@ service ProtocolTestingHelper { // Returns test parameters like public key used for testing. rpc GetTestConfig(GetTestConfigRequest) returns (GetTestConfigResponse) {} - // Given a cleartext http message body in bytes format, wraps it in BinaryHTTP - // format. - rpc BHTTPEncapsulate(BHTTPEncapsulateRequest) returns (BHTTPEncapsulateResponse) {} - - // Given a BinaryHTTP message, unwraps it into a cleartext message. - rpc BHTTPDecapsulate(BHTTPDecapsulateRequest) returns (BHTTPDecapsulateResponse) {} - // Wraps a byte string with Oblivious HTTP encryption. rpc OHTTPEncapsulate(OHTTPEncapsulateRequest) returns (OHTTPEncapsulateResponse) {} @@ -58,31 +51,6 @@ message GetTestConfigResponse { bytes public_key = 1; } -// ======================== BHTTP functions ======================== -message BHTTPEncapsulateRequest { - // Message body to be wrapped - bytes body = 1; - // True: request False: response - bool is_request = 2; -} - -message BHTTPEncapsulateResponse { - // serialized BHTTP message - bytes bhttp_message = 1; -} - -message BHTTPDecapsulateRequest { - // serialized BHTTP message - bytes bhttp_message = 1; - // True: request False: response - bool is_request = 2; -} - -message BHTTPDecapsulateResponse { - // Message body unwrapped - bytes body = 1; -} - // ======================== OHTTP functions ======================== message OHTTPEncapsulateRequest { // Request to be encrypted diff --git a/production/packaging/aws/build_and_test b/production/packaging/aws/build_and_test index 9c99fcd1..726e2c48 100755 --- a/production/packaging/aws/build_and_test +++ b/production/packaging/aws/build_and_test @@ -34,6 +34,7 @@ function _print_runtime() { } declare MODE=prod +declare -i SKIP_EIF=0 function usage() { local exitval=${1-1} @@ -46,6 +47,7 @@ usage: --no-precommit Skip precommit checks --with-tests Also runs tests before building --mode Mode can be prod or nonprod. Default: ${MODE} + --skip-eif Skip building the eif and checking pcr0 environment variables (all optional): WORKSPACE Set the path to the workspace (repo root) @@ -79,6 +81,10 @@ while [[ $# -gt 0 ]]; do shift shift ;; + --skip-eif) + SKIP_EIF=1 + shift + ;; --verbose) BUILD_AND_TEST_ARGS+=("--verbose") set -o xtrace @@ -131,54 +137,33 @@ printf "==== Creating dist dir =====\n" mkdir -p "${DIST}"/aws chmod 770 "${DIST}" "${DIST}"/aws -printf "==== build AWS artifacts using build-amazonlinux2 =====\n" -# build nitro enclave image, collect eif artifacts -IMAGE_URI=bazel/production/packaging/aws/data_server -readonly IMAGE_URI -BUILD_ARCH=$("${WORKSPACE}"/builders/tools/get-architecture) -IMAGE_TAG=$(mktemp --dry-run temp-XXXXXX) -builder::cbuild_al2 $" -set -o errexit -# extract server docker image into local docker client and retag it -docker load -i dist/server_docker_image.tar -docker tag ${IMAGE_URI}:server_docker_image ${IMAGE_URI}:'${IMAGE_TAG}' -rm -f dist/aws/{server_enclave_image.eif,server_enclave_image.json,pcr0.json} -nitro-cli build-enclave --docker-uri ${IMAGE_URI}:'${IMAGE_TAG}' --output-file dist/aws/server_enclave_image.eif >dist/aws/server_enclave_image.json -if [[ \$? -ne 0 ]]; then - printf 'error building nitro eif image\n' - exit 1 -fi -jq --compact-output '{PCR0: .Measurements.PCR0}' dist/aws/server_enclave_image.json >dist/aws/pcr0.json -jq --compact-output --raw-output '.Measurements.PCR0' dist/aws/server_enclave_image.json >dist/aws/pcr0.txt -cat dist/aws/pcr0.json -exit 0 -" -docker image rm --force ${IMAGE_URI}:"${IMAGE_TAG}" || fail "Unable to remove Docker image" - -printf "==== Checking PCR0 =====\n" -PCR0_REL_DIR=production/packaging/aws/data_server/nitro-pcr0 -readonly PCR0_REL_DIR -PCR0_DIR="${WORKSPACE}"/${PCR0_REL_DIR} -readonly PCR0_DIR -PCR0_FILE="${PCR0_DIR}"/${BUILD_ARCH}.json -readonly PCR0_FILE -declare -i OUTDATED_PCR0=0 -if [[ ! -s ${PCR0_FILE} ]]; then - printf "warning: PCR0 file doesn't exist or is empty [%s]\n" "${PCR0_FILE}" - OUTDATED_PCR0=1 -elif ! diff -q "${DIST}"/aws/pcr0.json "${PCR0_FILE}" &>/dev/null; then - printf "warning: PCR0 hash differs\n" - OUTDATED_PCR0=1 -fi -if [[ ${OUTDATED_PCR0} -eq 0 ]]; then - printf "PCR0 hash unchanged\n" -else - printf "PCR0 hash differs: %s:\n" "${PCR0_FILE}" - cat "${PCR0_FILE}" +if [[ ${SKIP_EIF} -eq 0 ]]; then + printf "==== build AWS artifacts using build-amazonlinux2023 =====\n" + # build nitro enclave image, collect eif artifacts + IMAGE_URI=bazel/production/packaging/aws/data_server + readonly IMAGE_URI + IMAGE_TAG=$(mktemp --dry-run temp-XXXXXX) + builder::cbuild_al2023 $" + set -o errexit + # extract server docker image into local docker client and retag it + docker load -i dist/server_docker_image.tar + docker tag ${IMAGE_URI}:server_docker_image ${IMAGE_URI}:'${IMAGE_TAG}' + rm -f dist/aws/{server_enclave_image.eif,server_enclave_image.json,pcr0.json} + nitro-cli build-enclave --docker-uri ${IMAGE_URI}:'${IMAGE_TAG}' --output-file dist/aws/server_enclave_image.eif >dist/aws/server_enclave_image.json + if [[ \$? -ne 0 ]]; then + printf 'error building nitro eif image\n' + exit 1 + fi + jq --compact-output '{PCR0: .Measurements.PCR0}' dist/aws/server_enclave_image.json >dist/aws/pcr0.json + jq --compact-output --raw-output '.Measurements.PCR0' dist/aws/server_enclave_image.json >dist/aws/pcr0.txt + cat dist/aws/pcr0.json + exit 0 + " + docker image rm --force ${IMAGE_URI}:"${IMAGE_TAG}" || fail "Unable to remove Docker image" fi printf "==== Copying to dist =====\n" -builder::cbuild_al2 $" +builder::cbuild_al2023 $" trap _collect_logs EXIT function _collect_logs() { local -r -i STATUS=\$? @@ -207,7 +192,7 @@ if [[ -n ${AMI_REGIONS[0]} ]]; then docker run --rm --user "${DOCKER_USER}" --entrypoint=unzip --volume "${AWSDIR}:${AWSDIR}" --workdir "${AWSDIR}" "${UTILS_IMAGE}" -o -q "${AWSDIR}"/aws_artifacts.zip printf "==== build AWS AMI (using packer) =====\n" regions="$(arr_to_string_list AMI_REGIONS)" - builder::cbuild_al2 " + builder::cbuild_al2023 " set -o errexit packer build \ -var=regions='${regions}' \ diff --git a/production/packaging/aws/codebuild/buildspec.yaml b/production/packaging/aws/codebuild/buildspec.yaml new file mode 100644 index 00000000..939fb503 --- /dev/null +++ b/production/packaging/aws/codebuild/buildspec.yaml @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +version: 0.2 + +env: + variables: + # CODEBUILD_SRC_DIR is a CodeBuild-provided env variable. + # Other variables can be provided by the CodeBuild project GUI or AWS CLI. + # Check the README and the following link for more information about + # env vars required by this script: + # https://docs.aws.amazon.com/codebuild/latest/userguide/build-env-ref-env-vars.html + AMI_REGION: us-east-1 # Default. Region where to upload the final AMIs. + BUILD_FLAVOR: prod # Default. Use nonprod for enhanced logging output. + DOCKERHUB_USERNAME: YOUR_USERNAME # Default. Used to login to Docker with DOCKERHUB_PASSWORD_SECRET_NAME. + DOCKERHUB_PASSWORD_SECRET_NAME: codebuild-dockerhub-password # Default. Name of AWS Secret resource. + +phases: + build: + commands: + - | + aws secretsmanager get-secret-value \ + --secret-id ${DOCKERHUB_PASSWORD_SECRET_NAME} \ + --query SecretString --output text \ + | docker login -u ${DOCKERHUB_USERNAME} --password-stdin + - | + WORKSPACE_MOUNT=${CODEBUILD_SRC_DIR} \ + production/packaging/aws/build_and_test \ + --with-ami ${AMI_REGION} \ + --mode ${BUILD_FLAVOR} \ + --no-precommit diff --git a/production/packaging/aws/data_server/BUILD.bazel b/production/packaging/aws/data_server/BUILD.bazel index be177338..6821bb9e 100644 --- a/production/packaging/aws/data_server/BUILD.bazel +++ b/production/packaging/aws/data_server/BUILD.bazel @@ -13,7 +13,7 @@ # limitations under the License. load("@container_structure_test//:defs.bzl", "container_structure_test") -load("@rules_oci//oci:defs.bzl", "oci_image", "oci_tarball") +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load") load( "@rules_pkg//pkg:mappings.bzl", "pkg_attributes", @@ -63,8 +63,8 @@ pkg_tar( oci_image( name = "server_docker_image", base = select({ - "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64//image", - "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64//image", + "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64", + "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64", }), entrypoint = [ "/init_server_basic", @@ -75,29 +75,35 @@ oci_image( "--port=50051", # These affect PCR0, so changing these would result in the loss of ability to communicate with # the downstream components - "--public_key_endpoint=https://publickeyservice.pa-3.aws.privacysandboxservices.com/v1alpha/publicKeys", + "--public_key_endpoint=https://publickeyservice.pa.aws.privacysandboxservices.com/.well-known/protected-auction/v1/public-keys", "--stderrthreshold=0", ], tars = [ "@google_privacysandbox_servers_common//src/aws/proxy:libnsm_and_proxify_tar", "@google_privacysandbox_servers_common//src/cpio/client_providers/kms_client_provider/aws:kms_binaries", - "//production/packaging/aws/resolv:resolv_config_layer", + "//production/packaging/aws/resolv:resolv_config_tar", ":server_binaries_tar", ], ) -oci_tarball( +oci_load( name = "server_docker_tarball", image = ":server_docker_image", repo_tags = ["bazel/production/packaging/aws/data_server:server_docker_image"], ) +filegroup( + name = "server_docker_tarball.tar", + srcs = [":server_docker_tarball"], + output_group = "tarball", +) + container_structure_test( name = "structure_test", size = "medium", configs = ["test/structure.yaml"], driver = "tar", - image = ":server_docker_tarball", + image = ":server_docker_tarball.tar", ) container_structure_test( @@ -118,14 +124,14 @@ genrule( name = "copy_to_dist", srcs = [ ":server_artifacts", - ":server_docker_tarball", + ":server_docker_tarball.tar", "//public/query:query_api_descriptor_set", ], outs = ["copy_to_dist.bin"], cmd_bash = """cat << EOF > '$@' mkdir -p dist/debian cp $(execpath :server_artifacts) dist/debian -cp $(execpath :server_docker_tarball) dist/server_docker_image.tar +cp $(execpath :server_docker_tarball.tar) dist/server_docker_image.tar cp $(execpath //public/query:query_api_descriptor_set) dist # retain previous server_docker_image.tar location as a symlink ln -rsf dist/server_docker_image.tar dist/debian/server_docker_image.tar diff --git a/production/packaging/aws/data_server/ami/BUILD.bazel b/production/packaging/aws/data_server/ami/BUILD.bazel index 3e294b2f..91e85f56 100644 --- a/production/packaging/aws/data_server/ami/BUILD.bazel +++ b/production/packaging/aws/data_server/ami/BUILD.bazel @@ -18,7 +18,7 @@ load("@rules_pkg//pkg:zip.bzl", "pkg_zip") pkg_zip( name = "aws_artifacts", srcs = [ - "//components/aws:sqs_lambda.tar", + "//components/aws:sqs_lambda_tarball", "//production/packaging/aws/otel_collector:aws-otel-collector.rpm", "//production/packaging/aws/otel_collector:aws_otel_collector_cfg", "@google_privacysandbox_servers_common//src/aws/proxy", diff --git a/production/packaging/aws/data_server/ami/image.pkr.hcl b/production/packaging/aws/data_server/ami/image.pkr.hcl index 38ae72f9..455b85b7 100644 --- a/production/packaging/aws/data_server/ami/image.pkr.hcl +++ b/production/packaging/aws/data_server/ami/image.pkr.hcl @@ -71,7 +71,7 @@ source "amazon-ebs" "dataserver" { ami_regions = var.regions source_ami_filter { filters = { - name = "amzn2-ami-kernel-*-x86_64-gp2" + name = "al2023-ami-20*-kernel-*-x86_64" root-device-type = "ebs" virtualization-type = "hvm" } diff --git a/production/packaging/aws/data_server/ami/setup.sh b/production/packaging/aws/data_server/ami/setup.sh index 74d6d994..0f8ffc65 100644 --- a/production/packaging/aws/data_server/ami/setup.sh +++ b/production/packaging/aws/data_server/ami/setup.sh @@ -20,7 +20,6 @@ chmod 500 /home/ec2-user/proxy chmod 500 /home/ec2-user/server_enclave_image.eif sudo mkdir /opt/privacysandbox -mkdir /tmp/proxy sudo cp /home/ec2-user/vsockproxy.service /etc/systemd/system/vsockproxy.service sudo cp /home/ec2-user/proxy /opt/privacysandbox/proxy @@ -36,10 +35,10 @@ sudo chmod 555 /opt/privacysandbox/hc.bash sudo chmod 555 /opt/privacysandbox/health.proto # Install necessary dependencies -sudo yum update -y -sudo yum install -y docker -sudo yum localinstall -y /home/ec2-user/aws-otel-collector.rpm -sudo amazon-linux-extras install -y aws-nitro-enclaves-cli +sudo dnf update -y +sudo dnf install -y docker +sudo dnf localinstall -y /home/ec2-user/aws-otel-collector.rpm +sudo dnf install aws-nitro-enclaves-cli -y sudo usermod -a -G docker ec2-user sudo usermod -a -G ne ec2-user diff --git a/production/packaging/aws/data_server/ami/vsockproxy.service b/production/packaging/aws/data_server/ami/vsockproxy.service index 26d2e053..3a6b929e 100644 --- a/production/packaging/aws/data_server/ami/vsockproxy.service +++ b/production/packaging/aws/data_server/ami/vsockproxy.service @@ -3,7 +3,7 @@ Description=Vsock proxy for allowing outgoing connections from enclave image After=network.target [Service] ExecStart=/opt/privacysandbox/proxy -WorkingDirectory=/tmp/proxy +WorkingDirectory=/tmp Restart=always [Install] WantedBy=multi-user.target diff --git a/production/packaging/aws/data_server/nitro-pcr0/amd64.json b/production/packaging/aws/data_server/nitro-pcr0/amd64.json deleted file mode 100644 index 615805cc..00000000 --- a/production/packaging/aws/data_server/nitro-pcr0/amd64.json +++ /dev/null @@ -1 +0,0 @@ -{"PCR0":"76ece8efa6f8c5723edd64b3a0464ec72411651be7642c951c73bcbc42863a8f1c376970726109c0e718b2f0d326e39d"} diff --git a/production/packaging/aws/data_server/nitro-pcr0/arm64.json b/production/packaging/aws/data_server/nitro-pcr0/arm64.json deleted file mode 100644 index 517e62da..00000000 --- a/production/packaging/aws/data_server/nitro-pcr0/arm64.json +++ /dev/null @@ -1 +0,0 @@ -{"PCR0":"1147a9ea02769b53a45bdb5a52aa5a291fa314049c15e2fc8e22fb216cd1cfb13967c944cd76c73287c894e64184d8ad"} diff --git a/production/packaging/aws/otel_collector/BUILD.bazel b/production/packaging/aws/otel_collector/BUILD.bazel index 9e3d2d56..dbbab38c 100644 --- a/production/packaging/aws/otel_collector/BUILD.bazel +++ b/production/packaging/aws/otel_collector/BUILD.bazel @@ -12,32 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@bazel_skylib//rules:copy_file.bzl", "copy_file") load("@rules_pkg//pkg:mappings.bzl", "pkg_files") package(default_visibility = ["//visibility:public"]) -otel_version = "0.36.0" - -otel_tag = "v" + otel_version - -target_file = otel_tag + ".tar.gz" - -genrule( +copy_file( name = "aws_otel_collector", - outs = ["aws-otel-collector.rpm"], - cmd = """ - set -x - yum install -y wget - yum install -y rpm-build - wget https://github.com/aws-observability/aws-otel-collector/archive/refs/tags/%s - tar xvf %s - cd aws-otel-collector-%s - # Remove linting from build targets. - sed -i 's/build: install-tools lint multimod-verify/build:/g' Makefile - export HOME=/root - make package-rpm - cp build/packages/linux/amd64/aws-otel-collector.rpm ../$@ - """ % (target_file, target_file, otel_version), + src = select({ + "@platforms//cpu:aarch64": "@otel_collector_aarch64//file", + "@platforms//cpu:x86_64": "@otel_collector_amd64//file", + }), + out = "aws-otel-collector.rpm", ) pkg_files( diff --git a/production/packaging/aws/resolv/BUILD.bazel b/production/packaging/aws/resolv/BUILD.bazel index b4bcdd35..8a365b2d 100644 --- a/production/packaging/aws/resolv/BUILD.bazel +++ b/production/packaging/aws/resolv/BUILD.bazel @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@io_bazel_rules_docker//container:container.bzl", - "container_layer", -) load( "@rules_pkg//pkg:mappings.bzl", "pkg_attributes", @@ -37,13 +33,5 @@ pkg_tar( srcs = [ ":etc_resolv_files", ], -) - -container_layer( - name = "resolv_config_layer", - directory = "/", - tars = [ - ":resolv_config_tar", - ], visibility = ["//production/packaging:__subpackages__"], ) diff --git a/production/packaging/gcp/cloud_build/Dockerfile b/production/packaging/gcp/cloud_build/Dockerfile new file mode 100644 index 00000000..f6cb3ddc --- /dev/null +++ b/production/packaging/gcp/cloud_build/Dockerfile @@ -0,0 +1,21 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Use a debian-based image with gcloud and docker installed: +FROM google/cloud-sdk:473.0.0 + +# Install additional dependencies +RUN apt-get update && apt-get install -y \ + git \ + gettext-base # Package for envsubst diff --git a/production/packaging/gcp/cloud_build/cloudbuild.yaml b/production/packaging/gcp/cloud_build/cloudbuild.yaml new file mode 100644 index 00000000..748ab946 --- /dev/null +++ b/production/packaging/gcp/cloud_build/cloudbuild.yaml @@ -0,0 +1,46 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +steps: + - name: gcr.io/cloud-builders/docker + args: + - build + - -t + - service-builder + - -f + - production/packaging/gcp/cloud_build/Dockerfile + - . + - name: service-builder + env: + - WORKSPACE_MOUNT=/workspace + script: | + #!/usr/bin/env bash + production/packaging/gcp/build_and_test \ + --mode ${_BUILD_FLAVOR} \ + && production/packaging/gcp/docker_push_gcp_repo \ + --gcp-image-repo $_GCP_IMAGE_REPO \ + --gcp-image-tag $_GCP_IMAGE_TAG +substitutions: + # The following variables may be overridden via the gcloud CLI or the + # CloudBuild Trigger GUI. + # See https://cloud.google.com/build/docs/configuring-builds/substitute-variable-values + # for more information. + _BUILD_FLAVOR: prod # Default. Use nonprod for enhanced logging output. + _GCP_IMAGE_TAG: ${BUILD_ID} # Default. Required for server deployment later. + _GCP_IMAGE_REPO: us-docker.pkg.dev/${PROJECT_ID}/kvs-docker-repo-shared/kv-service # Default. Artifact Registry repo to house images. +timeout: 3h +options: + machineType: E2_HIGHCPU_32 + automapSubstitutions: true + logging: CLOUD_LOGGING_ONLY diff --git a/production/packaging/gcp/data_server/BUILD.bazel b/production/packaging/gcp/data_server/BUILD.bazel index eb077cc8..3ac9a540 100644 --- a/production/packaging/gcp/data_server/BUILD.bazel +++ b/production/packaging/gcp/data_server/BUILD.bazel @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@io_bazel_rules_docker//container:container.bzl", - "container_flatten", - "container_image", - "container_layer", -) +load("@container_structure_test//:defs.bzl", "container_structure_test") +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load") load( "@rules_pkg//pkg:mappings.bzl", "pkg_attributes", @@ -54,12 +50,18 @@ pkg_tar( srcs = server_binaries, ) -container_layer( - name = "server_binary_layer", - directory = "/", - tars = [ - ":server_binaries_tar", +pkg_files( + name = "envoy_executables", + srcs = [ + "@envoy_binary//file", ], + attributes = pkg_attributes(mode = "0555"), + prefix = "/usr/local/bin", +) + +pkg_tar( + name = "envoy_binary_tar", + srcs = [":envoy_executables"], ) pkg_files( @@ -82,53 +84,15 @@ pkg_tar( ], ) -container_layer( - name = "envoy_config_layer", - directory = "/", - tars = [ - ":envoy_config_tar", - ], - visibility = [ - "//production/packaging:__subpackages__", - "//services:__subpackages__", - ], -) - -container_flatten( - name = "envoy_distroless_flat", - image = select({ - "@platforms//cpu:arm64": "@envoy-distroless-arm64//image", - "@platforms//cpu:x86_64": "@envoy-distroless-amd64//image", - }), -) - -container_layer( - name = "envoy_distroless_layer", - tars = [ - ":envoy_distroless_flat.tar", - ], - visibility = [ - "//components:__subpackages__", - "//production/packaging:__subpackages__", - ], -) - -# This image target is meant for testing running the server in an enclave using. -# -# See project README.md on how to run the image. -container_image( +oci_image( name = "server_docker_image", - architecture = select({ - "@platforms//cpu:arm64": "arm64", - "@platforms//cpu:x86_64": "amd64", - }), base = select({ - "@platforms//cpu:arm64": "@runtime-debian-debug-root-arm64//image", - "@platforms//cpu:x86_64": "@runtime-debian-debug-root-amd64//image", + "@platforms//cpu:arm64": "@runtime-debian-debug-root-arm64", + "@platforms//cpu:x86_64": "@runtime-debian-debug-root-amd64", }), entrypoint = [ "/init_server_basic", - # These affect PCR0, so changing these would result in the loss of ability to communicate with + # These affect image digest, so changing these would result in the loss of ability to communicate with # the downstream components "--public_key_endpoint=https://publickeyservice-a.pa-3.gcp.privacysandboxservices.com/.well-known/protected-auction/v1/public-keys", "--stderrthreshold=0", @@ -136,20 +100,48 @@ container_image( env = { "GRPC_DNS_RESOLVER": "native", }, - labels = {"tee.launch_policy.log_redirect": "debugonly"}, - layers = [ - ":server_binary_layer", - ":envoy_distroless_layer", - ":envoy_config_layer", + exposed_ports = [ + "50050/tcp", + "50051/tcp", + "50100/tcp", + "51052/tcp", ], - ports = [ - "50050", - "50051", - "50100", - "51052", + labels = {"tee.launch_policy.log_redirect": "debugonly"}, + tars = [ + ":server_binaries_tar", + ":envoy_binary_tar", + ":envoy_config_tar", ], ) +oci_load( + name = "server_docker_tarball", + image = ":server_docker_image", + repo_tags = ["bazel/production/packaging/gcp/data_server:server_docker_image"], +) + +filegroup( + name = "server_docker_tarball.tar", + srcs = [":server_docker_tarball"], + output_group = "tarball", +) + +container_structure_test( + name = "structure_test", + size = "medium", + configs = ["test/structure.yaml"], + driver = "tar", + image = ":server_docker_tarball.tar", +) + +container_structure_test( + name = "commands_test", + size = "small", + configs = ["test/commands.yaml"], + driver = "docker", + image = ":server_docker_image", +) + # server artifacts pkg_zip( name = "server_artifacts", @@ -160,14 +152,15 @@ genrule( name = "copy_to_dist", srcs = [ ":server_artifacts", - ":server_docker_image.tar", + ":server_docker_tarball.tar", "//public/query:query_api_descriptor_set", ], outs = ["copy_to_dist.bin"], cmd_bash = """cat << EOF > '$@' mkdir -p dist/debian cp $(execpath :server_artifacts) dist/debian -cp $(execpath :server_docker_image.tar) $(execpath //public/query:query_api_descriptor_set) dist +cp $(execpath :server_docker_tarball.tar) dist/server_docker_image.tar +cp $(execpath //public/query:query_api_descriptor_set) dist # retain previous server_docker_image.tar location as a symlink ln -rsf dist/server_docker_image.tar dist/debian/server_docker_image.tar builders/tools/normalize-dist diff --git a/production/packaging/gcp/data_server/bin/init_server_main.cc b/production/packaging/gcp/data_server/bin/init_server_main.cc index 34df2cb1..7b62d8be 100644 --- a/production/packaging/gcp/data_server/bin/init_server_main.cc +++ b/production/packaging/gcp/data_server/bin/init_server_main.cc @@ -113,8 +113,8 @@ int main(int argc, char* argv[]) { if (PrepareTlsKeyCertForEnvoy()) { // Starts Envoy and server in separate processes - if (const pid_t pid = fork(); pid == 1) { - LOG(ERROR) << "Fork failure!"; + if (const pid_t pid = fork(); pid == -1) { + PLOG(ERROR) << "Fork failure!"; return errno; } else if (pid == 0) { StartEnvoy(); diff --git a/production/packaging/gcp/data_server/envoy/envoy.yaml b/production/packaging/gcp/data_server/envoy/envoy.yaml index c1ea1b38..4e4b5efd 100644 --- a/production/packaging/gcp/data_server/envoy/envoy.yaml +++ b/production/packaging/gcp/data_server/envoy/envoy.yaml @@ -30,7 +30,16 @@ static_resources: virtual_hosts: - name: local_service domains: [ "*" ] + cors: + allow_origin_string_match: + - prefix: "*" + allow_methods: GET, POST, PUT, OPTIONS + allow_headers: Origin, Content, Accept, Content-Type, Authorization, X-Requested-With + allow_credentials: true routes: + - match: { prefix: "/", headers: [ { name: ":method", exact_match: "OPTIONS" } ] } + # fake cluster route ... some issue in envoy. + route: { cluster: "grpc_cluster", timeout: 60s } - match: { prefix: "/kv_server.v1.KeyValueService" } route: { cluster: grpc_cluster, timeout: 60s } - match: { prefix: "/kv_server.v2.KeyValueService" } @@ -40,6 +49,10 @@ static_resources: retry_policy: retry_on: 5xx num_retries: 10 + request_headers_to_add: + - header: + key: "kv-content-type" + value: "%DYNAMIC_METADATA(tkv:content-type)%" response_headers_to_add: - header: key: 'Ad-Auction-Allowed' @@ -49,7 +62,46 @@ static_resources: key: 'x-fledge-bidding-signals-format-version' value: '2' append: false + - header: + key: 'Access-Control-Allow-Origin' + value: '*' http_filters: + # TODO(b/362736353): Remove lua filter. + # Add a LUA filter to get the request body. + # On GCP, the request body does not seem to get fully propagated to the server. + # Adding a `request_handle:body()` call fixes this. + - name: envoy.filters.http.lua + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.lua.v3.Lua + default_source_code: + inline_string: | + function envoy_on_request(request_handle) + -- If path is grpc reflection, don't wait for body or envoy will get stuck + if(request_handle:headers():get(":path") ~= "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo") + then + request_handle:body() + end + end + - name: envoy.filters.http.cors + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.cors.v3.Cors + # Pass the content-type to the grpc server. By default the header is overwritten to "application/grpc" + # This extra filter must be used because the content-type must be copied to a metadata prior to grpc transcoding + # By the time the Router filter is invoked, the original content-type is already changed. + - name: envoy.filters.http.header_to_metadata + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.header_to_metadata.v3.Config + request_rules: + - header: "content-type" + on_header_present: + metadata_namespace: tkv + key: content-type + type: STRING + on_header_missing: + metadata_namespace: tkv + key: content-type + value: 'unknown' + type: STRING - name: envoy.filters.http.grpc_stats, typed_config: "@type": type.googleapis.com/envoy.extensions.filters.http.grpc_stats.v3.FilterConfig diff --git a/production/packaging/gcp/data_server/test/commands.yaml b/production/packaging/gcp/data_server/test/commands.yaml new file mode 100644 index 00000000..2754458f --- /dev/null +++ b/production/packaging/gcp/data_server/test/commands.yaml @@ -0,0 +1,24 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# refer to docs at https://github.com/GoogleContainerTools/container-structure-test + +schemaVersion: 2.0.0 + +# command tests require the docker toolchain +commandTests: + - name: "server help" + command: "/server" + args: ["--help"] + exitCode: 1 diff --git a/production/packaging/gcp/data_server/test/structure.yaml b/production/packaging/gcp/data_server/test/structure.yaml new file mode 100644 index 00000000..d5b60c76 --- /dev/null +++ b/production/packaging/gcp/data_server/test/structure.yaml @@ -0,0 +1,40 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# refer to docs at https://github.com/GoogleContainerTools/container-structure-test + +schemaVersion: 2.0.0 + +fileExistenceTests: + - name: init_server_basic + path: /init_server_basic + shouldExist: true + isExecutableBy: any + + - name: server + path: /server + shouldExist: true + isExecutableBy: any + + - name: envoy-config + path: /etc/envoy/envoy.yaml + shouldExist: true + + - name: envoy-binary + path: /usr/local/bin/envoy + shouldExist: true + isExecutableBy: any + +licenseTests: + - debian: true diff --git a/production/packaging/local/data_server/BUILD.bazel b/production/packaging/local/data_server/BUILD.bazel index c2d02cf3..fb80c09b 100644 --- a/production/packaging/local/data_server/BUILD.bazel +++ b/production/packaging/local/data_server/BUILD.bazel @@ -12,98 +12,131 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@io_bazel_rules_docker//container:container.bzl", - "container_image", +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load") +load("@rules_pkg//pkg:mappings.bzl", "pkg_attributes", "pkg_files") +load("@rules_pkg//pkg:tar.bzl", "pkg_tar") +load("@rules_pkg//pkg:zip.bzl", "pkg_zip") + +pkg_files( + name = "server_executables", + srcs = [ + "//components/data_server/server", + ], + attributes = pkg_attributes(mode = "0755"), + prefix = "/server/bin", +) + +server_binaries = [ + ":server_executables", +] + +pkg_zip( + name = "server_binaries", + srcs = server_binaries, +) + +pkg_tar( + name = "server_binaries_tar", + srcs = server_binaries, ) -load("@io_bazel_rules_docker//docker/util:run.bzl", "container_run_and_commit_layer") -container_run_and_commit_layer( - name = "profiling_tools_layer", - commands = [ - # Install killall, wget, ps, binutils and linux perf. - "apt-get update", - "apt-get install wget psmisc procps binutils linux-tools-generic -y", - # Installing graphviz and gperftools requires these ENV variables. - "export DEBIAN_FRONTEND=noninteractive; export TZ=Etc/UTC", - # Install graphviz and libgperf tools. - "apt-get install graphviz libgoogle-perftools-dev -y", - # Install pprof. - "wget https://go.dev/dl/go1.21.6.linux-amd64.tar.gz", - "rm -rf /usr/local/go && tar -C /usr/local -xzf go1.21.6.linux-amd64.tar.gz && rm go1.21.6.linux-amd64.tar.gz", - "/usr/local/go/bin/go install github.com/google/pprof@latest", +# server artifacts +pkg_zip( + name = "server_artifacts", + srcs = server_binaries, +) + +pkg_files( + name = "init_server_with_profiler_execs", + srcs = [ + ":init_server_with_profiler", + ], + attributes = pkg_attributes(mode = "0755"), + prefix = "/server/bin", +) + +pkg_tar( + name = "init_server_with_profiler_tar", + srcs = [ + ":init_server_with_profiler_execs", ], - image = select({ - "@platforms//cpu:arm64": "@runtime-ubuntu-fulldist-debug-root-arm64//image", - "@platforms//cpu:x86_64": "@runtime-ubuntu-fulldist-debug-root-amd64//image", - }), ) # This image target is meant for cpu and memory profiling of the server. -container_image( +oci_image( name = "server_profiling_docker_image", - architecture = select({ - "@platforms//cpu:arm64": "arm64", - "@platforms//cpu:x86_64": "amd64", - }), base = select({ - "@platforms//cpu:arm64": "@runtime-ubuntu-fulldist-debug-root-arm64//image", - "@platforms//cpu:x86_64": "@runtime-ubuntu-fulldist-debug-root-amd64//image", + "@platforms//cpu:arm64": "@runtime-ubuntu-fulldist-debug-root-arm64", + "@platforms//cpu:x86_64": "@runtime-ubuntu-fulldist-debug-root-amd64", }), - entrypoint = [ - "/server", - "--port=50051", - "--delta_directory=/data", - "--realtime_directory=/data/realtime", - "--stderrthreshold=0", + cmd = [ + "/server/bin/init_server_with_profiler", ], - layers = [ - ":profiling_tools_layer", - "//production/packaging/gcp/data_server:server_binary_layer", + entrypoint = ["/bin/bash"], + tars = [ + ":server_binaries_tar", + ":init_server_with_profiler_tar", ], - symlinks = { - "/usr/local/bin/linux-perf": "/usr/lib/linux-tools/5.4.0-169-generic/perf", - "/usr/local/bin/pprof": "/root/go/bin/pprof", - }, +) + +oci_load( + name = "server_profiling_docker_tarball", + image = ":server_profiling_docker_image", + repo_tags = ["bazel/production/packaging/local/data_server:server_profiling_docker_image"], +) + +filegroup( + name = "server_profiling_docker_image.tar", + srcs = [":server_profiling_docker_tarball"], + output_group = "tarball", ) # This image target is meant for testing running the server in an enclave using. # # See project README.md on how to run the image. -container_image( +oci_image( name = "server_docker_image", - architecture = select({ - "@platforms//cpu:arm64": "arm64", - "@platforms//cpu:x86_64": "amd64", - }), base = select({ - "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64//image", - "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64//image", + "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64", + "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64", }), entrypoint = [ - "/server", + "/server/bin/server", "--port=50051", "--delta_directory=/data", "--realtime_directory=/data/realtime", "--stderrthreshold=0", ], - layers = [ - "//production/packaging/gcp/data_server:server_binary_layer", + tars = [ + ":server_binaries_tar", ], ) +oci_load( + name = "server_docker_tarball", + image = ":server_docker_image", + repo_tags = ["bazel/production/packaging/local/data_server:server_docker_image"], +) + +filegroup( + name = "server_docker_image.tar", + srcs = [":server_docker_tarball"], + output_group = "tarball", +) + genrule( name = "copy_to_dist", srcs = [ ":server_docker_image.tar", - "//production/packaging/gcp/data_server:server_artifacts", + ":server_artifacts", "//public/query:query_api_descriptor_set", ], outs = ["copy_to_dist.bin"], cmd_bash = """cat << EOF > '$@' mkdir -p dist/debian -cp $(execpath //production/packaging/gcp/data_server:server_artifacts) dist/debian -cp $(execpath :server_docker_image.tar) $(execpath //public/query:query_api_descriptor_set) dist +cp $(execpath :server_artifacts) dist/debian +cp $(execpath :server_docker_image.tar) dist/server_docker_image.tar +cp $(execpath //public/query:query_api_descriptor_set) dist # retain previous server_docker_image.tar location as a symlink ln -rsf dist/server_docker_image.tar dist/debian/server_docker_image.tar builders/tools/normalize-dist diff --git a/production/packaging/local/data_server/init_server_with_profiler b/production/packaging/local/data_server/init_server_with_profiler new file mode 100644 index 00000000..a0c189a6 --- /dev/null +++ b/production/packaging/local/data_server/init_server_with_profiler @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -o errexit +# Install killall, wget, ps, binutils and linux perf. +apt-get update +apt-get install wget psmisc procps binutils linux-tools-generic -y +# Installing graphviz and gperftools requires these ENV variables. +export DEBIAN_FRONTEND=noninteractive +export TZ=Etc/UTC +# Install graphviz and libgperf tools. +apt-get install graphviz libgoogle-perftools-dev -y +# Install pprof. +wget https://go.dev/dl/go1.21.6.linux-amd64.tar.gz +rm -rf /usr/local/go && tar -C /usr/local -xzf go1.21.6.linux-amd64.tar.gz && rm go1.21.6.linux-amd64.tar.gz +/usr/local/go/bin/go install github.com/google/pprof@latest +# Create symlink +ln -s /usr/lib/linux-tools/5.4.0-169-generic/perf /usr/local/bin/linux-perf +ln -s /root/go/bin/pprof /usr/local/bin/pprof +# Run the server +/server/bin/server --port=50051 --delta_directory=/data --realtime_directory=/data/realtime --stderrthreshold=0 diff --git a/production/packaging/sync_key_value_repo.yaml b/production/packaging/sync_key_value_repo.yaml new file mode 100644 index 00000000..0b651197 --- /dev/null +++ b/production/packaging/sync_key_value_repo.yaml @@ -0,0 +1,44 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file contains a Github Action that syncs with +# the Bidding and Auction Servers repo, including tags, +# on a regular schedule. It will push any new tags +# to your repo, and you can use Cloud Build, +# CodeBuild, or any webhook-based system to trigger +# an automatic build based on the new tag. +name: Sync with github.com/privacysandbox/protected-auction-key-value-service + +on: + schedule: + - cron: '*/30 * * * *' # every 30 minutes + workflow_dispatch: # on button click + +jobs: + sync_code: + runs-on: ubuntu-latest + steps: + - uses: tgymnich/fork-sync@v1.8 + continue-on-error: true + with: + base: main + head: main + - name: Checkout Code + uses: actions/checkout@v3 + if: always() # Always checkout, even if sync fails + - name: Sync Tags with Upstream + if: always() # Always sync tags, even if sync or checkout fails + run: | + git fetch https://github.com/privacysandbox/protected-auction-key-value-service --tags --force + git push origin --tags --force diff --git a/production/packaging/tools/BUILD.bazel b/production/packaging/tools/BUILD.bazel index 15706f82..cc0a2e77 100644 --- a/production/packaging/tools/BUILD.bazel +++ b/production/packaging/tools/BUILD.bazel @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@io_bazel_rules_docker//container:container.bzl", - "container_image", - "container_layer", -) +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load") load( "@rules_pkg//pkg:mappings.bzl", "pkg_attributes", @@ -87,27 +83,27 @@ pkg_tar( srcs = [":tools_executables"], ) -container_layer( - name = "tools_binaries_layer", - directory = "/", +oci_image( + name = "tools_binaries_image", + base = select({ + "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64", + "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64", + }), tars = [ ":tools_binaries_tar", ], ) -container_image( +oci_load( name = "tools_binaries_docker_image", - architecture = select({ - "@platforms//cpu:arm64": "arm64", - "@platforms//cpu:x86_64": "amd64", - }), - base = select({ - "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64//image", - "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64//image", - }), - layers = [ - ":tools_binaries_layer", - ], + image = ":tools_binaries_image", + repo_tags = ["bazel/production/packaging/tools:tools_binaries_docker_image"], +) + +filegroup( + name = "tools_binaries_docker_image.tar", + srcs = [":tools_binaries_docker_image"], + output_group = "tarball", ) genrule( @@ -120,7 +116,7 @@ genrule( cmd_bash = """cat << EOF > '$@' mkdir -p dist/debian cp $(execpath :tools_binaries) dist/debian -cp $(execpath :tools_binaries_docker_image.tar) dist +cp $(execpath :tools_binaries_docker_image.tar) dist/tools_binaries_docker_image.tar builders/tools/normalize-dist EOF""", executable = True, diff --git a/production/packaging/tools/request_simulation/BUILD.bazel b/production/packaging/tools/request_simulation/BUILD.bazel index ccafeaa3..00cb62aa 100644 --- a/production/packaging/tools/request_simulation/BUILD.bazel +++ b/production/packaging/tools/request_simulation/BUILD.bazel @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@io_bazel_rules_docker//container:container.bzl", - "container_image", - "container_layer", -) +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load") load( "@rules_pkg//pkg:mappings.bzl", "pkg_attributes", @@ -49,33 +45,33 @@ pkg_tar( srcs = request_simulation_binaries, ) -container_layer( - name = "request_simulation_binary_layer", - directory = "/", - tars = [ - ":request_simulation_tar", - ], -) - -container_image( - name = "request_simulation_docker_image", - architecture = select({ - "@platforms//cpu:arm64": "arm64", - "@platforms//cpu:x86_64": "amd64", - }), +oci_image( + name = "request_simulation_image", base = select({ - "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64//image", - "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64//image", + "@platforms//cpu:arm64": "@runtime-debian-debug-nonroot-arm64", + "@platforms//cpu:x86_64": "@runtime-debian-debug-nonroot-amd64", }), cmd = [ "/request_simulation/bin/start_request_simulation_system", ], entrypoint = ["/bin/bash"], - layers = [ - ":request_simulation_binary_layer", + tars = [ + ":request_simulation_tar", ], ) +oci_load( + name = "request_simulation_docker_image", + image = ":request_simulation_image", + repo_tags = ["bazel/production/packaging/tools/request_simulation:request_simulation_docker_image"], +) + +filegroup( + name = "request_simulation_docker_image.tar", + srcs = [":request_simulation_docker_image"], + output_group = "tarball", +) + pkg_zip( name = "request_simulation_artifacts", srcs = request_simulation_binaries, @@ -90,7 +86,8 @@ genrule( outs = ["copy_to_dist.bin"], cmd_bash = """cat << EOF > '$@' mkdir -p dist/request_simulation/aws -cp $(execpath :request_simulation_artifacts) $(execpath :request_simulation_docker_image.tar) dist/request_simulation/aws +cp $(execpath :request_simulation_artifacts) dist/request_simulation/aws +cp $(execpath :request_simulation_docker_image.tar) dist/request_simulation/aws/request_simulation_docker_image.tar builders/tools/normalize-dist EOF""", executable = True, diff --git a/production/packaging/tools/request_simulation/otel_collector/BUILD.bazel b/production/packaging/tools/request_simulation/otel_collector/BUILD.bazel index 51b64c23..2d8d73b1 100644 --- a/production/packaging/tools/request_simulation/otel_collector/BUILD.bazel +++ b/production/packaging/tools/request_simulation/otel_collector/BUILD.bazel @@ -18,28 +18,9 @@ package(default_visibility = [ "//production/packaging/tools/request_simulation:__subpackages__", ]) -otel_version = "0.23.0" - -otel_tag = "v" + otel_version - -target_file = otel_tag + ".tar.gz" - -genrule( - name = "aws_otel_collector", - outs = ["aws-otel-collector.rpm"], - cmd = """ - set -x - yum install -y wget - yum install -y rpm-build - wget https://github.com/aws-observability/aws-otel-collector/archive/refs/tags/%s - tar xvf %s - cd aws-otel-collector-%s - # Remove linting from build targets. - sed -i 's/build: install-tools lint multimod-verify/build:/g' Makefile - export HOME=/root - make package-rpm - cp build/packages/linux/amd64/aws-otel-collector.rpm ../$@ - """ % (target_file, target_file, otel_version), +filegroup( + name = "aws-otel-collector.rpm", + srcs = ["//production/packaging/aws/otel_collector:aws-otel-collector.rpm"], ) pkg_files( diff --git a/production/terraform/aws/environments/demo/us-east-1.tfvars.json b/production/terraform/aws/environments/demo/us-east-1.tfvars.json index 2f13cb21..246880c2 100644 --- a/production/terraform/aws/environments/demo/us-east-1.tfvars.json +++ b/production/terraform/aws/environments/demo/us-east-1.tfvars.json @@ -1,4 +1,5 @@ { + "add_chaff_sharding_clusters": true, "add_missing_keys_v1": true, "autoscaling_desired_capacity": 4, "autoscaling_max_size": 6, @@ -25,6 +26,7 @@ "http_api_paths": ["/v1/*", "/v2/*", "/healthcheck"], "instance_ami_id": "ami-0000000", "instance_type": "m5.xlarge", + "logging_verbosity_backup_poll_frequency_secs": 300, "logging_verbosity_level": 0, "metrics_collector_endpoint": "", "metrics_export_interval_millis": 5000, diff --git a/production/terraform/aws/environments/kv_server.tf b/production/terraform/aws/environments/kv_server.tf index cac8ce72..ab1ba879 100644 --- a/production/terraform/aws/environments/kv_server.tf +++ b/production/terraform/aws/environments/kv_server.tf @@ -88,9 +88,10 @@ module "kv_server" { data_loading_blob_prefix_allowlist = var.data_loading_blob_prefix_allowlist # Variables related to sharding. - num_shards = var.num_shards - use_sharding_key_regex = var.use_sharding_key_regex - sharding_key_regex = var.sharding_key_regex + num_shards = var.num_shards + use_sharding_key_regex = var.use_sharding_key_regex + sharding_key_regex = var.sharding_key_regex + add_chaff_sharding_clusters = var.add_chaff_sharding_clusters # Variables related to UDF execution. udf_num_workers = var.udf_num_workers @@ -109,10 +110,11 @@ module "kv_server" { public_key_endpoint = var.public_key_endpoint # Variables related to logging - logging_verbosity_level = var.logging_verbosity_level - enable_otel_logger = var.enable_otel_logger - consented_debug_token = var.consented_debug_token - enable_consented_log = var.enable_consented_log + logging_verbosity_level = var.logging_verbosity_level + logging_verbosity_backup_poll_frequency_secs = var.logging_verbosity_backup_poll_frequency_secs + enable_otel_logger = var.enable_otel_logger + consented_debug_token = var.consented_debug_token + enable_consented_log = var.enable_consented_log } output "kv_server_url" { diff --git a/production/terraform/aws/environments/kv_server_variables.tf b/production/terraform/aws/environments/kv_server_variables.tf index 23e7cf72..e5fb19dc 100644 --- a/production/terraform/aws/environments/kv_server_variables.tf +++ b/production/terraform/aws/environments/kv_server_variables.tf @@ -223,6 +223,12 @@ variable "add_missing_keys_v1" { type = bool } +variable "add_chaff_sharding_clusters" { + description = "Whether to add chaff when querying sharding clusters." + default = true + type = bool +} + variable "use_real_coordinators" { description = "Use real coordinators." type = bool @@ -250,6 +256,12 @@ variable "logging_verbosity_level" { type = number } +variable "logging_verbosity_backup_poll_frequency_secs" { + description = "Backup poll frequency in seconds for the logging verbosity parameter." + default = "300" + type = number +} + variable "run_server_outside_tee" { description = "Whether to run the server outside the TEE, in a docker container. Untrusted mode, for debugging." default = false diff --git a/production/terraform/aws/modules/kv_server/main.tf b/production/terraform/aws/modules/kv_server/main.tf index 984d5d1c..ea0cae24 100644 --- a/production/terraform/aws/modules/kv_server/main.tf +++ b/production/terraform/aws/modules/kv_server/main.tf @@ -44,18 +44,6 @@ module "data_storage" { bucket_notification_dependency = [module.sqs_cleanup.allow_sqs_cleanup_execution_as_dependency] } -module "sqs_cleanup" { - source = "../../services/sqs_cleanup" - environment = var.environment - service = local.service - sqs_cleanup_image_uri = var.sqs_cleanup_image_uri - lambda_role_arn = module.iam_roles.lambda_role_arn - sqs_cleanup_schedule = var.sqs_cleanup_schedule - sns_data_updates_topic_arn = module.data_storage.sns_data_updates_topic_arn - sqs_queue_timeout_secs = var.sqs_queue_timeout_secs - sns_realtime_topic_arn = module.data_storage.sns_realtime_topic_arn -} - module "networking" { source = "../../services/networking" service = local.service @@ -182,6 +170,13 @@ module "ssh" { instance_profile_name = module.iam_roles.ssh_instance_profile_name } + +module "parameter_notification" { + source = "../../services/parameter_notification" + service = local.service + environment = var.environment +} + module "parameter" { source = "../../services/parameter" service = local.service @@ -206,6 +201,7 @@ module "parameter" { udf_min_log_level_parameter_value = var.udf_min_log_level route_v1_requests_to_v2_parameter_value = var.route_v1_requests_to_v2 add_missing_keys_v1_parameter_value = var.add_missing_keys_v1 + add_chaff_sharding_clusters_parameter_value = var.add_chaff_sharding_clusters use_real_coordinators_parameter_value = var.use_real_coordinators primary_coordinator_account_identity_parameter_value = var.primary_coordinator_account_identity secondary_coordinator_account_identity_parameter_value = var.secondary_coordinator_account_identity @@ -218,12 +214,27 @@ module "parameter" { enable_consented_log_parameter_value = var.enable_consented_log - data_loading_file_format_parameter_value = var.data_loading_file_format - logging_verbosity_level_parameter_value = var.logging_verbosity_level - use_sharding_key_regex_parameter_value = var.use_sharding_key_regex - sharding_key_regex_parameter_value = var.sharding_key_regex - enable_otel_logger_parameter_value = var.enable_otel_logger - data_loading_blob_prefix_allowlist = var.data_loading_blob_prefix_allowlist + data_loading_file_format_parameter_value = var.data_loading_file_format + logging_verbosity_level_parameter_value = var.logging_verbosity_level + logging_verbosity_update_sns_arn_parameter_value = module.parameter_notification.logging_verbosity_updates_topic_arn + logging_verbosity_backup_poll_frequency_secs_parameter_value = var.logging_verbosity_backup_poll_frequency_secs + use_sharding_key_regex_parameter_value = var.use_sharding_key_regex + sharding_key_regex_parameter_value = var.sharding_key_regex + enable_otel_logger_parameter_value = var.enable_otel_logger + data_loading_blob_prefix_allowlist = var.data_loading_blob_prefix_allowlist +} + +module "sqs_cleanup" { + source = "../../services/sqs_cleanup" + environment = var.environment + service = local.service + sqs_cleanup_image_uri = var.sqs_cleanup_image_uri + lambda_role_arn = module.iam_roles.lambda_role_arn + sqs_cleanup_schedule = var.sqs_cleanup_schedule + sns_data_updates_topic_arn = module.data_storage.sns_data_updates_topic_arn + sqs_queue_timeout_secs = var.sqs_queue_timeout_secs + sns_realtime_topic_arn = module.data_storage.sns_realtime_topic_arn + sns_logging_verbosity_updates_topic_arn = module.parameter_notification.logging_verbosity_updates_topic_arn } module "security_group_rules" { @@ -244,15 +255,16 @@ module "security_group_rules" { } module "iam_role_policies" { - source = "../../services/iam_role_policies" - service = local.service - environment = var.environment - server_instance_role_name = module.iam_roles.instance_role_name - sqs_cleanup_lambda_role_name = module.iam_roles.lambda_role_name - s3_delta_file_bucket_arn = module.data_storage.s3_data_bucket_arn - sns_data_updates_topic_arn = module.data_storage.sns_data_updates_topic_arn - sns_realtime_topic_arn = module.data_storage.sns_realtime_topic_arn - ssh_instance_role_name = module.iam_roles.ssh_instance_role_name + source = "../../services/iam_role_policies" + service = local.service + environment = var.environment + server_instance_role_name = module.iam_roles.instance_role_name + sqs_cleanup_lambda_role_name = module.iam_roles.lambda_role_name + s3_delta_file_bucket_arn = module.data_storage.s3_data_bucket_arn + sns_data_updates_topic_arn = module.data_storage.sns_data_updates_topic_arn + sns_realtime_topic_arn = module.data_storage.sns_realtime_topic_arn + logging_verbosity_updates_topic_arn = module.parameter_notification.logging_verbosity_updates_topic_arn + ssh_instance_role_name = module.iam_roles.ssh_instance_role_name server_parameter_arns = [ module.parameter.s3_bucket_parameter_arn, module.parameter.bucket_update_sns_arn_parameter_arn, @@ -271,8 +283,11 @@ module "iam_role_policies" { module.parameter.udf_num_workers_parameter_arn, module.parameter.route_v1_requests_to_v2_parameter_arn, module.parameter.add_missing_keys_v1_parameter_arn, + module.parameter.add_chaff_sharding_clusters_parameter_arn, module.parameter.data_loading_file_format_parameter_arn, module.parameter.logging_verbosity_level_parameter_arn, + module.parameter.logging_verbosity_update_sns_arn_parameter_arn, + module.parameter.logging_verbosity_backup_poll_frequency_secs_parameter_arn, module.parameter.use_real_coordinators_parameter_arn, module.parameter.use_sharding_key_regex_parameter_arn, module.parameter.udf_timeout_millis_parameter_arn, diff --git a/production/terraform/aws/modules/kv_server/variables.tf b/production/terraform/aws/modules/kv_server/variables.tf index cdaf5293..285c4152 100644 --- a/production/terraform/aws/modules/kv_server/variables.tf +++ b/production/terraform/aws/modules/kv_server/variables.tf @@ -224,6 +224,11 @@ variable "add_missing_keys_v1" { type = bool } +variable "add_chaff_sharding_clusters" { + description = "Whether to add chaff when querying sharding clusters." + type = bool +} + variable "use_real_coordinators" { description = "Will use real coordinators. `enclave_enable_debug_mode` should be set to `false` if the attestation check is enabled for coordinators. Attestation check is enabled on all production instances, and might be disabled for testing purposes only on staging/dev environments." type = bool @@ -249,6 +254,11 @@ variable "logging_verbosity_level" { type = number } +variable "logging_verbosity_backup_poll_frequency_secs" { + description = "Backup poll frequency in seconds for the logging verbosity parameter." + type = number +} + variable "run_server_outside_tee" { description = "Whether to run the server outside the TEE, in a docker container. Untrusted mode, for debugging." type = bool diff --git a/production/terraform/aws/services/iam_role_policies/main.tf b/production/terraform/aws/services/iam_role_policies/main.tf index ce1bc9be..7b74232e 100644 --- a/production/terraform/aws/services/iam_role_policies/main.tf +++ b/production/terraform/aws/services/iam_role_policies/main.tf @@ -91,6 +91,11 @@ data "aws_iam_policy_document" "instance_policy_doc" { actions = ["sns:Subscribe"] resources = [var.sns_realtime_topic_arn] } + statement { + sid = "AllowInstancesToSubscribeToLoggingVerbosityParameterUpdates" + actions = ["sns:Subscribe"] + resources = [var.logging_verbosity_updates_topic_arn] + } statement { sid = "AllowXRay" actions = [ @@ -148,6 +153,11 @@ resource "aws_iam_role_policy_attachment" "instance_role_policy_attachment" { role = var.server_instance_role_name } +resource "aws_iam_role_policy_attachment" "ssm_instance_role_attachment" { + role = var.server_instance_role_name + policy_arn = "arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore" +} + # Set up access policies for the SQS cleanup lambda function. data "aws_iam_policy_document" "sqs_cleanup_lambda_policy_doc" { statement { diff --git a/production/terraform/aws/services/iam_role_policies/variables.tf b/production/terraform/aws/services/iam_role_policies/variables.tf index 191701b3..8fd8bd2e 100644 --- a/production/terraform/aws/services/iam_role_policies/variables.tf +++ b/production/terraform/aws/services/iam_role_policies/variables.tf @@ -75,6 +75,11 @@ variable "sns_realtime_topic_arn" { type = string } +variable "logging_verbosity_updates_topic_arn" { + description = "ARN for the sns topic that receives logging verbosity parameter updates." + type = string +} + variable "ssh_instance_role_name" { description = "Role for SSH instance (bastion)." } diff --git a/production/terraform/aws/services/mesh_service/main.tf b/production/terraform/aws/services/mesh_service/main.tf index 890718cc..00fe5c90 100644 --- a/production/terraform/aws/services/mesh_service/main.tf +++ b/production/terraform/aws/services/mesh_service/main.tf @@ -38,6 +38,9 @@ resource "aws_service_discovery_service" "cloud_map_service" { health_check_custom_config { failure_threshold = 1 } + + # Ensure all cloud map entries are deleted. + force_destroy = true } resource "aws_appmesh_virtual_node" "appmesh_virtual_node" { diff --git a/production/terraform/aws/services/parameter/main.tf b/production/terraform/aws/services/parameter/main.tf index eecf0fa8..2ab2fef3 100644 --- a/production/terraform/aws/services/parameter/main.tf +++ b/production/terraform/aws/services/parameter/main.tf @@ -142,6 +142,13 @@ resource "aws_ssm_parameter" "add_missing_keys_v1_parameter" { overwrite = true } +resource "aws_ssm_parameter" "add_chaff_sharding_clusters_parameter" { + name = "${var.service}-${var.environment}-add-chaff-sharding-clusters" + type = "String" + value = var.add_chaff_sharding_clusters_parameter_value + overwrite = true +} + resource "aws_ssm_parameter" "use_real_coordinators_parameter" { name = "${var.service}-${var.environment}-use-real-coordinators" type = "String" @@ -219,6 +226,20 @@ resource "aws_ssm_parameter" "logging_verbosity_level_parameter" { overwrite = true } +resource "aws_ssm_parameter" "logging_verbosity_update_sns_arn_parameter" { + name = "${var.service}-${var.environment}-logging-verbosity-update-sns-arn" + type = "String" + value = var.logging_verbosity_update_sns_arn_parameter_value + overwrite = true +} + +resource "aws_ssm_parameter" "logging_verbosity_backup_poll_frequency_secs_parameter" { + name = "${var.service}-${var.environment}-logging-verbosity-backup-poll-frequency-secs" + type = "String" + value = var.logging_verbosity_backup_poll_frequency_secs_parameter_value + overwrite = true +} + resource "aws_ssm_parameter" "use_sharding_key_regex_parameter" { name = "${var.service}-${var.environment}-use-sharding-key-regex" type = "String" diff --git a/production/terraform/aws/services/parameter/outputs.tf b/production/terraform/aws/services/parameter/outputs.tf index 9b6fd2ca..ae334282 100644 --- a/production/terraform/aws/services/parameter/outputs.tf +++ b/production/terraform/aws/services/parameter/outputs.tf @@ -90,6 +90,10 @@ output "add_missing_keys_v1_parameter_arn" { value = aws_ssm_parameter.add_missing_keys_v1_parameter.arn } +output "add_chaff_sharding_clusters_parameter_arn" { + value = aws_ssm_parameter.add_chaff_sharding_clusters_parameter.arn +} + output "use_real_coordinators_parameter_arn" { value = aws_ssm_parameter.use_real_coordinators_parameter.arn } @@ -130,6 +134,14 @@ output "logging_verbosity_level_parameter_arn" { value = aws_ssm_parameter.logging_verbosity_level_parameter.arn } +output "logging_verbosity_update_sns_arn_parameter_arn" { + value = aws_ssm_parameter.logging_verbosity_update_sns_arn_parameter.arn +} + +output "logging_verbosity_backup_poll_frequency_secs_parameter_arn" { + value = aws_ssm_parameter.logging_verbosity_backup_poll_frequency_secs_parameter.arn +} + output "use_sharding_key_regex_parameter_arn" { value = aws_ssm_parameter.use_sharding_key_regex_parameter.arn } diff --git a/production/terraform/aws/services/parameter/variables.tf b/production/terraform/aws/services/parameter/variables.tf index 3ed6e121..694dafbb 100644 --- a/production/terraform/aws/services/parameter/variables.tf +++ b/production/terraform/aws/services/parameter/variables.tf @@ -104,6 +104,11 @@ variable "add_missing_keys_v1_parameter_value" { type = bool } +variable "add_chaff_sharding_clusters_parameter_value" { + description = "Add chaff when querying sharding clusters." + type = bool +} + variable "use_real_coordinators_parameter_value" { description = "Number of parallel threads for reading and loading data files." type = bool @@ -134,6 +139,16 @@ variable "logging_verbosity_level_parameter_value" { type = number } +variable "logging_verbosity_update_sns_arn_parameter_value" { + description = "Value for the logging verbosity update SNS ARN parameter." + type = string +} + +variable "logging_verbosity_backup_poll_frequency_secs_parameter_value" { + description = "Backup poll frequency in seconds for the logging verbosity parameter" + type = number +} + variable "use_sharding_key_regex_parameter_value" { description = "Use sharding key regex. This is useful if you want to use data locality feature for sharding." type = bool diff --git a/production/terraform/aws/services/parameter_notification/main.tf b/production/terraform/aws/services/parameter_notification/main.tf new file mode 100644 index 00000000..87ec8eed --- /dev/null +++ b/production/terraform/aws/services/parameter_notification/main.tf @@ -0,0 +1,74 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SNS topic for listening to logging verbosity update +resource "aws_sns_topic" "logging_verbosity_level_sns_topic" { + name = "${var.service}-${var.environment}-logging-verbosity-update-sns-topic" + + tags = { + Name = "${var.service}-${var.environment}-logging-verbosity-update-sns-topic" + } +} + +resource "aws_cloudwatch_event_rule" "parameter_update_event_rule" { + name = "${var.service}-${var.environment}-parameter-update-event-rule" + description = "Event rule to trigger events of parameter update notification" + event_pattern = jsonencode({ + "source" : ["aws.ssm"], + "detail-type" : ["Parameter Store Change"], + "detail" : { + "name" : [ + "${var.service}-${var.environment}-logging-verbosity-level" + ], + "operation" : [ + "Update" + ] + } + }) +} + +# Allow Event bridge to publish events to SNS topic. +data "aws_iam_policy_document" "sns_topic_policy_doc" { + statement { + principals { + identifiers = [ + "events.amazonaws.com" + ] + type = "Service" + } + + actions = [ + "SNS:Publish" + ] + + resources = [ + aws_sns_topic.logging_verbosity_level_sns_topic.arn + ] + } +} + +resource "aws_sns_topic_policy" "sns_topic_policy" { + arn = aws_sns_topic.logging_verbosity_level_sns_topic.arn + policy = data.aws_iam_policy_document.sns_topic_policy_doc.json +} + +resource "aws_cloudwatch_event_target" "logging_parameter_update_target" { + target_id = "${var.service}-${var.environment}-logging-verbosity-update-target" + rule = aws_cloudwatch_event_rule.parameter_update_event_rule.name + arn = aws_sns_topic.logging_verbosity_level_sns_topic.arn + depends_on = [ + aws_sns_topic.logging_verbosity_level_sns_topic, + aws_cloudwatch_event_rule.parameter_update_event_rule + ] +} diff --git a/production/packaging/aws/data_server/nitro-pcr0/BUILD.bazel b/production/terraform/aws/services/parameter_notification/outputs.tf similarity index 70% rename from production/packaging/aws/data_server/nitro-pcr0/BUILD.bazel rename to production/terraform/aws/services/parameter_notification/outputs.tf index a010b10a..7a3a23b1 100644 --- a/production/packaging/aws/data_server/nitro-pcr0/BUILD.bazel +++ b/production/terraform/aws/services/parameter_notification/outputs.tf @@ -1,10 +1,10 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//production/packaging:__subpackages__"]) - -exports_files([ - "amd64.json", - "arm64.json", -]) +output "logging_verbosity_updates_topic_arn" { + value = aws_sns_topic.logging_verbosity_level_sns_topic.arn +} diff --git a/production/terraform/aws/services/parameter_notification/variables.tf b/production/terraform/aws/services/parameter_notification/variables.tf new file mode 100644 index 00000000..23096811 --- /dev/null +++ b/production/terraform/aws/services/parameter_notification/variables.tf @@ -0,0 +1,23 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +variable "service" { + description = "Assigned name of the KV server." + type = string +} + +variable "environment" { + description = "Assigned environment name to group related resources." + type = string +} diff --git a/production/terraform/aws/services/sqs_cleanup/main.tf b/production/terraform/aws/services/sqs_cleanup/main.tf index 483c710c..97d71410 100644 --- a/production/terraform/aws/services/sqs_cleanup/main.tf +++ b/production/terraform/aws/services/sqs_cleanup/main.tf @@ -51,7 +51,9 @@ resource "aws_cloudwatch_event_target" "sqs_cleanup_target" { "queue_prefix": "BlobNotifier_", "timeout_secs": "${var.sqs_queue_timeout_secs}", "realtime_sns_topic": "${var.sns_realtime_topic_arn}", - "realtime_queue_prefix": "QueueNotifier_" + "realtime_queue_prefix": "QueueNotifier_", + "logging_verbosity_updates_sns_topic": "${var.sns_logging_verbosity_updates_topic_arn}", + "parameter_queue_prefix": "ParameterNotifier_" } JSON } diff --git a/production/terraform/aws/services/sqs_cleanup/variables.tf b/production/terraform/aws/services/sqs_cleanup/variables.tf index a504d052..1e9b416b 100644 --- a/production/terraform/aws/services/sqs_cleanup/variables.tf +++ b/production/terraform/aws/services/sqs_cleanup/variables.tf @@ -54,3 +54,8 @@ variable "sns_realtime_topic_arn" { description = "SNS topic where realtime updates are pushed to." type = string } + +variable "sns_logging_verbosity_updates_topic_arn" { + description = "SNS topic where logging verbosity updates are pushed to." + type = string +} diff --git a/production/terraform/gcp/environments/demo/us-east1.tfvars.json b/production/terraform/gcp/environments/demo/us-east1.tfvars.json index 212f64ff..24dd7b35 100644 --- a/production/terraform/gcp/environments/demo/us-east1.tfvars.json +++ b/production/terraform/gcp/environments/demo/us-east1.tfvars.json @@ -1,4 +1,5 @@ { + "add_chaff_sharding_clusters": true, "add_missing_keys_v1": true, "backup_poll_frequency_secs": 5, "collector_dns_zone": "your-dns-zone-name", @@ -6,6 +7,7 @@ "collector_machine_type": "e2-micro", "collector_service_name": "otel-collector", "collector_service_port": 4317, + "collector_startup_script_path": "../../services/metrics_collector_autoscaling/collector_startup.sh", "consented_debug_token": "EMPTY_STRING", "cpu_utilization_percent": 0.9, "data_bucket_id": "your-delta-file-bucket", @@ -21,6 +23,7 @@ "gcp_image_tag": "demo", "instance_template_waits_for_instances": true, "kv_service_port": 50051, + "logging_verbosity_backup_poll_frequency_secs": 300, "logging_verbosity_level": 0, "machine_type": "n2d-standard-4", "max_replicas_per_service_region": 5, diff --git a/production/terraform/gcp/environments/kv_server.tf b/production/terraform/gcp/environments/kv_server.tf index 4fb509c2..0f28d307 100644 --- a/production/terraform/gcp/environments/kv_server.tf +++ b/production/terraform/gcp/environments/kv_server.tf @@ -70,46 +70,48 @@ module "kv_server" { enable_external_traffic = var.enable_external_traffic parameters = { - data-bucket-id = var.data_bucket_id - launch-hook = "${local.kv_service}-${var.environment}-launch-hook" - use-external-metrics-collector-endpoint = var.use_external_metrics_collector_endpoint - metrics-collector-endpoint = "${var.environment}-${var.collector_service_name}.${var.collector_domain_name}:${var.collector_service_port}" - metrics-export-interval-millis = var.metrics_export_interval_millis - metrics-export-timeout-millis = var.metrics_export_timeout_millis - backup-poll-frequency-secs = var.backup_poll_frequency_secs - realtime-updater-num-threads = var.realtime_updater_num_threads - data-loading-num-threads = var.data_loading_num_threads - num-shards = var.num_shards - udf-num-workers = var.udf_num_workers - udf-timeout-millis = var.udf_timeout_millis - udf-update-timeout-millis = var.udf_update_timeout_millis - udf-min-log-level = var.udf_min_log_level - route-v1-to-v2 = var.route_v1_to_v2 - add-missing-keys-v1 = var.add_missing_keys_v1 - use-real-coordinators = var.use_real_coordinators - environment = var.environment - project-id = var.project_id - primary-key-service-cloud-function-url = var.primary_key_service_cloud_function_url - primary-workload-identity-pool-provider = var.primary_workload_identity_pool_provider - secondary-key-service-cloud-function-url = var.secondary_key_service_cloud_function_url - secondary-workload-identity-pool-provider = var.secondary_workload_identity_pool_provider - primary-coordinator-account-identity = var.primary_coordinator_account_identity - secondary-coordinator-account-identity = var.secondary_coordinator_account_identity - primary-coordinator-private-key-endpoint = var.primary_coordinator_private_key_endpoint - primary-coordinator-region = var.primary_coordinator_region - secondary-coordinator-private-key-endpoint = var.secondary_coordinator_private_key_endpoint - secondary-coordinator-region = var.secondary_coordinator_region - public-key-endpoint = var.public_key_endpoint - logging-verbosity-level = var.logging_verbosity_level - use-sharding-key-regex = var.use_sharding_key_regex - sharding-key-regex = var.sharding_key_regex - tls-key = var.tls_key - tls-cert = var.tls_cert - enable-otel-logger = var.enable_otel_logger - enable-external-traffic = var.enable_external_traffic - telemetry-config = var.telemetry_config - data-loading-blob-prefix-allowlist = var.data_loading_blob_prefix_allowlist - consented-debug-token = var.consented_debug_token - enable-consented-log = var.enable_consented_log + data-bucket-id = var.data_bucket_id + launch-hook = "${local.kv_service}-${var.environment}-launch-hook" + use-external-metrics-collector-endpoint = var.use_external_metrics_collector_endpoint + metrics-collector-endpoint = "${var.environment}-${var.collector_service_name}.${var.collector_domain_name}:${var.collector_service_port}" + metrics-export-interval-millis = var.metrics_export_interval_millis + metrics-export-timeout-millis = var.metrics_export_timeout_millis + backup-poll-frequency-secs = var.backup_poll_frequency_secs + realtime-updater-num-threads = var.realtime_updater_num_threads + data-loading-num-threads = var.data_loading_num_threads + num-shards = var.num_shards + udf-num-workers = var.udf_num_workers + udf-timeout-millis = var.udf_timeout_millis + udf-update-timeout-millis = var.udf_update_timeout_millis + udf-min-log-level = var.udf_min_log_level + route-v1-to-v2 = var.route_v1_to_v2 + add-missing-keys-v1 = var.add_missing_keys_v1 + add-chaff-sharding-clusters = var.add_chaff_sharding_clusters + use-real-coordinators = var.use_real_coordinators + environment = var.environment + project-id = var.project_id + primary-key-service-cloud-function-url = var.primary_key_service_cloud_function_url + primary-workload-identity-pool-provider = var.primary_workload_identity_pool_provider + secondary-key-service-cloud-function-url = var.secondary_key_service_cloud_function_url + secondary-workload-identity-pool-provider = var.secondary_workload_identity_pool_provider + primary-coordinator-account-identity = var.primary_coordinator_account_identity + secondary-coordinator-account-identity = var.secondary_coordinator_account_identity + primary-coordinator-private-key-endpoint = var.primary_coordinator_private_key_endpoint + primary-coordinator-region = var.primary_coordinator_region + secondary-coordinator-private-key-endpoint = var.secondary_coordinator_private_key_endpoint + secondary-coordinator-region = var.secondary_coordinator_region + public-key-endpoint = var.public_key_endpoint + logging-verbosity-backup-poll-frequency-secs = var.logging_verbosity_backup_poll_frequency_secs + logging-verbosity-level = var.logging_verbosity_level + use-sharding-key-regex = var.use_sharding_key_regex + sharding-key-regex = var.sharding_key_regex + tls-key = var.tls_key + tls-cert = var.tls_cert + enable-otel-logger = var.enable_otel_logger + enable-external-traffic = var.enable_external_traffic + telemetry-config = var.telemetry_config + data-loading-blob-prefix-allowlist = var.data_loading_blob_prefix_allowlist + consented-debug-token = var.consented_debug_token + enable-consented-log = var.enable_consented_log } } diff --git a/production/terraform/gcp/environments/kv_server_variables.tf b/production/terraform/gcp/environments/kv_server_variables.tf index 94247d2a..b71c1ea9 100644 --- a/production/terraform/gcp/environments/kv_server_variables.tf +++ b/production/terraform/gcp/environments/kv_server_variables.tf @@ -196,6 +196,12 @@ variable "add_missing_keys_v1" { description = "Add missing keys for v1." } +variable "add_chaff_sharding_clusters" { + type = bool + default = true + description = "Add chaff sharding clusters." +} + variable "use_real_coordinators" { type = bool description = "Use real coordinators." @@ -298,6 +304,13 @@ variable "logging_verbosity_level" { type = string } +variable "logging_verbosity_backup_poll_frequency_secs" { + description = "Backup poll frequency in seconds for the logging verbosity parameter." + default = 60 + type = number +} + + variable "use_sharding_key_regex" { description = "Use sharding key regex. This is useful if you want to use data locality feature for sharding." default = false diff --git a/production/terraform/gcp/services/data_storage/main.tf b/production/terraform/gcp/services/data_storage/main.tf index bc5c3093..c8b6b2c6 100644 --- a/production/terraform/gcp/services/data_storage/main.tf +++ b/production/terraform/gcp/services/data_storage/main.tf @@ -19,5 +19,9 @@ resource "google_storage_bucket" "default" { location = "US" storage_class = "STANDARD" + versioning { + enabled = true + } + uniform_bucket_level_access = true } diff --git a/public/BUILD b/public/BUILD index 95b5b8e7..7f0d2c2c 100644 --- a/public/BUILD +++ b/public/BUILD @@ -8,7 +8,6 @@ package( proto_library( name = "api_schema_proto", srcs = ["api_schema.proto"], - cc_api_version = 2, deps = [ "//google/protobuf:struct", ], @@ -22,7 +21,6 @@ cc_proto_library( proto_library( name = "base_types_proto", srcs = ["base_types.proto"], - cc_api_version = 2, ) cc_proto_library( diff --git a/public/api_schema.proto b/public/api_schema.proto index 6de552e3..37b7de83 100644 --- a/public/api_schema.proto +++ b/public/api_schema.proto @@ -28,6 +28,8 @@ message UDFExecutionMetadata { int32 udf_interface_version = 1; // Metadata passed in from the request that are related to the request. google.protobuf.Struct request_metadata = 2; + // Metadata passed in from the request that are related to the partition. + google.protobuf.Struct partition_metadata = 3; } // Represents one argument to UDF. One UDF invocation may have multiple diff --git a/public/applications/pa/BUILD.bazel b/public/applications/pa/BUILD.bazel index ff1bdd0b..bc0f5a1a 100644 --- a/public/applications/pa/BUILD.bazel +++ b/public/applications/pa/BUILD.bazel @@ -36,6 +36,7 @@ cc_library( hdrs = ["response_utils.h"], deps = [ "api_overlay_cc_proto", + "//components/errors:error_tag", "@com_google_absl//absl/status:statusor", "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", ], diff --git a/public/applications/pa/api_overlay.proto b/public/applications/pa/api_overlay.proto index 3ba00a4b..1eb5b9e3 100644 --- a/public/applications/pa/api_overlay.proto +++ b/public/applications/pa/api_overlay.proto @@ -20,6 +20,16 @@ import "google/protobuf/struct.proto"; // https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md#query-api-version-2 +message V2CompressionGroup { + repeated PartitionOutput partition_outputs = 1; +} + +message PartitionOutput { + optional int64 id = 1; + repeated KeyGroupOutput key_group_outputs = 2; + optional int32 udf_output_api_version = 3; +} + message KeyGroupOutput { repeated string tags = 1; map key_values = 2; @@ -27,11 +37,4 @@ message KeyGroupOutput { message ValueObject { google.protobuf.Value value = 1; - int64 global_ttl_sec = 2; - int64 dedicated_ttl_sec = 3; -} - -message KeyGroupOutputs { - repeated KeyGroupOutput key_group_outputs = 1; - int32 udf_output_api_version = 2; } diff --git a/public/applications/pa/response_utils.cc b/public/applications/pa/response_utils.cc index bbc44965..c2ba3d76 100644 --- a/public/applications/pa/response_utils.cc +++ b/public/applications/pa/response_utils.cc @@ -14,25 +14,40 @@ #include "public/applications/pa/response_utils.h" +#include "components/errors/error_tag.h" #include "google/protobuf/util/json_util.h" #include "src/util/status_macro/status_macros.h" namespace kv_server::application_pa { +enum class ErrorTag : int { + kJsonStringToMessageError = 1, + kMessageToJsonStringError = 2 +}; + using google::protobuf::util::JsonStringToMessage; using google::protobuf::util::MessageToJsonString; -absl::StatusOr KeyGroupOutputsFromJson( +absl::StatusOr PartitionOutputFromJson( std::string_view json_str) { - KeyGroupOutputs outputs_proto; - PS_RETURN_IF_ERROR(JsonStringToMessage(json_str, &outputs_proto)); - return outputs_proto; + PartitionOutput partition_output_proto; + if (const auto status = + JsonStringToMessage(json_str, &partition_output_proto); + !status.ok()) { + return StatusWithErrorTag(status, __FILE__, + ErrorTag::kJsonStringToMessageError); + } + return partition_output_proto; } -absl::StatusOr KeyGroupOutputsToJson( - const KeyGroupOutputs& key_group_outputs) { +absl::StatusOr PartitionOutputToJson( + const PartitionOutput& partition_output) { std::string json_str; - PS_RETURN_IF_ERROR(MessageToJsonString(key_group_outputs, &json_str)); + if (const auto status = MessageToJsonString(partition_output, &json_str); + !status.ok()) { + return StatusWithErrorTag(status, __FILE__, + ErrorTag::kMessageToJsonStringError); + } return json_str; } diff --git a/public/applications/pa/response_utils.h b/public/applications/pa/response_utils.h index c6a39153..60af00a4 100644 --- a/public/applications/pa/response_utils.h +++ b/public/applications/pa/response_utils.h @@ -24,11 +24,11 @@ namespace kv_server::application_pa { -absl::StatusOr KeyGroupOutputsFromJson( +absl::StatusOr PartitionOutputFromJson( std::string_view json_str); -absl::StatusOr KeyGroupOutputsToJson( - const KeyGroupOutputs& key_group_outputs); +absl::StatusOr PartitionOutputToJson( + const PartitionOutput& key_group_outputs); } // namespace kv_server::application_pa diff --git a/public/applications/pa/response_utils_test.cc b/public/applications/pa/response_utils_test.cc index 5c2b42eb..3e5da934 100644 --- a/public/applications/pa/response_utils_test.cc +++ b/public/applications/pa/response_utils_test.cc @@ -25,8 +25,8 @@ namespace { using google::protobuf::TextFormat; -TEST(ResponseUtils, KeyGroupOutputsFromAndToJson) { - KeyGroupOutputs proto; +TEST(ResponseUtils, PartitionOutputFromAndToJson) { + PartitionOutput proto; TextFormat::ParseFromString( R"( key_group_outputs { @@ -61,16 +61,26 @@ TEST(ResponseUtils, KeyGroupOutputsFromAndToJson) { } )", &proto); - auto maybe_json = KeyGroupOutputsToJson(proto); + auto maybe_json = PartitionOutputToJson(proto); ASSERT_TRUE(maybe_json.ok()); std::string expected_json = "{\"keyGroupOutputs\":[{\"tags\":[\"tag1\",\"tag2\"],\"keyValues\":{" "\"key1\":{\"value\":\"str_val\"},\"key2\":{\"value\":[\"item1\"," "\"item2\",\"item3\"]}}}]}"; EXPECT_EQ(*maybe_json, expected_json); - auto maybe_proto = KeyGroupOutputsFromJson(expected_json); + auto maybe_proto = PartitionOutputFromJson(expected_json); ASSERT_TRUE(maybe_proto.ok()); EXPECT_THAT(*maybe_proto, EqualsProto(proto)); } + +TEST(ResponseUtils, PartitionOutputFromJson_InvalidJsonError) { + std::string expected_json = + "{\"keyGroupOutputs\":{\"tags\":[\"tag1\",\"tag2\"],\"keyValues\":{" + "\"key1\":{\"value\":\"str_val\"},\"key2\":{\"value\":[\"item1\"," + "\"item2\",\"item3\"]}}}]}"; + const auto maybe_proto = PartitionOutputFromJson(expected_json); + ASSERT_FALSE(maybe_proto.ok()); +} + } // namespace } // namespace kv_server::application_pa diff --git a/public/constants.h b/public/constants.h index ea2a346a..b90630e2 100644 --- a/public/constants.h +++ b/public/constants.h @@ -44,6 +44,9 @@ constexpr int kFileGroupFileIndexDigits = 5; // Number of digits used to represent the size of a file group. constexpr int kFileGroupSizeDigits = 6; +// Minimum size of the returned response in bytes. +constexpr int kMinResponsePaddingBytes = 0; + // "DELTA_\d{16}" // The first component represents the file type. // @@ -105,6 +108,13 @@ const uint16_t kKDFParameter = 0x0001; // AEAD: AES-256-GCM const uint16_t kAEADParameter = 0x0002; +// Custom media types for KV. Used as input for ohttp request/response +// encryption/decryption. +inline constexpr absl::string_view kKVOhttpRequestLabel = + "message/ad-auction-trusted-signals-request"; +inline constexpr absl::string_view kKVOhttpResponseLabel = + "message/ad-auction-trusted-signals-response"; + constexpr std::string_view kServiceName = "kv-server"; // Returns a compiled logical sharding config file name regex defined as diff --git a/public/data_loading/csv/BUILD.bazel b/public/data_loading/csv/BUILD.bazel index 396907f0..6e4a67f1 100644 --- a/public/data_loading/csv/BUILD.bazel +++ b/public/data_loading/csv/BUILD.bazel @@ -21,7 +21,9 @@ package(default_visibility = [ cc_library( name = "constants", hdrs = ["constants.h"], - deps = [], + deps = [ + "@com_google_riegeli//riegeli/csv:csv_record", + ], ) cc_library( diff --git a/public/data_loading/csv/constants.h b/public/data_loading/csv/constants.h index 49af6ba0..1f54f78b 100644 --- a/public/data_loading/csv/constants.h +++ b/public/data_loading/csv/constants.h @@ -17,9 +17,10 @@ #ifndef TOOLS_DATA_CLI_CSV_CONSTANTS_H_ #define TOOLS_DATA_CLI_CSV_CONSTANTS_H_ -#include #include +#include "riegeli/csv/csv_record.h" + namespace kv_server { inline constexpr std::string_view kUpdateMutationType = "update"; @@ -34,6 +35,7 @@ inline constexpr std::string_view kValueTypeColumn = "value_type"; inline constexpr std::string_view kValueTypeString = "string"; inline constexpr std::string_view kValueTypeStringSet = "string_set"; inline constexpr std::string_view kValueTypeUInt32Set = "uint32_set"; +inline constexpr std::string_view kValueTypeUInt64Set = "uint64_set"; inline constexpr std::string_view kRecordTypeColumn = "record_type"; inline constexpr std::string_view kRecordTypeKVMutation = "key_value_mutation"; @@ -49,16 +51,15 @@ inline constexpr std::string_view kVersionColumn = "version"; inline constexpr std::string_view kLogicalShardColumn = "logical_shard"; inline constexpr std::string_view kPhysicalShardColumn = "physical_shard"; -inline constexpr std::array kKeyValueMutationRecordHeader = - {kKeyColumn, kLogicalCommitTimeColumn, kMutationTypeColumn, kValueColumn, - kValueTypeColumn}; +inline constexpr riegeli::CsvHeaderConstant kKeyValueMutationRecordHeader = { + kKeyColumn, kLogicalCommitTimeColumn, kMutationTypeColumn, kValueColumn, + kValueTypeColumn}; -inline constexpr std::array - kUserDefinedFunctionsConfigHeader = {kCodeSnippetColumn, kHandlerNameColumn, - kLogicalCommitTimeColumn, - kLanguageColumn, kVersionColumn}; +inline constexpr riegeli::CsvHeaderConstant kUserDefinedFunctionsConfigHeader = + {kCodeSnippetColumn, kHandlerNameColumn, kLogicalCommitTimeColumn, + kLanguageColumn, kVersionColumn}; -inline constexpr std::array kShardMappingRecordHeader = { +inline constexpr riegeli::CsvHeaderConstant kShardMappingRecordHeader = { kLogicalShardColumn, kPhysicalShardColumn}; } // namespace kv_server diff --git a/public/data_loading/csv/csv_delta_record_stream_reader.cc b/public/data_loading/csv/csv_delta_record_stream_reader.cc index a7b94f9b..e0a6c849 100644 --- a/public/data_loading/csv/csv_delta_record_stream_reader.cc +++ b/public/data_loading/csv/csv_delta_record_stream_reader.cc @@ -24,6 +24,7 @@ #include "absl/strings/match.h" #include "absl/strings/str_split.h" #include "public/data_loading/record_utils.h" +#include "src/util/status_macro/status_macros.h" namespace kv_server { namespace { @@ -60,12 +61,13 @@ absl::StatusOr> BuildSetValue( if constexpr (std::is_same_v) { result.push_back(std::string(set_value)); } - if constexpr (std::is_same_v) { - if (uint32_t number; absl::SimpleAtoi(set_value, &number)) { + if constexpr (std::is_same_v || + std::is_same_v) { + if (ElementType number; absl::SimpleAtoi(set_value, &number)) { result.push_back(number); } else { - return absl::InvalidArgumentError(absl::StrCat( - "Cannot convert: ", set_value, " to a uint32 number.")); + return absl::InvalidArgumentError( + absl::StrCat("Cannot convert: ", set_value, " to a number.")); } } } @@ -140,13 +142,20 @@ absl::Status SetRecordValue(char value_separator, return absl::OkStatus(); } if (absl::EqualsIgnoreCase(type, kValueTypeUInt32Set)) { - auto maybe_value = - GetSetValue(csv_record, value_separator, csv_encoding); - if (!maybe_value.ok()) { - return maybe_value.status(); - } + PS_ASSIGN_OR_RETURN( + auto value, + GetSetValue(csv_record, value_separator, csv_encoding)); UInt32SetT set_value; - set_value.value = std::move(*maybe_value); + set_value.value = std::move(value); + mutation_record.value.Set(std::move(set_value)); + return absl::OkStatus(); + } + if (absl::EqualsIgnoreCase(type, kValueTypeUInt64Set)) { + PS_ASSIGN_OR_RETURN( + auto value, + GetSetValue(csv_record, value_separator, csv_encoding)); + UInt64SetT set_value; + set_value.value = std::move(value); mutation_record.value.Set(std::move(set_value)); return absl::OkStatus(); } diff --git a/public/data_loading/csv/csv_delta_record_stream_reader.h b/public/data_loading/csv/csv_delta_record_stream_reader.h index 31ca6d1b..5c179a0f 100644 --- a/public/data_loading/csv/csv_delta_record_stream_reader.h +++ b/public/data_loading/csv/csv_delta_record_stream_reader.h @@ -19,7 +19,6 @@ #include #include -#include #include "absl/log/log.h" #include "public/data_loading/csv/constants.h" @@ -116,26 +115,21 @@ riegeli::CsvReaderBase::Options GetRecordReaderOptions( riegeli::CsvReaderBase::Options reader_options; reader_options.set_field_separator(options.field_separator); - std::vector header; + riegeli::CsvHeader header; switch (options.record_type) { case Record::KeyValueMutationRecord: - header = - std::vector(kKeyValueMutationRecordHeader.begin(), - kKeyValueMutationRecordHeader.end()); + header = *kKeyValueMutationRecordHeader; break; case Record::UserDefinedFunctionsConfig: - header = std::vector( - kUserDefinedFunctionsConfigHeader.begin(), - kUserDefinedFunctionsConfigHeader.end()); + header = *kUserDefinedFunctionsConfigHeader; break; case Record::ShardMappingRecord: - header = std::vector(kShardMappingRecordHeader.begin(), - kShardMappingRecordHeader.end()); + header = *kShardMappingRecordHeader; break; default: LOG(ERROR) << "Unable to set CSV reader header"; } - reader_options.set_required_header(riegeli::CsvHeader(std::move(header))); + reader_options.set_required_header(std::move(header)); return reader_options; } } // namespace internal diff --git a/public/data_loading/csv/csv_delta_record_stream_reader_test.cc b/public/data_loading/csv/csv_delta_record_stream_reader_test.cc index 9add4b92..ab176c6f 100644 --- a/public/data_loading/csv/csv_delta_record_stream_reader_test.cc +++ b/public/data_loading/csv/csv_delta_record_stream_reader_test.cc @@ -39,10 +39,11 @@ StringSetT GetStringSetValue(const std::vector& values) { return string_set; } -UInt32SetT GetUInt32SetValue(const std::vector& values) { - UInt32SetT uint32_set; - uint32_set.value = {values.begin(), values.end()}; - return uint32_set; +template +UIntSetType GetUIntSetValue(const std::vector& values) { + UIntSetType uint_set; + uint_set.value = {values.begin(), values.end()}; + return uint_set; } template @@ -228,8 +229,34 @@ TEST(CsvDeltaRecordStreamReaderTest, }; std::stringstream string_stream; CsvDeltaRecordStreamWriter record_writer(string_stream); - auto [legacy_mutation, mutation] = - GetKVMutationRecord(GetUInt32SetValue(values), values); + auto [legacy_mutation, mutation] = GetKVMutationRecord( + GetUIntSetValue(values), values); + auto input = GetDataRecord(legacy_mutation); + auto expected = GetNativeDataRecord(mutation); + EXPECT_TRUE(record_writer.WriteRecord(input).ok()); + EXPECT_TRUE(record_writer.Flush().ok()); + LOG(INFO) << string_stream.str(); + CsvDeltaRecordStreamReader record_reader(string_stream); + auto status = + record_reader.ReadRecords([&expected](const DataRecord& record) { + std::unique_ptr native_type_record(record.UnPack()); + EXPECT_EQ(*native_type_record, expected); + return absl::OkStatus(); + }); + EXPECT_TRUE(status.ok()) << status; +} + +TEST(CsvDeltaRecordStreamReaderTest, + ValidateReadingAndWriting_KVMutation_UInt64SetValues_Success) { + const std::vector values{ + 18446744073709551613UL, + 18446744073709551614UL, + 18446744073709551615UL, + }; + std::stringstream string_stream; + CsvDeltaRecordStreamWriter record_writer(string_stream); + auto [legacy_mutation, mutation] = GetKVMutationRecord( + GetUIntSetValue(values), values); auto input = GetDataRecord(legacy_mutation); auto expected = GetNativeDataRecord(mutation); EXPECT_TRUE(record_writer.WriteRecord(input).ok()); diff --git a/public/data_loading/csv/csv_delta_record_stream_writer.cc b/public/data_loading/csv/csv_delta_record_stream_writer.cc index 5f807d07..f5fa93a7 100644 --- a/public/data_loading/csv/csv_delta_record_stream_writer.cc +++ b/public/data_loading/csv/csv_delta_record_stream_writer.cc @@ -16,7 +16,6 @@ #include "public/data_loading/csv/csv_delta_record_stream_writer.h" -#include "absl/log/log.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -94,6 +93,13 @@ absl::StatusOr GetRecordValue( value_separator), }; } + if constexpr (std::is_same_v>) { + return ValueStruct{ + .value_type = std::string(kValueTypeUInt64Set), + .value = absl::StrJoin(MaybeEncode(arg, csv_encoding), + value_separator), + }; + } return absl::InvalidArgumentError("Value must be set."); }, value); @@ -110,8 +116,7 @@ absl::StatusOr MakeCsvRecordWithKVMutation( const auto record = std::get(data_record.record); - riegeli::CsvHeader header(kKeyValueMutationRecordHeader); - riegeli::CsvRecord csv_record(header); + riegeli::CsvRecord csv_record(*kKeyValueMutationRecordHeader); csv_record[kKeyColumn] = record.key; absl::StatusOr value = GetRecordValue( record.value, std::string(1, value_separator), csv_encoding); @@ -152,8 +157,8 @@ absl::StatusOr MakeCsvRecordWithUdfConfig( const auto udf_config = std::get(data_record.record); - riegeli::CsvHeader header(kUserDefinedFunctionsConfigHeader); - riegeli::CsvRecord csv_record(header); + riegeli::CsvRecord csv_record(*kUserDefinedFunctionsConfigHeader); + csv_record[kCodeSnippetColumn] = udf_config.code_snippet; csv_record[kHandlerNameColumn] = udf_config.handler_name; csv_record[kLogicalCommitTimeColumn] = @@ -176,8 +181,7 @@ absl::StatusOr MakeCsvRecordWithShardMapping( } const auto shard_mapping_struct = std::get(data_record.record); - riegeli::CsvHeader header(kShardMappingRecordHeader); - riegeli::CsvRecord csv_record(header); + riegeli::CsvRecord csv_record(*kShardMappingRecordHeader); csv_record[kLogicalShardColumn] = absl::StrCat(shard_mapping_struct.logical_shard); csv_record[kPhysicalShardColumn] = diff --git a/public/data_loading/csv/csv_delta_record_stream_writer.h b/public/data_loading/csv/csv_delta_record_stream_writer.h index 849e7fc6..4f4cc39f 100644 --- a/public/data_loading/csv/csv_delta_record_stream_writer.h +++ b/public/data_loading/csv/csv_delta_record_stream_writer.h @@ -18,7 +18,6 @@ #define PUBLIC_DATA_LOADING_CSV_CSV_DELTA_RECORD_STREAM_WRITER_H_ #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -107,25 +106,19 @@ riegeli::CsvWriterBase::Options GetRecordWriterOptions( const typename CsvDeltaRecordStreamWriter::Options& options) { riegeli::CsvWriterBase::Options writer_options; writer_options.set_field_separator(options.field_separator); - std::vector header; + riegeli::CsvHeader header; switch (options.record_type) { case DataRecordType::kKeyValueMutationRecord: - header = - std::vector(kKeyValueMutationRecordHeader.begin(), - kKeyValueMutationRecordHeader.end()); + header = *kKeyValueMutationRecordHeader; break; case DataRecordType::kUserDefinedFunctionsConfig: - header = std::vector( - kUserDefinedFunctionsConfigHeader.begin(), - kUserDefinedFunctionsConfigHeader.end()); + header = *kUserDefinedFunctionsConfigHeader; break; case DataRecordType::kShardMappingRecord: - header = std::vector(kShardMappingRecordHeader.begin(), - kShardMappingRecordHeader.end()); + header = *kShardMappingRecordHeader; break; } - riegeli::CsvHeader header_opt(std::move(header)); - writer_options.set_header(std::move(header_opt)); + writer_options.set_header(std::move(header)); return writer_options; } } // namespace internal diff --git a/public/data_loading/data_loading.fbs b/public/data_loading/data_loading.fbs index 955fdcb4..c1ad8587 100644 --- a/public/data_loading/data_loading.fbs +++ b/public/data_loading/data_loading.fbs @@ -11,7 +11,8 @@ table StringValue { value:string; } // (2) `Delete` mutation removes the elements from existing set. table StringSet { value:[string]; } table UInt32Set { value:[uint]; } -union Value { StringValue, StringSet, UInt32Set } +table UInt64Set { value:[ulong]; } +union Value { StringValue, StringSet, UInt32Set, UInt64Set } table KeyValueMutationRecord { // Required. For updates, the value will overwrite the previous value, if any. diff --git a/public/data_loading/data_loading_generated.h b/public/data_loading/data_loading_generated.h index 6d68243c..a6ed34a7 100644 --- a/public/data_loading/data_loading_generated.h +++ b/public/data_loading/data_loading_generated.h @@ -1,5 +1,5 @@ /* - * Copyright 2023 Google LLC + * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,6 +42,10 @@ struct UInt32Set; struct UInt32SetBuilder; struct UInt32SetT; +struct UInt64Set; +struct UInt64SetBuilder; +struct UInt64SetT; + struct KeyValueMutationRecord; struct KeyValueMutationRecordBuilder; struct KeyValueMutationRecordT; @@ -64,6 +68,8 @@ bool operator==(const StringSetT& lhs, const StringSetT& rhs); bool operator!=(const StringSetT& lhs, const StringSetT& rhs); bool operator==(const UInt32SetT& lhs, const UInt32SetT& rhs); bool operator!=(const UInt32SetT& lhs, const UInt32SetT& rhs); +bool operator==(const UInt64SetT& lhs, const UInt64SetT& rhs); +bool operator!=(const UInt64SetT& lhs, const UInt64SetT& rhs); bool operator==(const KeyValueMutationRecordT& lhs, const KeyValueMutationRecordT& rhs); bool operator!=(const KeyValueMutationRecordT& lhs, @@ -108,24 +114,26 @@ enum class Value : uint8_t { StringValue = 1, StringSet = 2, UInt32Set = 3, + UInt64Set = 4, MIN = NONE, - MAX = UInt32Set + MAX = UInt64Set }; -inline const Value (&EnumValuesValue())[4] { +inline const Value (&EnumValuesValue())[5] { static const Value values[] = {Value::NONE, Value::StringValue, - Value::StringSet, Value::UInt32Set}; + Value::StringSet, Value::UInt32Set, + Value::UInt64Set}; return values; } inline const char* const* EnumNamesValue() { - static const char* const names[5] = {"NONE", "StringValue", "StringSet", - "UInt32Set", nullptr}; + static const char* const names[6] = {"NONE", "StringValue", "StringSet", + "UInt32Set", "UInt64Set", nullptr}; return names; } inline const char* EnumNameValue(Value e) { - if (flatbuffers::IsOutRange(e, Value::NONE, Value::UInt32Set)) return ""; + if (flatbuffers::IsOutRange(e, Value::NONE, Value::UInt64Set)) return ""; const size_t index = static_cast(e); return EnumNamesValue()[index]; } @@ -150,6 +158,11 @@ struct ValueTraits { static const Value enum_value = Value::UInt32Set; }; +template <> +struct ValueTraits { + static const Value enum_value = Value::UInt64Set; +}; + template struct ValueUnionTraits { static const Value enum_value = Value::NONE; @@ -170,6 +183,11 @@ struct ValueUnionTraits { static const Value enum_value = Value::UInt32Set; }; +template <> +struct ValueUnionTraits { + static const Value enum_value = Value::UInt64Set; +}; + struct ValueUnion { Value type; void* value; @@ -242,6 +260,16 @@ struct ValueUnion { ? reinterpret_cast(value) : nullptr; } + kv_server::UInt64SetT* AsUInt64Set() { + return type == Value::UInt64Set + ? reinterpret_cast(value) + : nullptr; + } + const kv_server::UInt64SetT* AsUInt64Set() const { + return type == Value::UInt64Set + ? reinterpret_cast(value) + : nullptr; + } }; inline bool operator==(const ValueUnion& lhs, const ValueUnion& rhs) { @@ -262,6 +290,10 @@ inline bool operator==(const ValueUnion& lhs, const ValueUnion& rhs) { return *(reinterpret_cast(lhs.value)) == *(reinterpret_cast(rhs.value)); } + case Value::UInt64Set: { + return *(reinterpret_cast(lhs.value)) == + *(reinterpret_cast(rhs.value)); + } default: { return false; } @@ -716,6 +748,76 @@ flatbuffers::Offset CreateUInt32Set( flatbuffers::FlatBufferBuilder& _fbb, const UInt32SetT* _o, const flatbuffers::rehasher_function_t* _rehasher = nullptr); +struct UInt64SetT : public flatbuffers::NativeTable { + typedef UInt64Set TableType; + std::vector value{}; +}; + +struct UInt64Set FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UInt64SetT NativeTableType; + typedef UInt64SetBuilder Builder; + struct Traits; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUE = 4 + }; + const flatbuffers::Vector* value() const { + return GetPointer*>(VT_VALUE); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VALUE) && + verifier.VerifyVector(value()) && verifier.EndTable(); + } + UInt64SetT* UnPack( + const flatbuffers::resolver_function_t* _resolver = nullptr) const; + void UnPackTo( + UInt64SetT* _o, + const flatbuffers::resolver_function_t* _resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder& _fbb, const UInt64SetT* _o, + const flatbuffers::rehasher_function_t* _rehasher = nullptr); +}; + +struct UInt64SetBuilder { + typedef UInt64Set Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_value(flatbuffers::Offset> value) { + fbb_.AddOffset(UInt64Set::VT_VALUE, value); + } + explicit UInt64SetBuilder(flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUInt64Set( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset> value = 0) { + UInt64SetBuilder builder_(_fbb); + builder_.add_value(value); + return builder_.Finish(); +} + +struct UInt64Set::Traits { + using type = UInt64Set; + static auto constexpr Create = CreateUInt64Set; +}; + +inline flatbuffers::Offset CreateUInt64SetDirect( + flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* value = nullptr) { + auto value__ = value ? _fbb.CreateVector(*value) : 0; + return kv_server::CreateUInt64Set(_fbb, value__); +} + +flatbuffers::Offset CreateUInt64Set( + flatbuffers::FlatBufferBuilder& _fbb, const UInt64SetT* _o, + const flatbuffers::rehasher_function_t* _rehasher = nullptr); + struct KeyValueMutationRecordT : public flatbuffers::NativeTable { typedef KeyValueMutationRecord TableType; kv_server::KeyValueMutationType mutation_type = @@ -768,6 +870,11 @@ struct KeyValueMutationRecord FLATBUFFERS_FINAL_CLASS ? static_cast(value()) : nullptr; } + const kv_server::UInt64Set* value_as_UInt64Set() const { + return value_type() == kv_server::Value::UInt64Set + ? static_cast(value()) + : nullptr; + } bool Verify(flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_MUTATION_TYPE, 1) && @@ -805,6 +912,12 @@ KeyValueMutationRecord::value_as() const { return value_as_UInt32Set(); } +template <> +inline const kv_server::UInt64Set* +KeyValueMutationRecord::value_as() const { + return value_as_UInt64Set(); +} + struct KeyValueMutationRecordBuilder { typedef KeyValueMutationRecord Table; flatbuffers::FlatBufferBuilder& fbb_; @@ -1347,6 +1460,57 @@ inline flatbuffers::Offset CreateUInt32Set( return kv_server::CreateUInt32Set(_fbb, _value); } +inline bool operator==(const UInt64SetT& lhs, const UInt64SetT& rhs) { + return (lhs.value == rhs.value); +} + +inline bool operator!=(const UInt64SetT& lhs, const UInt64SetT& rhs) { + return !(lhs == rhs); +} + +inline UInt64SetT* UInt64Set::UnPack( + const flatbuffers::resolver_function_t* _resolver) const { + auto _o = std::make_unique(); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void UInt64Set::UnPackTo( + UInt64SetT* _o, const flatbuffers::resolver_function_t* _resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = value(); + if (_e) { + _o->value.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->value[_i] = _e->Get(_i); + } + } + } +} + +inline flatbuffers::Offset UInt64Set::Pack( + flatbuffers::FlatBufferBuilder& _fbb, const UInt64SetT* _o, + const flatbuffers::rehasher_function_t* _rehasher) { + return CreateUInt64Set(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateUInt64Set( + flatbuffers::FlatBufferBuilder& _fbb, const UInt64SetT* _o, + const flatbuffers::rehasher_function_t* _rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder* __fbb; + const UInt64SetT* __o; + const flatbuffers::rehasher_function_t* __rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _value = _o->value.size() ? _fbb.CreateVector(_o->value) : 0; + return kv_server::CreateUInt64Set(_fbb, _value); +} + inline bool operator==(const KeyValueMutationRecordT& lhs, const KeyValueMutationRecordT& rhs) { return (lhs.mutation_type == rhs.mutation_type) && @@ -1627,6 +1791,10 @@ inline bool VerifyValue(flatbuffers::Verifier& verifier, const void* obj, auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case Value::UInt64Set: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } @@ -1663,6 +1831,10 @@ inline void* ValueUnion::UnPack( auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case Value::UInt64Set: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } @@ -1685,6 +1857,10 @@ inline flatbuffers::Offset ValueUnion::Pack( auto ptr = reinterpret_cast(value); return CreateUInt32Set(_fbb, ptr, _rehasher).Union(); } + case Value::UInt64Set: { + auto ptr = reinterpret_cast(value); + return CreateUInt64Set(_fbb, ptr, _rehasher).Union(); + } default: return 0; } @@ -1708,6 +1884,11 @@ inline ValueUnion::ValueUnion(const ValueUnion& u) *reinterpret_cast(u.value)); break; } + case Value::UInt64Set: { + value = new kv_server::UInt64SetT( + *reinterpret_cast(u.value)); + break; + } default: break; } @@ -1730,6 +1911,11 @@ inline void ValueUnion::Reset() { delete ptr; break; } + case Value::UInt64Set: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } diff --git a/public/data_loading/record_utils.cc b/public/data_loading/record_utils.cc index 65126a74..dad29e4d 100644 --- a/public/data_loading/record_utils.cc +++ b/public/data_loading/record_utils.cc @@ -17,6 +17,7 @@ #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "flatbuffers/flatbuffer_builder.h" namespace kv_server { namespace { @@ -54,6 +55,11 @@ absl::Status ValidateValue(const KeyValueMutationRecord& kv_mutation_record) { kv_mutation_record.value_as_UInt32Set()->value() == nullptr)) { return absl::InvalidArgumentError("UInt32Set value not set."); } + if (kv_mutation_record.value_type() == Value::UInt64Set && + (kv_mutation_record.value_as_UInt64Set() == nullptr || + kv_mutation_record.value_as_UInt64Set()->value() == nullptr)) { + return absl::InvalidArgumentError("UInt64Set value not set."); + } return absl::OkStatus(); } @@ -178,4 +184,18 @@ absl::StatusOr> MaybeGetRecordValue( maybe_value->value()->end()); } +template <> +absl::StatusOr> MaybeGetRecordValue( + const KeyValueMutationRecord& record) { + const kv_server::UInt64Set* maybe_value = record.value_as_UInt64Set(); + if (!maybe_value) { + return absl::InvalidArgumentError(absl::StrCat( + "KeyValueMutationRecord does not contain expected value type. " + "Expected: UInt64Set", + ". Actual: ", EnumNameValue(record.value_type()))); + } + return std::vector(maybe_value->value()->begin(), + maybe_value->value()->end()); +} + } // namespace kv_server diff --git a/public/data_loading/record_utils.h b/public/data_loading/record_utils.h index 30a752a7..c7566b34 100644 --- a/public/data_loading/record_utils.h +++ b/public/data_loading/record_utils.h @@ -53,6 +53,13 @@ inline std::ostream& operator<<(std::ostream& os, const UInt32SetT& set_value) { return os; } +inline std::ostream& operator<<(std::ostream& os, const UInt64SetT& set_value) { + for (const auto& value : set_value.value) { + os << value << ", "; + } + return os; +} + inline std::ostream& operator<<(std::ostream& os, const ValueUnion& value_union) { switch (value_union.type) { @@ -68,6 +75,10 @@ inline std::ostream& operator<<(std::ostream& os, os << *(reinterpret_cast(value_union.value)); break; } + case Value::UInt64Set: { + os << *(reinterpret_cast(value_union.value)); + break; + } case Value::NONE: { break; } @@ -198,6 +209,12 @@ template <> absl::StatusOr> MaybeGetRecordValue( const KeyValueMutationRecord& record); +// Returns the vector of uint64_t stored in `record.value`. Returns error if the +// record.value is not a uint64_t set. +template <> +absl::StatusOr> MaybeGetRecordValue( + const KeyValueMutationRecord& record); + } // namespace kv_server #endif // PUBLIC_DATA_LOADING_RECORD_UTILS_H_ diff --git a/public/data_loading/record_utils_test.cc b/public/data_loading/record_utils_test.cc index e8e3064c..3c8f4c09 100644 --- a/public/data_loading/record_utils_test.cc +++ b/public/data_loading/record_utils_test.cc @@ -215,6 +215,44 @@ TEST(RecordUtilsTest, DataRecordWithKeyValueMutationRecordWithUInt32SetValue) { EXPECT_TRUE(status.ok()) << status; } +TEST(RecordUtilsTest, DataRecordWithKeyValueMutationRecordWithUInt64SetValue) { + // Serialize + KeyValueMutationRecordT kv_mutation_record_native; + kv_mutation_record_native.key = "key"; + kv_mutation_record_native.logical_commit_time = 5; + UInt64SetT value_native; + value_native.value = { + 18446744073709551613UL, + 18446744073709551614UL, + 18446744073709551615UL, + }; + kv_mutation_record_native.value.Set(std::move(value_native)); + DataRecordT data_record_native; + data_record_native.record.Set(std::move(kv_mutation_record_native)); + auto [fbs_buffer, serialized_string_view] = Serialize(data_record_native); + // Deserialize + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .Times(1) + .WillOnce([](const DataRecord& fbs_record) { + const KeyValueMutationRecord& kv_mutation_record = + *fbs_record.record_as_KeyValueMutationRecord(); + EXPECT_EQ(kv_mutation_record.key()->string_view(), "key"); + EXPECT_EQ(kv_mutation_record.logical_commit_time(), 5); + absl::StatusOr> maybe_record_value = + MaybeGetRecordValue>(kv_mutation_record); + EXPECT_TRUE(maybe_record_value.ok()) << maybe_record_value.status(); + EXPECT_THAT(*maybe_record_value, + testing::UnorderedElementsAre(18446744073709551613UL, + 18446744073709551614UL, + 18446744073709551615UL)); + return absl::OkStatus(); + }); + auto status = DeserializeRecord(serialized_string_view, + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + TEST(DataRecordTest, DeserializeDataRecordEmptyRecordFailure) { DataRecordT data_record_native; auto [fbs_buffer, serialized_string_view] = Serialize(data_record_native); diff --git a/public/data_loading/records_utils.cc b/public/data_loading/records_utils.cc index a16f293c..4ce130d9 100644 --- a/public/data_loading/records_utils.cc +++ b/public/data_loading/records_utils.cc @@ -58,6 +58,13 @@ ValueUnion BuildValueUnion(const KeyValueMutationRecordValueT& value, .value = CreateUInt32Set(builder, values_offset).Union(), }; } + if constexpr (std::is_same_v>) { + auto values_offset = builder.CreateVector(arg); + return ValueUnion{ + .value_type = Value::UInt64Set, + .value = CreateUInt64Set(builder, values_offset).Union(), + }; + } if constexpr (std::is_same_v) { return ValueUnion{ .value_type = Value::NONE, @@ -161,6 +168,11 @@ absl::Status ValidateValue(const KeyValueMutationRecord& kv_mutation_record) { kv_mutation_record.value_as_UInt32Set()->value() == nullptr)) { return absl::InvalidArgumentError("UInt32Set value not set."); } + if (kv_mutation_record.value_type() == Value::UInt64Set && + (kv_mutation_record.value_as_UInt64Set() == nullptr || + kv_mutation_record.value_as_UInt64Set()->value() == nullptr)) { + return absl::InvalidArgumentError("UInt64Set value not set."); + } return absl::OkStatus(); } @@ -218,6 +230,9 @@ KeyValueMutationRecordValueT GetRecordStructValue( if (fbs_record.value_type() == Value::UInt32Set) { value = GetRecordValue>(fbs_record); } + if (fbs_record.value_type() == Value::UInt64Set) { + value = GetRecordValue>(fbs_record); + } return value; } @@ -373,6 +388,12 @@ std::vector GetRecordValue(const KeyValueMutationRecord& record) { record.value_as_UInt32Set()->value()->end()); } +template <> +std::vector GetRecordValue(const KeyValueMutationRecord& record) { + return std::vector(record.value_as_UInt64Set()->value()->begin(), + record.value_as_UInt64Set()->value()->end()); +} + template <> KeyValueMutationRecordStruct GetTypedRecordStruct( const DataRecord& data_record) { diff --git a/public/data_loading/records_utils.h b/public/data_loading/records_utils.h index a1532079..3ce2d4d9 100644 --- a/public/data_loading/records_utils.h +++ b/public/data_loading/records_utils.h @@ -35,7 +35,8 @@ enum class DataRecordType : int { using KeyValueMutationRecordValueT = std::variant, std::vector>; + std::vector, std::vector, + std::vector>; struct KeyValueMutationRecordStruct { KeyValueMutationType mutation_type; @@ -140,6 +141,8 @@ std::vector GetRecordValue( const KeyValueMutationRecord& record); template <> std::vector GetRecordValue(const KeyValueMutationRecord& record); +template <> +std::vector GetRecordValue(const KeyValueMutationRecord& record); // Utility function to get the union record set on the `data_record`. Must // be called after checking the type of the union record using diff --git a/public/data_loading/records_utils_test.cc b/public/data_loading/records_utils_test.cc index 03c5f90f..3542d1a1 100644 --- a/public/data_loading/records_utils_test.cc +++ b/public/data_loading/records_utils_test.cc @@ -130,6 +130,11 @@ void ExpectEqual(const KeyValueMutationRecordStruct& record, testing::ContainerEq( GetRecordValue>(fbs_record))); } + if (fbs_record.value_type() == Value::UInt64Set) { + EXPECT_THAT(std::get>(record.value), + testing::ContainerEq( + GetRecordValue>(fbs_record))); + } } void ExpectEqual(const UserDefinedFunctionsConfigStruct& record, @@ -254,6 +259,26 @@ TEST(DataRecordTest, EXPECT_TRUE(status.ok()) << status; } +TEST(DataRecordTest, + DeserializeDataRecord_ToFbsRecord_KVMutation_UInt64VectorValue_Success) { + std::vector values({ + 18446744073709551613UL, + 18446744073709551614UL, + 18446744073709551615UL, + }); + auto data_record_struct = GetDataRecord(GetKeyValueMutationRecord(values)); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecord& data_record_fbs) { + ExpectEqual(data_record_struct, data_record_fbs); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + TEST(DataRecordTest, DeserializeDataRecord_ToFbsRecord_KVMutation_KeyNotSet_Failure) { flatbuffers::FlatBufferBuilder builder; @@ -370,6 +395,29 @@ TEST( EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } +TEST( + DataRecordTest, + DeserializeDataRecord_ToFbsRecord_KVMutation_UInt64SetValueNotSet_Failure) { + flatbuffers::FlatBufferBuilder builder; + const auto kv_mutation_fbs = CreateKeyValueMutationRecordDirect( + builder, + /*mutation_type=*/KeyValueMutationType::Update, + /*logical_commit_time=*/0, + /*key=*/"key", + /*value_type=*/Value::UInt64Set, + /*value=*/CreateUInt64Set(builder).Union()); + const auto data_record_fbs = + CreateDataRecord(builder, /*record_type=*/Record::KeyValueMutationRecord, + kv_mutation_fbs.Union()); + builder.Finish(data_record_fbs); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call).Times(0); + auto status = DeserializeDataRecord(ToStringView(builder), + record_callback.AsStdFunction()); + ASSERT_FALSE(status.ok()) << status; + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + TEST(DataRecordTest, DeserializeDataRecord_ToFbsRecord_UdfConfig_Success) { auto data_record_struct = GetDataRecord(GetUdfConfigStruct()); testing::MockFunction record_callback; @@ -477,6 +525,26 @@ TEST(DataRecordTest, EXPECT_TRUE(status.ok()) << status; } +TEST(DataRecordTest, + DeserializeDataRecord_ToStruct_KVMutation_Uint64VectorValue_Success) { + std::vector values({ + 18446744073709551613UL, + 18446744073709551614UL, + 18446744073709551615UL, + }); + auto data_record_struct = GetDataRecord(GetKeyValueMutationRecord(values)); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecordStruct& actual_record) { + EXPECT_EQ(data_record_struct, actual_record); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + TEST(DataRecordTest, DeserializeDataRecord_ToStruct_UdfConfig_Success) { auto data_record_struct = GetDataRecord(GetUdfConfigStruct()); testing::MockFunction record_callback; diff --git a/public/query/BUILD.bazel b/public/query/BUILD.bazel index 09360863..00b55a8b 100644 --- a/public/query/BUILD.bazel +++ b/public/query/BUILD.bazel @@ -59,3 +59,19 @@ cc_grpc_library( grpc_only = True, deps = [":get_values_cc_proto"], ) + +genrule( + name = "copy_to_dist", + srcs = [ + ":query_api_descriptor_set", + ], + outs = ["copy_to_dist.bin"], + cmd_bash = """cat << EOF > '$@' +mkdir -p dist +cp $(execpath //public/query:query_api_descriptor_set) dist +builders/tools/normalize-dist +EOF""", + executable = True, + local = True, + message = "copying server artifacts to dist directory", +) diff --git a/public/query/v2/get_values_v2.proto b/public/query/v2/get_values_v2.proto index f161711d..a6de26da 100644 --- a/public/query/v2/get_values_v2.proto +++ b/public/query/v2/get_values_v2.proto @@ -43,21 +43,6 @@ service KeyValueService { rpc GetValues(GetValuesRequest) returns (GetValuesResponse) {} - // Debugging API to communication in Binary Http. - // - // The body should be a binary Http request described in - // https://www.rfc-editor.org/rfc/rfc9292.html - // - // The response will be a binary Http response. The response will return 200 - // code as long as the binary Http response can be encoded. The actual status - // code of the processing can be extracted from the binary Http response. - rpc BinaryHttpGetValues(BinaryHttpGetValuesRequest) returns (google.api.HttpBody) { - option (google.api.http) = { - post: "/v2/bhttp_getvalues" - body: "raw_body" - }; - } - // V2 GetValues API based on the Oblivious HTTP protocol. rpc ObliviousGetValues(ObliviousGetValuesRequest) returns (google.api.HttpBody) { option (google.api.http) = { @@ -71,10 +56,6 @@ message GetValuesHttpRequest { google.api.HttpBody raw_body = 1; } -message BinaryHttpGetValuesRequest { - google.api.HttpBody raw_body = 1; -} - message ObliviousGetValuesRequest { google.api.HttpBody raw_body = 1; } @@ -95,7 +76,9 @@ message RequestPartition { // Partitions from the same owner can be compressed together so they can // belong to the same compression group. The actual number does not matter as // long as it is unique from other owners' compression groups. - int32 compression_group_id = 2; + optional int32 compression_group_id = 2; + // Per partition metadata. + google.protobuf.Struct metadata = 3; // Each input is one argument to UDF // They are passed to the UDF in the same order as stored here. repeated UDFArgument arguments = 5; @@ -112,6 +95,9 @@ message GetValuesRequest { privacy_sandbox.server_common.LogContext log_context = 4; // Consented debugging configuration privacy_sandbox.server_common.ConsentedDebugConfiguration consented_debug_config = 5; + // Algorithm accepted by the browser for the response. + // Must contain at least one of: none, gzip, brotli. + repeated string accept_compression = 6; } message ResponsePartition { @@ -130,15 +116,24 @@ message ResponsePartition { // will compress while it builds the response, before letting gRPC send them // over the wire. In use cases where there is always only one partition, the // server will rely on gRPC compression instead. -message CompressionGroups { - repeated bytes compressed_partition_groups = 1; +message CompressionGroup { + // All responses for all partitions with this compression group id specified in the request are present in this object. + optional int32 compression_group_id = 1; + // Adtech-specified TTL for client-side caching. In milliseconds. Unset means no caching. + optional int32 ttl_ms = 2; + // Compressed CBOR binary string. For details see compressed response content schema -- V2CompressionGroup. + bytes content = 3; } + message GetValuesResponse { - oneof format { - // For single partition response, no explicit compression is necessary at - // request handling layer. Compression can be applied during the - // communication by the protocol layer such as gRPC or HTTP. - ResponsePartition single_partition = 1; - CompressionGroups compressed_partition_groups = 2; - } + // For single partition response use cases, no explicit compression is necessary at + // request handling layer. Compression can be applied during the + // communication by the protocol layer such as gRPC or HTTP. + // Note that single partition responses in cbor are not currently supported. + ResponsePartition single_partition = 1; + repeated CompressionGroup compression_groups = 2; + // Debug logs to send back to upstream servers (only in non_prod) + // The server name in the debug info will be set by the upstream servers after + // they get response from KV server + privacy_sandbox.server_common.DebugInfo debug_info = 3; } diff --git a/public/test_util/request_example.h b/public/test_util/request_example.h index b24c9a73..7d4509ac 100644 --- a/public/test_util/request_example.h +++ b/public/test_util/request_example.h @@ -25,7 +25,8 @@ namespace kv_server { constexpr std::string_view kExampleV2RequestInJson = R"( { "metadata": { - "hostname": "example.com" + "hostname": "example.com", + "is_pas": "true" }, "partitions": [ { @@ -59,7 +60,8 @@ constexpr std::string_view kExampleV2RequestInJson = R"( constexpr std::string_view kExampleConsentedV2RequestInJson = R"( { "metadata": { - "hostname": "example.com" + "hostname": "example.com", + "is_pas": "true" }, "partitions": [ { @@ -97,7 +99,8 @@ constexpr std::string_view kExampleConsentedV2RequestInJson = R"( constexpr std::string_view kExampleConsentedV2RequestWithLogContextInJson = R"( { "metadata": { - "hostname": "example.com" + "hostname": "example.com", + "is_pas": "true" }, "partitions": [ { @@ -135,6 +138,253 @@ constexpr std::string_view kExampleConsentedV2RequestWithLogContextInJson = R"( }, })"; +// Non-consented V2 request example with multiple partitions +constexpr std::string_view kV2RequestMultiplePartitionsInJson = R"( + { + "metadata": { + "hostname": "example.com" + }, + "partitions": [ + { + "id": 0, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "structured", + "groupNames" + ], + "data": [ + "hello" + ] + } + ] + }, + { + "id": 1, + "compressionGroupId": 1, + "arguments": [ + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key1" + ] + } + ] + }, + { + "id": 2, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key2" + ] + } + ] + } + ] + } + )"; + +// Consented V2 request example with multiple partitions without log context +constexpr std::string_view kConsentedV2RequestMultiplePartitionsInJson = + R"( + { + "metadata": { + "hostname": "example.com" + }, + "partitions": [ + { + "id": 0, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "structured", + "groupNames" + ], + "data": [ + "hello" + ] + } + ] + }, + { + "id": 1, + "compressionGroupId": 1, + "arguments": [ + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key1" + ] + } + ] + }, + { + "id": 2, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key2" + ] + } + ] + } + ], + "consented_debug_config": { + "is_consented": true, + "token": "debug_token" + } + })"; + +// Consented V2 request example with multiple partitions with log context +constexpr std::string_view + kConsentedV2RequestMultiplePartitionsWithLogContextInJson = R"( + { + "metadata": { + "hostname": "example.com" + }, + "partitions": [ + { + "id": 0, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "structured", + "groupNames" + ], + "data": [ + "hello" + ] + } + ] + }, + { + "id": 1, + "compressionGroupId": 1, + "arguments": [ + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key1" + ] + } + ] + }, + { + "id": 2, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key2" + ] + } + ] + } + ], + "consented_debug_config": { + "is_consented": true, + "token": "debug_token" + }, + "log_context": { + "generation_id": "client_UUID", + "adtech_debug_id": "adtech_debug_test" + } + })"; + +// Consented V2 request example with multiple partitions with log context and +// debug info response flag +constexpr std::string_view + kConsentedV2RequestMultiPartWithDebugInfoResponseInJson = + R"( + { + "metadata": { + "hostname": "example.com" + }, + "partitions": [ + { + "id": 0, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "structured", + "groupNames" + ], + "data": [ + "hello" + ] + } + ] + }, + { + "id": 1, + "compressionGroupId": 1, + "arguments": [ + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key1" + ] + } + ] + }, + { + "id": 2, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key2" + ] + } + ] + } + ], + "consented_debug_config": { + "is_consented": true, + "token": "debug_token", + "is_debug_info_in_response": true + }, + "log_context": { + "generation_id": "client_UUID", + "adtech_debug_id": "adtech_debug_test" + } + })"; + // Example consented debug token used in the unit tests constexpr std::string_view kExampleConsentedDebugToken = "debug_token"; diff --git a/public/udf/constants.h b/public/udf/constants.h index 8a7215dc..cd29376f 100644 --- a/public/udf/constants.h +++ b/public/udf/constants.h @@ -54,7 +54,7 @@ function handlePas(udf_arguments) { if (udf_arguments.length != 1) { const error_message = 'For PAS default UDF exactly one argument should be provided, but was provided ' + udf_arguments.length; - console.error(error_message); + logMessage(error_message); throw new Error(error_message); } const kv_result = JSON.parse(getValues(udf_arguments[0])); @@ -63,7 +63,7 @@ function handlePas(udf_arguments) { } const error_message = "Error executing handle PAS:" + JSON.stringify(kv_result); - console.error(error_message); + logMessage(error_message); throw new Error(error_message); } @@ -73,9 +73,10 @@ function handlePA(udf_arguments) { } function HandleRequest(executionMetadata, ...udf_arguments) { + logMessage("Executing UDF"); if(executionMetadata.requestMetadata && executionMetadata.requestMetadata.is_pas) { - console.log('Executing PAS branch'); + logMessage('Executing PAS branch'); return handlePas(udf_arguments); } return handlePA(udf_arguments); diff --git a/testing/run_local/BUILD.bazel b/testing/run_local/BUILD.bazel index f41e4cd5..de611ac4 100644 --- a/testing/run_local/BUILD.bazel +++ b/testing/run_local/BUILD.bazel @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@io_bazel_rules_docker//container:container.bzl", - "container_image", - "container_layer", -) +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load") load( "@rules_pkg//pkg:mappings.bzl", "pkg_attributes", @@ -41,27 +37,27 @@ pkg_tar( ], ) -container_layer( - name = "envoy_config_layer", - directory = "/", +oci_image( + name = "envoy_image", + base = select({ + "@platforms//cpu:arm64": "@envoy-distroless-arm64", + "@platforms//cpu:x86_64": "@envoy-distroless-amd64", + }), tars = [ ":envoy_config_tar", ], ) -container_image( - name = "envoy_image", - architecture = select({ - "@platforms//cpu:arm64": "arm64", - "@platforms//cpu:x86_64": "amd64", - }), - base = select({ - "@platforms//cpu:arm64": "@envoy-distroless-arm64//image", - "@platforms//cpu:x86_64": "@envoy-distroless-amd64//image", - }), - layers = [ - ":envoy_config_layer", - ], +oci_load( + name = "envoy_image_tarball", + image = ":envoy_image", + repo_tags = ["bazel/testing/run_local:envoy_image"], +) + +filegroup( + name = "envoy_image.tar", + srcs = [":envoy_image_tarball"], + output_group = "tarball", ) genrule( @@ -72,7 +68,7 @@ genrule( outs = ["build_envoy_image.bin"], cmd_bash = """cat << EOF > '$@' mkdir -p testing/run_local/dist -cp $(execpath :envoy_image.tar) testing/run_local/dist +cp $(execpath :envoy_image.tar) testing/run_local/dist/envoy_image.tar EOF""", executable = True, local = True, diff --git a/third_party_deps/container_deps.bzl b/third_party_deps/container_deps.bzl index 712cad5a..80e859af 100644 --- a/third_party_deps/container_deps.bzl +++ b/third_party_deps/container_deps.bzl @@ -11,10 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -load("@io_bazel_rules_docker//container:container.bzl", "container_pull") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file") +load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") +load("@google_privacysandbox_servers_common//third_party:container_deps.bzl", common_container_deps = "container_deps") +load("@rules_oci//oci:pull.bzl", "oci_pull") def container_deps(): + common_container_deps() + images = { "aws-lambda-python": { "arch_hashes": { @@ -25,53 +29,34 @@ def container_deps(): "registry": "public.ecr.aws", "repository": "lambda/python", }, - # Used for deploying Envoy locally for testing + # Used for deploying Envoy locally "envoy-distroless": { "arch_hashes": { - # v1.23.1 - "amd64": "e2c642bc6949cb3053810ca14524324d7daf884a0046d7173e46e2b003144f1d", - "arm64": "7763f6325882122afb1beb6ba0a047bed318368f9656fd9c1df675f3d89f1dbe", + # v1.24.1 + "amd64": "9f5d0d7c817c588cd4bd6ef4508ad544ef19cef6d217aa894315790da7662ba7", + "arm64": "94c9e77eaa85893daaf95a20fdd5dfb3141250a8c5d707d789265ee3abe49a1e", }, "registry": "docker.io", "repository": "envoyproxy/envoy-distroless", }, - "runtime-debian-debug-nonroot": { - "arch_hashes": { - # cc-debian11:debug-nonroot - "amd64": "7caec0c1274f808d29492012a5c3f57331c7f44d5e9e83acf5819eb2e3ae14dc", - "arm64": "f17be941beeaa468ef03fc986cd525fe61e7550affc12fbd4160ec9e1dac9c1d", - }, - "registry": "gcr.io", - "repository": "distroless/cc-debian11", - }, - "runtime-debian-debug-root": { - # debug build so we can use 'sh'. Root, for gcp coordinators - # auth to work - "arch_hashes": { - "amd64": "6865ad48467c89c3c3524d4c426f52ad12d9ab7dec31fad31fae69da40eb6445", - "arm64": "3c399c24b13bfef7e38257831b1bb05cbddbbc4d0327df87a21b6fbbb2480bc9", - }, - "registry": "gcr.io", - "repository": "distroless/cc-debian11", - }, - # Non-distroless; only for debugging purposes - "runtime-ubuntu-fulldist-debug-root": { - # Ubuntu 20.04 - "arch_hashes": { - "amd64": "218bb51abbd1864df8be26166f847547b3851a89999ca7bfceb85ca9b5d2e95d", - "arm64": "a80d11b67ef30474bcccab048020ee25aee659c4caaca70794867deba5d392b6", - }, - "registry": "docker.io", - "repository": "library/ubuntu", - }, } [ - container_pull( - name = img_name + "-" + arch, + oci_pull( + name = "{}-{}".format(img_name, arch), digest = "sha256:" + hash, - registry = image["registry"], - repository = image["repository"], + image = "{}/{}".format(image["registry"], image["repository"]), ) for img_name, image in images.items() for arch, hash in image["arch_hashes"].items() ] + + # Used for deploying Envoy on GCP + # version 1.24.1, same version as the one used for AWS + maybe( + http_file, + name = "envoy_binary", + downloaded_file_path = "envoy", + executable = True, + url = "https://github.com/envoyproxy/envoy/releases/download/v1.24.1/envoy-1.24.1-linux-x86_64", + sha256 = "b4984647923c1506300995830f51b03008b18977e72326dc33cd414e21f5036e", + ) diff --git a/third_party_deps/cpp_repositories.bzl b/third_party_deps/cpp_repositories.bzl index e42f7f0e..361c406b 100644 --- a/third_party_deps/cpp_repositories.bzl +++ b/third_party_deps/cpp_repositories.bzl @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") def cpp_repositories(): """Entry point for all external repositories used for C++/C dependencies.""" @@ -101,3 +101,15 @@ def cpp_repositories(): "https://github.com/RoaringBitmap/CRoaring/archive/refs/tags/v3.0.1.zip", ], ) + + http_file( + name = "otel_collector_aarch64", + url = "https://aws-otel-collector.s3.amazonaws.com/amazon_linux/arm64/v0.40.0/aws-otel-collector.rpm", + sha256 = "c1860bac86d2c8b21a7448bb41b548589f3a65507f7768be94a9bf36ec188801", + ) + + http_file( + name = "otel_collector_amd64", + url = "https://aws-otel-collector.s3.amazonaws.com/amazon_linux/amd64/v0.40.0/aws-otel-collector.rpm", + sha256 = "3d3837ad0b0a32b905b94713ab3534eb58c377cf211a9c75d89d39f35b0f4152", + ) diff --git a/third_party_deps/libcbor.BUILD b/third_party_deps/libcbor.BUILD new file mode 100644 index 00000000..8e0df01f --- /dev/null +++ b/third_party_deps/libcbor.BUILD @@ -0,0 +1,15 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +cc_library( + name = "cbor", + srcs = glob([ + "allocators.c", + "cbor.c", + "cbor/**/*.c", + "cbor/**/*.h", + ]), + hdrs = [ + "cbor.h", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party_deps/libcbor.patch b/third_party_deps/libcbor.patch new file mode 100644 index 00000000..90320755 --- /dev/null +++ b/third_party_deps/libcbor.patch @@ -0,0 +1,69 @@ +diff --git a/cbor/cbor_export.h b/cbor/cbor_export.h +new file mode 100644 +index 0000000..0758ee6 +--- /dev/null ++++ b/cbor/cbor_export.h +@@ -0,0 +1,42 @@ ++ ++#ifndef CBOR_EXPORT_H ++#define CBOR_EXPORT_H ++ ++#ifdef CBOR_STATIC_DEFINE ++# define CBOR_EXPORT ++# define CBOR_NO_EXPORT ++#else ++# ifndef CBOR_EXPORT ++# ifdef cbor_EXPORTS ++ /* We are building this library */ ++# define CBOR_EXPORT ++# else ++ /* We are using this library */ ++# define CBOR_EXPORT ++# endif ++# endif ++ ++# ifndef CBOR_NO_EXPORT ++# define CBOR_NO_EXPORT ++# endif ++#endif ++ ++#ifndef CBOR_DEPRECATED ++# define CBOR_DEPRECATED __attribute__ ((__deprecated__)) ++#endif ++ ++#ifndef CBOR_DEPRECATED_EXPORT ++# define CBOR_DEPRECATED_EXPORT CBOR_EXPORT CBOR_DEPRECATED ++#endif ++ ++#ifndef CBOR_DEPRECATED_NO_EXPORT ++# define CBOR_DEPRECATED_NO_EXPORT CBOR_NO_EXPORT CBOR_DEPRECATED ++#endif ++ ++#if 0 /* DEFINE_NO_DEPRECATED */ ++# ifndef CBOR_NO_DEPRECATED ++# define CBOR_NO_DEPRECATED ++# endif ++#endif ++ ++#endif /* CBOR_EXPORT_H */ +diff --git a/cbor/configuration.h b/cbor/configuration.h +new file mode 100644 +index 0000000..aee03db +--- /dev/null ++++ b/cbor/configuration.h +@@ -0,0 +1,15 @@ ++#ifndef LIBCBOR_CONFIGURATION_H ++#define LIBCBOR_CONFIGURATION_H ++ ++#define CBOR_MAJOR_VERSION 0 ++#define CBOR_MINOR_VERSION 10 ++#define CBOR_PATCH_VERSION 2 ++ ++#define CBOR_BUFFER_GROWTH 2 ++#define CBOR_MAX_STACK_SIZE 2048 ++#define CBOR_PRETTY_PRINTER 1 ++ ++#define CBOR_RESTRICT_SPECIFIER restrict ++#define CBOR_INLINE_SPECIFIER ++ ++#endif //LIBCBOR_CONFIGURATION_H diff --git a/tools/benchmarking/README.md b/tools/benchmarking/README.md index dcad8f6a..191a9482 100644 --- a/tools/benchmarking/README.md +++ b/tools/benchmarking/README.md @@ -61,25 +61,28 @@ Usage: Flags: +- `--help` Display full list of flags. + - `--server-address` Required. gRPC host and port. Example: `--server-address my-server:8443` -- `--snapshot-dir` +- `--snapshot-dir`, `--snapshot-csv-dir`, or `lookup-keys-file` + + One of these options is required. - Required if `--snapshot-csv-dir` is not provided. Full path to a directory of snapshot files. - These snapshot files are converted to CSVs using the - [`data cli`](/docs/data_loading/loading_data.md). The keys from the snapshot file are used to - create requests. + - `--snapshot-dir`: Full path to a directory of snapshot files. These snapshot files are + converted to CSVs using the [`data cli`](/docs/data_loading/loading_data.md). The keys from + the snapshot file are used to create requests. -- `--snapshot-csv-dir` + - `--snapshot-csv-dir`: Full path to a directory of only snapshot CSV files (i.e. converted + using the [`data cli`](/docs/data_loading/loading_data.md)). This avoids doing the + conversion in the tool, which saves some time for multiple runs on the same snapshot files. - Required if `--snapshot-dir` is not provided. Full path to a directory of only snapshot CSV - files (i.e. converted using the [`data cli`](/docs/data_loading/loading_data.md)). This avoids - doing the conversion in the tool, which saves some time for multiple runs on the same snapshot - files. + - `--lookup-keys-file`: Full path to file with lookup keys. Each line should have one key. + Ignores the `filter-snapshot-by-sets` option. - `--number-of-lookup-keys-list` (Optional) @@ -177,10 +180,11 @@ Flags: Example: "" -- `--snapshot-dir` +- `--snapshot-dir` or `--lookup-keys-file` - Required. Full path to a directory of snapshot files. The keys from the snapshot file are used - to create requests. + One of these options is required. Full path to a directory of snapshot files or a file with + lookup keys. The keys are used to create requests. For the `lookup-keys-file`, each line should + have one key. Ignores the `filter-snapshot-by-sets` option. - `--csv-output` (Optional) diff --git a/tools/benchmarking/run_benchmarks b/tools/benchmarking/run_benchmarks index 9305d675..09d06656 100755 --- a/tools/benchmarking/run_benchmarks +++ b/tools/benchmarking/run_benchmarks @@ -129,11 +129,11 @@ function run_ghz_for_requests_from_file() { printf "Running ghz for number of keys %s\n" "${N}" BASE_GHZ_TAGS=$( - jq -n --arg n "${N}" --arg f "${FILENAME}" \ + builders/tools/jq -n --arg n "${N}" --arg f "${FILENAME}" \ '{"number_of_lookup_keys": $n, "keys_from_file": $f}' ) - REQUEST_METADATA_TAGS=$(echo "${REQUEST_METADATA}" | jq 'with_entries(.key |= "request_metadata."+.) | with_entries( .value |= @json)') - TAGS=$(echo "${GHZ_TAGS} ${BASE_GHZ_TAGS} ${REQUEST_METADATA_TAGS} ${FILTER_BY_SETS_JSON}" | jq -s -c 'add') + REQUEST_METADATA_TAGS=$(echo "${REQUEST_METADATA}" | builders/tools/jq 'with_entries(.key |= "request_metadata."+.) | with_entries( .value |= @json)') + TAGS=$(echo "${GHZ_TAGS} ${BASE_GHZ_TAGS} ${REQUEST_METADATA_TAGS} ${FILTER_BY_SETS_JSON}" | builders/tools/jq -s -c 'add') GHZ_OUTPUT_JSON_FILE="${DIR}/ghz_output.json" EXTRA_DOCKER_RUN_ARGS+=" --volume ${BASE_OUTPUT_DIR}:${DOCKER_OUTPUT_DIR} --volume ${WORKSPACE}:/src/workspace" \ builders/tools/ghz --protoset /src/workspace/dist/query_api_descriptor_set.pb \ diff --git a/tools/request_simulation/BUILD.bazel b/tools/request_simulation/BUILD.bazel index ffd6eb7d..eb12feaa 100644 --- a/tools/request_simulation/BUILD.bazel +++ b/tools/request_simulation/BUILD.bazel @@ -67,6 +67,7 @@ cc_library( name = "grpc_client", hdrs = ["grpc_client.h"], deps = [ + "//components/data_server/request_handler:get_values_v2_handler", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status:statusor", diff --git a/tools/request_simulation/client_worker.h b/tools/request_simulation/client_worker.h index d4c82ae4..d2d3cef8 100644 --- a/tools/request_simulation/client_worker.h +++ b/tools/request_simulation/client_worker.h @@ -123,9 +123,10 @@ void ClientWorker::SendRequests() { metrics_collector_.IncrementRequestSentPerInterval(); auto start = absl::Now(); std::shared_ptr response = std::make_shared(); + std::shared_ptr request = std::make_shared( + request_converter_(request_body.value())); auto status = - grpc_client_->SendMessage(request_converter_(request_body.value()), - service_method_, response); + grpc_client_->SendMessage(request, service_method_, response); metrics_collector_.IncrementServerResponseStatusEvent(status); if (!status.ok()) { VLOG(8) << "Received error in response " << status; diff --git a/tools/request_simulation/grpc_client.h b/tools/request_simulation/grpc_client.h index a29762ff..7fc2d286 100644 --- a/tools/request_simulation/grpc_client.h +++ b/tools/request_simulation/grpc_client.h @@ -24,6 +24,7 @@ #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "absl/synchronization/notification.h" +#include "components/data_server/request_handler/get_values_v2_handler.h" #include "grpcpp/generic/generic_stub.h" #include "grpcpp/grpcpp.h" #include "src/google/protobuf/message.h" @@ -130,7 +131,7 @@ class GrpcClient { // Sends message via grpc unary call. The request method is the // api name supported by the grpc service, an example method name is // "/PackageName.ExampleService/APIName". - absl::Status SendMessage(const RequestT& request, + absl::Status SendMessage(std::shared_ptr request, const std::string& request_method, std::shared_ptr response) { if (is_client_channel_ && @@ -141,11 +142,14 @@ class GrpcClient { std::make_shared(); std::shared_ptr client_context = std::make_shared(); + client_context->AddMetadata(std::string(kContentTypeHeader), + std::string(kContentEncodingJsonHeaderValue)); std::shared_ptr grpc_status = std::make_shared(); generic_stub_->UnaryCall( - client_context.get(), request_method, grpc::StubOptions(), &request, - response.get(), [notification, grpc_status](grpc::Status status) { + client_context.get(), request_method, grpc::StubOptions(), + request.get(), response.get(), + [notification, grpc_status](grpc::Status status) { grpc_status->Update(absl::Status( absl::StatusCode(status.error_code()), status.error_message())); notification->Notify(); diff --git a/tools/request_simulation/grpc_client_test.cc b/tools/request_simulation/grpc_client_test.cc index 0bc2f84a..206ca5b9 100644 --- a/tools/request_simulation/grpc_client_test.cc +++ b/tools/request_simulation/grpc_client_test.cc @@ -59,8 +59,8 @@ TEST_F(GrpcClientTest, TestRequestOKResponse) { std::string key("key"); std::string method("/kv_server.v2.KeyValueService/GetValuesHttp"); auto response_ptr = std::make_shared(); - auto response = - grpc_client_->SendMessage(request_converter_(key), method, response_ptr); + auto request_ptr = std::make_shared(request_converter_(key)); + auto response = grpc_client_->SendMessage(request_ptr, method, response_ptr); EXPECT_TRUE(response.ok()); EXPECT_EQ(response_ptr->data(), "value"); } @@ -69,8 +69,8 @@ TEST_F(GrpcClientTest, TestRequestErrorResponse) { std::string key("missing"); std::string method("/kv_server.v2.KeyValueService/GetValuesHttp"); auto response_ptr = std::make_shared(); - auto response = - grpc_client_->SendMessage(request_converter_(key), method, response_ptr); + auto request_ptr = std::make_shared(request_converter_(key)); + auto response = grpc_client_->SendMessage(request_ptr, method, response_ptr); EXPECT_FALSE(response.ok()); } diff --git a/tools/request_simulation/metrics_collector.h b/tools/request_simulation/metrics_collector.h index 5b552643..11547c6a 100644 --- a/tools/request_simulation/metrics_collector.h +++ b/tools/request_simulation/metrics_collector.h @@ -71,9 +71,9 @@ inline constexpr absl::Span< kRequestSimulationMetricsSpan = kRequestSimulationMetricsList; inline auto* RequestSimulationContextMap( - std::optional< + std::unique_ptr< privacy_sandbox::server_common::telemetry::BuildDependentConfig> - config = std::nullopt, + config = nullptr, std::unique_ptr provider = nullptr, absl::string_view service = "Request-simulation", absl::string_view version = "") { diff --git a/tools/request_simulation/request_simulation_system.cc b/tools/request_simulation/request_simulation_system.cc index 039e104e..7dfa8fd8 100644 --- a/tools/request_simulation/request_simulation_system.cc +++ b/tools/request_simulation/request_simulation_system.cc @@ -190,14 +190,12 @@ absl::Status RequestSimulationSystem::Init( metrics_collector == nullptr ? std::make_unique(std::make_unique()) : std::move(metrics_collector); - // Initialize no-op telemetry for the new Telemetry API - // TODO(b/304306398): deprecate metric recorder and use new telemetry API to - // log metrics privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; config_proto.set_mode( privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( + std::make_unique< + privacy_sandbox::server_common::telemetry::BuildDependentConfig>( config_proto)); if (auto status = InitializeGrpcClientWorkers(); !status.ok()) { @@ -288,17 +286,6 @@ absl::Status RequestSimulationSystem::InitializeGrpcClientWorkers() { auto channel = channel_creation_fn_(server_address_, absl::GetFlag(FLAGS_server_auth_mode)); bool is_client_channel = absl::GetFlag(FLAGS_is_client_channel); - - if (is_client_channel) { - RetryUntilOk( - [channel]() { - if (channel->GetState(true) != GRPC_CHANNEL_READY) { - return absl::UnavailableError("GRPC channel is disconnected"); - } - return absl::OkStatus(); - }, - "check grpc connection in start up", LogMetricsNoOpCallback()); - } auto request_timeout = absl::GetFlag(FLAGS_request_timeout); for (int i = 0; i < num_of_workers; ++i) { auto request_converter = [](const std::string& request_body) { @@ -420,7 +407,8 @@ void RequestSimulationSystem::InitializeTelemetry() { config_proto.set_mode( privacy_sandbox::server_common::telemetry::TelemetryConfig::EXPERIMENT); auto* context_map = RequestSimulationContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( + std::make_unique< + privacy_sandbox::server_common::telemetry::BuildDependentConfig>( config_proto), ConfigurePrivateMetrics(resource, metrics_options)); } diff --git a/tools/server_diagnostic/BUILD.bazel b/tools/server_diagnostic/BUILD.bazel index 99ad3778..6332e5af 100644 --- a/tools/server_diagnostic/BUILD.bazel +++ b/tools/server_diagnostic/BUILD.bazel @@ -13,7 +13,7 @@ # limitations under the License. load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") -load("@rules_oci//oci:defs.bzl", "oci_image", "oci_tarball") +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load") load( "@rules_pkg//pkg:mappings.bzl", "pkg_attributes", @@ -22,30 +22,24 @@ load( load("@rules_pkg//pkg:tar.bzl", "pkg_tar") go_library( - name = "diagnostic_lib", - srcs = ["diagnostic.go"], - importpath = "tools/server_diagnostic", - visibility = ["//visibility:private"], + name = "common", + srcs = ["common.go"], + importpath = "tools/server_diagnostic/common", + visibility = ["//tools/server_diagnostic:__subpackages__"], ) -go_binary( - name = "diagnostic", - embed = [":diagnostic_lib"], - visibility = ["//visibility:public"], +go_library( + name = "diagnostic_aws_lib", + srcs = ["diagnostic_aws.go"], + importpath = "tools/server_diagnostic", + visibility = ["//tools/server_diagnostic:__subpackages__"], + deps = [":common"], ) -pkg_files( +go_binary( name = "diagnostic_cli", - srcs = [ - ":diagnostic", - ], - attributes = pkg_attributes(mode = "0555"), - prefix = "/tools/diagnostic_cli", -) - -pkg_tar( - name = "diagnostic_tar", - srcs = [":diagnostic_cli"], + embed = [":diagnostic_aws_lib"], + visibility = ["//visibility:public"], ) pkg_files( @@ -122,11 +116,10 @@ pkg_files( [ oci_image( - name = "diagnostic_tools_image_{}".format(arch), - base = "@runtime-ubuntu-fulldist-debug-root-{}//image".format(arch), + name = "diagnostic_tool_box_image_{}".format(arch), + base = "@runtime-ubuntu-fulldist-debug-root-{}".format(arch), tars = [ ":helloworld_server_binaries_tar", - ":diagnostic_tar", ":query_api_descriptor_set_tar", ":grpcurl_tar_{}".format(arch), ], @@ -138,10 +131,22 @@ pkg_files( ] [ - oci_tarball( - name = "diagnostic_tools_docker_image_{}".format(arch), - image = ":diagnostic_tools_image_{}".format(arch), - repo_tags = ["bazel/tools/server_diagnostic:diagnostic_tools_docker_image"], + oci_load( + name = "diagnostic_tool_box_docker_image_{}".format(arch), + image = ":diagnostic_tool_box_image_{}".format(arch), + repo_tags = ["bazel/tools/server_diagnostic:diagnostic_tool_box_docker_image"], + ) + for arch in [ + "arm64", + "amd64", + ] +] + +[ + filegroup( + name = "diagnostic_tool_box_docker_image_{}.tar".format(arch), + srcs = [":diagnostic_tool_box_docker_image_{}".format(arch)], + output_group = "tarball", ) for arch in [ "arm64", @@ -152,18 +157,20 @@ pkg_files( genrule( name = "copy_to_dist", srcs = [ - ":diagnostic_tools_docker_image_arm64", - ":diagnostic_tools_docker_image_amd64", + ":diagnostic_cli", + ":diagnostic_tool_box_docker_image_arm64.tar", + ":diagnostic_tool_box_docker_image_amd64.tar", ], outs = ["copy_to_dist.bin"], cmd_bash = """cat << EOF > '$@' -mkdir -p dist/tools/arm64/server_diagnostic -cp $(execpath :diagnostic_tools_docker_image_arm64) dist/tools/arm64/server_diagnostic/diagnostic_tools_docker_image_arm64.tar -mkdir -p dist/tools/amd64/server_diagnostic -cp $(execpath :diagnostic_tools_docker_image_amd64) dist/tools/amd64/server_diagnostic/diagnostic_tools_docker_image_amd64.tar +mkdir -p dist/tools/server_diagnostic/arm64 +cp $(execpath :diagnostic_tool_box_docker_image_arm64.tar) dist/tools/server_diagnostic/arm64/diagnostic_tool_box_docker_image_arm64.tar +mkdir -p dist/tools/server_diagnostic/amd64 +cp $(execpath :diagnostic_tool_box_docker_image_amd64.tar) dist/tools/server_diagnostic/amd64/diagnostic_tool_box_docker_image_amd64.tar +cp $(execpath :diagnostic_cli) dist/tools/server_diagnostic/diagnostic_cli builders/tools/normalize-dist EOF""", executable = True, local = True, - message = "Copying server diagnostic artifacts to dist/tools/arm64/server_diagnostic and dist/tools/amd64/server_diagnostic directories", + message = "Copying server diagnostic artifacts to dist/tools/server_diagnostic", ) diff --git a/tools/server_diagnostic/README.md b/tools/server_diagnostic/README.md new file mode 100644 index 00000000..f8bf937f --- /dev/null +++ b/tools/server_diagnostic/README.md @@ -0,0 +1,180 @@ +# Diagnostic Tool + +The KV server's out-of-box build and terraform solution will handle all the setups required for the +server cloud deployment. However, Adtechs may prefer to use their own server deployment solution +tailored to their production needs. This diagnostic tool is developed to help Adtechs identify and +resolve setup issues that prevent KV server from running properly on their end. + +The diagnostic tool only supports AWS environment for now. The tool will need to be deployed and run +on the same EC2 machine where the KV server in enclave is running. The tool will perform several +checks including server health checks and system dependency checks etc, and prints the summary in +the end. + +## Check steps performed by the diagnostic tool + +1. Enclave server health check. If any of checks in this step is failed, move on to next steps + - Check if enclave is running + - Check grpc request against KV server + - Check http request against KV server +2. System dependencies check. All checks must be passed for the KV server to start and process grpc + requests. If all checks passed in this step, move on to next step + - VSock proxy check + - VPC CIDR and etc/resolv.conf checks. More details about required setup for DNS resolution in + [private_communication_aws.md](/docs/private_communication_aws.md) +3. Envoy proxy checks. Envoy proxy is not required for the KV server to run and process grpc + requests, but is required for KV server to receive http requests. All checks must be passed in + this step for the Envoy proxy to successfully forward the http requests to the KV server. + - Envoy dependency checks + - Test Envoy proxy with hello world grpc server +4. Otel collector checks. Otel collector is not required for the KV server to run and process http + and grpc requests. All checks must be passed in this step for the KV server to export metrics and + consented logs to AWS Cloudwatch +5. Run KV server outside of enclave to check if server can start and process grpc and http requests. + If all checks in this step are passed, it means that the server functionalities work as expected. + +If all the checks besides the enclave server health check are passed, it means something else might +be wrong to prevent the server running properly inside enclave. Adtechs will need to check the +infrastructure setup on their end and provide relevant information for further assistance. Adtechs +can also run +[CPIO validator](https://github.com/privacysandbox/data-plane-shared-libraries/blob/main/docs/cpio/validator.md) +to further troubleshoot specific configurations like parameter fetching and DNS config etc. + +## Build the tool + +From the KV server repo root, run the build command + +```shell +builders/tools/bazel-debian run //tools/server_diagnostic:copy_to_dist +``` + +The build command will generate the following binaries: + +1. The diagnostic_cli go binary + - dist/tools/server_diagnostic/diagnostic_cli. +2. The diagnostic tool box docker image. The tool box contains the hello world grpc server and + grpcurl that are required by the diagnostic tool to perform checks + - amd64:dist/tools/server_diagnostic/amd64/diagnostic_tool_box_docker_image_amd64.tar + - arm64:dist/tools/server_diagnostic/arm64/diagnostic_tool_box_docker_image_arm64.tar + +## Deploy the tool binaries to AWS EC2 + +The tool will need to be deployed and run on the same EC2 machine where the KV server enclave is +running. Follow the similar steps described in this +[doc](/docs/developing_the_server.md#develop-and-run-the-server-inside-aws-enclave) to scp the tool +binaries to EC2 ssh instance then to the server's EC2 instance. + +Here are the steps and example commands (assume EC2 instance is running AMD architecture): + +1.Send public key to EC2 ssh machine to establish connection (Skip this if there is no EC2 ssh +instance setup) + +```shell +# Send public key to the EC2 ssh instance +aws ec2-instance-connect send-ssh-public-key --instance-id --availability-zone --instance-os-user ec2-user --ssh-public-key file://my_key.pub --region +``` + +2.Copy diagnostic tool binaries to EC2 ssh instance (Skip this if there is no EC2 ssh instance +setup) + +```shell +# The EC2_ADDR for scp'ing from the public internet to ssh instance is the Public IPv4 DNS, e.g., ec2-3-81-186-232.compute-1.amazonaws.com +# Copy the diagnostic_cli to the EC2 ssh instance /home/ec2-user directory +scp -o "IdentitiesOnly=yes" -i ./my_key dist/tools/server_diagnostic/diagnostic_cli ec2-user@{EC2_ADDR}:/home/ec2-user +# Copy the diagnostic tool box docker container to the EC2 ssh instance /home/ec2-user directory +scp -o "IdentitiesOnly=yes" -i ./my_key dist/tools/server_diagnostic/amd64/diagnostic_tool_box_docker_image_amd64.tar ec2-user@{EC2_ADDR}:/home/ec2-user +``` + +3.From EC2 ssh instance, send public key to server's EC2 machine to establish connection + +```shell +aws ec2-instance-connect send-ssh-public-key --instance-id --availability-zone --instance-os-user ec2-user --ssh-public-key file://my_key.pub --region +``` + +4.Copy diagnostic tool binaries from EC2 ssh instance to server's EC2 instance + +```shell +# The EC2_ADDR for scp'ing from the ssh instance is the Private IP DNS name e.g., ip-10-0-226-225.ec2.internal +# Copy the diagnostic_cli to the EC2 instance /home/ec2-user directory +scp -o "IdentitiesOnly=yes" -i ./my_key /home/ec2-user/diagnostic_cli ec2-user@{EC2_ADDR}:/home/ec2-user +# Copy the diagnostic tool box docker container to the EC2 instance /home/ec2-user directory +scp -o "IdentitiesOnly=yes" -i ./my_key /home/ec2-user/diagnostic_tool_box_docker_image_amd64.tar ec2-user@{EC2_ADDR}:/home/ec2-user +``` + +## Run the diagnostic tool + +Help command to see all available flags and their default values + +```shell +./diagnostic_cli -help +``` + +Example command to run the diagnostic tool + +```shell +./diagnostic_cli -environment= -region=us-east-1 +``` + +## Example summary report + +```txt +-------------------------------------SUMMARY-------------------------------------------- + +-----------------------------------INPUT FLAGS------------------------------------------ + +environment: demo +envoy_config: /etc/envoy/envoy.yaml +envoy_protoset: /etc/envoy/query_api_descriptor_set.pb +help: false +otel_collector_config: /opt/aws/aws-otel-collector/etc/otel_collector_config.yaml +otel_collector_ctl: /opt/aws/aws-otel-collector/bin/aws-otel-collector-ctl +region: us-east-1 +server_docker_image: /home/ec2-user/server_docker_image.tar +server_enclave_image: /opt/privacysandbox/server_enclave_image.eif +server_grpc_port: 50051 +server_http_port: 51052 +tool_output_dir: ./tools/output +toolbox_docker_image: ./diagnostic_tool_box_docker_image_amd64.tar +verbosity: 1 +vsock_proxy_file: /opt/privacysandbox/proxy +----------------------------ENCLAVE SERVER HEALTH CHECKS-------------------------------- +Are all checks passed? false + +OK. Server binary exists. +FAILED. Test server was running. +FAILED. Test grpc requests working. +FAILED. Test http requests working. + +----------------------------SYSTEM DEPENDENCY CHECKS------------------------------------ +All checks must be passed for server to start and process grpc requests + +OK. /opt/privacysandbox/proxy exists. +OK. vsockproxy.service is enabled. +OK. /etc/resolv.conf exists. +OK. VPC Ipv4 Cidr matches nameserver in /etc/resolv.conf. + +-------------------------------ENVOY PROXY CHECKS--------------------------------------- +All checks must be passed for server to process http requests + +OK. /etc/envoy/query_api_descriptor_set.pb exists. +OK. /etc/envoy/envoy.yaml exists. +OK. Envoy is running. +OK. Envoy is listening the http port 51052. +OK. Test Envoy with hello world grpc server. +Hello world server log is located in ./tools/output/hello_world_server.log + +------------------------------OTEL COLLECTOR CHECKS-------------------------------------- +All checks must be passed for server to export metrics + +OK. /opt/aws/aws-otel-collector/bin/aws-otel-collector-ctl exists. +OK. /opt/aws/aws-otel-collector/etc/otel_collector_config.yaml exists. +OK. aws-otel-collector.service is enabled. + +------------------------OUTSIDE OF ENCLAVE SERVER HEALTH CHECKS-------------------------- +Are all checks passed? true + +OK. Server binary exists. +OK. Test server was running. +OK. Test grpc requests working. +OK. Test http requests working. +Test server log is located in ./tools/output/kv_server_docker.log +``` diff --git a/tools/server_diagnostic/common.go b/tools/server_diagnostic/common.go new file mode 100644 index 00000000..80926206 --- /dev/null +++ b/tools/server_diagnostic/common.go @@ -0,0 +1,180 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "bytes" + "fmt" + "io" + "os" + "os/exec" + "strings" +) + +const ( + systemctl = "/usr/bin/systemctl" +) + +// Status to return whether an operation executes okay or not, and +// error if any +type Status struct { + Ok bool + Err error +} + +// Converts the status result to string +func (s *Status) Result() string { + if s.Ok { + return "OK" + } + return "FAILED" +} + +// Returns formatted error message in string for a given status +func (s *Status) Error() string { + if s.Ok || s.Err == nil { + return "" + } + return fmt.Sprintf("ERROR: %v", s.Err) +} + +// Checks if a given process is listening to a given port +func CheckProcessListeningToPort(process string, port int) Status { + output, err := ExecuteBashCommand(fmt.Sprintf("sudo netstat -nlp | grep :%d | grep %s", port, process)) + if err != nil { + return Status{false, err} + } + if len(output) == 0 { + return Status{false, fmt.Errorf("no process %s is listening to the port %d", process, port)} + } + return Status{true, nil} +} + +// Stops and removes a docker container, returns output and error if any +func StopAndRemoveDockerContainer(name string) (string, error) { + output, err := ExecuteShellCommand("docker", []string{"stop", name}) + if err != nil { + return output, err + } + return RemoveDockerContainer(name) +} + +// Executes bash command and returns output and error if any +func ExecuteBashCommand(cmd string) (string, error) { + return ExecuteShellCommand("bash", []string{"-c", cmd}) +} + +// Executes simple shell command, prints to the console the execution output, +// and returns output and error if any +func ExecuteShellCommand(command string, args []string) (string, error) { + cmd := exec.Command(command, args...) + fmt.Println(fmt.Sprintf("\nExecuting command: %s \n", cmd.String())) + var stdBuffer bytes.Buffer + writer := io.MultiWriter(os.Stdout, &stdBuffer) + cmd.Stdout = writer + cmd.Stderr = writer + if err := cmd.Run(); err != nil { + return stdBuffer.String(), fmt.Errorf("error %v: %s", err, stdBuffer.String()) + } + return stdBuffer.String(), nil +} + +// Removes a docker container, returns the output and error if any +func RemoveDockerContainer(name string) (string, error) { + containerId, _ := ExecuteBashCommand(fmt.Sprintf("docker ps -a | grep %s | awk '{print $1}'", name)) + containerId = strings.TrimSpace(containerId) + if len(containerId) != 0 { + output, err := ExecuteShellCommand("docker", []string{"rm", containerId}) + if err != nil { + fmt.Printf("Error removing docker container for %s, %s \n", name, containerId) + } + return output, err + } + return "", nil +} + +// Checks if a docker container is running +func CheckDockerContainerIsRunning(name string) Status { + output, err := ExecuteBashCommand(fmt.Sprintf("docker ps --no-trunc | grep %s | awk '{print $2}'", name)) + if err != nil { + fmt.Printf("Error executing docker ps for name %s! %v. Output %s", name, err, output) + return Status{false, err} + } + if len(output) == 0 { + return Status{false, fmt.Errorf("no docker process is found for %s", name)} + } + return Status{true, nil} +} + +// Writes a docker process's logs to a given log file, the process can be running or a stopped docker process +func WriteDockerProcessLogs(name string, log string) Status { + containerId, err := ExecuteBashCommand(fmt.Sprintf("docker ps -a | grep %s | awk '{print $1}'", name)) + if err != nil { + return Status{false, err} + } + containerId = strings.TrimSpace(containerId) + _, err = ExecuteBashCommand(fmt.Sprintf("docker logs %s > %s 2>&1 --region=us-east-1 | jq -r .Vpcs[].CidrBlock + fmt.Printf("Reading VPC CIDR block for deployed environment %s and region %s \n", *environmentFlag, *regionFlag) + output, err := common.ExecuteBashCommand( + fmt.Sprintf("aws ec2 describe-vpcs --filters Name=tag:environment,Values=%s --region=%s | jq -r .Vpcs[].CidrBlock", *environmentFlag, *regionFlag)) + if err != nil { + return common.Status{false, err} + } + cidrBlock := strings.Split(output, "/") + if len(cidrBlock) == 0 || len(cidrBlock[0]) == 0 { + return common.Status{false, errors.New("empty Cidr block read from aws describe-vpcs")} + } + cidrDomainFields := strings.Split(cidrBlock[0], ".") + if len(cidrDomainFields) != 4 { + return common.Status{false, fmt.Errorf("invalid cidr domain in the cidr block %s, expected 4 fields seperated by periods", cidrBlock)} + } + cidrDomainLastField, err := strconv.Atoi(cidrDomainFields[3]) + if err != nil { + return common.Status{false, fmt.Errorf("error parsing last field of cidr domain", err.Error())} + } + // Read nameservers from /etc/resolv.conf + fmt.Printf("Reading nameservers from %s \n", resolvConf) + nameserverResult, err := common.ExecuteBashCommand(fmt.Sprintf("cat %s | grep -v '^#' | grep nameserver | awk '{print $2}'", resolvConf)) + if err != nil { + return common.Status{false, err} + } + nameservers := strings.Split(nameserverResult, "\n") + if len(nameservers) == 0 { + return common.Status{false, fmt.Errorf("empty nameserver defined in the %s", resolvConf)} + } + fmt.Printf("Checking if cidr block domain matches the nameserver in %s \n", resolvConf) + for _, nameserver := range nameservers { + nameserverDomainFields := strings.Split(nameserver, ".") + if len(nameserverDomainFields) == 4 { + if nameserverDomainFields[0] == cidrDomainFields[0] && nameserverDomainFields[1] == cidrDomainFields[1] && nameserverDomainFields[2] == cidrDomainFields[2] && nameserverDomainFields[3] == strconv.Itoa((cidrDomainLastField+2)) { + return common.Status{true, nil} + } + } + } + return common.Status{false, fmt.Errorf("the nameserver %s defined in the %s does not match domain name %s in the cidr block", nameserverResult, resolvConf, cidrBlock[0])} +} + +// Checks grpc request against test server +func checkGrpcRequest() common.Status { + encodedPayload := base64.StdEncoding.EncodeToString([]byte(v2Payload)) + args := fmt.Sprintf("--protoset %s -d '{\"raw_body\": {\"data\": \"%s\"}}' --plaintext localhost:%d %s", protoSetForTest, encodedPayload, *serverGrpcPortFlag, v2APIEndpoint) + status := runToolboxDockerCommand(grpcurl, "grpcurl", args, false) + common.RemoveDockerContainer("grpcurl") + if status.Err != nil { + return status + } + return common.Status{true, nil} +} + +// Check http request against test server +func checkHttpRequest() common.Status { + args := []string{"-vX", "PUT", "-d", v2Payload, fmt.Sprintf("http://localhost:%d", *serverHttpPortFlag) + "/v2/getvalues"} + output, err := common.ExecuteShellCommand(curl, args) + if err != nil { + return common.Status{false, fmt.Errorf("http request failed! response: %s, error: %v", output, err.Error())} + } + return common.Status{true, nil} +} + +// Executes tool box docker commands from different entry-points +func runToolboxDockerCommand(entrypoint string, name string, args string, detach bool) common.Status { + var cmd strings.Builder + cmd.WriteString("docker run ") + if detach { + cmd.WriteString("--detach ") + } + cmd.WriteString(fmt.Sprintf("--entrypoint=%s --network host --name %s --add-host=host.docker.internal:host-gateway %s %s", entrypoint, name, toolBoxContainerTag, args)) + _, err := common.ExecuteBashCommand(cmd.String()) + if err != nil { + return common.Status{false, err} + } + return common.Status{true, nil} +} diff --git a/tools/udf/inline_wasm/examples/get_values_binary_proto/my_udf.js b/tools/udf/inline_wasm/examples/get_values_binary_proto/my_udf.js index 285d4bbd..fbb998e6 100644 --- a/tools/udf/inline_wasm/examples/get_values_binary_proto/my_udf.js +++ b/tools/udf/inline_wasm/examples/get_values_binary_proto/my_udf.js @@ -16,11 +16,11 @@ async function HandleRequest(executionMetadata, ...udf_arguments) { const module = await getModule(); - console.log('Done loading WASM Module'); + logMessage('Done loading WASM Module'); // Pass in the getValuesBinary function for the C++ code to call. // getValuesBinary returns a Uint8Array, which emscripten converts to std::string const result = module.handleRequestCc(getValuesBinary, udf_arguments); - console.log('handleRequestCc result: ' + JSON.stringify(result)); + logMessage('handleRequestCc result: ' + JSON.stringify(result)); return result; } diff --git a/tools/udf/inline_wasm/examples/hello_world/my_udf.js b/tools/udf/inline_wasm/examples/hello_world/my_udf.js index e14ef923..81a9a3c7 100644 --- a/tools/udf/inline_wasm/examples/hello_world/my_udf.js +++ b/tools/udf/inline_wasm/examples/hello_world/my_udf.js @@ -41,9 +41,9 @@ function getKeyGroupOutputs(udf_arguments, module) { } async function HandleRequest(executionMetadata, ...udf_arguments) { - console.log('Handling request'); + logMessage('Handling request'); const module = await getModule(); - console.log('Done loading WASM Module'); + logMessage('Done loading WASM Module'); const keyGroupOutputs = getKeyGroupOutputs(udf_arguments, module); return { keyGroupOutputs, udfOutputApiVersion: 1 }; } diff --git a/tools/udf/inline_wasm/examples/js_call/my_udf.js b/tools/udf/inline_wasm/examples/js_call/my_udf.js index a156d031..038a5503 100644 --- a/tools/udf/inline_wasm/examples/js_call/my_udf.js +++ b/tools/udf/inline_wasm/examples/js_call/my_udf.js @@ -15,12 +15,12 @@ */ async function HandleRequest(executionMetadata, ...input) { - console.log('Handling request'); + logMessage('Handling request'); const module = await getModule(); - console.log('Done loading WASM Module'); + logMessage('Done loading WASM Module'); // Pass in the getValues function for the C++ code to call. const result = module.handleRequestCc(getValues, input); - console.log('handleRequestCc result: ' + JSON.stringify(result)); + logMessage('handleRequestCc result: ' + JSON.stringify(result)); return result; } diff --git a/tools/udf/inline_wasm/examples/protobuf/my_udf.js b/tools/udf/inline_wasm/examples/protobuf/my_udf.js index 57a5145a..89d97825 100644 --- a/tools/udf/inline_wasm/examples/protobuf/my_udf.js +++ b/tools/udf/inline_wasm/examples/protobuf/my_udf.js @@ -42,9 +42,9 @@ function getKeyGroupOutputs(udf_arguments, module) { } async function HandleRequest(executionMetadata, ...udf_arguments) { - console.log('Handling request'); + logMessage('Handling request'); const module = await getModule(); - console.log('Done loading WASM Module'); + logMessage('Done loading WASM Module'); const keyGroupOutputs = getKeyGroupOutputs(udf_arguments, module); return { keyGroupOutputs, udfOutputApiVersion: 1 }; } diff --git a/tools/udf/sample_udf/BUILD.bazel b/tools/udf/sample_udf/BUILD.bazel index 920ce70c..02f9f174 100644 --- a/tools/udf/sample_udf/BUILD.bazel +++ b/tools/udf/sample_udf/BUILD.bazel @@ -49,18 +49,35 @@ run_binary( ) run_binary( - name = "generate_run_set_query_int_delta", + name = "generate_run_set_query_uint32_delta", srcs = [ - ":run_set_query_int_udf.js", + ":run_set_query_uint32_udf.js", ], outs = [ "DELTA_0000000000000007", ], args = [ "--udf_file_path", - "$(location :run_set_query_int_udf.js)", + "$(location :run_set_query_uint32_udf.js)", "--output_path", "$(location DELTA_0000000000000007)", ], tool = "//tools/udf/udf_generator:udf_delta_file_generator", ) + +run_binary( + name = "generate_run_set_query_uint64_delta", + srcs = [ + ":run_set_query_uint64_udf.js", + ], + outs = [ + "DELTA_0000000000000008", + ], + args = [ + "--udf_file_path", + "$(location :run_set_query_uint64_udf.js)", + "--output_path", + "$(location DELTA_0000000000000008)", + ], + tool = "//tools/udf/udf_generator:udf_delta_file_generator", +) diff --git a/tools/udf/sample_udf/run_set_query_int_udf.js b/tools/udf/sample_udf/run_set_query_uint32_udf.js similarity index 96% rename from tools/udf/sample_udf/run_set_query_int_udf.js rename to tools/udf/sample_udf/run_set_query_uint32_udf.js index ef405476..424c73ff 100644 --- a/tools/udf/sample_udf/run_set_query_int_udf.js +++ b/tools/udf/sample_udf/run_set_query_uint32_udf.js @@ -27,7 +27,7 @@ function HandleRequest(executionMetadata, ...input) { } // Get the first key in the data. - const runQueryArray = runSetQueryInt(keyGroup.data[0]); + const runQueryArray = runSetQueryUInt32(keyGroup.data[0]); // runSetQueryInt returns an Uint8Array of 'uint32' ints and "code" on failure. // Ignore failures and only add successful runQuery results to output. if (runQueryArray instanceof Uint8Array) { diff --git a/tools/udf/sample_udf/run_set_query_uint64_udf.js b/tools/udf/sample_udf/run_set_query_uint64_udf.js new file mode 100644 index 00000000..268e2e7b --- /dev/null +++ b/tools/udf/sample_udf/run_set_query_uint64_udf.js @@ -0,0 +1,43 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +function HandleRequest(executionMetadata, ...input) { + let keyGroupOutputs = []; + for (const keyGroup of input) { + let keyGroupOutput = {}; + if (!keyGroup.tags.includes('custom') || !keyGroup.tags.includes('queries')) { + continue; + } + keyGroupOutput.tags = keyGroup.tags; + if (!Array.isArray(keyGroup.data) || !keyGroup.data.length) { + continue; + } + + // Get the first key in the data. + const runQueryArray = runSetQueryUInt64(keyGroup.data[0]); + // runSetQueryInt returns an Uint8Array of 'uint64' ints and "code" on failure. + // Ignore failures and only add successful runQuery results to output. + if (runQueryArray instanceof Uint8Array) { + const keyValuesOutput = {}; + const uint64Array = new BigUint64Array(bytes.buffer); + const value = Array.from(uint64Array, (uint64) => uint64.toString()); + keyValuesOutput['result'] = { value: value }; + keyGroupOutput.keyValues = keyValuesOutput; + keyGroupOutputs.push(keyGroupOutput); + } + } + return { keyGroupOutputs, udfOutputApiVersion: 1 }; +} diff --git a/tools/udf/udf_tester/udf_delta_file_tester.cc b/tools/udf/udf_tester/udf_delta_file_tester.cc index 3caf6661..fb6a14d7 100644 --- a/tools/udf/udf_tester/udf_delta_file_tester.cc +++ b/tools/udf/udf_tester/udf_delta_file_tester.cc @@ -177,7 +177,7 @@ absl::Status TestUdf(const std::string& kv_delta_file_path, config_builder.RegisterStringGetValuesHook(*string_get_values_hook) .RegisterBinaryGetValuesHook(*binary_get_values_hook) .RegisterRunSetQueryStringHook(*run_set_query_string_hook) - .RegisterLoggingFunction() + .RegisterLoggingHook() .SetNumberOfWorkers(1) .Config())); PS_RETURN_IF_ERROR(udf_client.status()) diff --git a/version.txt b/version.txt index 14a8c245..afaf360d 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.17.1 \ No newline at end of file +1.0.0 \ No newline at end of file