diff --git a/.bazelrc b/.bazelrc index fdae62c2..e3d6c35d 100644 --- a/.bazelrc +++ b/.bazelrc @@ -5,7 +5,7 @@ build --output_filter='^//((?!(third_party):).)*$'` build --color=yes build --@io_bazel_rules_docker//transitions:enable=false build --workspace_status_command="bash tools/get_workspace_status" -build --copt=-Werror=thread-safety-analysis +build --copt=-Werror=thread-safety build --config=clang build --config=noexcept # Disable some ROMA error checking @@ -118,7 +118,7 @@ build:gcp_platform --@google_privacysandbox_servers_common//:platform=gcp build:prod_mode --//:mode=prod build:prod_mode --@google_privacysandbox_servers_common//:build_flavor=prod -# --config prod_mode: builds the service in prod mode +# --config nonprod_mode: builds the service in nonprod mode build:nonprod_mode --//:mode=nonprod build:nonprod_mode --@google_privacysandbox_servers_common//:build_flavor=non_prod diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe88324b..4bc1daa7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ exclude: (?x)^( fail_fast: true repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: end-of-file-fixer - id: fix-byte-order-marker @@ -60,12 +60,12 @@ repos: exclude: ^(google_internal|builders/images)/.*$ - repo: https://github.com/bufbuild/buf - rev: v1.29.0 + rev: v1.31.0 hooks: - id: buf-format - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v17.0.6 + rev: v18.1.4 hooks: - id: clang-format types_or: @@ -119,7 +119,7 @@ repos: )$ - repo: https://github.com/DavidAnson/markdownlint-cli2 - rev: v0.12.1 + rev: v0.13.0 hooks: - id: markdownlint-cli2 name: lint markdown @@ -154,7 +154,7 @@ repos: - --quiet - repo: https://github.com/psf/black - rev: 24.2.0 + rev: 24.4.0 hooks: - id: black name: black python formatter diff --git a/BUILD.bazel b/BUILD.bazel index 6c8c9cb4..483d1d9a 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -19,8 +19,9 @@ load("@io_bazel_rules_go//go:def.bzl", "nogo") package(default_visibility = ["//:__subpackages__"]) # Config settings to determine which platform the system will be built to run on +# Use --config=VALUE_platform specified in .bazelrc instead of using this flag directly. # Example: -# bazel build components/... --//:platform=aws +# bazel build components/... --config=aws_platform string_flag( name = "platform", build_setting_default = "aws", @@ -64,6 +65,10 @@ config_setting( ], ) +# Config settings to determine which instance the system will be built to run on +# Use --configVALUE_instance specified in .bazelrc instead of using this flag directly. +# Example: +# bazel build components/... --config=aws_instance string_flag( name = "instance", build_setting_default = "aws", diff --git a/CHANGELOG.md b/CHANGELOG.md index 6699b468..07125613 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,84 @@ All notable changes to this project will be documented in this file. See [commit-and-tag-version](https://github.com/absolute-version/commit-and-tag-version) for commit guidelines. +## 0.17.0 (2024-07-08) + + +### Features + +* Add a set wrapper around bitset for storing uint32 values +* Add a thread safe wrapper around hash map +* Add b&a e2e test env +* Add data loading support for uint32 sets +* Add health check to AWS mesh. +* Add hook for running set query using uint32 sets as input +* Add interestGroupNames to V1 API +* Add latency metrics for cache uint32 sets functions +* Add latency without custom code execution metric +* Add option to use existing network on AWS. +* Add padding to responses +* Add request log context to request context +* Add runsetqueryint udf hook +* Add set operation functions for bitsets +* Add support for int32_t sets to key value cache +* Add support for reading and writing int sets to csv files +* Add udf hook for running int sets set query (local lookup) +* Allow pas request to pass consented debug config and log context +* Implement sharded RunSetQueryInt rpc for lookup client +* Implement uint32 sets sharded lookup support +* Load consented debug token from server parameter +* Pass LogContext and ConsentedDebugConfig to internal lookup server in sharded case +* Plumb the safe path log context in the cache update execution path +* Set verbosity level for PS_VLOG +* Simplify thread safe hash map and use a single map for node storage +* Support uint32 sets for query parsing and evaluation +* Support uint32 sets in InternalLookup rpc +* Switch absl log for PS_LOG and PS_VLOG for unsafe code path +* Switch absl log to PS_LOG for safe code path +* Switch absl vlog to PS_VLOG for safe code path +* Update AWS coordinators public prod endpoint from GG to G3P + + +### Bug Fixes + +* Add missing include/library deps +* Augment UDF loading info message +* Correct copts build config. +* Correct verbosity flag for gcp validator. +* Effectively lock the key in the set map cleanup +* Fix detached head of continuous e2e branch. +* Properly initialize runSetQueryInt hook +* Remove ignore interestGroupNames from envoy +* Remove test filter to allow all unit tests run in the build +* Simplify request context and pass it as shared pointer to the hooks +* Upgrade common repo version +* Use kms_binaries tar target from common repo +* Use structured initializer for clarity + + +### Dependencies + +* **deps:** Upgrade build-system to 0.62.0 +* **deps:** Upgrade data-plane-shared-libraries to 52239f15 2024-05-21 +* **deps:** Upgrade pre-commit hooks + + +### GCP: Features + +* **GCP:** Switch to internal lb for the otlp collector +* **GCP:** Switch to internal lb for the otlp collector with bug fixes + + +### Documentation + +* Add debugging playbook +* Correct commands for sample_word2vec getting_started example +* KV onboarding guide +* Update to the ads retrieval explainer +* Update word2vec example +* Use aws_platform bazel config +* Use local_{platform,instance} bazel configs + ## 0.16.0 (2024-04-05) diff --git a/README.md b/README.md index d1632784..fcf6a733 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,16 @@ +> [!IMPORTANT] +> +> The `main` branch hosts live code with latest changes. It is unstable and is used for development. +> It is suitable for contribution and inspection of the latest code. The `release-*` branches are +> stable releases that can be used to build and deploy the system. + +--- + > FLEDGE has been renamed to Protected Audience API. To learn more about the name change, see the > [blog post](https://privacysandbox.com/intl/en_us/news/protected-audience-api-our-new-name-for-fledge). +--- + # ![Privacy Sandbox Logo](docs/assets/privacy_sandbox_logo.png) FLEDGE Key/Value service # Background @@ -40,31 +50,234 @@ moment, to load data, instead of calling the mutation API, you would place the d location that can be directly read by the server. See more details in the [data loading guide](/docs/data_loading/loading_data.md). -Currently, this service can be deployed to 1 region of your choice with more regions to be added -soon. Monitoring and alerts are currently unavailable. +Currently, this service can be deployed to 1 region of your choice. Multi-region configuration is up +to the service owner to configure. + +## Current features + +### Build and deployment + +- Source code is available on Github +- Releases are done on a regular basis +- Binaries can be built from source code + - C++ binary + - [AWS & GCP] Docker container image + - [AWS]: Nitro EIF + - [AWS]: Reference AMI + - Other tools +- Server can run as a standalone local process for testing without any cloud dependency or + TEE-related functionality +- Server can be deployed to AWS Nitro enclave +- Server can be deployed to GCP Confidential Space +- Reference terraform available for a clean and comprehensive deployment to AWS or GCP + - Clean: assumes the cloud environment has no preset infrastructure + - Comprehensive: deploys all dependencies and some recommended (but not necessarily required) + configuration + - Many server behaviors can be configured by parameter flags without having to rebuild + +### Data loading + +- Server loads key/value data from cloud file storage +- Server loads key/value data from cloud pub/sub services +- Server loads data into an in-RAM representation for query serving +- Server continuously monitors for new data and incrementally updates ("delta files") the in-RAM + representation +- Support independent data ingestion pipelining by monitoring directories in cloud file storage + independently +- Supports Flatbuffers as the data event format +- Supports Avro and Riegeli as the data file format +- Supports snapshot files for faster server start up +- Users can perform compactions of delta files into snapshot files in an offline path + +### Read request processing + +- Support Protected Audience Key Value Server query spec: can be used as a BYOS server to serve + requests from Chrome +- Support simple key value lookups for queries +- Users can write "user defined functions" to execute custom logic to process queries +- User defined functions can be written in JavaScript or WASM +- User defined functions can call "GetValues" to look up key value from the dataset + +### Advanced features + +- Set-as-a-value is supported + - A key "value" pair in the dataset can be a key and a set of values +- UDF can call "RunQuery" API to run set operations on sets (intersection, union, difference) +- For GCP, Terraform supports deploying into an existing VPC, such as used by the Bidding and + Auction services Non-prod Server logs are persisted after server shutdown +- Data can be sharded and different servers may load and serve different shards (subset) of the + dataset. +- Sharding supports data locality, where the operator specifies "sharding key" for key value pairs + so different key value pairs can have the same sharding key. + +## **Timeline and roadmap** + +The following sections include the timelines for the Trusted Key Value Server for Protected +Auctions. Protected Auctions refer to Protected Audiences and Protected App Signals ad targeting +products. + +### **Timelines** + + + + + + + + + + + + + + + + +
+ Beta testing + General availability + Enforcement +
For Protected Audience +

+(web browser on desktop) +

July 2024 +

+The Privacy-Sandbox-provided Key Value Server implementation can

    + +
  • run as a BYOS KV server +
  • support production scale traffic and common functionalities
+ +
Q4 2024 +

+Opt-in TEE mode will be available to the Adtechs. Opt-in guidance will be published in early Q4 2024. +

No sooner than Q3 2025 +
+ + + + + + + + + + + + + + + + + +
+ Beta testing + General availability +
For Protected Audience +

+(With Bidding & Auction services for Chrome or Android) +

July 2024 +

+The Privacy-Sandbox-provided Key Value Server implementation can be used with the Bidding and Auction services and

    + +
  • run as a BYOS KV server +
  • support production scale traffic and common functionalities
+ +
Dec 2024 +

+The Privacy-Sandbox-provided Key Value Server implementation can be used with the Bidding and Auction services and adtechs can opt-in TEE mode +

For Protected App Signals + June 2024 +

+The Privacy-Sandbox-provided Key Value Server implementation supports Ad retrieval server functionality and protected communication for live traffic testing +

Dec 2024 +

+The implementation supports live traffic at scale +

+ + + +### **Roadmap** + +#### June 2024 Beta release + +##### Deployment and Setup + +- For AWS, Terraform supports deploying into an existing VPC, such as the one that is used by the + Bidding and Auction services +- Internal load balancer is used for servers to send metrics to OpenTelemetry collector + - In v0.16, the communication goes through a public load balancer + +##### Integration with the Bidding & Auction services + +- The Bidding and Auction services can send encrypted requests to the Key Value Server for + Protected App Signals + +##### Debugging support + +- [Consented Debugging](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/debugging_protected_audience_api_services.md#adtech-consented-debugging) + is supported +- Diagnose tool to check the cloud environment to warn for potential setup errors before the + system is deployed +- Operational playbook +- Introduction of unsafe metrics + - Unsafe metrics have privacy protections such as differential privacy noises + - More metrics for comprehensive SLO monitoring + +##### Runtime features + +- Data loading error handling + - The system can be configured to use different error handling strategy for different dataset + +##### Performance/Cost + +- Benchmarking tools +- Cost explainer +- Sets-as-values will switch to using bitsets to represent sets for faster RunQuery performance. + +##### Support process + +- Commitment to support window for active releases + +#### Q4 2024 Chrome-PA GA + +##### Chrome integration + +- Update to V2 protocol to support the hybrid mode of BYOS & Opt-in TEE +- Chrome and Key Value server can communicate in the updated V2 protocol +- Chrome can send full publisher URL to TEE KV server under V2 protocol + +#### H2 2024 Android-PA GA, PAS GA + +##### User Defined Functions + +- UDF can perform Key/Value lookup asynchronously +- Flags can be passed from the server parameters into UDF +- One Key Value Server system can be used for multiple use cases. Multiple UDFs can be loaded. + Different UDF can be selected based on the request type. +- Canaring support for UDF: canary version UDF can be staged in machines with specific tags. + +##### Customization support + +- First class support for customization of the system (without violating the trust model) + +##### Debugging support -> **Attention**: The Key/Value Server is publicly queryable. It does not authenticate callers. That -> is also true for the product end state. It is recommended to only store data you do not mind seen -> by other entities. +- Diagnose tool to collect standard and necessary debug information for troubleshooting requests -## How to use this repo +##### Documentation -The `main` branch hosts live code with latest changes. It is unstable and is used for development. -It is suitable for contribution and inspection of the latest code. The `release-*` branches are -stable releases that can be used to build and deploy the system. +- Complete end to end example as a template to set up the service ## Breaking changes -This codebase right now is in a very early stage. We expect frequent updates that may not be fully -backward compatible. +While we make efforts to not introduce breaking changes, we expect that to happen occasionally. The release version follows the `[major change]-[minor change]-[patch]` scheme. All 0.x.x versions may contain breaking changes without notice. Refer to the [release changelog](/CHANGELOG.md) for the details of the breaking changes. -Once the codebase is in a more stable state that is version 1.0.0, we will establish additional -channels for announcing breaking changes and major version will always be incremented for breaking -changes. +At GA the version will become 1.0.0, we will establish additional channels for announcing breaking +changes and major version will always be incremented for breaking changes. # Key documents diff --git a/WORKSPACE b/WORKSPACE index da073316..2813868a 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -13,11 +13,11 @@ python_deps("//builders/bazel") http_archive( name = "google_privacysandbox_servers_common", - # commit b34fe82 2024-04-03 - sha256 = "2afc7017723efb9d34b6ed713be03dbdf9b45de8ba585d2ea314eb3a52903d0a", - strip_prefix = "data-plane-shared-libraries-b34fe821b982e06446df617edb7a6e3041c8b0db", + # commit 34445c1 2024-07-01 + sha256 = "ce300bc178b1eedd88d7545b89d1d672b3b9bfb62c138ab3f4a845f159436285", + strip_prefix = "data-plane-shared-libraries-37522d6ac55c8592060f636d68f50feddcb9598a", urls = [ - "https://github.com/privacysandbox/data-plane-shared-libraries/archive/b34fe821b982e06446df617edb7a6e3041c8b0db.zip", + "https://github.com/privacysandbox/data-plane-shared-libraries/archive/37522d6ac55c8592060f636d68f50feddcb9598a.zip", ], ) diff --git a/builders/.github/workflows/scorecard.yaml b/builders/.github/workflows/scorecard.yaml index dbcb6200..22bd7f8e 100644 --- a/builders/.github/workflows/scorecard.yaml +++ b/builders/.github/workflows/scorecard.yaml @@ -26,14 +26,14 @@ on: - cron: '35 10 * * 4' push: branches: - - main + - main # Declare default permissions as read only. permissions: read-all jobs: analysis: - name: Scorecard analysis + name: OpenSSF Scorecard analysis runs-on: ubuntu-latest permissions: # Needed to upload the results to code-scanning dashboard. @@ -46,12 +46,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@93ea575cb5d8a053eaa0ac8fa3b40d7e05a33cc8 # v3.1.0 + uses: actions/checkout@9bb56186c3b09b4f86b1c65136769dd318469633 # v4.1.2 with: persist-credentials: false - name: Run analysis - uses: ossf/scorecard-action@e38b1902ae4f44df626f11ba0734b14fb91f8f86 # v2.1.2 + uses: ossf/scorecard-action@0864cf19026789058feabb7e87baa5f140aac736 # v2.3.1 with: results_file: results.sarif results_format: sarif @@ -73,7 +73,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: Upload artifact - uses: actions/upload-artifact@3cea5372237819ed00197afe530f5a7ea3e805c8 # v3.1.0 + uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 with: name: SARIF file path: results.sarif @@ -81,6 +81,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard. - name: Upload to code-scanning - uses: github/codeql-action/upload-sarif@17573ee1cc1b9d061760f3a006fc4aac4f944fd5 # v2.2.4 + uses: github/codeql-action/upload-sarif@1b1aada464948af03b950897e5eb522f92603cc2 # v3.24.9 with: sarif_file: results.sarif diff --git a/builders/.pre-commit-config.yaml b/builders/.pre-commit-config.yaml index c7997b3e..e1ad672d 100644 --- a/builders/.pre-commit-config.yaml +++ b/builders/.pre-commit-config.yaml @@ -21,7 +21,7 @@ exclude: (?x)^( fail_fast: true repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: end-of-file-fixer - id: fix-byte-order-marker @@ -47,7 +47,7 @@ repos: - id: shellcheck - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.6 + rev: v18.1.4 hooks: - id: clang-format types_or: @@ -55,7 +55,7 @@ repos: - c - repo: https://github.com/bufbuild/buf - rev: v1.23.1 + rev: v1.31.0 hooks: - id: buf-format @@ -109,7 +109,7 @@ repos: - markdown - repo: https://github.com/DavidAnson/markdownlint-cli2 - rev: v0.8.1 + rev: v0.13.0 hooks: - id: markdownlint-cli2 name: lint markdown @@ -144,7 +144,7 @@ repos: - --quiet - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 24.4.2 hooks: - id: black name: black python formatter diff --git a/builders/CHANGELOG.md b/builders/CHANGELOG.md index e323d8a3..3ced1112 100644 --- a/builders/CHANGELOG.md +++ b/builders/CHANGELOG.md @@ -2,6 +2,65 @@ All notable changes to this project will be documented in this file. See [commit-and-tag-version](https://github.com/absolute-version/commit-and-tag-version) for commit guidelines. +## 0.62.0 (2024-05-10) + + +### Features + +* Add --dir flag to normalize-dist + +## 0.61.1 (2024-05-10) + + +### Bug Fixes + +* Add docker flags to container name +* Set 8h ttl for long-running build container + +## 0.61.0 (2024-05-08) + + +### Features + +* Add cbuild support for container reuse + +## 0.60.0 (2024-05-07) + + +### Dependencies + +* **deps:** Upgrade coverage-tools to ubuntu 24.04 +* **deps:** Upgrade golang to 1.22.2 + +## 0.59.0 (2024-05-02) + + +### Bug Fixes + +* **deps:** Update pre-commit hooks + + +### Dependencies + +* **deps:** Upgrade alpine base image +* **deps:** Upgrade base images for Amazon Linux +* **deps:** Upgrade grpcurl to 1.9.1 +* **deps:** Upgrade presubmit to ubuntu 24.04 + +## 0.58.0 (2024-04-26) + + +### Features + +* add missing AWS env variable for CodeBuild + +## 0.57.1 (2024-03-28) + + +### Bug Fixes + +* Upgrade OpenSSF scorecard GitHub Action + ## 0.57.0 (2024-03-10) diff --git a/builders/images/build-amazonlinux2/Dockerfile b/builders/images/build-amazonlinux2/Dockerfile index bb6c1edc..6ccce805 100644 --- a/builders/images/build-amazonlinux2/Dockerfile +++ b/builders/images/build-amazonlinux2/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM amazonlinux:2.0.20230822.0 +FROM amazonlinux:2.0.20240412.0 COPY /install_apps install_golang_apps install_go.sh generate_system_bazelrc .bazelversion /scripts/ COPY get_workspace_mount /usr/local/bin diff --git a/builders/images/build-amazonlinux2/install_apps b/builders/images/build-amazonlinux2/install_apps index a5e6ff35..43fa2166 100755 --- a/builders/images/build-amazonlinux2/install_apps +++ b/builders/images/build-amazonlinux2/install_apps @@ -22,14 +22,8 @@ while [[ $# -gt 0 ]]; do VERBOSE=1 shift ;; - -h | --help) - usage 0 - break - ;; - *) - usage - break - ;; + -h | --help) usage 0 ;; + *) usage ;; esac done diff --git a/builders/images/build-amazonlinux2023/Dockerfile b/builders/images/build-amazonlinux2023/Dockerfile index d24ba9af..268bbcba 100644 --- a/builders/images/build-amazonlinux2023/Dockerfile +++ b/builders/images/build-amazonlinux2023/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM amazonlinux:2023.1.20230825.0 +FROM amazonlinux:2023.4.20240416.0 COPY /install_apps install_golang_apps install_go.sh generate_system_bazelrc .bazelversion /scripts/ COPY get_workspace_mount /usr/local/bin diff --git a/builders/images/build-amazonlinux2023/install_apps b/builders/images/build-amazonlinux2023/install_apps index 51c64377..442021c4 100755 --- a/builders/images/build-amazonlinux2023/install_apps +++ b/builders/images/build-amazonlinux2023/install_apps @@ -22,14 +22,8 @@ while [[ $# -gt 0 ]]; do VERBOSE=1 shift ;; - -h | --help) - usage 0 - break - ;; - *) - usage - break - ;; + -h | --help) usage 0 ;; + *) usage ;; esac done diff --git a/builders/images/build-debian/Dockerfile b/builders/images/build-debian/Dockerfile index eb370730..735d9b66 100644 --- a/builders/images/build-debian/Dockerfile +++ b/builders/images/build-debian/Dockerfile @@ -18,7 +18,7 @@ ARG BASE_IMAGE=ubuntu:20.04 # hadolint ignore=DL3006 FROM ${BASE_IMAGE} as libprofiler-builder ENV CC=clang \ - CXX=clang + CXX=clang++ ADD https://github.com/gperftools/gperftools/releases/download/gperftools-2.13/gperftools-2.13.tar.gz /build/gperftools.tar.gz ADD https://apt.llvm.org/llvm.sh /build/llvm.sh COPY compile_libprofiler /scripts/ diff --git a/builders/images/build-debian/compile_libprofiler b/builders/images/build-debian/compile_libprofiler index f05397fb..bcc15a68 100755 --- a/builders/images/build-debian/compile_libprofiler +++ b/builders/images/build-debian/compile_libprofiler @@ -24,6 +24,7 @@ function install_clang() { /build/llvm.sh ${CLANG_VER} apt-get --quiet install -y --no-install-recommends libc++-${CLANG_VER}-dev update-alternatives --install /usr/bin/clang clang /usr/bin/clang-${CLANG_VER} 100 + update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-${CLANG_VER} 100 rm -f llvm.sh clang --version diff --git a/builders/images/coverage-tools/Dockerfile b/builders/images/coverage-tools/Dockerfile index c2f25551..6915b8d2 100644 --- a/builders/images/coverage-tools/Dockerfile +++ b/builders/images/coverage-tools/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM ubuntu:20.04 +FROM ubuntu:24.04 COPY install_apps /scripts/ diff --git a/builders/images/coverage-tools/install_apps b/builders/images/coverage-tools/install_apps index 72fd822a..5f8713b8 100755 --- a/builders/images/coverage-tools/install_apps +++ b/builders/images/coverage-tools/install_apps @@ -33,8 +33,8 @@ function apt_update() { function install_misc() { DEBIAN_FRONTEND=noninteractive apt-get --quiet install -y --no-install-recommends \ - lcov="1.*" \ - google-perftools="2.*" + google-perftools="2.*" \ + lcov="2.*" } function clean_debian() { diff --git a/builders/images/generate_system_bazelrc b/builders/images/generate_system_bazelrc index 9c8e419f..b68d9405 100755 --- a/builders/images/generate_system_bazelrc +++ b/builders/images/generate_system_bazelrc @@ -22,14 +22,12 @@ while [[ $# -gt 0 ]]; do case "$1" in --user-root-name) USER_ROOT_NAME="$2" - shift - shift + shift 2 || usage ;; -h | --help) usage 0 ;; *) printf "unrecognized arg: %s\n" "$1" usage - break ;; esac done diff --git a/builders/images/install_go.sh b/builders/images/install_go.sh index a436d20a..dc30cd73 100644 --- a/builders/images/install_go.sh +++ b/builders/images/install_go.sh @@ -22,12 +22,12 @@ function _golang_install_dir() { function install_golang() { declare -r _ARCH="$1" declare -r FNAME=gobin.tar.gz - declare -r VERSION=1.20.4 + declare -r VERSION=1.22.2 # shellcheck disable=SC2155 declare -r GO_INSTALL_DIR="$(_golang_install_dir)" declare -r -A GO_HASHES=( - [amd64]="698ef3243972a51ddb4028e4a1ac63dc6d60821bf18e59a807e051fee0a385bd" - [arm64]="105889992ee4b1d40c7c108555222ca70ae43fccb42e20fbf1eebb822f5e72c6" + [amd64]="5901c52b7a78002aeff14a21f93e0f064f74ce1360fce51c6ee68cd471216a17" + [arm64]="36e720b2d564980c162a48c7e97da2e407dfcc4239e1e58d98082dfa2486a0c1" ) declare -r GO_HASH=${GO_HASHES[${_ARCH}]} if [[ -z ${GO_HASH} ]]; then diff --git a/builders/images/install_golang_apps b/builders/images/install_golang_apps index 9d75eb94..a6bb3cfd 100755 --- a/builders/images/install_golang_apps +++ b/builders/images/install_golang_apps @@ -23,10 +23,7 @@ while [[ $# -gt 0 ]]; do shift ;; -h | --help) usage 0 ;; - *) - usage - break - ;; + *) usage ;; esac done diff --git a/builders/images/presubmit/Dockerfile b/builders/images/presubmit/Dockerfile index 024c05fe..53ee3778 100644 --- a/builders/images/presubmit/Dockerfile +++ b/builders/images/presubmit/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM ubuntu:20.04 +FROM ubuntu:24.04 COPY install_apps install_go.sh .pre-commit-config.yaml /scripts/ COPY gitconfig /etc diff --git a/builders/images/presubmit/install_apps b/builders/images/presubmit/install_apps index c0d8d0d8..3986ec65 100755 --- a/builders/images/presubmit/install_apps +++ b/builders/images/presubmit/install_apps @@ -51,18 +51,17 @@ function apt_update() { function install_packages() { DEBIAN_FRONTEND=noninteractive apt-get --quiet install -y --no-install-recommends \ - apt-transport-https="2.0.*" \ + apt-transport-https="2.7.*" \ ca-certificates \ - libcurl4="7.68.*" \ - curl="7.68.*" \ - gnupg="2.2.*" \ - lsb-release="11.1.*" \ + libcurl4t64="8.5.*" \ + curl="8.5.*" \ + lsb-release="12.0*" \ openjdk-11-jre="11.0.*" \ - python3.9-venv="3.9.*" \ - shellcheck="0.7.*" \ + python3.12-venv="3.12.*" \ + shellcheck="0.9.*" \ software-properties-common="0.99.*" \ - wget="1.20.*" - update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 100 + wget="1.21.*" + update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 100 } # Install Docker (https://docs.docker.com/engine/install/debian/) @@ -81,9 +80,9 @@ function install_docker() { } function install_precommit() { - /usr/bin/python3.9 -m venv "${PRE_COMMIT_VENV_DIR}" + /usr/bin/python3.12 -m venv "${PRE_COMMIT_VENV_DIR}" "${PRE_COMMIT_VENV_DIR}"/bin/pip install \ - pre-commit~=3.1 \ + pre-commit~=3.7 \ pylint~=3.1.0 "${PRE_COMMIT_TOOL}" --version diff --git a/builders/images/test-tools/Dockerfile b/builders/images/test-tools/Dockerfile index 45fab3cf..69399eb4 100644 --- a/builders/images/test-tools/Dockerfile +++ b/builders/images/test-tools/Dockerfile @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM alpine:3.18 as slowhttptest_builder +FROM alpine:3.19 as slowhttptest_builder # hadolint ignore=DL3018 RUN apk add --no-cache autoconf automake build-base git openssl-dev WORKDIR /build ADD https://github.com/shekyan/slowhttptest/archive/refs/tags/v1.9.0.tar.gz /build/src.tar.gz RUN tar xz --strip-components 1 -f src.tar.gz && ./configure && make -FROM alpine:3.18 as wrk_builder +FROM alpine:3.19 as wrk_builder ARG TARGETARCH ENV BUILD_ARCH="${TARGETARCH}" COPY build_wrk /build/ @@ -27,14 +27,14 @@ WORKDIR /build ADD https://github.com/giltene/wrk2/archive/44a94c17d8e6a0bac8559b53da76848e430cb7a7.tar.gz /build/src.tar.gz RUN /build/build_wrk -FROM golang:1.21-alpine3.18 AS golang +FROM golang:1.22-alpine3.19 AS golang ENV GOBIN=/usr/local/go/bin COPY build_golang_apps /scripts/ RUN /scripts/build_golang_apps -FROM fullstorydev/grpcurl:v1.8.9-alpine AS grpcurl +FROM fullstorydev/grpcurl:v1.9.1-alpine AS grpcurl -FROM alpine:3.18 +FROM alpine:3.19 COPY --from=golang /usr/local/go/bin/* /usr/local/bin/ COPY --from=grpcurl /bin/grpcurl /usr/local/bin/ ARG TARGETARCH diff --git a/builders/images/utils/Dockerfile b/builders/images/utils/Dockerfile index 0d1defd6..ea4bee81 100644 --- a/builders/images/utils/Dockerfile +++ b/builders/images/utils/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM alpine:3.16 +FROM alpine:3.19 RUN apk --no-cache add \ unzip~=6.0 \ diff --git a/builders/tests/data/hashes/build-amazonlinux2 b/builders/tests/data/hashes/build-amazonlinux2 index 7aed7b0f..71f56dfb 100644 --- a/builders/tests/data/hashes/build-amazonlinux2 +++ b/builders/tests/data/hashes/build-amazonlinux2 @@ -1 +1 @@ -3efa00f3a5dbe0a4708be523aa32aca91dcd56d403d3ff32e0202756b8321b3b +57ca4f2381a0fc193b0476663171c4d339b6ef66c0d1f1c24bb3f48d368b38ab diff --git a/builders/tests/data/hashes/build-amazonlinux2023 b/builders/tests/data/hashes/build-amazonlinux2023 index 1bcc412e..5fe991bf 100644 --- a/builders/tests/data/hashes/build-amazonlinux2023 +++ b/builders/tests/data/hashes/build-amazonlinux2023 @@ -1 +1 @@ -57396ff1c765f7b63905963cfe4498912f7f75b5cb9f7bc36bd6879af69872e7 +8d01333fe93d2ac2102dd8360a58717724b7b594d51fe4e412ec20aae181efce diff --git a/builders/tests/data/hashes/build-debian b/builders/tests/data/hashes/build-debian index c89114f2..57095aed 100644 --- a/builders/tests/data/hashes/build-debian +++ b/builders/tests/data/hashes/build-debian @@ -1 +1 @@ -38cc8a23a6a56eb6567bef3685100cd3be1c0491dcc8b953993c42182da3fa40 +c194dafd287978093f8fe6e16e981fb22028e37345e20a4d7ca84caa43f0d4c0 diff --git a/builders/tests/data/hashes/coverage-tools b/builders/tests/data/hashes/coverage-tools index f0336331..e0127b80 100644 --- a/builders/tests/data/hashes/coverage-tools +++ b/builders/tests/data/hashes/coverage-tools @@ -1 +1 @@ -cd3fb189dd23793af3bdfa02d6774ccb35bddbec7059761e25c4f7be4c1e8ca1 +b768060d602e2ed1b60573edfa6afad5379e96a9d6153cd721b2a0665075fe98 diff --git a/builders/tests/data/hashes/presubmit b/builders/tests/data/hashes/presubmit index a35c6c86..b02b21b0 100644 --- a/builders/tests/data/hashes/presubmit +++ b/builders/tests/data/hashes/presubmit @@ -1 +1 @@ -d9dab1c798d51f79e68fd8eb3bb83312086808d789bbc09d0f2dbf708ef5f114 +afaf1932764d07d480c4e833e6b08877f069abae87401bdac4782277c535a298 diff --git a/builders/tests/data/hashes/test-tools b/builders/tests/data/hashes/test-tools index fc5e0b5c..63f4e4bd 100644 --- a/builders/tests/data/hashes/test-tools +++ b/builders/tests/data/hashes/test-tools @@ -1 +1 @@ -dd1ec6137d4dd22fec555044cd85f484adfa6c7b686880ea5449cff936bad34e +c1111c91dcb1e9f4df65f9fd5eab60b2545b0e716cfaf59fb88c1006a6496a5e diff --git a/builders/tests/data/hashes/utils b/builders/tests/data/hashes/utils index da29b2fa..188febac 100644 --- a/builders/tests/data/hashes/utils +++ b/builders/tests/data/hashes/utils @@ -1 +1 @@ -9fca27d931acc2bc96fa0560466cc0914a0d1cc73fb8749af057caacf2911f85 +f4b8d15b26c7bef3bc94038be9b71aaf8ba8ba8d33663b7d6fb55ebdff9a902e diff --git a/builders/tests/run-tests b/builders/tests/run-tests index 037067d2..b7c1a7f9 100755 --- a/builders/tests/run-tests +++ b/builders/tests/run-tests @@ -19,18 +19,20 @@ set -o pipefail set -o errexit -trap _cleanup EXIT +# shellcheck disable=SC2317 function _cleanup() { - local -r -i STATUS=$? - if [[ -d ${TMP_HASHES_DIR1} ]]; then + local -r -i _status=$? + if [[ -d "${TMP_HASHES_DIR1}" ]]; then rm -rf "${TMP_HASHES_DIR1}" "${TMP_HASHES_DIR2}" fi - if [[ ${STATUS} -ne 0 ]]; then - printf "Error: run-tests status code: %d\n" "${STATUS}" + if [[ ${_status} -ne 0 ]]; then + printf "Error: run-tests status code: %d\n" "${_status}" sleep 5s fi - exit ${STATUS} + # shellcheck disable=SC2086 + exit ${_status} } +trap _cleanup EXIT function get_image_list() { local -r _images_dir="$1" diff --git a/builders/tools/builder.sh b/builders/tools/builder.sh index 77317aae..d2ca2f98 100644 --- a/builders/tools/builder.sh +++ b/builders/tools/builder.sh @@ -112,6 +112,7 @@ function builder::add_aws_env_vars() { "AWS_REGION" "AWS_DEFAULT_REGION" "AWS_PROFILE" + "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" ) } diff --git a/builders/tools/cbuild b/builders/tools/cbuild index c40a3aed..17da1336 100755 --- a/builders/tools/cbuild +++ b/builders/tools/cbuild @@ -77,6 +77,8 @@ declare -i WITH_DOCKER_SOCK=1 declare -i WITH_CMD_PROFILER=0 DOCKER_NETWORK="${DOCKER_NETWORK:-bridge}" declare -i DOCKER_SECCOMP_UNCONFINED=0 +declare -i KEEP_CONTAINER_RUNNING=0 +declare LONG_RUNNING_CONTAINER_TIMEOUT=8h while [[ $# -gt 0 ]]; do case "$1" in @@ -94,6 +96,9 @@ while [[ $# -gt 0 ]]; do ;; --image) IMAGE="$2" + if [[ ${IMAGE} =~ ^build-* ]]; then + KEEP_CONTAINER_RUNNING=1 + fi shift 2 || usage ;; --without-shared-cache) @@ -171,12 +176,14 @@ if [[ ${PWD_WORKSPACE_REL_PATH:0:1} != / ]]; then fi readonly WORKDIR -declare -a DOCKER_RUN_ARGS -DOCKER_RUN_ARGS+=( +# DOCKER_EXEC_RUN_ARGS applies to both `docker run` and `docker exec` +declare -a DOCKER_EXEC_RUN_ARGS=( + "--workdir=${WORKDIR}" +) +declare -a DOCKER_RUN_ARGS=( "--rm" "--entrypoint=/bin/bash" "--volume=${WORKSPACE_MOUNT}:/src/workspace" - "--workdir=${WORKDIR}" "--network=${DOCKER_NETWORK}" "$(echo "${EXTRA_DOCKER_RUN_ARGS}" | envsubst)" ) @@ -200,53 +207,118 @@ fi readonly BAZEL_ROOT=/bazel_root if [[ ${WITH_SHARED_CACHE} -eq 0 ]]; then # use tmpfs for as temporary, container-bound bazel cache - DOCKER_RUN_ARGS+=( - "--tmpfs ${BAZEL_ROOT}:exec" - ) + DOCKER_RUN_ARGS+=("--tmpfs=${BAZEL_ROOT}:exec") else # mount host filesystem for "shared" use by multiple docker container invocations - DOCKER_RUN_ARGS+=( - "--volume ${HOME}/.cache/bazel:${BAZEL_ROOT}" - ) + DOCKER_RUN_ARGS+=("--volume=${HOME}/.cache/bazel:${BAZEL_ROOT}") fi if [[ ${WITH_DOCKER_SOCK} -eq 1 ]]; then - DOCKER_RUN_ARGS+=( - "--volume /var/run/docker.sock:/var/run/docker.sock" - ) + DOCKER_RUN_ARGS+=("--volume=/var/run/docker.sock:/var/run/docker.sock") fi for evar in "${ENV_VARS[@]}" do - DOCKER_RUN_ARGS+=( - "--env=${evar}" - ) + DOCKER_EXEC_RUN_ARGS+=("--env=${evar}") done if [[ -t 0 ]] && [[ -t 1 ]]; then # stdin and stdout are open, assume it's an interactive tty session - DOCKER_RUN_ARGS+=( - --interactive - --tty + DOCKER_EXEC_RUN_ARGS+=( + "--interactive" + "--tty" ) fi +function get_container_name() { + local -r mount="$(echo "${WORKSPACE_MOUNT}" | sha256sum)" + local -r image_sha="${IMAGE_TAGGED##*-}" + local -r docker_args_sha="$({ +cat </dev/stderr + docker container rm --force "${name}" >/dev/null + printf "finished removing docker container: %s\n" "${name}" &>/dev/stderr + fi + docker "${docker_args[@]}" --filter "status=running" +} + +function long_running_container() { + local -r container_name="$1" + local -r docker_running_container="$(running_container_for "${container_name}")" + if [[ -z ${docker_running_container} ]]; then + printf "starting a new container [%s]\n" "${container_name}" &>/dev/stderr + if [[ -z ${CMD} ]]; then + # shellcheck disable=SC2068 + docker run \ + ${DOCKER_RUN_ARGS[@]} \ + "${DOCKER_EXEC_RUN_ARGS[@]}" \ + "${IMAGE_TAGGED}" + else + # shellcheck disable=SC2068 + docker run \ + ${DOCKER_RUN_ARGS[@]} \ + "${DOCKER_EXEC_RUN_ARGS[@]}" \ + --detach \ + "${IMAGE_TAGGED}" \ + --login -c " +declare -i -r pid=\$(bazel info server_pid 2>/dev/null) +# wait for pid, even if it's not a child process of this shell +timeout ${LONG_RUNNING_CONTAINER_TIMEOUT} tail --pid=\${pid} -f /dev/null +" &>/dev/null + fi + fi + running_container_for "${DOCKER_CONTAINER_NAME}" +} + +if [[ ${KEEP_CONTAINER_RUNNING} -eq 1 ]]; then + DOCKER_RUNNING_CONTAINER="$(long_running_container "${DOCKER_CONTAINER_NAME}")" + docker exec \ + "${DOCKER_EXEC_RUN_ARGS[@]}" \ + "${DOCKER_RUNNING_CONTAINER}" \ + /bin/bash -c "${CMD}" else - # shellcheck disable=SC2068 - docker run \ - ${DOCKER_RUN_ARGS[@]} \ - "${IMAGE_TAGGED}" \ - --login -c "${CMD}" + if [[ -z ${CMD} ]]; then + # shellcheck disable=SC2068 + docker run \ + ${DOCKER_RUN_ARGS[@]} \ + "${DOCKER_EXEC_RUN_ARGS[@]}" \ + "${IMAGE_TAGGED}" \ + --login + elif [[ ${WITH_CMD_PROFILER} -eq 1 ]]; then + # shellcheck disable=SC2068 + docker run \ + ${DOCKER_RUN_ARGS[@]} \ + "${DOCKER_EXEC_RUN_ARGS[@]}" \ + "${IMAGE_TAGGED}" \ + --login -c "'${TOOLS_RELDIR}'/normalize-bazel-symlinks; env \${CMD_PROFILER} ${CMD}" + else + # shellcheck disable=SC2068 + docker run \ + ${DOCKER_RUN_ARGS[@]} \ + "${DOCKER_EXEC_RUN_ARGS[@]}" \ + "${IMAGE_TAGGED}" \ + --login -c "$CMD" + fi fi diff --git a/builders/tools/normalize-dist b/builders/tools/normalize-dist index 93627d25..cfa86e11 100755 --- a/builders/tools/normalize-dist +++ b/builders/tools/normalize-dist @@ -19,13 +19,40 @@ set -o pipefail set -o errexit +declare TOP_LEVEL_DIR=dist + +function usage() { + local exitval=${1-1} + cat &>/dev/stderr < + --dir directory to normalize recursively. Default: ${TOP_LEVEL_DIR} +USAGE + # shellcheck disable=SC2086 + exit ${exitval} +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --dir) + TOP_LEVEL_DIR="$2" + shift 2 || usage + ;; + -h | --help) usage 0 ;; + *) + printf "unrecognized arg: %s\n" "$1" + usage + ;; + esac +done + trap _cleanup EXIT function _cleanup() { local -r -i STATUS=$? if [[ ${STATUS} -eq 0 ]]; then - printf "normalize-dist completed successfully\n" &>/dev/stderr + printf "normalize-dist [%s] completed successfully\n" "${TOP_LEVEL_DIR}" &>/dev/stderr else - printf "Error: normalize-dist completed with status code: %s\n" "${STATUS}" &>/dev/stderr + printf "Error: normalize-dist [%s] completed with status code: %s\n" "${TOP_LEVEL_DIR}" "${STATUS}" &>/dev/stderr fi exit 0 } @@ -40,9 +67,7 @@ readonly GROUP USER="$(builder::id u)" readonly USER -readonly TOP_LEVEL_DIRS="dist" - -printf "Setting file ownership [%s], group [%s] in dirs [%s]\n" "${USER}" "${GROUP}" "${TOP_LEVEL_DIRS}" +printf "Setting file ownership [%s], group [%s] in dirs [%s]\n" "${USER}" "${GROUP}" "${TOP_LEVEL_DIR}" declare -a runner=() if [[ -f /.dockerenv ]]; then runner+=(bash -c) @@ -51,11 +76,9 @@ else fi "${runner[@]}" " -for TOP_LEVEL_DIR in ${TOP_LEVEL_DIRS}; do - find \${TOP_LEVEL_DIR} -type f ! -executable -exec chmod 644 {} \; - find \${TOP_LEVEL_DIR} -type f -executable -exec chmod 755 {} \; - find \${TOP_LEVEL_DIR} -type d -exec chmod 755 {} \; - chgrp --recursive ${GROUP} \${TOP_LEVEL_DIR} - chown --recursive ${USER} \${TOP_LEVEL_DIR} -done +find ${TOP_LEVEL_DIR} -type f ! -executable -exec chmod 644 {} \; +find ${TOP_LEVEL_DIR} -type f -executable -exec chmod 755 {} \; +find ${TOP_LEVEL_DIR} -type d -exec chmod 755 {} \; +chgrp --recursive ${GROUP} ${TOP_LEVEL_DIR} +chown --recursive ${USER} ${TOP_LEVEL_DIR} " diff --git a/builders/tools/terraform b/builders/tools/terraform index 71a3e708..6cfc214c 100755 --- a/builders/tools/terraform +++ b/builders/tools/terraform @@ -82,4 +82,5 @@ DOCKER_RUN_ARGS+=( # shellcheck disable=SC2068 docker run \ "${DOCKER_RUN_ARGS[@]}" \ - ${IMAGE_TAGGED} "$@" + "${IMAGE_TAGGED}" \ + "$@" diff --git a/builders/version.txt b/builders/version.txt index 78756de3..7e9253a3 100644 --- a/builders/version.txt +++ b/builders/version.txt @@ -1 +1 @@ -0.57.0 \ No newline at end of file +0.62.0 \ No newline at end of file diff --git a/components/cloud_config/BUILD.bazel b/components/cloud_config/BUILD.bazel index 8cd15ad6..d8ef7895 100644 --- a/components/cloud_config/BUILD.bazel +++ b/components/cloud_config/BUILD.bazel @@ -37,6 +37,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", ], ) @@ -55,6 +56,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", "@google_privacysandbox_servers_common//src/public/core/interface:errors", "@google_privacysandbox_servers_common//src/public/cpio/interface/parameter_client", ], @@ -123,6 +125,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", ], ) @@ -163,6 +166,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", ], ) diff --git a/components/cloud_config/instance_client.h b/components/cloud_config/instance_client.h index 8d55733e..69e35de4 100644 --- a/components/cloud_config/instance_client.h +++ b/components/cloud_config/instance_client.h @@ -21,6 +21,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "src/logger/request_context_logger.h" // TODO: Replace config cpio client once ready namespace kv_server { @@ -55,7 +56,10 @@ using DescribeInstanceGroupInput = // Client to perform instance-specific operations. class InstanceClient { public: - static std::unique_ptr Create(); + static std::unique_ptr Create( + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); virtual ~InstanceClient() = default; // Retrieves all tags for the current instance and returns the tag with the @@ -89,6 +93,12 @@ class InstanceClient { // Retrieves descriptive information about the given instances. virtual absl::StatusOr> DescribeInstances( const absl::flat_hash_set& instance_ids) = 0; + + // Updates the log context reference to enable otel logging for instance + // client. This function should be called after telemetry is initialized with + // retrieved parameters. + virtual void UpdateLogContext( + privacy_sandbox::server_common::log::PSLogContext& log_context) = 0; }; } // namespace kv_server diff --git a/components/cloud_config/instance_client_aws.cc b/components/cloud_config/instance_client_aws.cc index 0191e891..64709b20 100644 --- a/components/cloud_config/instance_client_aws.cc +++ b/components/cloud_config/instance_client_aws.cc @@ -186,18 +186,19 @@ class AwsInstanceClient : public InstanceClient { std::string_view lifecycle_hook_name) override { const absl::StatusOr instance_id = GetInstanceId(); if (!instance_id.ok()) { - LOG(ERROR) << "Failed to get instance_id: " << instance_id.status(); + PS_LOG(ERROR, log_context_) + << "Failed to get instance_id: " << instance_id.status(); return instance_id.status(); } - LOG(INFO) << "Retrieved instance id: " << *instance_id; + PS_LOG(INFO, log_context_) << "Retrieved instance id: " << *instance_id; const absl::StatusOr auto_scaling_group_name = GetAutoScalingGroupName(*auto_scaling_client_, *instance_id); if (!auto_scaling_group_name.ok()) { return auto_scaling_group_name.status(); } - LOG(INFO) << "Retrieved auto scaling group name " - << *auto_scaling_group_name; + PS_LOG(INFO, log_context_) + << "Retrieved auto scaling group name " << *auto_scaling_group_name; Aws::AutoScaling::Model::RecordLifecycleActionHeartbeatRequest request; request.SetAutoScalingGroupName(*auto_scaling_group_name); @@ -216,18 +217,19 @@ class AwsInstanceClient : public InstanceClient { std::string_view lifecycle_hook_name) override { const absl::StatusOr instance_id = GetInstanceId(); if (!instance_id.ok()) { - LOG(ERROR) << "Failed to get instance_id: " << instance_id.status(); + PS_LOG(ERROR, log_context_) + << "Failed to get instance_id: " << instance_id.status(); return instance_id.status(); } - LOG(INFO) << "Retrieved instance id: " << *instance_id; + PS_LOG(INFO, log_context_) << "Retrieved instance id: " << *instance_id; const absl::StatusOr auto_scaling_group_name = GetAutoScalingGroupName(*auto_scaling_client_, *instance_id); if (!auto_scaling_group_name.ok()) { return auto_scaling_group_name.status(); } - LOG(INFO) << "Retrieved auto scaling group name " - << *auto_scaling_group_name; + PS_LOG(INFO, log_context_) + << "Retrieved auto scaling group name " << *auto_scaling_group_name; Aws::AutoScaling::Model::CompleteLifecycleActionRequest request; request.SetAutoScalingGroupName(*auto_scaling_group_name); @@ -329,7 +331,13 @@ class AwsInstanceClient : public InstanceClient { return instances; } - AwsInstanceClient() + void UpdateLogContext( + privacy_sandbox::server_common::log::PSLogContext& log_context) override { + log_context_ = log_context; + } + + AwsInstanceClient( + privacy_sandbox::server_common::log::PSLogContext& log_context) : ec2_client_(std::make_unique()), // EC2MetadataClient does not fall back to the default client // configuration, needs to specify it to @@ -338,21 +346,24 @@ class AwsInstanceClient : public InstanceClient { ec2_metadata_client_(std::make_unique( Aws::Client::ClientConfiguration())), auto_scaling_client_( - std::make_unique()) {} + std::make_unique()), + log_context_(log_context) {} private: std::unique_ptr ec2_client_; std::unique_ptr ec2_metadata_client_; std::unique_ptr auto_scaling_client_; std::string machine_id_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; absl::StatusOr GetTag(std::string tag) { absl::StatusOr instance_id = GetInstanceId(); if (!instance_id.ok()) { - LOG(ERROR) << "Failed to get instance_id: " << instance_id.status(); + PS_LOG(ERROR, log_context_) + << "Failed to get instance_id: " << instance_id.status(); return instance_id; } - LOG(INFO) << "Retrieved instance id: " << *instance_id; + PS_LOG(INFO, log_context_) << "Retrieved instance id: " << *instance_id; Aws::EC2::Model::Filter resource_id_filter; resource_id_filter.SetName(kResourceIdFilter); resource_id_filter.AddValues(*instance_id); @@ -363,18 +374,20 @@ class AwsInstanceClient : public InstanceClient { Aws::EC2::Model::DescribeTagsRequest request; request.SetFilters({resource_id_filter, key_filter}); - LOG(INFO) << "Sending Aws::EC2::Model::DescribeTagsRequest to get tag: " - << tag; + PS_LOG(INFO, log_context_) + << "Sending Aws::EC2::Model::DescribeTagsRequest to get tag: " << tag; const auto outcome = ec2_client_->DescribeTags(request); if (!outcome.IsSuccess()) { - LOG(ERROR) << "Failed to get tag: " << outcome.GetError(); + PS_LOG(ERROR, log_context_) + << "Failed to get tag: " << outcome.GetError(); return AwsErrorToStatus(outcome.GetError()); } if (outcome.GetResult().GetTags().size() != 1) { const std::string error_msg = absl::StrCat( "Could not get tag ", tag, " for instance ", *instance_id); - LOG(ERROR) << error_msg << "; Retrieved " - << outcome.GetResult().GetTags().size() << " tags"; + PS_LOG(ERROR, log_context_) + << error_msg << "; Retrieved " << outcome.GetResult().GetTags().size() + << " tags"; return absl::NotFoundError(error_msg); } return outcome.GetResult().GetTags()[0].GetValue(); @@ -383,8 +396,9 @@ class AwsInstanceClient : public InstanceClient { } // namespace -std::unique_ptr InstanceClient::Create() { - return std::make_unique(); +std::unique_ptr InstanceClient::Create( + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(log_context); } } // namespace kv_server diff --git a/components/cloud_config/instance_client_gcp.cc b/components/cloud_config/instance_client_gcp.cc index 4839204e..5ad5f919 100644 --- a/components/cloud_config/instance_client_gcp.cc +++ b/components/cloud_config/instance_client_gcp.cc @@ -54,7 +54,6 @@ using google::cmrt::sdk::instance_service::v1:: using ::google::scp::core::ExecutionResult; using ::google::scp::core::errors::GetErrorMessage; using google::scp::cpio::InstanceClientInterface; -using google::scp::cpio::InstanceClientOptions; namespace compute = ::google::cloud::compute_instances_v1; @@ -86,9 +85,10 @@ InstanceServiceStatus GetInstanceServiceStatus(const std::string& status) { class GcpInstanceClient : public InstanceClient { public: - GcpInstanceClient() - : instance_client_(google::scp::cpio::InstanceClientFactory::Create( - InstanceClientOptions())) { + GcpInstanceClient( + privacy_sandbox::server_common::log::PSLogContext& log_context) + : instance_client_(google::scp::cpio::InstanceClientFactory::Create()), + log_context_(log_context) { instance_client_->Init(); } @@ -120,7 +120,7 @@ class GcpInstanceClient : public InstanceClient { absl::Status RecordLifecycleHeartbeat( std::string_view lifecycle_hook_name) override { - LOG(INFO) << "Record lifecycle heartbeat."; + PS_LOG(INFO, log_context_) << "Record lifecycle heartbeat."; return absl::OkStatus(); } @@ -162,7 +162,7 @@ class GcpInstanceClient : public InstanceClient { std::string_view lifecycle_hook_name) override { PS_RETURN_IF_ERROR(SetInitializedLabel()) << "Error setting the initialized label"; - LOG(INFO) << "Complete lifecycle."; + PS_LOG(INFO, log_context_) << "Complete lifecycle."; return absl::OkStatus(); } @@ -227,6 +227,11 @@ class GcpInstanceClient : public InstanceClient { return std::vector{InstanceInfo{.id = *id}}; } + void UpdateLogContext( + privacy_sandbox::server_common::log::PSLogContext& log_context) override { + log_context_ = log_context; + } + private: absl::Status GetInstanceDetails() { absl::StatusOr resource_name = @@ -239,13 +244,15 @@ class GcpInstanceClient : public InstanceClient { GetInstanceDetailsByResourceNameRequest request; request.set_instance_resource_name(resource_name.value()); - const auto& result = instance_client_->GetInstanceDetailsByResourceName( + absl::Status status = instance_client_->GetInstanceDetailsByResourceName( std::move(request), [&done, this]( const ExecutionResult& result, const GetInstanceDetailsByResourceNameResponse& response) { if (result.Successful()) { - VLOG(2) << response.DebugString(); + // TODO(b/342614468): Temporarily turn off this vlog until + // verbosity setting API in the common repo is fixed + // PS_VLOG(2, log_context_) << response.DebugString(); instance_id_ = std::string{response.instance_details().instance_id()}; environment_ = @@ -253,36 +260,35 @@ class GcpInstanceClient : public InstanceClient { shard_number_ = response.instance_details().labels().at(kShardNumberLabel); } else { - LOG(ERROR) << "Failed to get instance details: " - << GetErrorMessage(result.status_code); + PS_LOG(ERROR, log_context_) << "Failed to get instance details: " + << GetErrorMessage(result.status_code); } done.Notify(); }); done.WaitForNotification(); - return result.Successful() - ? absl::OkStatus() - : absl::InternalError(GetErrorMessage(result.status_code)); + return status; } absl::StatusOr GetResourceName( std::unique_ptr& instance_client) { std::string resource_name; absl::Notification done; - const auto& result = instance_client->GetCurrentInstanceResourceName( + absl::Status status = instance_client->GetCurrentInstanceResourceName( GetCurrentInstanceResourceNameRequest(), [&](const ExecutionResult& result, const GetCurrentInstanceResourceNameResponse& response) { if (result.Successful()) { resource_name = std::string{response.instance_resource_name()}; } else { - LOG(ERROR) << "Failed to get instance resource name: " - << GetErrorMessage(result.status_code); + PS_LOG(ERROR, log_context_) + << "Failed to get instance resource name: " + << GetErrorMessage(result.status_code); } done.Notify(); }); - if (!result.Successful()) { - return absl::InternalError(GetErrorMessage(result.status_code)); + if (!status.ok()) { + return status; } done.WaitForNotification(); if (resource_name.empty()) { @@ -299,10 +305,12 @@ class GcpInstanceClient : public InstanceClient { std::unique_ptr instance_client_; compute::InstancesClient client_ = compute::InstancesClient(compute::MakeInstancesConnectionRest()); + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace -std::unique_ptr InstanceClient::Create() { - return std::make_unique(); +std::unique_ptr InstanceClient::Create( + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(log_context); } } // namespace kv_server diff --git a/components/cloud_config/instance_client_local.cc b/components/cloud_config/instance_client_local.cc index 32a30f19..a66ab089 100644 --- a/components/cloud_config/instance_client_local.cc +++ b/components/cloud_config/instance_client_local.cc @@ -30,6 +30,9 @@ namespace { class LocalInstanceClient : public InstanceClient { public: + explicit LocalInstanceClient( + privacy_sandbox::server_common::log::PSLogContext& log_context) + : log_context_(log_context) {} absl::StatusOr GetEnvironmentTag() override { return absl::GetFlag(FLAGS_environment); } @@ -40,13 +43,13 @@ class LocalInstanceClient : public InstanceClient { absl::Status RecordLifecycleHeartbeat( std::string_view lifecycle_hook_name) override { - LOG(INFO) << "Record lifecycle heartbeat."; + PS_LOG(INFO, log_context_) << "Record lifecycle heartbeat."; return absl::OkStatus(); } absl::Status CompleteLifecycle( std::string_view lifecycle_hook_name) override { - LOG(INFO) << "Complete lifecycle."; + PS_LOG(INFO, log_context_) << "Complete lifecycle."; return absl::OkStatus(); } @@ -75,12 +78,21 @@ class LocalInstanceClient : public InstanceClient { } return std::vector{InstanceInfo{.id = *id}}; } + + void UpdateLogContext( + privacy_sandbox::server_common::log::PSLogContext& log_context) override { + log_context_ = log_context; + } + + private: + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace -std::unique_ptr InstanceClient::Create() { - return std::make_unique(); +std::unique_ptr InstanceClient::Create( + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(log_context); } } // namespace kv_server diff --git a/components/cloud_config/parameter_client.h b/components/cloud_config/parameter_client.h index 0b6be921..b14e039c 100644 --- a/components/cloud_config/parameter_client.h +++ b/components/cloud_config/parameter_client.h @@ -19,6 +19,7 @@ #include #include "absl/status/statusor.h" +#include "src/logger/request_context_logger.h" // TODO: Replace config cpio client once ready namespace kv_server { @@ -32,7 +33,10 @@ class ParameterClient { }; static std::unique_ptr Create( - ClientOptions client_options = ClientOptions()); + ClientOptions client_options = ClientOptions(), + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); virtual ~ParameterClient() = default; @@ -45,6 +49,12 @@ class ParameterClient { virtual absl::StatusOr GetBoolParameter( std::string_view parameter_name) const = 0; + + // Updates the log context reference to enable otel logging for parameter + // client. This function should be called after telemetry is initialized with + // retrieved parameters. + virtual void UpdateLogContext( + privacy_sandbox::server_common::log::PSLogContext& log_context) = 0; }; } // namespace kv_server diff --git a/components/cloud_config/parameter_client_aws.cc b/components/cloud_config/parameter_client_aws.cc index bb13143a..9a3cc563 100644 --- a/components/cloud_config/parameter_client_aws.cc +++ b/components/cloud_config/parameter_client_aws.cc @@ -41,25 +41,27 @@ class AwsParameterClient : public ParameterClient { absl::StatusOr GetParameter( std::string_view parameter_name, std::optional default_value = std::nullopt) const override { - LOG(INFO) << "Getting parameter: " << parameter_name; + PS_LOG(INFO, log_context_) << "Getting parameter: " << parameter_name; Aws::SSM::Model::GetParameterRequest request; request.SetName(std::string(parameter_name)); const auto outcome = ssm_client_->GetParameter(request); if (!outcome.IsSuccess()) { if (default_value.has_value()) { - LOG(WARNING) << "Unable to get parameter: " << parameter_name - << " with error: " << outcome.GetError() - << ", returning default value: " << *default_value; + PS_LOG(WARNING, log_context_) + << "Unable to get parameter: " << parameter_name + << " with error: " << outcome.GetError() + << ", returning default value: " << *default_value; return *default_value; } else { - LOG(ERROR) << "Unable to get parameter: " << parameter_name - << " with error: " << outcome.GetError(); + PS_LOG(ERROR, log_context_) + << "Unable to get parameter: " << parameter_name + << " with error: " << outcome.GetError(); } return AwsErrorToStatus(outcome.GetError()); } std::string result = outcome.GetResult().GetParameter().GetValue(); - LOG(INFO) << "Got parameter: " << parameter_name - << " with value: " << result; + PS_LOG(INFO, log_context_) + << "Got parameter: " << parameter_name << " with value: " << result; return result; }; @@ -79,7 +81,7 @@ class AwsParameterClient : public ParameterClient { const std::string error = absl::StrFormat("Failed converting %s parameter: %s to int32.", parameter_name, *parameter); - LOG(ERROR) << error; + PS_LOG(ERROR, log_context_) << error; return absl::InvalidArgumentError(error); } @@ -102,15 +104,22 @@ class AwsParameterClient : public ParameterClient { const std::string error = absl::StrFormat("Failed converting %s parameter: %s to bool.", parameter_name, *parameter); - LOG(ERROR) << error; + PS_LOG(ERROR, log_context_) << error; return absl::InvalidArgumentError(error); } return parameter_bool; }; - explicit AwsParameterClient(ParameterClient::ClientOptions client_options) - : client_options_(std::move(client_options)) { + void UpdateLogContext( + privacy_sandbox::server_common::log::PSLogContext& log_context) override { + log_context_ = log_context; + } + + explicit AwsParameterClient( + ParameterClient::ClientOptions client_options, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : client_options_(std::move(client_options)), log_context_(log_context) { if (client_options.client_for_unit_testing_ != nullptr) { ssm_client_.reset( (Aws::SSM::SSMClient*)client_options.client_for_unit_testing_); @@ -122,13 +131,16 @@ class AwsParameterClient : public ParameterClient { private: ClientOptions client_options_; std::unique_ptr ssm_client_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace std::unique_ptr ParameterClient::Create( - ParameterClient::ClientOptions client_options) { - return std::make_unique(std::move(client_options)); + ParameterClient::ClientOptions client_options, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(std::move(client_options), + log_context); } } // namespace kv_server diff --git a/components/cloud_config/parameter_client_aws_test.cc b/components/cloud_config/parameter_client_aws_test.cc index 72a128e2..5880de85 100644 --- a/components/cloud_config/parameter_client_aws_test.cc +++ b/components/cloud_config/parameter_client_aws_test.cc @@ -37,7 +37,7 @@ class MockSsmClient : public ::Aws::SSM::SSMClient { }; class ParameterClientAwsTest : public ::testing::Test { - protected: + private: PlatformInitializer initializer_; }; diff --git a/components/cloud_config/parameter_client_gcp.cc b/components/cloud_config/parameter_client_gcp.cc index a5aee5e4..e416f22f 100644 --- a/components/cloud_config/parameter_client_gcp.cc +++ b/components/cloud_config/parameter_client_gcp.cc @@ -44,7 +44,10 @@ using google::scp::cpio::ParameterClientOptions; class GcpParameterClient : public ParameterClient { public: - explicit GcpParameterClient(ParameterClient::ClientOptions client_options) { + explicit GcpParameterClient( + ParameterClient::ClientOptions client_options, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : log_context_(log_context) { if (client_options.client_for_unit_testing_ == nullptr) { parameter_client_ = ParameterClientFactory::Create(ParameterClientOptions()); @@ -52,39 +55,31 @@ class GcpParameterClient : public ParameterClient { parameter_client_.reset(std::move( (ParameterClientInterface*)client_options.client_for_unit_testing_)); } - auto execution_result = parameter_client_->Init(); - CHECK(execution_result.Successful()) - << "Cannot init parameter client!" - << GetErrorMessage(execution_result.status_code); - execution_result = parameter_client_->Run(); - CHECK(execution_result.Successful()) - << "Cannot run parameter client!" - << GetErrorMessage(execution_result.status_code); + CHECK_OK(parameter_client_->Init()); + CHECK_OK(parameter_client_->Run()); } ~GcpParameterClient() { - auto execution_result = parameter_client_->Stop(); - if (!execution_result.Successful()) { - LOG(ERROR) << "Cannot stop parameter client!" - << GetErrorMessage(execution_result.status_code); + if (absl::Status status = parameter_client_->Stop(); !status.ok()) { + PS_LOG(ERROR, log_context_) << "Cannot stop parameter client!" << status; } } absl::StatusOr GetParameter( std::string_view parameter_name, std::optional default_value = std::nullopt) const override { - LOG(INFO) << "Getting parameter: " << parameter_name; + PS_LOG(INFO, log_context_) << "Getting parameter: " << parameter_name; GetParameterRequest get_parameter_request; get_parameter_request.set_parameter_name(parameter_name); std::string param_value; absl::BlockingCounter counter(1); - auto execution_result = parameter_client_->GetParameter( + absl::Status status = parameter_client_->GetParameter( std::move(get_parameter_request), - [¶m_value, &counter](const ExecutionResult result, - GetParameterResponse response) { + [¶m_value, &counter, &log_context = log_context_]( + const ExecutionResult result, GetParameterResponse response) { if (!result.Successful()) { - LOG(ERROR) << "GetParameter failed: " - << GetErrorMessage(result.status_code); + PS_LOG(ERROR, log_context) << "GetParameter failed: " + << GetErrorMessage(result.status_code); } else { param_value = response.parameter_value() != "EMPTY_STRING" ? response.parameter_value() @@ -94,21 +89,21 @@ class GcpParameterClient : public ParameterClient { counter.DecrementCount(); }); counter.Wait(); - if (!execution_result.Successful()) { - auto status = - absl::UnavailableError(GetErrorMessage(execution_result.status_code)); + if (!status.ok()) { if (default_value.has_value()) { - LOG(WARNING) << "Unable to get parameter: " << parameter_name - << " with error: " << status - << ", returning default value: " << *default_value; + PS_LOG(WARNING, log_context_) + << "Unable to get parameter: " << parameter_name + << " with error: " << status + << ", returning default value: " << *default_value; return *default_value; } - LOG(ERROR) << "Unable to get parameter: " << parameter_name - << " with error: " << status; + PS_LOG(ERROR, log_context_) + << "Unable to get parameter: " << parameter_name + << " with error: " << status; return status; } - LOG(INFO) << "Got parameter: " << parameter_name - << " with value: " << param_value; + PS_LOG(INFO, log_context_) << "Got parameter: " << parameter_name + << " with value: " << param_value; return param_value; } @@ -123,7 +118,7 @@ class GcpParameterClient : public ParameterClient { const std::string error = absl::StrFormat("Failed converting %s parameter: %s to int32.", parameter_name, *parameter); - LOG(ERROR) << error; + PS_LOG(ERROR, log_context_) << error; return absl::InvalidArgumentError(error); } @@ -143,22 +138,30 @@ class GcpParameterClient : public ParameterClient { const std::string error = absl::StrFormat("Failed converting %s parameter: %s to bool.", parameter_name, *parameter); - LOG(ERROR) << error; + PS_LOG(ERROR, log_context_) << error; return absl::InvalidArgumentError(error); } return parameter_bool; } + void UpdateLogContext( + privacy_sandbox::server_common::log::PSLogContext& log_context) override { + log_context_ = log_context; + } + private: std::unique_ptr parameter_client_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace std::unique_ptr ParameterClient::Create( - ParameterClient::ClientOptions client_options) { - return std::make_unique(std::move(client_options)); + ParameterClient::ClientOptions client_options, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(std::move(client_options), + log_context); } } // namespace kv_server diff --git a/components/cloud_config/parameter_client_gcp_test.cc b/components/cloud_config/parameter_client_gcp_test.cc index 1f81c585..d87a95de 100644 --- a/components/cloud_config/parameter_client_gcp_test.cc +++ b/components/cloud_config/parameter_client_gcp_test.cc @@ -56,13 +56,12 @@ class ParameterClientGcpTest : public ::testing::Test { std::unique_ptr mock_parameter_client = std::make_unique(); EXPECT_CALL(*mock_parameter_client, Init) - .WillOnce(Return(SuccessExecutionResult())); - EXPECT_CALL(*mock_parameter_client, Run) - .WillOnce(Return(SuccessExecutionResult())); + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_parameter_client, Run).WillOnce(Return(absl::OkStatus())); EXPECT_CALL(*mock_parameter_client, GetParameter) .WillRepeatedly( [this](GetParameterRequest get_param_req, - Callback callback) -> ExecutionResult { + Callback callback) -> absl::Status { // async reading parameter like the real case bool param_not_found = false; if (expected_param_values_.find(get_param_req.parameter_name()) == @@ -83,11 +82,9 @@ class ParameterClientGcpTest : public ::testing::Test { cb(SuccessExecutionResult(), response); } })); - if (param_not_found) { - return FailureExecutionResult(5); - } else { - return SuccessExecutionResult(); - } + return param_not_found + ? absl::NotFoundError("Parameter not found.") + : absl::OkStatus(); }); ParameterClient::ClientOptions client_options; diff --git a/components/cloud_config/parameter_client_local.cc b/components/cloud_config/parameter_client_local.cc index ea8bf5f4..c051a649 100644 --- a/components/cloud_config/parameter_client_local.cc +++ b/components/cloud_config/parameter_client_local.cc @@ -57,6 +57,8 @@ ABSL_FLAG(std::int32_t, logging_verbosity_level, 0, "Loggging verbosity level."); ABSL_FLAG(absl::Duration, udf_timeout, absl::Seconds(5), "Timeout for one UDF invocation"); +ABSL_FLAG(absl::Duration, udf_update_timeout, absl::Seconds(30), + "Timeout for UDF code update"); ABSL_FLAG(int32_t, udf_min_log_level, 0, "Minimum logging level for UDFs. Info=0, Warn=1, Error=2. Default is " "0(info)."); @@ -67,6 +69,8 @@ ABSL_FLAG(std::string, data_loading_prefix_allowlist, "", "Allowlist for blob prefixes."); ABSL_FLAG(bool, add_missing_keys_v1, false, "Whether to add missing keys for v1."); +ABSL_FLAG(bool, enable_consented_log, false, "Whether to enable consented log"); +ABSL_FLAG(std::string, consented_debug_token, "", "Consented debug token"); namespace kv_server { namespace { @@ -79,7 +83,9 @@ namespace { // if there's a better way. class LocalParameterClient : public ParameterClient { public: - LocalParameterClient() { + LocalParameterClient( + privacy_sandbox::server_common::log::PSLogContext& log_context) + : log_context_(log_context) { string_flag_values_.insert( {"kv-server-local-directory", absl::GetFlag(FLAGS_delta_directory)}); string_flag_values_.insert({"kv-server-local-data-bucket-id", @@ -95,6 +101,8 @@ class LocalParameterClient : public ParameterClient { string_flag_values_.insert( {"kv-server-local-data-loading-blob-prefix-allowlist", absl::GetFlag(FLAGS_data_loading_prefix_allowlist)}); + string_flag_values_.insert({"kv-server-local-consented-debug-token", + absl::GetFlag(FLAGS_consented_debug_token)}); // Insert more string flag values here. int32_t_flag_values_.insert( @@ -127,6 +135,9 @@ class LocalParameterClient : public ParameterClient { int32_t_flag_values_.insert( {"kv-server-local-udf-timeout-millis", absl::ToInt64Milliseconds(absl::GetFlag(FLAGS_udf_timeout))}); + int32_t_flag_values_.insert( + {"kv-server-local-udf-update-timeout-millis", + absl::ToInt64Milliseconds(absl::GetFlag(FLAGS_udf_update_timeout))}); int32_t_flag_values_.insert({"kv-server-local-udf-min-log-level", absl::GetFlag(FLAGS_udf_min_log_level)}); // Insert more int32 flag values here. @@ -140,6 +151,8 @@ class LocalParameterClient : public ParameterClient { bool_flag_values_.insert({"kv-server-local-use-sharding-key-regex", false}); bool_flag_values_.insert({"kv-server-local-enable-otel-logger", absl::GetFlag(FLAGS_enable_otel_logger)}); + bool_flag_values_.insert({"kv-server-local-enable-consented-log", + absl::GetFlag(FLAGS_enable_consented_log)}); // Insert more bool flag values here. } @@ -177,17 +190,24 @@ class LocalParameterClient : public ParameterClient { } } + void UpdateLogContext( + privacy_sandbox::server_common::log::PSLogContext& log_context) override { + log_context_ = log_context; + } + private: absl::flat_hash_map int32_t_flag_values_; absl::flat_hash_map string_flag_values_; absl::flat_hash_map bool_flag_values_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace std::unique_ptr ParameterClient::Create( - ParameterClient::ClientOptions client_options) { - return std::make_unique(); + ParameterClient::ClientOptions client_options, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(log_context); } } // namespace kv_server diff --git a/components/cloud_config/parameter_client_local_test.cc b/components/cloud_config/parameter_client_local_test.cc index 62393754..f3218579 100644 --- a/components/cloud_config/parameter_client_local_test.cc +++ b/components/cloud_config/parameter_client_local_test.cc @@ -104,6 +104,12 @@ TEST(ParameterClientLocal, ExpectedFlagDefaultsArePresent) { ASSERT_TRUE(statusor.ok()); EXPECT_EQ(5000, *statusor); } + { + const auto statusor = + client->GetInt32Parameter("kv-server-local-udf-update-timeout-millis"); + ASSERT_TRUE(statusor.ok()); + EXPECT_EQ(30000, *statusor); + } { const auto statusor = client->GetInt32Parameter("kv-server-local-udf-min-log-level"); diff --git a/components/container/BUILD.bazel b/components/container/BUILD.bazel new file mode 100644 index 00000000..876ad64c --- /dev/null +++ b/components/container/BUILD.bazel @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "thread_safe_hash_map", + hdrs = ["thread_safe_hash_map.h"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/synchronization", + ], +) + +cc_test( + name = "thread_safe_hash_map_test", + size = "small", + srcs = [ + "thread_safe_hash_map_test.cc", + ], + deps = [ + ":thread_safe_hash_map", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/components/container/thread_safe_hash_map.h b/components/container/thread_safe_hash_map.h new file mode 100644 index 00000000..cdeb2bed --- /dev/null +++ b/components/container/thread_safe_hash_map.h @@ -0,0 +1,279 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_CONTAINER_THREAD_SAFE_HASH_MAP_H_ +#define COMPONENTS_CONTAINER_THREAD_SAFE_HASH_MAP_H_ + +#include +#include +#include + +#include "absl/container/node_hash_map.h" +#include "absl/synchronization/mutex.h" + +namespace kv_server { + +// Implements a mutex based "thread-safe" container wrapper around Abseil +// maps. Exposes key-level locking such that two values associated two separate +// keys can be modified concurrently. +// +// Note that synchronization is done using `absl::Mutex` which is not re-entrant +// so trying to obtain a `MutableLockedNode` on key while already holding +// another `MutableLockedNode` for the same key will result in a deadlock. For +// example, the following code will deadlock: +// +// ThreadSafeHashMap map; +// auto node = map.PutIfAbsent(10, 20); +// map.Get(10); // This call deadlocks. +// +// For convenience, `LockedNode` provides a function `release()` to release a +// lock on a key early (instead of relying on RAII), but accessing the +// `LockedNode`'s contents after releasing the lock results in undefined +// behavior. +template +class ThreadSafeHashMap { + public: + class const_iterator; + template + class LockedNodePtr; + // Locked read-only view for a specific key value associtation in the map. + // Must call `is_present()` before calling and dereferencing `key()` and + // `value()`. + class ConstLockedNodePtr; + // Locked view for a specific key value associtation in the map with an + // in-place modifable value. Must call `is_present()` before calling and + // dereferencing `key()` and `value()`. + using MutableLockedNodePtr = LockedNodePtr; + + ThreadSafeHashMap() : nodes_map_mutex_(std::make_unique()) {} + + // Returns a locked read-only view for `key` and it's associated value. If + // `key` does not exist in the map, then: `ConstLockedNode.is_present()` is + // false. + // Prefer this function for reads because it uses a shared + // `absl::ReaderMutexLock` for synchronization and concurrent reads do not + // block. + template + ConstLockedNodePtr CGet(Key&& key) const; + + // Returns a locked view for `key` with a modifable value. If `key` does not + // exist in the map, then: `MutableLockedNode.is_present()` is false. + // Uses an exclusive lock `absl::WriterMutexLock` so prefer `CGet()` above for + // reads only. + template + MutableLockedNodePtr Get(Key&& key) const; + + // Inserts `key` and `value` mapping into the map if `key` does not exist + // in the map. + // Returns: + // - `true` and view to the newly inserted `key`, `value` mapping if `key` + // does not exist. + // - `false` and view to the existing `key`, `value` mapping if `key` exist. + template + std::pair PutIfAbsent(Key&& key, Value&& value); + + // Removes `key` and it's associated value from the map if `predicate(value)` + // is true. + template + void RemoveIf( + Key&& key, std::function predicate = + [](const ValueT&) { return true; }); + + const_iterator begin() ABSL_NO_THREAD_SAFETY_ANALYSIS; + const_iterator end() ABSL_NO_THREAD_SAFETY_ANALYSIS; + + private: + struct ValueNode { + template + explicit ValueNode(Value&& val) + : value(std::forward(val)), + mutex(std::make_unique()) {} + ValueT value; + std::unique_ptr mutex; + }; + using KeyValueNodesMapType = + absl::node_hash_map>; + + template + NodeT GetNode(Key&& key) const; + + std::unique_ptr nodes_map_mutex_; + KeyValueNodesMapType key_value_nodes_map_ ABSL_GUARDED_BY(*nodes_map_mutex_); +}; + +template +template +NodeT ThreadSafeHashMap::GetNode(Key&& key) const { + absl::ReaderMutexLock map_lock(nodes_map_mutex_.get()); + if (auto iter = key_value_nodes_map_.find(std::forward(key)); + iter == key_value_nodes_map_.end()) { + return NodeT(nullptr, nullptr, nullptr); + } else { + return NodeT(&iter->first, &iter->second->value, + std::make_unique(iter->second->mutex.get())); + } +} + +template +template +typename ThreadSafeHashMap::ConstLockedNodePtr +ThreadSafeHashMap::CGet(Key&& key) const { + return GetNode( + std::forward(key)); +} + +template +template +typename ThreadSafeHashMap::MutableLockedNodePtr +ThreadSafeHashMap::Get(Key&& key) const { + return GetNode( + std::forward(key)); +} + +template +template +std::pair::MutableLockedNodePtr, bool> +ThreadSafeHashMap::PutIfAbsent(Key&& key, Value&& value) { + absl::WriterMutexLock map_lock(nodes_map_mutex_.get()); + if (auto iter = key_value_nodes_map_.find(key); + iter != key_value_nodes_map_.end()) { + return std::make_pair( + MutableLockedNodePtr( + &iter->first, &iter->second->value, + std::make_unique(iter->second->mutex.get())), + false); + } + auto result = key_value_nodes_map_.emplace( + std::forward(key), + std::make_unique(std::forward(value))); + return std::make_pair( + MutableLockedNodePtr(&result.first->first, &result.first->second->value, + std::make_unique( + result.first->second->mutex.get())), + true); +} + +template +template +void ThreadSafeHashMap::RemoveIf( + Key&& key, std::function predicate) { + absl::WriterMutexLock map_lock(nodes_map_mutex_.get()); + auto iter = key_value_nodes_map_.find(std::forward(key)); + if (iter == key_value_nodes_map_.end()) { + return; + } + { + // Wait for any current threads using the value to release their locks. + absl::WriterMutexLock value_lock(iter->second->mutex.get()); + } + if (predicate(iter->second->value)) { + key_value_nodes_map_.erase(iter); + } +} + +template +typename ThreadSafeHashMap::const_iterator +ThreadSafeHashMap::begin() { + return const_iterator( + std::make_unique(nodes_map_mutex_.get()), + key_value_nodes_map_.begin()); +} + +template +typename ThreadSafeHashMap::const_iterator +ThreadSafeHashMap::end() { + return const_iterator(nullptr, key_value_nodes_map_.end()); +} + +template +template +class ThreadSafeHashMap::LockedNodePtr { + public: + bool is_present() const { return key_ != nullptr; } + const KeyT* key() const { return key_; } + ValueT* value() const { return value_; } + void release() { lock_ = nullptr; } + + private: + LockedNodePtr() : LockedNodePtr(nullptr, nullptr, nullptr) {} + LockedNodePtr(const KeyT* key, ValueT* value, + std::unique_ptr lock) + : key_(key), value_(value), lock_(std::move(lock)) {} + + friend class ThreadSafeHashMap; + + const KeyT* key_; + ValueT* value_; + std::unique_ptr lock_; +}; + +template +class ThreadSafeHashMap::ConstLockedNodePtr + : LockedNodePtr { + using Base = typename ThreadSafeHashMap::ConstLockedNodePtr::LockedNodePtr; + + public: + using Base::is_present; + using Base::key; + using Base::release; + const ValueT* value() const { return Base::value(); } + + private: + using Base::Base; +}; + +template +class ThreadSafeHashMap::const_iterator { + public: + using value_type = ConstLockedNodePtr; + using pointer = value_type*; + using reference = value_type&; + + reference operator*() ABSL_NO_THREAD_SAFETY_ANALYSIS { + current_node_ = std::move(ConstLockedNodePtr( + &nodes_map_iter_->first, &nodes_map_iter_->second->value, + std::make_unique( + nodes_map_iter_->second->mutex.get()))); + return current_node_; + } + pointer operator->() { return &operator*(); } + const_iterator& operator++() { + nodes_map_iter_++; + return *this; + } + friend bool operator==(const const_iterator& a, const const_iterator& b) { + return a.nodes_map_iter_ == b.nodes_map_iter_; + } + friend bool operator!=(const const_iterator& a, const const_iterator& b) { + return !(a == b); + } + + private: + const_iterator(std::unique_ptr nodes_map_lock, + typename KeyValueNodesMapType::iterator nodes_map_iter) + : nodes_map_lock_(std::move(nodes_map_lock)), + nodes_map_iter_(nodes_map_iter) {} + + friend class ThreadSafeHashMap; + + ConstLockedNodePtr current_node_; + std::unique_ptr nodes_map_lock_; + typename KeyValueNodesMapType::iterator nodes_map_iter_; +}; + +} // namespace kv_server + +#endif // COMPONENTS_CONTAINER_THREAD_SAFE_HASH_MAP_H_ diff --git a/components/container/thread_safe_hash_map_test.cc b/components/container/thread_safe_hash_map_test.cc new file mode 100644 index 00000000..f379fbb5 --- /dev/null +++ b/components/container/thread_safe_hash_map_test.cc @@ -0,0 +1,193 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/container/thread_safe_hash_map.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +using testing::UnorderedElementsAre; +using testing::UnorderedElementsAreArray; + +TEST(ThreadSafeHashMapTest, VerifyCGet) { + ThreadSafeHashMap map; + auto node = map.CGet(10); + EXPECT_FALSE(node.is_present()); + map.PutIfAbsent(10, 20); + node = map.CGet(10); + ASSERT_TRUE(node.is_present()); + EXPECT_EQ(*node.key(), 10); + EXPECT_EQ(*node.value(), 20); +} + +TEST(ThreadSafeHashMapTest, VerifyGet) { + ThreadSafeHashMap map; + { + auto node = map.Get("key"); + EXPECT_FALSE(node.is_present()); + } + for (auto i : std::vector{1, 2, 3, 4, 5}) { + std::string key = absl::StrCat("key", i); + map.PutIfAbsent(key, i); + auto node = map.Get(key); + ASSERT_TRUE(node.is_present()); + EXPECT_EQ(*node.key(), key); + EXPECT_EQ(*node.value(), i); + } +} + +TEST(ThreadSafeHashMapTest, VerifyPutIfAbsent) { + ThreadSafeHashMap map; + { + auto result = map.PutIfAbsent("key", "value"); + EXPECT_TRUE(result.second); + EXPECT_EQ(*result.first.value(), "value"); + } + { + auto result = map.PutIfAbsent("key", "not applied"); + EXPECT_FALSE(result.second); + EXPECT_EQ(*result.first.value(), "value"); + } +} + +TEST(ThreadSafeHashMapTest, VerifyRemoveIf) { + ThreadSafeHashMap map; + map.PutIfAbsent(10, "value"); + { + auto node = map.CGet(10); + ASSERT_TRUE(node.is_present()); + EXPECT_EQ(*node.key(), 10); + EXPECT_EQ(*node.value(), "value"); + } + map.RemoveIf(10, + [](const std::string& value) { return value == "wrong value"; }); + { + auto node = map.CGet(10); + ASSERT_TRUE(node.is_present()); + EXPECT_EQ(*node.key(), 10); + EXPECT_EQ(*node.value(), "value"); + } + map.RemoveIf(10, [](const std::string& value) { return value == "value"; }); + { + auto node = map.CGet(10); + ASSERT_FALSE(node.is_present()); + } +} + +TEST(ThreadSafeHashMapTest, VerifyIteration) { + ThreadSafeHashMap map; + for (const auto& value : + std::vector{"one", "two", "three", "four", "five"}) { + map.PutIfAbsent(value, value); + } + std::vector values; + for (auto& node : map) { + values.push_back(*node.value()); + } + EXPECT_THAT(values, + UnorderedElementsAre("one", "two", "three", "four", "five")); +} + +TEST(ThreadSafeHashMapTest, VerifyMutableLockedValue) { + ThreadSafeHashMap map; + map.PutIfAbsent("key", "value"); + { + auto node = map.Get("key"); + ASSERT_TRUE(node.is_present()); + EXPECT_EQ(*node.value(), "value"); + // Let's modify value in place. + *node.value() = "modified"; + } + { + auto node = map.Get("key"); + ASSERT_TRUE(node.is_present()); + EXPECT_EQ(*node.value(), "modified"); + } +} + +TEST(ThreadSafeHashMapTest, VerifyMoveOnlyValues) { + ThreadSafeHashMap> map; + std::string_view key = "key"; + { + auto node = map.PutIfAbsent(key, std::make_unique("value")); + EXPECT_TRUE(node.second); + } + { + auto node = map.CGet(key); + EXPECT_TRUE(node.is_present()); + EXPECT_THAT(*node.key(), key); + EXPECT_THAT(**node.value(), "value"); + } +} + +TEST(ThreadSafeHashMapTest, VerifyMultiThreadedWritesToSimpleType) { + ThreadSafeHashMap map; + auto key = 10; + map.PutIfAbsent(key, 0); + auto incr = [key, &map]() { + auto node = map.Get(key); + int32_t* value = node.value(); + *value = *value + 1; + }; + int num_tasks = 100; + std::vector> tasks; + tasks.reserve(num_tasks); + for (int t = 0; t < num_tasks; t++) { + tasks.push_back(std::async(std::launch::async, incr)); + } + for (auto& task : tasks) { + task.get(); + } + auto node = map.CGet(key); + EXPECT_EQ(*node.value(), num_tasks); +} + +TEST(ThreadSafeHashMapTest, VerifyMultiThreadedWritesToComplexType) { + ThreadSafeHashMap> map; + int32_t key = 10; + std::size_t size = 1000; + map.PutIfAbsent(key, std::vector(size, 0)); + auto incr = [key, &map]() { + auto node = map.Get(key); + auto* values = node.value(); + for (int i = 0; i < values->size(); i++) { + (*values)[i] = (*values)[i] + 1; + } + }; + int num_tasks = 100; + std::vector> tasks; + tasks.reserve(num_tasks); + for (int t = 0; t < num_tasks; t++) { + tasks.push_back(std::async(std::launch::async, incr)); + } + for (auto& task : tasks) { + task.get(); + } + std::vector expected(size, num_tasks); + auto node = map.CGet(key); + EXPECT_THAT(*node.value(), + UnorderedElementsAreArray(expected.begin(), expected.end())); +} + +} // namespace +} // namespace kv_server diff --git a/components/data/blob_storage/BUILD.bazel b/components/data/blob_storage/BUILD.bazel index 3597529b..03d4732f 100644 --- a/components/data/blob_storage/BUILD.bazel +++ b/components/data/blob_storage/BUILD.bazel @@ -30,6 +30,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", "@google_privacysandbox_servers_common//src/telemetry:telemetry_provider", ], ) @@ -101,6 +102,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", ], ) diff --git a/components/data/blob_storage/blob_storage_change_notifier.h b/components/data/blob_storage/blob_storage_change_notifier.h index 59ad8e31..047bc0dd 100644 --- a/components/data/blob_storage/blob_storage_change_notifier.h +++ b/components/data/blob_storage/blob_storage_change_notifier.h @@ -40,7 +40,10 @@ class BlobStorageChangeNotifier { const std::function& should_stop_callback) = 0; static absl::StatusOr> Create( - NotifierMetadata notifier_metadata); + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server diff --git a/components/data/blob_storage/blob_storage_change_notifier_gcp.cc b/components/data/blob_storage/blob_storage_change_notifier_gcp.cc index 8b76c1bb..b77d8c9b 100644 --- a/components/data/blob_storage/blob_storage_change_notifier_gcp.cc +++ b/components/data/blob_storage/blob_storage_change_notifier_gcp.cc @@ -27,8 +27,9 @@ namespace { class GcpBlobStorageChangeNotifier : public BlobStorageChangeNotifier { public: explicit GcpBlobStorageChangeNotifier( - std::unique_ptr notifier) - : notifier_(std::move(notifier)) {} + std::unique_ptr notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : notifier_(std::move(notifier)), log_context_(log_context) {} ~GcpBlobStorageChangeNotifier() override { sleep_for_.Stop(); } @@ -46,19 +47,24 @@ class GcpBlobStorageChangeNotifier : public BlobStorageChangeNotifier { private: std::unique_ptr notifier_; SleepFor sleep_for_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> -BlobStorageChangeNotifier::Create(NotifierMetadata notifier_metadata) { +BlobStorageChangeNotifier::Create( + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { absl::StatusOr> notifier = - ChangeNotifier::Create(std::get(notifier_metadata)); + ChangeNotifier::Create(std::get(notifier_metadata), + log_context); if (!notifier.ok()) { return notifier.status(); } - return std::make_unique(std::move(*notifier)); + return std::make_unique(std::move(*notifier), + log_context); } } // namespace kv_server diff --git a/components/data/blob_storage/blob_storage_change_notifier_local.cc b/components/data/blob_storage/blob_storage_change_notifier_local.cc index c2394dd6..bdf9298f 100644 --- a/components/data/blob_storage/blob_storage_change_notifier_local.cc +++ b/components/data/blob_storage/blob_storage_change_notifier_local.cc @@ -23,8 +23,9 @@ namespace { class LocalBlobStorageChangeNotifier : public BlobStorageChangeNotifier { public: explicit LocalBlobStorageChangeNotifier( - std::unique_ptr notifier) - : notifier_(std::move(notifier)) {} + std::unique_ptr notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : notifier_(std::move(notifier)), log_context_(log_context) {} absl::StatusOr> GetNotifications( absl::Duration max_wait, @@ -34,20 +35,24 @@ class LocalBlobStorageChangeNotifier : public BlobStorageChangeNotifier { private: std::unique_ptr notifier_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> -BlobStorageChangeNotifier::Create(NotifierMetadata notifier_metadata) { +BlobStorageChangeNotifier::Create( + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { absl::StatusOr> notifier = - ChangeNotifier::Create( - std::get(notifier_metadata)); + ChangeNotifier::Create(std::get(notifier_metadata), + log_context); if (!notifier.ok()) { return notifier.status(); } - return std::make_unique(std::move(*notifier)); + return std::make_unique(std::move(*notifier), + log_context); } } // namespace kv_server diff --git a/components/data/blob_storage/blob_storage_change_notifier_s3.cc b/components/data/blob_storage/blob_storage_change_notifier_s3.cc index 0f24f007..7bea7d84 100644 --- a/components/data/blob_storage/blob_storage_change_notifier_s3.cc +++ b/components/data/blob_storage/blob_storage_change_notifier_s3.cc @@ -26,8 +26,10 @@ namespace { class S3BlobStorageChangeNotifier : public BlobStorageChangeNotifier { public: - explicit S3BlobStorageChangeNotifier(std::unique_ptr notifier) - : change_notifier_(std::move(notifier)) {} + explicit S3BlobStorageChangeNotifier( + std::unique_ptr notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : change_notifier_(std::move(notifier)), log_context_(log_context) {} absl::StatusOr> GetNotifications( absl::Duration max_wait, @@ -45,8 +47,9 @@ class S3BlobStorageChangeNotifier : public BlobStorageChangeNotifier { const absl::StatusOr parsedMessage = ParseObjectKeyFromJson(message); if (!parsedMessage.ok()) { - LOG(ERROR) << "Failed to parse JSON. Error: " << parsedMessage.status() - << " Message:" << message; + PS_LOG(ERROR, log_context_) + << "Failed to parse JSON. Error: " << parsedMessage.status() + << " Message:" << message; LogServerErrorMetric(kAwsJsonParseError); continue; } @@ -97,23 +100,27 @@ class S3BlobStorageChangeNotifier : public BlobStorageChangeNotifier { } std::unique_ptr change_notifier_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> -BlobStorageChangeNotifier::Create(NotifierMetadata notifier_metadata) { +BlobStorageChangeNotifier::Create( + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { auto cloud_notifier_metadata = std::get(notifier_metadata); cloud_notifier_metadata.queue_prefix = "BlobNotifier_"; absl::StatusOr> status_or = - ChangeNotifier::Create(std::move(cloud_notifier_metadata)); + ChangeNotifier::Create(std::move(cloud_notifier_metadata), log_context); if (!status_or.ok()) { return status_or.status(); } - return std::make_unique(std::move(*status_or)); + return std::make_unique(std::move(*status_or), + log_context); } } // namespace kv_server diff --git a/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc b/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc index 0709d015..69bc4e8f 100644 --- a/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc +++ b/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc @@ -84,8 +84,10 @@ class BlobStorageChangeNotifierS3Test : public ::testing::Test { .WillOnce(::testing::Return(outcome)); } - PlatformInitializer initializer_; MockMessageService mock_message_service_; + + private: + PlatformInitializer initializer_; }; TEST_F(BlobStorageChangeNotifierS3Test, AwsSqsUnavailable) { diff --git a/components/data/blob_storage/blob_storage_client.h b/components/data/blob_storage/blob_storage_client.h index 68cb3f27..47585ef4 100644 --- a/components/data/blob_storage/blob_storage_client.h +++ b/components/data/blob_storage/blob_storage_client.h @@ -27,6 +27,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "src/logger/request_context_logger.h" namespace kv_server { @@ -95,7 +96,10 @@ class BlobStorageClientFactory { virtual ~BlobStorageClientFactory() = default; virtual std::unique_ptr CreateBlobStorageClient( BlobStorageClient::ClientOptions client_options = - BlobStorageClient::ClientOptions()) = 0; + BlobStorageClient::ClientOptions(), + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)) = 0; static std::unique_ptr Create(); }; diff --git a/components/data/blob_storage/blob_storage_client_gcp.cc b/components/data/blob_storage/blob_storage_client_gcp.cc index 5be23c51..1d4ae64c 100644 --- a/components/data/blob_storage/blob_storage_client_gcp.cc +++ b/components/data/blob_storage/blob_storage_client_gcp.cc @@ -86,16 +86,24 @@ class GcpBlobInputStreamBuf : public SeekingInputStreambuf { class GcpBlobReader : public BlobReader { public: - GcpBlobReader(google::cloud::storage::Client& client, - BlobStorageClient::DataLocation location) + GcpBlobReader( + google::cloud::storage::Client& client, + BlobStorageClient::DataLocation location, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)) : BlobReader(), + log_context_(log_context), streambuf_(client, location, - GetOptions([this, location](absl::Status status) { - LOG(ERROR) << "Blob " - << AppendPrefix(location.key, location.prefix) - << " failed stream with: " << status; - is_.setstate(std::ios_base::badbit); - })), + GetOptions( + [this, location](absl::Status status) { + PS_LOG(ERROR, log_context_) + << "Blob " + << AppendPrefix(location.key, location.prefix) + << " failed stream with: " << status; + is_.setstate(std::ios_base::badbit); + }, + log_context)), is_(&streambuf_) {} std::istream& Stream() { return is_; } @@ -103,24 +111,28 @@ class GcpBlobReader : public BlobReader { private: static SeekingInputStreambuf::Options GetOptions( - std::function error_callback) { + std::function error_callback, + privacy_sandbox::server_common::log::PSLogContext& log_context) { SeekingInputStreambuf::Options options; options.error_callback = std::move(error_callback); + options.log_context = log_context; return options; } - + privacy_sandbox::server_common::log::PSLogContext& log_context_; GcpBlobInputStreamBuf streambuf_; std::istream is_; }; } // namespace GcpBlobStorageClient::GcpBlobStorageClient( - std::unique_ptr client) - : client_(std::move(client)) {} + std::unique_ptr client, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : client_(std::move(client)), log_context_(log_context) {} std::unique_ptr GcpBlobStorageClient::GetBlobReader( DataLocation location) { - return std::make_unique(*client_, std::move(location)); + return std::make_unique(*client_, std::move(location), + log_context_); } absl::Status GcpBlobStorageClient::PutBlob(BlobReader& blob_reader, @@ -157,8 +169,9 @@ absl::StatusOr> GcpBlobStorageClient::ListBlobs( } for (auto&& object_metadata : list_object_reader) { if (!object_metadata) { - LOG(ERROR) << "Blob error when listing blobs:" - << std::move(object_metadata).status().message(); + PS_LOG(ERROR, log_context_) + << "Blob error when listing blobs:" + << std::move(object_metadata).status().message(); continue; } // Manually exclude the starting name as the StartOffset option is @@ -178,9 +191,10 @@ class GcpBlobStorageClientFactory : public BlobStorageClientFactory { public: ~GcpBlobStorageClientFactory() = default; std::unique_ptr CreateBlobStorageClient( - BlobStorageClient::ClientOptions /*client_options*/) override { + BlobStorageClient::ClientOptions /*client_options*/, + privacy_sandbox::server_common::log::PSLogContext& log_context) override { return std::make_unique( - std::make_unique()); + std::make_unique(), log_context); } }; } // namespace diff --git a/components/data/blob_storage/blob_storage_client_gcp.h b/components/data/blob_storage/blob_storage_client_gcp.h index d937a2b5..aaed23b2 100644 --- a/components/data/blob_storage/blob_storage_client_gcp.h +++ b/components/data/blob_storage/blob_storage_client_gcp.h @@ -28,7 +28,8 @@ namespace kv_server { class GcpBlobStorageClient : public BlobStorageClient { public: explicit GcpBlobStorageClient( - std::unique_ptr client); + std::unique_ptr client, + privacy_sandbox::server_common::log::PSLogContext& log_context); ~GcpBlobStorageClient() = default; @@ -43,5 +44,6 @@ class GcpBlobStorageClient : public BlobStorageClient { private: std::unique_ptr client_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace kv_server diff --git a/components/data/blob_storage/blob_storage_client_gcp_test.cc b/components/data/blob_storage/blob_storage_client_gcp_test.cc index e08c952e..ac4b4900 100644 --- a/components/data/blob_storage/blob_storage_client_gcp_test.cc +++ b/components/data/blob_storage/blob_storage_client_gcp_test.cc @@ -50,6 +50,7 @@ using testing::Property; class GcpBlobStorageClientTest : public ::testing::Test { protected: PlatformInitializer initializer_; + privacy_sandbox::server_common::log::NoOpContext no_op_context_; void SetUp() override { privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; config_proto.set_mode( @@ -73,7 +74,8 @@ TEST_F(GcpBlobStorageClientTest, DeleteBlobSucceeds) { google::cloud::make_status_or(gcs::internal::EmptyResponse{}))); std::unique_ptr client = - std::make_unique(std::move(mock_client)); + std::make_unique(std::move(mock_client), + no_op_context_); BlobStorageClient::DataLocation location; location.bucket = "test_bucket"; location.key = "test_object"; @@ -93,7 +95,8 @@ TEST_F(GcpBlobStorageClientTest, DeleteBlobFails) { google::cloud::StatusCode::kPermissionDenied, "uh-oh"))); std::unique_ptr client = - std::make_unique(std::move(mock_client)); + std::make_unique(std::move(mock_client), + no_op_context_); BlobStorageClient::DataLocation location; location.bucket = "test_bucket"; location.key = "test_object"; @@ -118,7 +121,8 @@ TEST_F(GcpBlobStorageClientTest, ListBlobSucceeds) { .WillOnce(testing::Return(response)); std::unique_ptr client = - std::make_unique(std::move(mock_client)); + std::make_unique(std::move(mock_client), + no_op_context_); BlobStorageClient::DataLocation location; location.bucket = "test_bucket"; location.key = "test_object"; @@ -151,7 +155,8 @@ TEST_F(GcpBlobStorageClientTest, ListBlobWithNonInclusiveStartAfter) { .WillOnce(testing::Return(response)); std::unique_ptr client = - std::make_unique(std::move(mock_client)); + std::make_unique(std::move(mock_client), + no_op_context_); BlobStorageClient::DataLocation location; location.bucket = "test_bucket"; location.key = "test_object"; @@ -180,7 +185,8 @@ TEST_F(GcpBlobStorageClientTest, ListBlobWithNoNewObject) { .WillOnce(testing::Return(response)); std::unique_ptr client = - std::make_unique(std::move(mock_client)); + std::make_unique(std::move(mock_client), + no_op_context_); BlobStorageClient::DataLocation location; location.bucket = "test_bucket"; location.key = "test_object"; @@ -208,7 +214,8 @@ TEST_F(GcpBlobStorageClientTest, DeleteBlobWithPrefixSucceeds) { google::cloud::make_status_or(gcs::internal::EmptyResponse{}))); std::unique_ptr client = - std::make_unique(std::move(mock_client)); + std::make_unique(std::move(mock_client), + no_op_context_); BlobStorageClient::DataLocation location{ .bucket = "test_bucket", .prefix = "test_prefix", diff --git a/components/data/blob_storage/blob_storage_client_local.cc b/components/data/blob_storage/blob_storage_client_local.cc index 8f823a38..e4d7e190 100644 --- a/components/data/blob_storage/blob_storage_client_local.cc +++ b/components/data/blob_storage/blob_storage_client_local.cc @@ -50,7 +50,7 @@ std::unique_ptr FileBlobStorageClient::GetBlobReader( std::make_unique(GetFullPath(location)); if (!reader->Stream()) { - LOG(ERROR) << absl::ErrnoToStatus( + PS_LOG(ERROR, log_context_) << absl::ErrnoToStatus( errno, absl::StrCat("Unable to open file: ", GetFullPath(location).string())); return nullptr; @@ -135,8 +135,9 @@ class LocalBlobStorageClientFactory : public BlobStorageClientFactory { public: ~LocalBlobStorageClientFactory() = default; std::unique_ptr CreateBlobStorageClient( - BlobStorageClient::ClientOptions /*client_options*/) override { - return std::make_unique(); + BlobStorageClient::ClientOptions /*client_options*/, + privacy_sandbox::server_common::log::PSLogContext& log_context) override { + return std::make_unique(log_context); } }; } // namespace diff --git a/components/data/blob_storage/blob_storage_client_local.h b/components/data/blob_storage/blob_storage_client_local.h index 7b55a9ad..a53fe6bf 100644 --- a/components/data/blob_storage/blob_storage_client_local.h +++ b/components/data/blob_storage/blob_storage_client_local.h @@ -26,7 +26,9 @@ namespace kv_server { class FileBlobStorageClient : public BlobStorageClient { public: - FileBlobStorageClient() = default; + FileBlobStorageClient( + privacy_sandbox::server_common::log::PSLogContext& log_context) + : log_context_(log_context) {} ~FileBlobStorageClient() = default; @@ -41,5 +43,6 @@ class FileBlobStorageClient : public BlobStorageClient { private: std::filesystem::path GetFullPath(const DataLocation& location); + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace kv_server diff --git a/components/data/blob_storage/blob_storage_client_local_test.cc b/components/data/blob_storage/blob_storage_client_local_test.cc index e8b714d2..572090a5 100644 --- a/components/data/blob_storage/blob_storage_client_local_test.cc +++ b/components/data/blob_storage/blob_storage_client_local_test.cc @@ -33,6 +33,11 @@ namespace kv_server { namespace { +class LocalBlobStorageClientTest : public ::testing::Test { + protected: + privacy_sandbox::server_common::log::NoOpContext no_op_context_; +}; + void CreateSubDir(std::string_view subdir_name) { std::filesystem::create_directory( std::filesystem::path(::testing::TempDir()) / subdir_name); @@ -45,9 +50,9 @@ void CreateFileInTmpDir(const std::string& filename) { file << "arbitrary file contents"; } -TEST(LocalBlobStorageClientTest, ListNotFoundDirectory) { +TEST_F(LocalBlobStorageClientTest, ListNotFoundDirectory) { std::unique_ptr client = - std::make_unique(); + std::make_unique(no_op_context_); BlobStorageClient::DataLocation location; location.bucket = "this is not a valid directory path"; @@ -57,9 +62,9 @@ TEST(LocalBlobStorageClientTest, ListNotFoundDirectory) { client->ListBlobs(location, options).status().code()); } -TEST(LocalBlobStorageClientTest, ListEmptyDirectory) { +TEST_F(LocalBlobStorageClientTest, ListEmptyDirectory) { std::unique_ptr client = - std::make_unique(); + std::make_unique(no_op_context_); BlobStorageClient::DataLocation location; // Directory contains no files by default. @@ -70,9 +75,9 @@ TEST(LocalBlobStorageClientTest, ListEmptyDirectory) { EXPECT_TRUE(status_or.value().empty()); } -TEST(LocalBlobStorageClientTest, ListDirectoryWithFile) { +TEST_F(LocalBlobStorageClientTest, ListDirectoryWithFile) { std::unique_ptr client = - std::make_unique(); + std::make_unique(no_op_context_); CreateFileInTmpDir("a"); BlobStorageClient::DataLocation location; @@ -84,9 +89,9 @@ TEST(LocalBlobStorageClientTest, ListDirectoryWithFile) { EXPECT_EQ(*status_or, std::vector{"a"}); } -TEST(LocalBlobStorageClientTest, DeleteNotFoundBlob) { +TEST_F(LocalBlobStorageClientTest, DeleteNotFoundBlob) { std::unique_ptr client = - std::make_unique(); + std::make_unique(no_op_context_); BlobStorageClient::DataLocation location; location.bucket = "this is not a valid directory path"; @@ -95,9 +100,9 @@ TEST(LocalBlobStorageClientTest, DeleteNotFoundBlob) { EXPECT_EQ(absl::StatusCode::kInternal, client->DeleteBlob(location).code()); } -TEST(LocalBlobStorageClientTest, DeleteBlob) { +TEST_F(LocalBlobStorageClientTest, DeleteBlob) { std::unique_ptr client = - std::make_unique(); + std::make_unique(no_op_context_); BlobStorageClient::DataLocation location; location.bucket = ::testing::TempDir(); @@ -107,9 +112,9 @@ TEST(LocalBlobStorageClientTest, DeleteBlob) { EXPECT_EQ(absl::StatusCode::kOk, client->DeleteBlob(location).code()); } -TEST(LocalBlobStorageClientTest, PutBlob) { +TEST_F(LocalBlobStorageClientTest, PutBlob) { std::unique_ptr client = - std::make_unique(); + std::make_unique(no_op_context_); BlobStorageClient::DataLocation from; from.bucket = ::testing::TempDir(); @@ -126,9 +131,9 @@ TEST(LocalBlobStorageClientTest, PutBlob) { client->PutBlob(*from_blob_reader, to).code()); } -TEST(LocalBlobStorageClientTest, DeleteBlobWithPrefix) { +TEST_F(LocalBlobStorageClientTest, DeleteBlobWithPrefix) { std::unique_ptr client = - std::make_unique(); + std::make_unique(no_op_context_); CreateSubDir("prefix"); BlobStorageClient::DataLocation location{ .bucket = ::testing::TempDir(), @@ -144,9 +149,9 @@ TEST(LocalBlobStorageClientTest, DeleteBlobWithPrefix) { EXPECT_EQ(status.code(), absl::StatusCode::kInternal) << status; } -TEST(LocalBlobStorageClientTest, ListSubDirectoryWithFiles) { +TEST_F(LocalBlobStorageClientTest, ListSubDirectoryWithFiles) { std::unique_ptr client = - std::make_unique(); + std::make_unique(no_op_context_); CreateSubDir("prefix"); CreateFileInTmpDir("prefix/object1"); CreateFileInTmpDir("prefix/object2"); diff --git a/components/data/blob_storage/blob_storage_client_s3.cc b/components/data/blob_storage/blob_storage_client_s3.cc index 984e593c..72190159 100644 --- a/components/data/blob_storage/blob_storage_client_s3.cc +++ b/components/data/blob_storage/blob_storage_client_s3.cc @@ -105,17 +105,24 @@ class S3BlobInputStreamBuf : public SeekingInputStreambuf { class S3BlobReader : public BlobReader { public: - S3BlobReader(Aws::S3::S3Client& client, - BlobStorageClient::DataLocation location, - int64_t max_range_bytes) + S3BlobReader( + Aws::S3::S3Client& client, BlobStorageClient::DataLocation location, + int64_t max_range_bytes, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)) : BlobReader(), + log_context_(log_context), streambuf_(client, location, - GetOptions(max_range_bytes, - [this, location](absl::Status status) { - LOG(ERROR) << "Blob " << location.key - << " failed stream with: " << status; - is_.setstate(std::ios_base::badbit); - })), + GetOptions( + max_range_bytes, + [this, location](absl::Status status) { + PS_LOG(ERROR, log_context_) + << "Blob " << location.key + << " failed stream with: " << status; + is_.setstate(std::ios_base::badbit); + }, + log_context)), is_(&streambuf_) {} std::istream& Stream() { return is_; } @@ -123,21 +130,26 @@ class S3BlobReader : public BlobReader { private: static SeekingInputStreambuf::Options GetOptions( - int64_t buffer_size, std::function error_callback) { + int64_t buffer_size, std::function error_callback, + privacy_sandbox::server_common::log::PSLogContext& log_context) { SeekingInputStreambuf::Options options; options.buffer_size = buffer_size; options.error_callback = std::move(error_callback); + options.log_context = log_context; return options; } - + privacy_sandbox::server_common::log::PSLogContext& log_context_; S3BlobInputStreamBuf streambuf_; std::istream is_; }; } // namespace S3BlobStorageClient::S3BlobStorageClient( - std::shared_ptr client, int64_t max_range_bytes) - : client_(client), max_range_bytes_(max_range_bytes) { + std::shared_ptr client, int64_t max_range_bytes, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : client_(client), + max_range_bytes_(max_range_bytes), + log_context_(log_context) { executor_ = std::make_unique( std::thread::hardware_concurrency()); Aws::Transfer::TransferManagerConfiguration transfer_config(executor_.get()); @@ -148,7 +160,7 @@ S3BlobStorageClient::S3BlobStorageClient( std::unique_ptr S3BlobStorageClient::GetBlobReader( DataLocation location) { return std::make_unique(*client_, std::move(location), - max_range_bytes_); + max_range_bytes_, log_context_); } absl::Status S3BlobStorageClient::PutBlob(BlobReader& reader, @@ -224,14 +236,15 @@ class S3BlobStorageClientFactory : public BlobStorageClientFactory { public: ~S3BlobStorageClientFactory() = default; std::unique_ptr CreateBlobStorageClient( - BlobStorageClient::ClientOptions client_options) override { + BlobStorageClient::ClientOptions client_options, + privacy_sandbox::server_common::log::PSLogContext& log_context) override { Aws::Client::ClientConfiguration config; config.maxConnections = client_options.max_connections; std::shared_ptr client = std::make_shared(config); return std::make_unique( - client, client_options.max_range_bytes); + client, client_options.max_range_bytes, log_context); } }; } // namespace diff --git a/components/data/blob_storage/blob_storage_client_s3.h b/components/data/blob_storage/blob_storage_client_s3.h index b2d11561..6b3743cf 100644 --- a/components/data/blob_storage/blob_storage_client_s3.h +++ b/components/data/blob_storage/blob_storage_client_s3.h @@ -25,13 +25,15 @@ #include "aws/s3/S3Client.h" #include "aws/transfer/TransferManager.h" #include "components/data/blob_storage/blob_storage_client.h" +#include "src/logger/request_context_logger.h" namespace kv_server { class S3BlobStorageClient : public BlobStorageClient { public: - explicit S3BlobStorageClient(std::shared_ptr client, - int64_t max_range_bytes); + explicit S3BlobStorageClient( + std::shared_ptr client, int64_t max_range_bytes, + privacy_sandbox::server_common::log::PSLogContext& log_context); ~S3BlobStorageClient() = default; @@ -51,5 +53,6 @@ class S3BlobStorageClient : public BlobStorageClient { std::shared_ptr client_; std::shared_ptr transfer_manager_; int64_t max_range_bytes_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace kv_server diff --git a/components/data/blob_storage/blob_storage_client_s3_test.cc b/components/data/blob_storage/blob_storage_client_s3_test.cc index 7d20884e..938810b8 100644 --- a/components/data/blob_storage/blob_storage_client_s3_test.cc +++ b/components/data/blob_storage/blob_storage_client_s3_test.cc @@ -55,7 +55,6 @@ class MockS3Client : public ::Aws::S3::S3Client { class BlobStorageClientS3Test : public ::testing::Test { protected: - PlatformInitializer initializer_; void SetUp() override { privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; config_proto.set_mode( @@ -64,6 +63,10 @@ class BlobStorageClientS3Test : public ::testing::Test { privacy_sandbox::server_common::telemetry::BuildDependentConfig( config_proto)); } + privacy_sandbox::server_common::log::NoOpContext no_op_context_; + + private: + PlatformInitializer initializer_; }; TEST_F(BlobStorageClientS3Test, DeleteBlobSucceeds) { @@ -73,7 +76,8 @@ TEST_F(BlobStorageClientS3Test, DeleteBlobSucceeds) { .WillOnce(::testing::Return(result)); std::unique_ptr client = - std::make_unique(mock_s3_client, kMaxRangeBytes); + std::make_unique(mock_s3_client, kMaxRangeBytes, + no_op_context_); BlobStorageClient::DataLocation location; EXPECT_TRUE(client->DeleteBlob(location).ok()); } @@ -83,7 +87,8 @@ TEST_F(BlobStorageClientS3Test, DeleteBlobFails) { auto mock_s3_client = std::make_shared(); std::unique_ptr client = - std::make_unique(mock_s3_client, kMaxRangeBytes); + std::make_unique(mock_s3_client, kMaxRangeBytes, + no_op_context_); BlobStorageClient::DataLocation location; EXPECT_EQ(absl::StatusCode::kUnknown, client->DeleteBlob(location).code()); } @@ -102,7 +107,8 @@ TEST_F(BlobStorageClientS3Test, ListBlobsSucceeds) { } std::unique_ptr client = - std::make_unique(mock_s3_client, kMaxRangeBytes); + std::make_unique(mock_s3_client, kMaxRangeBytes, + no_op_context_); BlobStorageClient::DataLocation location; BlobStorageClient::ListOptions list_options; absl::StatusOr> response = @@ -139,7 +145,8 @@ TEST_F(BlobStorageClientS3Test, ListBlobsSucceedsWithContinuedRequests) { } std::unique_ptr client = - std::make_unique(mock_s3_client, kMaxRangeBytes); + std::make_unique(mock_s3_client, kMaxRangeBytes, + no_op_context_); BlobStorageClient::DataLocation location; BlobStorageClient::ListOptions list_options; absl::StatusOr> response = @@ -154,7 +161,8 @@ TEST_F(BlobStorageClientS3Test, ListBlobsFails) { auto mock_s3_client = std::make_shared(); std::unique_ptr client = - std::make_unique(mock_s3_client, kMaxRangeBytes); + std::make_unique(mock_s3_client, kMaxRangeBytes, + no_op_context_); BlobStorageClient::DataLocation location; BlobStorageClient::ListOptions list_options; EXPECT_EQ(absl::StatusCode::kUnknown, @@ -173,7 +181,8 @@ TEST_F(BlobStorageClientS3Test, DeleteBlobWithPrefixSucceeds) { "prefix/object")))) .WillOnce(::testing::Return(result)); std::unique_ptr client = - std::make_unique(mock_s3_client, kMaxRangeBytes); + std::make_unique(mock_s3_client, kMaxRangeBytes, + no_op_context_); BlobStorageClient::DataLocation location{ .bucket = "bucket", .prefix = "prefix", @@ -204,7 +213,8 @@ TEST_F(BlobStorageClientS3Test, ListBlobsWithPrefixSucceeds) { } std::unique_ptr client = - std::make_unique(mock_s3_client, kMaxRangeBytes); + std::make_unique(mock_s3_client, kMaxRangeBytes, + no_op_context_); BlobStorageClient::DataLocation location{ .bucket = "bucket", .prefix = "directory1", diff --git a/components/data/blob_storage/delta_file_notifier.cc b/components/data/blob_storage/delta_file_notifier.cc index cdef6df9..8ab7261e 100644 --- a/components/data/blob_storage/delta_file_notifier.cc +++ b/components/data/blob_storage/delta_file_notifier.cc @@ -37,17 +37,18 @@ using privacy_sandbox::server_common::SteadyClock; class DeltaFileNotifierImpl : public DeltaFileNotifier { public: - explicit DeltaFileNotifierImpl(BlobStorageClient& client, - const absl::Duration poll_frequency, - std::unique_ptr sleep_for, - SteadyClock& clock, - BlobPrefixAllowlist blob_prefix_allowlist) + explicit DeltaFileNotifierImpl( + BlobStorageClient& client, const absl::Duration poll_frequency, + std::unique_ptr sleep_for, SteadyClock& clock, + BlobPrefixAllowlist blob_prefix_allowlist, + privacy_sandbox::server_common::log::PSLogContext& log_context) : thread_manager_(ThreadManager::Create("Delta file notifier")), client_(client), poll_frequency_(poll_frequency), sleep_for_(std::move(sleep_for)), clock_(clock), - blob_prefix_allowlist_(std::move(blob_prefix_allowlist)) {} + blob_prefix_allowlist_(std::move(blob_prefix_allowlist)), + log_context_(log_context) {} absl::Status Start( BlobStorageChangeNotifier& change_notifier, @@ -94,6 +95,7 @@ class DeltaFileNotifierImpl : public DeltaFileNotifier { std::unique_ptr sleep_for_; SteadyClock& clock_; BlobPrefixAllowlist blob_prefix_allowlist_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; absl::StatusOr DeltaFileNotifierImpl::WaitForNotification( @@ -121,7 +123,7 @@ absl::StatusOr DeltaFileNotifierImpl::ShouldListBlobs( const absl::flat_hash_map& prefix_start_after_map) { if (!expiring_flag.Get()) { - VLOG(5) << "Backup poll"; + PS_VLOG(5, log_context_) << "Backup poll"; return true; } absl::StatusOr notification_key = @@ -129,7 +131,7 @@ absl::StatusOr DeltaFileNotifierImpl::ShouldListBlobs( // Don't poll on error. A backup poll will trigger if necessary. if (absl::IsDeadlineExceeded(notification_key.status())) { // Deadline exceeded while waiting, trigger backup poll - VLOG(5) << "Backup poll"; + PS_VLOG(5, log_context_) << "Backup poll"; return true; } if (!notification_key.ok()) { @@ -152,7 +154,8 @@ absl::flat_hash_map> ListPrefixDeltaFiles( BlobStorageClient::DataLocation location, const BlobPrefixAllowlist& prefix_allowlist, const absl::flat_hash_map& prefix_start_after_map, - BlobStorageClient& blob_client) { + BlobStorageClient& blob_client, + privacy_sandbox::server_common::log::PSLogContext& log_context) { absl::flat_hash_map> prefix_blobs_map; for (const auto& blob_prefix : prefix_allowlist.Prefixes()) { location.prefix = blob_prefix; @@ -163,7 +166,8 @@ absl::flat_hash_map> ListPrefixDeltaFiles( .start_after = (iter == prefix_start_after_map.end()) ? "" : iter->second}); if (!result.ok()) { - LOG(ERROR) << "Failed to list " << location << ": " << result.status(); + PS_LOG(ERROR, log_context) + << "Failed to list " << location << ": " << result.status(); continue; } if (result->empty()) { @@ -183,7 +187,7 @@ void DeltaFileNotifierImpl::Watch( BlobStorageClient::DataLocation location, absl::flat_hash_map&& prefix_start_after_map, std::function callback) { - LOG(INFO) << "Started to watch " << location; + PS_LOG(INFO, log_context_) << "Started to watch " << location; // Flag starts expired, and forces an initial poll. ExpiringFlag expiring_flag(clock_); uint32_t sequential_failures = 0; @@ -195,12 +199,12 @@ void DeltaFileNotifierImpl::Watch( const absl::Duration backoff_time = std::min(expiring_flag.GetTimeRemaining(), ExponentialBackoffForRetry(sequential_failures)); - LOG(ERROR) << "Failed to get delta file notifications: " - << should_list_blobs.status() << ". Waiting for " - << backoff_time; + PS_LOG(ERROR, log_context_) + << "Failed to get delta file notifications: " + << should_list_blobs.status() << ". Waiting for " << backoff_time; if (!sleep_for_->Duration(backoff_time)) { - LOG(ERROR) << "Failed to sleep for " << backoff_time - << ". SleepFor invalid."; + PS_LOG(ERROR, log_context_) + << "Failed to sleep for " << backoff_time << ". SleepFor invalid."; } continue; } @@ -212,8 +216,9 @@ void DeltaFileNotifierImpl::Watch( // Fake clock is moved forward in callback so flag must be set beforehand. expiring_flag.Set(poll_frequency_); int delta_file_count = 0; - auto prefix_blobs_map = ListPrefixDeltaFiles( - location, blob_prefix_allowlist_, prefix_start_after_map, client_); + auto prefix_blobs_map = + ListPrefixDeltaFiles(location, blob_prefix_allowlist_, + prefix_start_after_map, client_, log_context_); for (const auto& [prefix, prefix_blobs] : prefix_blobs_map) { for (const auto& blob : prefix_blobs) { if (!IsDeltaFilename(blob)) { @@ -225,7 +230,7 @@ void DeltaFileNotifierImpl::Watch( } } if (delta_file_count == 0) { - VLOG(2) << "No new file found"; + PS_VLOG(2, log_context_) << "No new file found"; } } } @@ -234,20 +239,22 @@ void DeltaFileNotifierImpl::Watch( std::unique_ptr DeltaFileNotifier::Create( BlobStorageClient& client, const absl::Duration poll_frequency, - BlobPrefixAllowlist blob_prefix_allowlist) { + BlobPrefixAllowlist blob_prefix_allowlist, + privacy_sandbox::server_common::log::PSLogContext& log_context) { return std::make_unique( client, poll_frequency, std::make_unique(), - SteadyClock::RealClock(), std::move(blob_prefix_allowlist)); + SteadyClock::RealClock(), std::move(blob_prefix_allowlist), log_context); } // For test only std::unique_ptr DeltaFileNotifier::Create( BlobStorageClient& client, const absl::Duration poll_frequency, std::unique_ptr sleep_for, SteadyClock& clock, - BlobPrefixAllowlist blob_prefix_allowlist) { + BlobPrefixAllowlist blob_prefix_allowlist, + privacy_sandbox::server_common::log::PSLogContext& log_context) { return std::make_unique( client, poll_frequency, std::move(sleep_for), clock, - std::move(blob_prefix_allowlist)); + std::move(blob_prefix_allowlist), log_context); } } // namespace kv_server diff --git a/components/data/blob_storage/delta_file_notifier.h b/components/data/blob_storage/delta_file_notifier.h index 6df74bce..c03f1f42 100644 --- a/components/data/blob_storage/delta_file_notifier.h +++ b/components/data/blob_storage/delta_file_notifier.h @@ -62,14 +62,20 @@ class DeltaFileNotifier { static std::unique_ptr Create( BlobStorageClient& client, const absl::Duration poll_frequency = absl::Minutes(5), - BlobPrefixAllowlist blob_prefix_allowlist = BlobPrefixAllowlist("")); + BlobPrefixAllowlist blob_prefix_allowlist = BlobPrefixAllowlist(""), + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); // Used for test static std::unique_ptr Create( BlobStorageClient& client, const absl::Duration poll_frequency, std::unique_ptr sleep_for, privacy_sandbox::server_common::SteadyClock& clock, - BlobPrefixAllowlist blob_prefix_allowlist = BlobPrefixAllowlist("")); + BlobPrefixAllowlist blob_prefix_allowlist = BlobPrefixAllowlist(""), + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server diff --git a/components/data/blob_storage/seeking_input_streambuf.cc b/components/data/blob_storage/seeking_input_streambuf.cc index fdb3bba1..5d48204f 100644 --- a/components/data/blob_storage/seeking_input_streambuf.cc +++ b/components/data/blob_storage/seeking_input_streambuf.cc @@ -35,11 +35,14 @@ constexpr std::string_view kUnderflowEventName = "SeekingInputStreambuf::underflow"; constexpr std::string_view kSeekoffEventName = "SeekingInputStreambuf::seekoff"; -void MaybeVerboseLogLatency(std::string_view event_name, absl::Duration latency, - double sampling_threshold = 0.05) { +void MaybeVerboseLogLatency( + std::string_view event_name, absl::Duration latency, + privacy_sandbox::server_common::log::PSLogContext& log_context, + double sampling_threshold = 0.05) { if ((double)std::rand() / RAND_MAX <= sampling_threshold) { - VLOG(3) << event_name << " latency: " << absl::ToDoubleMilliseconds(latency) - << " ms."; + PS_VLOG(3, log_context) + << event_name << " latency: " << absl::ToDoubleMilliseconds(latency) + << " ms."; } } } // namespace @@ -107,7 +110,8 @@ std::streampos SeekingInputStreambuf::seekoff(std::streamoff off, buffer_.data() + (new_position - BufferStartPosition()), buffer_.data() + buffer_.length()); } - MaybeVerboseLogLatency(kSeekoffEventName, latency_recorder.GetLatency()); + MaybeVerboseLogLatency(kSeekoffEventName, latency_recorder.GetLatency(), + options_.log_context); return std::streampos(std::streamoff(new_position)); } @@ -148,7 +152,12 @@ std::streambuf::int_type SeekingInputStreambuf::underflow() { buffer_.resize(total_bytes_read); } setg(buffer_.data(), buffer_.data(), buffer_.data() + buffer_.length()); - MaybeVerboseLogLatency(kUnderflowEventName, latency_recorder.GetLatency()); + LogIfError( + KVServerContextMap() + ->SafeMetric() + .template LogHistogram((int)total_bytes_read)); + MaybeVerboseLogLatency(kUnderflowEventName, latency_recorder.GetLatency(), + options_.log_context); return traits_type::to_int_type(buffer_[0]); } @@ -180,7 +189,8 @@ absl::StatusOr SeekingInputStreambuf::Size() { return size.status(); } src_cached_size_ = *size; - MaybeVerboseLogLatency(kSizeEventName, latency_recorder.GetLatency()); + MaybeVerboseLogLatency(kSizeEventName, latency_recorder.GetLatency(), + options_.log_context); return *size; } diff --git a/components/data/blob_storage/seeking_input_streambuf.h b/components/data/blob_storage/seeking_input_streambuf.h index c76a0c5f..c46e0979 100644 --- a/components/data/blob_storage/seeking_input_streambuf.h +++ b/components/data/blob_storage/seeking_input_streambuf.h @@ -20,6 +20,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "src/logger/request_context_logger.h" #include "src/telemetry/telemetry_provider.h" #ifndef COMPONENTS_DATA_BLOB_STORAGE_SEEKING_INPUT_STREAMBUF_H_ @@ -69,6 +70,9 @@ class SeekingInputStreambuf : public std::streambuf { // underlying source which can be painfully slow and expensive. std::int64_t buffer_size = 8 * 1024 * 1024; // 8MB std::function error_callback = [](absl::Status) {}; + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext); }; explicit SeekingInputStreambuf(Options options = Options()); diff --git a/components/data/common/BUILD.bazel b/components/data/common/BUILD.bazel index 05830260..511082f7 100644 --- a/components/data/common/BUILD.bazel +++ b/components/data/common/BUILD.bazel @@ -53,6 +53,8 @@ cc_library( "@com_google_absl//absl/random", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", + "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", ], ) @@ -92,6 +94,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", "@google_privacysandbox_servers_common//src/telemetry:telemetry_provider", ], ) @@ -147,6 +150,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", "@google_privacysandbox_servers_common//src/util:duration", ], ) diff --git a/components/data/common/change_notifier.h b/components/data/common/change_notifier.h index d2e95de2..752c2ef0 100644 --- a/components/data/common/change_notifier.h +++ b/components/data/common/change_notifier.h @@ -22,6 +22,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "components/data/common/msg_svc.h" +#include "src/logger/request_context_logger.h" namespace kv_server { @@ -43,7 +44,10 @@ class ChangeNotifier { const std::function& should_stop_callback) = 0; static absl::StatusOr> Create( - NotifierMetadata notifier_metadata); + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server diff --git a/components/data/common/change_notifier_aws.cc b/components/data/common/change_notifier_aws.cc index 9e1ee74c..48e571b7 100644 --- a/components/data/common/change_notifier_aws.cc +++ b/components/data/common/change_notifier_aws.cc @@ -54,9 +54,12 @@ constexpr char kLastUpdatedTag[] = "last_updated"; class AwsChangeNotifier : public ChangeNotifier { public: - explicit AwsChangeNotifier(AwsNotifierMetadata notifier_metadata) + explicit AwsChangeNotifier( + AwsNotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) : sns_arn_(std::move(notifier_metadata.sns_arn)), - queue_manager_(notifier_metadata.queue_manager) { + queue_manager_(notifier_metadata.queue_manager), + log_context_(log_context) { if (notifier_metadata.only_for_testing_sqs_client_ != nullptr) { sqs_.reset(notifier_metadata.only_for_testing_sqs_client_); } else { @@ -72,12 +75,14 @@ class AwsChangeNotifier : public ChangeNotifier { absl::StatusOr> GetNotifications( absl::Duration max_wait, const std::function& should_stop_callback) override { - LOG(INFO) << "Getting notifications for topic " << sns_arn_; + PS_LOG(INFO, log_context_) + << "Getting notifications for topic " << sns_arn_; do { if (!queue_manager_->IsSetupComplete()) { absl::Status status = queue_manager_->SetupQueue(); if (!status.ok()) { - LOG(ERROR) << "Could not set up queue for topic " << sns_arn_; + PS_LOG(ERROR, log_context_) + << "Could not set up queue for topic " << sns_arn_; LogServerErrorMetric(kAwsChangeNotifierQueueSetupFailure); return status; } @@ -124,8 +129,9 @@ class AwsChangeNotifier : public ChangeNotifier { if (status.ok()) { last_updated_ = now; } else { - LOG(ERROR) << "Failed to TagQueue with " << kLastUpdatedTag << ": " - << tag << " " << status; + PS_LOG(ERROR, log_context_) + << "Failed to TagQueue with " << kLastUpdatedTag << ": " << tag + << " " << status; LogServerErrorMetric(kAwsChangeNotifierTagFailure); } } @@ -146,13 +152,13 @@ class AwsChangeNotifier : public ChangeNotifier { request.SetMaxNumberOfMessages(10); const auto outcome = sqs_->ReceiveMessage(request); if (!outcome.IsSuccess()) { - LOG(ERROR) << "Failed to receive message from SQS: " - << outcome.GetError().GetMessage(); + PS_LOG(ERROR, log_context_) << "Failed to receive message from SQS: " + << outcome.GetError().GetMessage(); LogServerErrorMetric(kAwsChangeNotifierMessagesReceivingFailure); if (!outcome.GetError().ShouldRetry()) { // Handle case where recreating Queue will resolve the issue. // Example: Queue accidentally deleted. - LOG(INFO) << "Will create a new Queue"; + PS_LOG(INFO, log_context_) << "Will create a new Queue"; queue_manager_->Reset(); } return absl::UnavailableError(outcome.GetError().GetMessage()); @@ -168,9 +174,22 @@ class AwsChangeNotifier : public ChangeNotifier { absl::Now() - receive_message_request_started))); std::vector keys; + keys.reserve(messages.size()); + size_t total_message_size; for (const auto& message : messages) { keys.push_back(message.GetBody()); + total_message_size += message.GetBody().size(); } + LogIfError( + KVServerContextMap() + ->SafeMetric() + .template LogHistogram( + static_cast(total_message_size))); + LogIfError(KVServerContextMap() + ->SafeMetric() + .LogUpDownCounter( + static_cast(messages.size()))); + DeleteMessages(GetSqsUrl(), messages); if (keys.empty()) { LogServerErrorMetric(kAwsChangeNotifierMessagesDataLoss); @@ -197,8 +216,8 @@ class AwsChangeNotifier : public ChangeNotifier { req.SetEntries(std::move(delete_message_batch_request_entries)); const auto outcome = sqs_->DeleteMessageBatch(req); if (!outcome.IsSuccess()) { - LOG(ERROR) << "Failed to delete message from SQS: " - << outcome.GetError().GetMessage(); + PS_LOG(ERROR, log_context_) << "Failed to delete message from SQS: " + << outcome.GetError().GetMessage(); LogServerErrorMetric(kAwsChangeNotifierMessagesDeletionFailure); } } @@ -207,13 +226,15 @@ class AwsChangeNotifier : public ChangeNotifier { const std::string sns_arn_; absl::Time last_updated_ = absl::InfinitePast(); std::unique_ptr sqs_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> ChangeNotifier::Create( - NotifierMetadata notifier_metadata) { + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { return std::make_unique( - std::move(std::get(notifier_metadata))); + std::move(std::get(notifier_metadata)), log_context); } } // namespace kv_server diff --git a/components/data/common/change_notifier_aws_test.cc b/components/data/common/change_notifier_aws_test.cc index 2db5e4f6..8850a1ef 100644 --- a/components/data/common/change_notifier_aws_test.cc +++ b/components/data/common/change_notifier_aws_test.cc @@ -63,6 +63,8 @@ class ChangeNotifierAwsTest : public ::testing::Test { privacy_sandbox::server_common::telemetry::BuildDependentConfig( config_proto)); } + + private: PlatformInitializer initializer_; }; diff --git a/components/data/common/change_notifier_gcp.cc b/components/data/common/change_notifier_gcp.cc index b9f4c19c..06670f42 100644 --- a/components/data/common/change_notifier_gcp.cc +++ b/components/data/common/change_notifier_gcp.cc @@ -33,7 +33,9 @@ namespace { class GcpChangeNotifier : public ChangeNotifier { public: - GcpChangeNotifier() {} + explicit GcpChangeNotifier( + privacy_sandbox::server_common::log::PSLogContext& log_context) + : log_context_(log_context) {} ~GcpChangeNotifier() { sleep_for_.Stop(); } absl::StatusOr> GetNotifications( @@ -45,13 +47,15 @@ class GcpChangeNotifier : public ChangeNotifier { private: SleepFor sleep_for_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> ChangeNotifier::Create( - NotifierMetadata notifier_metadata) { - return std::make_unique(); + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(log_context); } } // namespace kv_server diff --git a/components/data/common/change_notifier_local.cc b/components/data/common/change_notifier_local.cc index 94d8a09d..285bea93 100644 --- a/components/data/common/change_notifier_local.cc +++ b/components/data/common/change_notifier_local.cc @@ -32,17 +32,21 @@ constexpr absl::Duration kPollInterval = absl::Seconds(5); class LocalChangeNotifier : public ChangeNotifier { public: - explicit LocalChangeNotifier(std::filesystem::path local_directory) - : local_directory_(local_directory) { - VLOG(1) << "Building initial list of local files in directory: " - << local_directory_.string(); + explicit LocalChangeNotifier( + std::filesystem::path local_directory, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : local_directory_(local_directory), log_context_(log_context) { + PS_VLOG(1, log_context_) + << "Building initial list of local files in directory: " + << local_directory_.string(); auto status_or = FindNewFiles({}); if (!status_or.ok()) { - LOG(ERROR) << "Unable to build initial file list" - << status_or.status().message(); + PS_LOG(ERROR, log_context_) << "Unable to build initial file list" + << status_or.status().message(); } files_in_directory_ = std::move(status_or.value()); - VLOG(1) << "Found " << files_in_directory_.size() << " files."; + PS_VLOG(1, log_context_) + << "Found " << files_in_directory_.size() << " files."; } ~LocalChangeNotifier() { sleep_for_.Stop(); } @@ -50,17 +54,18 @@ class LocalChangeNotifier : public ChangeNotifier { absl::StatusOr> GetNotifications( absl::Duration max_wait, const std::function& should_stop_callback) override { - LOG(INFO) << "Watching for new files in directory: " - << local_directory_.string(); + PS_LOG(INFO, log_context_) + << "Watching for new files in directory: " << local_directory_.string(); while (true) { if (should_stop_callback()) { - VLOG(1) << "Callback says to stop watching, stopping."; + PS_VLOG(1, log_context_) << "Callback says to stop watching, stopping."; return std::vector{}; } if (max_wait <= absl::ZeroDuration()) { - VLOG(1) << "No new files found within timeout, stopping."; + PS_VLOG(1, log_context_) + << "No new files found within timeout, stopping."; return absl::DeadlineExceededError("No messages found"); } @@ -69,7 +74,7 @@ class LocalChangeNotifier : public ChangeNotifier { return status_or.status(); } if (!status_or->empty()) { - VLOG(1) << "Found new local files."; + PS_VLOG(1, log_context_) << "Found new local files."; // Add the new files to the running list so that they'll be ignored if // GetNotifications is called again. files_in_directory_.insert(status_or->begin(), status_or->end()); @@ -98,7 +103,7 @@ class LocalChangeNotifier : public ChangeNotifier { const std::string filename = std::filesystem::path(file).filename().string(); if (previous_files.find(filename) == previous_files.end()) { - VLOG(1) << "Found new file: " << filename; + PS_VLOG(1, log_context_) << "Found new file: " << filename; new_files.emplace(filename); } } @@ -112,12 +117,14 @@ class LocalChangeNotifier : public ChangeNotifier { // We can't store std::filesystem::path objects in the set because the paths // aren't guaranteed to be canonical so we store the string paths instead. absl::flat_hash_set files_in_directory_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> ChangeNotifier::Create( - NotifierMetadata notifier_metadata) { + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { std::error_code error_code; auto local_notifier_metadata = std::get(notifier_metadata); @@ -136,7 +143,7 @@ absl::StatusOr> ChangeNotifier::Create( } return std::make_unique( - std::move(local_notifier_metadata.local_directory)); + std::move(local_notifier_metadata.local_directory), log_context); } } // namespace kv_server diff --git a/components/data/common/msg_svc.h b/components/data/common/msg_svc.h index 6d52741c..79402b68 100644 --- a/components/data/common/msg_svc.h +++ b/components/data/common/msg_svc.h @@ -26,6 +26,7 @@ #include "absl/status/statusor.h" #include "components/data/common/msg_svc.h" #include "components/data/common/notifier_metadata.h" +#include "src/logger/request_context_logger.h" namespace kv_server { struct AwsQueueMetadata { @@ -58,7 +59,10 @@ class MessageService { virtual void Reset() = 0; static absl::StatusOr> Create( - NotifierMetadata notifier_metadata); + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server diff --git a/components/data/common/msg_svc_aws.cc b/components/data/common/msg_svc_aws.cc index 166a98e6..c9e4471f 100644 --- a/components/data/common/msg_svc_aws.cc +++ b/components/data/common/msg_svc_aws.cc @@ -29,11 +29,12 @@ #include "aws/sqs/model/GetQueueAttributesRequest.h" #include "aws/sqs/model/GetQueueAttributesResult.h" #include "aws/sqs/model/ReceiveMessageRequest.h" -#include "aws/sqs/model/ReceiveMessageResult.h" #include "aws/sqs/model/SetQueueAttributesRequest.h" +#include "aws/sqs/model/TagQueueRequest.h" #include "components/data/common/msg_svc.h" #include "components/data/common/msg_svc_util.h" #include "components/errors/error_util_aws.h" +#include "src/util/status_macro/status_macros.h" namespace kv_server { namespace { @@ -59,16 +60,21 @@ constexpr char kPolicyTemplate[] = R"({ constexpr char kFilterPolicyTemplate[] = R"({ "shard_num": ["%d"] })"; +constexpr char kEnvironmentTag[] = "environment"; class AwsMessageService : public MessageService { public: // `prefix` is the prefix of randomly generated SQS Queue name. // The queue is subscribed to the topic at `sns_arn`. - AwsMessageService(std::string prefix, std::string sns_arn, - std::optional shard_num) + AwsMessageService( + std::string prefix, std::string sns_arn, std::string environment, + std::optional shard_num, + privacy_sandbox::server_common::log::PSLogContext& log_context) : prefix_(std::move(prefix)), sns_arn_(std::move(sns_arn)), - shard_num_(shard_num) {} + environment_(std::move(environment)), + shard_num_(shard_num), + log_context_(log_context) {} bool IsSetupComplete() const { absl::ReaderMutexLock lock(&mutex_); @@ -85,30 +91,21 @@ class AwsMessageService : public MessageService { absl::Status SetupQueue() { absl::MutexLock lock(&mutex_); if (sqs_url_.empty()) { - absl::StatusOr url = CreateQueue(sqs_client_, prefix_); - if (!url.ok()) { - return url.status(); - } - sqs_url_ = std::move(*url); + PS_ASSIGN_OR_RETURN(sqs_url_, CreateQueue(sqs_client_, prefix_)); } // TODO: Any non-retryable status from this point on should result in a // reset. if (sqs_arn_.empty()) { - absl::StatusOr arn = GetQueueArn(sqs_client_, sqs_url_); - if (!arn.ok()) { - return arn.status(); - } - sqs_arn_ = std::move(*arn); + PS_ASSIGN_OR_RETURN(sqs_arn_, GetQueueArn(sqs_client_, sqs_url_)); } if (!are_attributes_set_) { - auto result = - SetQueueAttributes(sqs_client_, sns_arn_, sqs_arn_, sqs_url_); - - if (!result.ok()) { - return result; - } + PS_RETURN_IF_ERROR( + SetQueueAttributes(sqs_client_, sns_arn_, sqs_arn_, sqs_url_)); are_attributes_set_ = true; } + if (!environment_.empty()) { + PS_RETURN_IF_ERROR(TagQueue(sqs_client_, sqs_url_)); + } const absl::Status status = SubscribeQueue(sns_client_, sns_arn_, sqs_arn_); if (status.ok()) { is_set_up_ = true; @@ -150,17 +147,28 @@ class AwsMessageService : public MessageService { : AwsErrorToStatus(outcome.GetError()); } + absl::Status TagQueue(Aws::SQS::SQSClient& sqs, const std::string& sqs_url) { + Aws::SQS::Model::TagQueueRequest request; + request.SetQueueUrl(sqs_url); + request.AddTags(kEnvironmentTag, environment_); + + const auto outcome = sqs.TagQueue(request); + return outcome.IsSuccess() ? absl::OkStatus() + : AwsErrorToStatus(outcome.GetError()); + } + absl::StatusOr GetQueueArn(Aws::SQS::SQSClient& sqs, const std::string& sqs_url) { Aws::SQS::Model::GetQueueAttributesRequest req; req.SetQueueUrl(sqs_url); req.AddAttributeNames(Aws::SQS::Model::QueueAttributeName::QueueArn); - const auto outcome = sqs.GetQueueAttributes(req); - if (outcome.IsSuccess()) { + + if (const auto outcome = sqs.GetQueueAttributes(req); outcome.IsSuccess()) { return outcome.GetResult().GetAttributes().at( Aws::SQS::Model::QueueAttributeName::QueueArn); + } else { + return AwsErrorToStatus(outcome.GetError()); } - return AwsErrorToStatus(outcome.GetError()); } absl::Status SubscribeQueue(Aws::SNS::SNSClient& sns, @@ -184,22 +192,26 @@ class AwsMessageService : public MessageService { Aws::SNS::SNSClient sns_client_; const std::string prefix_; const std::string sns_arn_; + const std::string environment_; bool is_set_up_ = false; std::string sqs_url_; std::string sqs_arn_; bool are_attributes_set_ = false; std::optional shard_num_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> MessageService::Create( - NotifierMetadata notifier_metadata) { + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { auto metadata = std::get(notifier_metadata); auto shard_num = (metadata.num_shards > 1 ? std::optional(metadata.shard_num) : std::nullopt); return std::make_unique( - std::move(metadata.queue_prefix), std::move(metadata.sns_arn), shard_num); + std::move(metadata.queue_prefix), std::move(metadata.sns_arn), + std::move(metadata.environment), shard_num, log_context); } } // namespace kv_server diff --git a/components/data/common/msg_svc_gcp.cc b/components/data/common/msg_svc_gcp.cc index 892c298e..bbae6788 100644 --- a/components/data/common/msg_svc_gcp.cc +++ b/components/data/common/msg_svc_gcp.cc @@ -34,19 +34,24 @@ using ::google::cloud::pubsub::Subscription; using ::google::cloud::pubsub::SubscriptionBuilder; using ::google::cloud::pubsub::Topic; +constexpr char kEnvironmentTag[] = "environment"; constexpr char kFilterPolicyTemplate[] = R"(attributes.shard_num="%d")"; class GcpMessageService : public MessageService { public: - GcpMessageService(std::string prefix, std::string topic_id, - std::string project_id, - pubsub::SubscriptionAdminClient subscription_admin_client, - std::optional shard_num) + GcpMessageService( + std::string prefix, std::string topic_id, std::string project_id, + std::string environment, + pubsub::SubscriptionAdminClient subscription_admin_client, + std::optional shard_num, + privacy_sandbox::server_common::log::PSLogContext& log_context) : prefix_(std::move(prefix)), topic_id_(std::move(topic_id)), project_id_(project_id), + environment_(environment), subscription_admin_client_(subscription_admin_client), - shard_num_(shard_num) {} + shard_num_(shard_num), + log_context_(log_context) {} bool IsSetupComplete() const { absl::ReaderMutexLock lock(&mutex_); @@ -82,12 +87,14 @@ class GcpMessageService : public MessageService { absl::StatusOr CreateQueue() { std::string subscription_id = GenerateQueueName(prefix_); SubscriptionBuilder subscription_builder; + subscription_builder.add_label(kEnvironmentTag, environment_); if (prefix_ == "QueueNotifier_" && shard_num_.has_value()) { subscription_builder.set_filter( absl::StrFormat(kFilterPolicyTemplate, shard_num_.value())); } - VLOG(1) << "Creating a subscription for project id " << project_id_ - << " with subsciprition id " << subscription_id; + PS_VLOG(1, log_context_) + << "Creating a subscription for project id " << project_id_ + << " with subsciprition id " << subscription_id; auto sub = subscription_admin_client_.CreateSubscription( Topic(project_id_, std::move(topic_id_)), Subscription(project_id_, subscription_id), subscription_builder); @@ -96,7 +103,7 @@ class GcpMessageService : public MessageService { } sub_id_ = subscription_id; - VLOG(1) << "Subscription created " << sub_id_; + PS_VLOG(1, log_context_) << "Subscription created " << sub_id_; return subscription_id; } @@ -104,6 +111,7 @@ class GcpMessageService : public MessageService { const std::string prefix_; const std::string topic_id_; const std::string project_id_; + const std::string environment_; bool is_set_up_ = false; std::string sub_id_; @@ -111,12 +119,14 @@ class GcpMessageService : public MessageService { bool are_attributes_set_ = false; std::optional shard_num_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> MessageService::Create( - NotifierMetadata notifier_metadata) { + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { auto metadata = std::get(notifier_metadata); auto shard_num = (metadata.num_shards > 1 ? std::optional(metadata.shard_num) @@ -126,7 +136,7 @@ absl::StatusOr> MessageService::Create( return std::make_unique( std::move(metadata.queue_prefix), std::move(metadata.topic_id), - std::move(metadata.project_id), std::move(subscription_admin_client), - shard_num); + std::move(metadata.project_id), std::move(metadata.environment), + std::move(subscription_admin_client), shard_num, log_context); } } // namespace kv_server diff --git a/components/data/common/msg_svc_local.cc b/components/data/common/msg_svc_local.cc index 7529a2b8..f8bda3ed 100644 --- a/components/data/common/msg_svc_local.cc +++ b/components/data/common/msg_svc_local.cc @@ -20,7 +20,10 @@ namespace kv_server { namespace { class LocalMessageService : public MessageService { public: - explicit LocalMessageService(std::string local_directory) {} + explicit LocalMessageService( + std::string local_directory, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : log_context_(log_context) {} bool IsSetupComplete() const { return true; } const QueueMetadata GetQueueMetadata() const { AwsQueueMetadata metadata; @@ -28,14 +31,16 @@ class LocalMessageService : public MessageService { } absl::Status SetupQueue() { return absl::OkStatus(); } void Reset() {} + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> MessageService::Create( - NotifierMetadata notifier_metadata) { + NotifierMetadata notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { auto metadata = std::get(notifier_metadata); return std::make_unique( - std::move(metadata.local_directory)); + std::move(metadata.local_directory), log_context); } } // namespace kv_server diff --git a/components/data/common/notifier_metadata.h b/components/data/common/notifier_metadata.h index 1260b8d1..d4f2a826 100644 --- a/components/data/common/notifier_metadata.h +++ b/components/data/common/notifier_metadata.h @@ -39,6 +39,7 @@ struct AwsNotifierMetadata { MessageService* queue_manager; int32_t num_shards = 1; int32_t shard_num; + std::string environment; // If this is set then it will be used instead of a real SQSClient. The // ChangeNotifier takes ownership of this. @@ -52,6 +53,7 @@ struct GcpNotifierMetadata { std::string queue_prefix; std::string project_id; std::string topic_id; + std::string environment; int32_t num_threads = 1; int32_t num_shards = 1; int32_t shard_num; diff --git a/components/data/common/thread_manager.cc b/components/data/common/thread_manager.cc index 59c6c6b0..f1c7f13e 100644 --- a/components/data/common/thread_manager.cc +++ b/components/data/common/thread_manager.cc @@ -32,14 +32,17 @@ namespace kv_server { namespace { class ThreadManagerImpl : public ThreadManager { public: - explicit ThreadManagerImpl(std::string thread_name) - : thread_name_(std::move(thread_name)) {} + explicit ThreadManagerImpl( + std::string thread_name, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : thread_name_(std::move(thread_name)), log_context_(log_context) {} ~ThreadManagerImpl() { if (!IsRunning()) return; - VLOG(8) << thread_name_ << " In destructor. Attempting to stop the thread."; + PS_VLOG(8, log_context_) + << thread_name_ << " In destructor. Attempting to stop the thread."; if (const auto s = Stop(); !s.ok()) { - LOG(ERROR) << thread_name_ << " failed to stop: " << s; + PS_LOG(ERROR, log_context_) << thread_name_ << " failed to stop: " << s; } } @@ -47,20 +50,21 @@ class ThreadManagerImpl : public ThreadManager { if (IsRunning()) { return absl::FailedPreconditionError("Already running"); } - LOG(INFO) << thread_name_ << " Creating thread for processing"; + PS_LOG(INFO, log_context_) + << thread_name_ << " Creating thread for processing"; thread_ = std::make_unique(watch); return absl::OkStatus(); } absl::Status Stop() override { - VLOG(8) << thread_name_ << "Stop called"; + PS_VLOG(8, log_context_) << thread_name_ << "Stop called"; if (!IsRunning()) { - LOG(ERROR) << thread_name_ << " not running"; + PS_LOG(ERROR, log_context_) << thread_name_ << " not running"; return absl::FailedPreconditionError("Not currently running"); } should_stop_ = true; thread_->join(); - VLOG(8) << thread_name_ << " joined"; + PS_VLOG(8, log_context_) << thread_name_ << " joined"; thread_.reset(); should_stop_ = false; return absl::OkStatus(); @@ -74,12 +78,16 @@ class ThreadManagerImpl : public ThreadManager { std::unique_ptr thread_; std::atomic should_stop_ = false; std::string thread_name_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace -std::unique_ptr ThreadManager::Create(std::string thread_name) { - return std::make_unique(std::move(thread_name)); +std::unique_ptr ThreadManager::Create( + std::string thread_name, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(std::move(thread_name), + log_context); } } // namespace kv_server diff --git a/components/data/common/thread_manager.h b/components/data/common/thread_manager.h index 123e9d43..9b4c55e8 100644 --- a/components/data/common/thread_manager.h +++ b/components/data/common/thread_manager.h @@ -21,6 +21,7 @@ #include #include "components/errors/retry.h" +#include "src/logger/request_context_logger.h" namespace kv_server { @@ -43,7 +44,11 @@ class ThreadManager { virtual bool ShouldStop() = 0; - static std::unique_ptr Create(std::string thread_name); + static std::unique_ptr Create( + std::string thread_name, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server diff --git a/components/data/realtime/BUILD.bazel b/components/data/realtime/BUILD.bazel index f8dc66c1..22cb01da 100644 --- a/components/data/realtime/BUILD.bazel +++ b/components/data/realtime/BUILD.bazel @@ -101,6 +101,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", ], ) diff --git a/components/data/realtime/delta_file_record_change_notifier.h b/components/data/realtime/delta_file_record_change_notifier.h index d48cda20..895e495d 100644 --- a/components/data/realtime/delta_file_record_change_notifier.h +++ b/components/data/realtime/delta_file_record_change_notifier.h @@ -70,7 +70,10 @@ class DeltaFileRecordChangeNotifier { const std::function& should_stop_callback) = 0; static std::unique_ptr Create( - std::unique_ptr change_notifier); + std::unique_ptr change_notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server diff --git a/components/data/realtime/delta_file_record_change_notifier_aws.cc b/components/data/realtime/delta_file_record_change_notifier_aws.cc index 5d0cb768..54d63d03 100644 --- a/components/data/realtime/delta_file_record_change_notifier_aws.cc +++ b/components/data/realtime/delta_file_record_change_notifier_aws.cc @@ -43,8 +43,10 @@ struct ParsedBody { class AwsDeltaFileRecordChangeNotifier : public DeltaFileRecordChangeNotifier { public: explicit AwsDeltaFileRecordChangeNotifier( - std::unique_ptr change_notifier) - : change_notifier_(std::move(change_notifier)) {} + std::unique_ptr change_notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : change_notifier_(std::move(change_notifier)), + log_context_(log_context) {} absl::StatusOr GetNotifications( absl::Duration max_wait, @@ -63,8 +65,8 @@ class AwsDeltaFileRecordChangeNotifier : public DeltaFileRecordChangeNotifier { for (const auto& message : *notifications) { const auto parsedMessage = ParseObjectKeyFromJson(message); if (!parsedMessage.ok()) { - LOG(ERROR) << "Failed to parse JSON: " << message - << ", error: " << parsedMessage.status(); + PS_LOG(ERROR, log_context_) << "Failed to parse JSON: " << message + << ", error: " << parsedMessage.status(); LogServerErrorMetric(kDeltaFileRecordChangeNotifierParsingFailure); continue; } @@ -127,14 +129,16 @@ class AwsDeltaFileRecordChangeNotifier : public DeltaFileRecordChangeNotifier { } std::unique_ptr change_notifier_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace std::unique_ptr DeltaFileRecordChangeNotifier::Create( - std::unique_ptr change_notifier) { + std::unique_ptr change_notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context) { return std::make_unique( - std::move(change_notifier)); + std::move(change_notifier), log_context); } } // namespace kv_server diff --git a/components/data/realtime/delta_file_record_change_notifier_local.cc b/components/data/realtime/delta_file_record_change_notifier_local.cc index f444cdf3..e176e9a6 100644 --- a/components/data/realtime/delta_file_record_change_notifier_local.cc +++ b/components/data/realtime/delta_file_record_change_notifier_local.cc @@ -32,8 +32,9 @@ class LocalDeltaFileRecordChangeNotifier : public DeltaFileRecordChangeNotifier { public: explicit LocalDeltaFileRecordChangeNotifier( - std::unique_ptr notifier) - : notifier_(std::move(notifier)) {} + std::unique_ptr notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : notifier_(std::move(notifier)), log_context_(log_context) {} absl::StatusOr GetNotifications( absl::Duration max_wait, @@ -60,15 +61,17 @@ class LocalDeltaFileRecordChangeNotifier private: std::unique_ptr notifier_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace std::unique_ptr DeltaFileRecordChangeNotifier::Create( - std::unique_ptr change_notifier) { + std::unique_ptr change_notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context) { return std::make_unique( - std::move(change_notifier)); + std::move(change_notifier), log_context); } } // namespace kv_server diff --git a/components/data/realtime/realtime_notifier.h b/components/data/realtime/realtime_notifier.h index 88ed1726..1b2b58f9 100644 --- a/components/data/realtime/realtime_notifier.h +++ b/components/data/realtime/realtime_notifier.h @@ -25,6 +25,7 @@ #include "components/data/realtime/realtime_notifier_metadata.h" #include "components/errors/retry.h" #include "components/util/sleepfor.h" +#include "src/logger/request_context_logger.h" namespace kv_server { struct DataLoadingStats { @@ -58,7 +59,10 @@ class RealtimeNotifier { static absl::StatusOr> Create( NotifierMetadata notifier_metadata, // This parameter allows overrides that are used for tests - RealtimeNotifierMetadata realtime_notifier_metadata = {}); + RealtimeNotifierMetadata realtime_notifier_metadata = {}, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server diff --git a/components/data/realtime/realtime_notifier_aws.cc b/components/data/realtime/realtime_notifier_aws.cc index 56f4cc6d..cf283e1d 100644 --- a/components/data/realtime/realtime_notifier_aws.cc +++ b/components/data/realtime/realtime_notifier_aws.cc @@ -35,10 +35,12 @@ class RealtimeNotifierImpl : public RealtimeNotifier { public: explicit RealtimeNotifierImpl( std::unique_ptr sleep_for, - std::unique_ptr change_notifier) + std::unique_ptr change_notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context) : thread_manager_(ThreadManager::Create("Realtime notifier")), sleep_for_(std::move(sleep_for)), - change_notifier_(std::move(change_notifier)) {} + change_notifier_(std::move(change_notifier)), + log_context_(log_context) {} absl::Status Start( std::function(const std::string& key)> @@ -80,12 +82,13 @@ class RealtimeNotifierImpl : public RealtimeNotifier { ++sequential_failures; const absl::Duration backoff_time = ExponentialBackoffForRetry(sequential_failures); - LOG(ERROR) << "Failed to get realtime notifications: " - << updates.status() << ". Waiting for " << backoff_time; + PS_LOG(ERROR, log_context_) + << "Failed to get realtime notifications: " << updates.status() + << ". Waiting for " << backoff_time; LogServerErrorMetric(kRealtimeGetNotificationsFailure); if (!sleep_for_->Duration(backoff_time)) { - LOG(ERROR) << "Failed to sleep for " << backoff_time - << ". SleepFor invalid."; + PS_LOG(ERROR, log_context_) << "Failed to sleep for " << backoff_time + << ". SleepFor invalid."; LogServerErrorMetric(kRealtimeSleepFailure); } continue; @@ -95,7 +98,8 @@ class RealtimeNotifierImpl : public RealtimeNotifier { for (const auto& realtime_message : updates->realtime_messages) { if (auto count = callback(realtime_message.parsed_notification); !count.ok()) { - LOG(ERROR) << "Data loading callback failed: " << count.status(); + PS_LOG(ERROR, log_context_) + << "Data loading callback failed: " << count.status(); LogServerErrorMetric(kRealtimeMessageApplicationFailure); } auto e2e_cloud_provided_latency = absl::ToDoubleMicroseconds( @@ -143,13 +147,15 @@ class RealtimeNotifierImpl : public RealtimeNotifier { std::unique_ptr thread_manager_; std::unique_ptr sleep_for_; std::unique_ptr change_notifier_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> RealtimeNotifier::Create( NotifierMetadata notifier_metadata, - RealtimeNotifierMetadata realtime_notifier_metadata) { + RealtimeNotifierMetadata realtime_notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { auto options = std::get_if(&realtime_notifier_metadata); std::unique_ptr @@ -173,7 +179,8 @@ absl::StatusOr> RealtimeNotifier::Create( sleep_for = std::make_unique(); } return std::make_unique( - std::move(sleep_for), std::move(delta_file_record_change_notifier)); + std::move(sleep_for), std::move(delta_file_record_change_notifier), + log_context); } } // namespace kv_server diff --git a/components/data/realtime/realtime_notifier_gcp.cc b/components/data/realtime/realtime_notifier_gcp.cc index 073ed9a4..3ec483b9 100644 --- a/components/data/realtime/realtime_notifier_gcp.cc +++ b/components/data/realtime/realtime_notifier_gcp.cc @@ -35,15 +35,18 @@ using ::google::cloud::pubsub::Subscriber; class RealtimeNotifierGcp : public RealtimeNotifier { public: - explicit RealtimeNotifierGcp(std::unique_ptr gcp_subscriber, - std::unique_ptr sleep_for) + explicit RealtimeNotifierGcp( + std::unique_ptr gcp_subscriber, + std::unique_ptr sleep_for, + privacy_sandbox::server_common::log::PSLogContext& log_context) : thread_manager_(ThreadManager::Create("Realtime notifier")), sleep_for_(std::move(sleep_for)), - gcp_subscriber_(std::move(gcp_subscriber)) {} + gcp_subscriber_(std::move(gcp_subscriber)), + log_context_(log_context) {} ~RealtimeNotifierGcp() { if (const auto s = Stop(); !s.ok()) { - LOG(ERROR) << "Realtime updater failed to stop: " << s; + PS_LOG(ERROR, log_context_) << "Realtime updater failed to stop: " << s; } } @@ -58,19 +61,19 @@ class RealtimeNotifierGcp : public RealtimeNotifier { absl::Status Stop() override { absl::Status status; - LOG(INFO) << "Realtime updater received stop signal."; + PS_LOG(INFO, log_context_) << "Realtime updater received stop signal."; { absl::MutexLock lock(&mutex_); if (session_.valid()) { - VLOG(8) << "Session valid."; + PS_VLOG(8, log_context_) << "Session valid."; session_.cancel(); - VLOG(8) << "Session cancelled."; + PS_VLOG(8, log_context_) << "Session cancelled."; } status = sleep_for_->Stop(); - VLOG(8) << "Sleep for just called stop."; + PS_VLOG(8, log_context_) << "Sleep for just called stop."; } status.Update(thread_manager_->Stop()); - LOG(INFO) << "Thread manager just called stop."; + PS_LOG(INFO, log_context_) << "Thread manager just called stop."; return status; } @@ -112,14 +115,24 @@ class RealtimeNotifierGcp : public RealtimeNotifier { callback) { auto start = absl::Now(); std::string string_decoded; + size_t message_size = m.data().size(); + LogIfError(KVServerContextMap() + ->SafeMetric() + .LogHistogram( + static_cast(message_size))); + LogIfError(KVServerContextMap() + ->SafeMetric() + .LogUpDownCounter(1)); if (!absl::Base64Unescape(m.data(), &string_decoded)) { LogServerErrorMetric(kRealtimeDecodeMessageFailure); - LOG(ERROR) << "The body of the message is not a base64 encoded string."; + PS_LOG(ERROR, log_context_) + << "The body of the message is not a base64 encoded string."; std::move(h).ack(); return; } if (auto count = callback(string_decoded); !count.ok()) { - LOG(ERROR) << "Data loading callback failed: " << count.status(); + PS_LOG(ERROR, log_context_) + << "Data loading callback failed: " << count.status(); LogServerErrorMetric(kRealtimeMessageApplicationFailure); } RecordGcpSuppliedE2ELatency(m); @@ -141,9 +154,9 @@ class RealtimeNotifierGcp : public RealtimeNotifier { OnMessageReceived(m, std::move(h), callback); }); } - LOG(INFO) << "Realtime updater initialized."; + PS_LOG(INFO, log_context_) << "Realtime updater initialized."; sleep_for_->Duration(absl::InfiniteDuration()); - LOG(INFO) << "Realtime updater stopped watching."; + PS_LOG(INFO, log_context_) << "Realtime updater stopped watching."; } std::unique_ptr thread_manager_; @@ -151,14 +164,16 @@ class RealtimeNotifierGcp : public RealtimeNotifier { future session_ ABSL_GUARDED_BY(mutex_); std::unique_ptr sleep_for_; std::unique_ptr gcp_subscriber_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; absl::StatusOr> CreateSubscriber( - NotifierMetadata metadata) { + NotifierMetadata metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { GcpNotifierMetadata notifier_metadata = std::get(metadata); auto realtime_message_service_status = - MessageService::Create(notifier_metadata); + MessageService::Create(notifier_metadata, log_context); if (!realtime_message_service_status.ok()) { return realtime_message_service_status.status(); } @@ -169,9 +184,10 @@ absl::StatusOr> CreateSubscriber( } auto queue_metadata = std::get(realtime_message_service->GetQueueMetadata()); - LOG(INFO) << "Listening to queue_id " << queue_metadata.queue_id - << " project id " << notifier_metadata.project_id << " with " - << notifier_metadata.num_threads << " threads."; + PS_LOG(INFO, log_context) + << "Listening to queue_id " << queue_metadata.queue_id << " project id " + << notifier_metadata.project_id << " with " + << notifier_metadata.num_threads << " threads."; return std::make_unique(pubsub::MakeSubscriberConnection( pubsub::Subscription(notifier_metadata.project_id, queue_metadata.queue_id), @@ -183,7 +199,8 @@ absl::StatusOr> CreateSubscriber( } // namespace absl::StatusOr> RealtimeNotifier::Create( - NotifierMetadata metadata, RealtimeNotifierMetadata realtime_metadata) { + NotifierMetadata metadata, RealtimeNotifierMetadata realtime_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { auto realtime_notifier_metadata = std::get_if(&realtime_metadata); std::unique_ptr sleep_for; @@ -199,14 +216,14 @@ absl::StatusOr> RealtimeNotifier::Create( gcp_subscriber.reset( realtime_notifier_metadata->gcp_subscriber_for_unit_testing); } else { - auto maybe_gcp_subscriber = CreateSubscriber(metadata); + auto maybe_gcp_subscriber = CreateSubscriber(metadata, log_context); if (!maybe_gcp_subscriber.ok()) { return maybe_gcp_subscriber.status(); } gcp_subscriber = std::move(*maybe_gcp_subscriber); } - return std::make_unique(std::move(gcp_subscriber), - std::move(sleep_for)); + return std::make_unique( + std::move(gcp_subscriber), std::move(sleep_for), log_context); } } // namespace kv_server diff --git a/components/data/realtime/realtime_thread_pool_manager.h b/components/data/realtime/realtime_thread_pool_manager.h index 07371c0e..1997b46d 100644 --- a/components/data/realtime/realtime_thread_pool_manager.h +++ b/components/data/realtime/realtime_thread_pool_manager.h @@ -48,7 +48,10 @@ class RealtimeThreadPoolManager { static absl::StatusOr> Create( NotifierMetadata notifier_metadata, int32_t num_threads, // This parameter allows overrides that are used for tests - std::vector realtime_notifier_metadata = {}); + std::vector realtime_notifier_metadata = {}, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server diff --git a/components/data/realtime/realtime_thread_pool_manager_aws.cc b/components/data/realtime/realtime_thread_pool_manager_aws.cc index 82e2c33e..bd51f643 100644 --- a/components/data/realtime/realtime_thread_pool_manager_aws.cc +++ b/components/data/realtime/realtime_thread_pool_manager_aws.cc @@ -25,8 +25,10 @@ namespace { class RealtimeThreadPoolManagerAws : public RealtimeThreadPoolManager { public: explicit RealtimeThreadPoolManagerAws( - std::vector> realtime_notifiers) - : realtime_notifiers_(std::move(realtime_notifiers)) {} + std::vector> realtime_notifiers, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : realtime_notifiers_(std::move(realtime_notifiers)), + log_context_(log_context) {} ~RealtimeThreadPoolManagerAws() override { Stop(); } @@ -38,7 +40,7 @@ class RealtimeThreadPoolManagerAws : public RealtimeThreadPoolManager { std::string error_message = "Realtime realtime_notifier is nullptr, realtime data " "loading disabled."; - LOG(ERROR) << error_message; + PS_LOG(ERROR, log_context_) << error_message; return absl::InvalidArgumentError(std::move(error_message)); } auto status = realtime_notifier->Start(callback); @@ -53,14 +55,14 @@ class RealtimeThreadPoolManagerAws : public RealtimeThreadPoolManager { absl::Status status = absl::OkStatus(); for (auto& realtime_notifier : realtime_notifiers_) { if (realtime_notifier == nullptr) { - LOG(ERROR) << "Realtime realtime_notifier is nullptr"; + PS_LOG(ERROR, log_context_) << "Realtime realtime_notifier is nullptr"; continue; } if (realtime_notifier->IsRunning()) { auto current_status = realtime_notifier->Stop(); status.Update(current_status); if (!current_status.ok()) { - LOG(ERROR) << current_status.message(); + PS_LOG(ERROR, log_context_) << current_status.message(); // we still want to try to stop others continue; } @@ -71,13 +73,15 @@ class RealtimeThreadPoolManagerAws : public RealtimeThreadPoolManager { private: std::vector> realtime_notifiers_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> RealtimeThreadPoolManager::Create( NotifierMetadata notifier_metadata, int32_t num_threads, - std::vector realtime_notifier_metadata) { + std::vector realtime_notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { std::vector> realtime_notifier; for (int i = 0; i < num_threads; i++) { RealtimeNotifierMetadata realtime_notifier_metadatum = @@ -85,14 +89,14 @@ RealtimeThreadPoolManager::Create( ? RealtimeNotifierMetadata{} : std::move(realtime_notifier_metadata[i]); auto maybe_realtime_notifier = RealtimeNotifier::Create( - notifier_metadata, std::move(realtime_notifier_metadatum)); + notifier_metadata, std::move(realtime_notifier_metadatum), log_context); if (!maybe_realtime_notifier.ok()) { return maybe_realtime_notifier.status(); } realtime_notifier.push_back(std::move(*maybe_realtime_notifier)); } return std::make_unique( - std::move(realtime_notifier)); + std::move(realtime_notifier), log_context); } } // namespace kv_server diff --git a/components/data/realtime/realtime_thread_pool_manager_gcp.cc b/components/data/realtime/realtime_thread_pool_manager_gcp.cc index 3e43ff7c..cc81f381 100644 --- a/components/data/realtime/realtime_thread_pool_manager_gcp.cc +++ b/components/data/realtime/realtime_thread_pool_manager_gcp.cc @@ -26,8 +26,10 @@ namespace { class RealtimeThreadPoolManagerGCP : public RealtimeThreadPoolManager { public: explicit RealtimeThreadPoolManagerGCP( - std::unique_ptr realtime_notifier) - : realtime_notifier_(std::move(realtime_notifier)) {} + std::unique_ptr realtime_notifier, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : realtime_notifier_(std::move(realtime_notifier)), + log_context_(log_context) {} ~RealtimeThreadPoolManagerGCP() override { Stop(); } absl::Status Start( @@ -40,7 +42,7 @@ class RealtimeThreadPoolManagerGCP : public RealtimeThreadPoolManager { if (realtime_notifier_->IsRunning()) { auto status = realtime_notifier_->Stop(); if (!status.ok()) { - LOG(ERROR) << status.message(); + PS_LOG(ERROR, log_context_) << status.message(); } return status; } @@ -49,24 +51,27 @@ class RealtimeThreadPoolManagerGCP : public RealtimeThreadPoolManager { private: std::unique_ptr realtime_notifier_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace absl::StatusOr> RealtimeThreadPoolManager::Create( NotifierMetadata notifier_metadata, int32_t num_threads, - std::vector realtime_notifier_metadata) { + std::vector realtime_notifier_metadata, + privacy_sandbox::server_common::log::PSLogContext& log_context) { RealtimeNotifierMetadata realtime_notifier_metadatum = realtime_notifier_metadata.empty() ? RealtimeNotifierMetadata{} : std::move(realtime_notifier_metadata[0]); auto maybe_realtime_notifier = RealtimeNotifier::Create( - std::move(notifier_metadata), std::move(realtime_notifier_metadatum)); + std::move(notifier_metadata), std::move(realtime_notifier_metadatum), + log_context); if (!maybe_realtime_notifier.ok()) { return maybe_realtime_notifier.status(); } return std::make_unique( - std::move(*maybe_realtime_notifier)); + std::move(*maybe_realtime_notifier), log_context); } } // namespace kv_server diff --git a/components/data_server/cache/BUILD.bazel b/components/data_server/cache/BUILD.bazel index d6eaecc1..7def83d5 100644 --- a/components/data_server/cache/BUILD.bazel +++ b/components/data_server/cache/BUILD.bazel @@ -19,6 +19,32 @@ package(default_visibility = [ "//tools:__subpackages__", ]) +cc_library( + name = "uint32_value_set", + srcs = ["uint32_value_set.cc"], + hdrs = ["uint32_value_set.h"], + deps = [ + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@roaring_bitmap//:c_roaring", + ], +) + +cc_test( + name = "uint32_value_set_test", + size = "small", + srcs = [ + "uint32_value_set_test.cc", + ], + deps = [ + ":uint32_value_set", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@roaring_bitmap//:c_roaring", + ], +) + cc_library( name = "get_key_value_set_result_impl", srcs = [ @@ -28,6 +54,8 @@ cc_library( "get_key_value_set_result.h", ], deps = [ + ":uint32_value_set", + "//components/container:thread_safe_hash_map", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ], @@ -57,6 +85,8 @@ cc_library( deps = [ ":cache", ":get_key_value_set_result_impl", + ":uint32_value_set", + "//components/container:thread_safe_hash_map", "//public:base_types_cc_proto", "@com_google_absl//absl/base", "@com_google_absl//absl/container:btree", @@ -90,6 +120,8 @@ cc_library( hdrs = ["mocks.h"], deps = [ ":cache", + ":uint32_value_set", + "//components/container:thread_safe_hash_map", "@com_google_googletest//:gtest", ], ) diff --git a/components/data_server/cache/cache.h b/components/data_server/cache/cache.h index 4341cff7..06081599 100644 --- a/components/data_server/cache/cache.h +++ b/components/data_server/cache/cache.h @@ -17,12 +17,9 @@ #ifndef COMPONENTS_DATA_SERVER_CACHE_CACHE_H_ #define COMPONENTS_DATA_SERVER_CACHE_CACHE_H_ -#include #include #include #include -#include -#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -47,34 +44,58 @@ class Cache { const RequestContext& request_context, const absl::flat_hash_set& key_set) const = 0; + // Looks up and returns key-value set result for the given key set. + virtual std::unique_ptr GetUInt32ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const = 0; + // Inserts or updates the key with the new value for a given prefix - virtual void UpdateKeyValue(std::string_view key, std::string_view value, - int64_t logical_commit_time, - std::string_view prefix = "") = 0; + virtual void UpdateKeyValue( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, std::string_view value, int64_t logical_commit_time, + std::string_view prefix = "") = 0; // Inserts or updates values in the set for a given key and prefix, if a value // exists, updates its timestamp to the latest logical commit time. - virtual void UpdateKeyValueSet(std::string_view key, - absl::Span value_set, - int64_t logical_commit_time, - std::string_view prefix = "") = 0; + virtual void UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") = 0; + + // Inserts or updates values in the set for a given key and prefix, if a value + // exists, updates its timestamp to the latest logical commit time. + virtual void UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") = 0; // Deletes a particular (key, value) pair for a given prefix. - virtual void DeleteKey(std::string_view key, int64_t logical_commit_time, - std::string_view prefix = "") = 0; + virtual void DeleteKey( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, int64_t logical_commit_time, + std::string_view prefix = "") = 0; + + // Deletes values in the set for a given key and prefix. The deletion, this + // object still exist and is marked "deleted", in case there are late-arriving + // updates to this value. + virtual void DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") = 0; // Deletes values in the set for a given key and prefix. The deletion, this // object still exist and is marked "deleted", in case there are late-arriving // updates to this value. - virtual void DeleteValuesInSet(std::string_view key, - absl::Span value_set, - int64_t logical_commit_time, - std::string_view prefix = "") = 0; + virtual void DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") = 0; // Removes the values that were deleted before the specified // logical_commit_time for a given prefix. - virtual void RemoveDeletedKeys(int64_t logical_commit_time, - std::string_view prefix = "") = 0; + virtual void RemoveDeletedKeys( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix = "") = 0; }; } // namespace kv_server diff --git a/components/data_server/cache/get_key_value_set_result.h b/components/data_server/cache/get_key_value_set_result.h index c2f14af1..f19aed41 100644 --- a/components/data_server/cache/get_key_value_set_result.h +++ b/components/data_server/cache/get_key_value_set_result.h @@ -18,10 +18,11 @@ #define COMPONENTS_DATA_SERVER_CACHE_GET_KEY_VALUE_SET_RESULT_H_ #include -#include -#include +#include #include "absl/container/flat_hash_set.h" +#include "components/container/thread_safe_hash_map.h" +#include "components/data_server/cache/uint32_value_set.h" namespace kv_server { // Class that holds the data retrieved from cache lookup and read locks for @@ -33,6 +34,8 @@ class GetKeyValueSetResult { // Looks up and returns key-value set result for the given key set. virtual absl::flat_hash_set GetValueSet( std::string_view key) const = 0; + virtual const UInt32ValueSet* GetUInt32ValueSet( + std::string_view key) const = 0; private: // Adds key, value_set to the result data map, mantains the lock on `key` @@ -40,6 +43,10 @@ class GetKeyValueSetResult { virtual void AddKeyValueSet( std::string_view key, absl::flat_hash_set value_set, std::unique_ptr key_lock) = 0; + virtual void AddUInt32ValueSet( + std::string_view key, + ThreadSafeHashMap::ConstLockedNodePtr + value_set_node) = 0; static std::unique_ptr Create(); diff --git a/components/data_server/cache/get_key_value_set_result_impl.cc b/components/data_server/cache/get_key_value_set_result_impl.cc index 66123c6c..8e3b44ce 100644 --- a/components/data_server/cache/get_key_value_set_result_impl.cc +++ b/components/data_server/cache/get_key_value_set_result_impl.cc @@ -25,12 +25,21 @@ namespace kv_server { namespace { +using UInt32ValueSetNodePtr = + ThreadSafeHashMap::ConstLockedNodePtr; + // Class that holds the data retrieved from cache lookup and read locks for // the lookup keys class GetKeyValueSetResultImpl : public GetKeyValueSetResult { public: GetKeyValueSetResultImpl() {} + GetKeyValueSetResultImpl(const GetKeyValueSetResultImpl&) = delete; + GetKeyValueSetResultImpl& operator=(const GetKeyValueSetResultImpl&) = delete; + GetKeyValueSetResultImpl(GetKeyValueSetResultImpl&& other) = default; + GetKeyValueSetResultImpl& operator=(GetKeyValueSetResultImpl&& other) = + default; + // Looks up the key in the data map and returns value set. If the value_set // for the key is missing, returns empty set. absl::flat_hash_set GetValueSet( @@ -41,17 +50,15 @@ class GetKeyValueSetResultImpl : public GetKeyValueSetResult { return key_itr == data_map_.end() ? *kEmptySet : key_itr->second; } - GetKeyValueSetResultImpl(const GetKeyValueSetResultImpl&) = delete; - GetKeyValueSetResultImpl& operator=(const GetKeyValueSetResultImpl&) = delete; - GetKeyValueSetResultImpl(GetKeyValueSetResultImpl&& other) = default; - GetKeyValueSetResultImpl& operator=(GetKeyValueSetResultImpl&& other) = - default; + const UInt32ValueSet* GetUInt32ValueSet(std::string_view key) const override { + if (auto iter = uin32t_sets_map_.find(key); + iter != uin32t_sets_map_.end() && iter->second.is_present()) { + return iter->second.value(); + } + return nullptr; + } private: - std::vector> read_locks_; - absl::flat_hash_map> - data_map_; - // Adds key, value_set to the result data map, creates a read lock for // the key mutex void AddKeyValueSet( @@ -60,6 +67,16 @@ class GetKeyValueSetResultImpl : public GetKeyValueSetResult { read_locks_.push_back(std::move(key_lock)); data_map_.emplace(key, std::move(value_set)); } + + void AddUInt32ValueSet(std::string_view key, + UInt32ValueSetNodePtr value_set_ptr) override { + uin32t_sets_map_.emplace(key, std::move(value_set_ptr)); + } + + std::vector> read_locks_; + absl::flat_hash_map> + data_map_; + absl::flat_hash_map uin32t_sets_map_; }; } // namespace diff --git a/components/data_server/cache/key_value_cache.cc b/components/data_server/cache/key_value_cache.cc index f984de36..fb2cf22a 100644 --- a/components/data_server/cache/key_value_cache.cc +++ b/components/data_server/cache/key_value_cache.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "components/data_server/cache/key_value_cache.h" -#include #include #include #include @@ -40,8 +39,9 @@ absl::flat_hash_map KeyValueCache::GetKeyValuePairs( if (key_iter == map_.end() || key_iter->second.value == nullptr) { continue; } else { - VLOG(9) << "Get called for " << key - << ". returning value: " << *(key_iter->second.value); + PS_VLOG(9, request_context.GetPSLogContext()) + << "Get called for " << key + << ". returning value: " << *(key_iter->second.value); kv_pairs.insert_or_assign(key, *(key_iter->second.value)); } } @@ -64,7 +64,7 @@ std::unique_ptr KeyValueCache::GetKeyValueSet( auto result = GetKeyValueSetResult::Create(); bool cache_hit = false; for (const auto& key : key_set) { - VLOG(8) << "Getting key: " << key; + PS_VLOG(8, request_context.GetPSLogContext()) << "Getting key: " << key; const auto key_itr = key_to_value_set_map_.find(key); if (key_itr != key_to_value_set_map_.end()) { absl::flat_hash_set value_set; @@ -88,24 +88,40 @@ std::unique_ptr KeyValueCache::GetKeyValueSet( return result; } +// Looks up and returns int32 value set result for the given key set. +std::unique_ptr KeyValueCache::GetUInt32ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const { + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetInternalLookupMetricsContext()); + auto result = GetKeyValueSetResult::Create(); + for (const auto& key : key_set) { + result->AddUInt32ValueSet(key, uint32_sets_map_.CGet(key)); + } + return result; +} + // Replaces the current key-value entry with the new key-value entry. -void KeyValueCache::UpdateKeyValue(std::string_view key, std::string_view value, - int64_t logical_commit_time, - std::string_view prefix) { +void KeyValueCache::UpdateKeyValue( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, std::string_view value, int64_t logical_commit_time, + std::string_view prefix) { ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); - VLOG(9) << "Received update for [" << key << "] at " << logical_commit_time - << ". value will be set to: " << value; + PS_VLOG(9, log_context) << "Received update for [" << key << "] at " + << logical_commit_time + << ". value will be set to: " << value; absl::MutexLock lock(&mutex_); auto max_cleanup_logical_commit_time = max_cleanup_logical_commit_time_map_[prefix]; if (logical_commit_time <= max_cleanup_logical_commit_time) { - VLOG(1) << "Skipping the update as its logical_commit_time: " - << logical_commit_time - << " is not newer than the current cutoff time:" - << max_cleanup_logical_commit_time; + PS_VLOG(1, log_context) + << "Skipping the update as its logical_commit_time: " + << logical_commit_time << " is not newer than the current cutoff time:" + << max_cleanup_logical_commit_time; return; } @@ -114,10 +130,10 @@ void KeyValueCache::UpdateKeyValue(std::string_view key, std::string_view value, if (key_iter != map_.end() && key_iter->second.last_logical_commit_time >= logical_commit_time) { - VLOG(1) << "Skipping the update as its logical_commit_time: " - << logical_commit_time - << " is not newer than the current value's time:" - << key_iter->second.last_logical_commit_time; + PS_VLOG(1, log_context) + << "Skipping the update as its logical_commit_time: " + << logical_commit_time << " is not newer than the current value's time:" + << key_iter->second.last_logical_commit_time; return; } @@ -142,12 +158,14 @@ void KeyValueCache::UpdateKeyValue(std::string_view key, std::string_view value, } void KeyValueCache::UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, std::string_view key, absl::Span input_value_set, int64_t logical_commit_time, std::string_view prefix) { ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); - VLOG(9) << "Received update for [" << key << "] at " << logical_commit_time; + PS_VLOG(9, log_context) << "Received update for [" << key << "] at " + << logical_commit_time; std::unique_ptr key_lock; absl::flat_hash_map* existing_value_set; // The max cleanup time needs to be locked before doing this comparison @@ -155,21 +173,22 @@ void KeyValueCache::UpdateKeyValueSet( absl::MutexLock lock_map(&set_map_mutex_); auto max_cleanup_logical_commit_time = - max_cleanup_logical_commit_time_map_for_set_cache_[prefix]; + set_cache_max_cleanup_logical_commit_time_[prefix]; if (logical_commit_time <= max_cleanup_logical_commit_time) { - VLOG(1) << "Skipping the update as its logical_commit_time: " - << logical_commit_time - << " is older than the current cutoff time:" - << max_cleanup_logical_commit_time; + PS_VLOG(1, log_context) + << "Skipping the update as its logical_commit_time: " + << logical_commit_time << " is older than the current cutoff time:" + << max_cleanup_logical_commit_time; return; } else if (input_value_set.empty()) { - VLOG(1) << "Skipping the update as it has no value in the set."; + PS_VLOG(1, log_context) + << "Skipping the update as it has no value in the set."; return; } auto key_itr = key_to_value_set_map_.find(key); if (key_itr == key_to_value_set_map_.end()) { - VLOG(9) << key << " is a new key. Adding it"; + PS_VLOG(9, log_context) << key << " is a new key. Adding it"; // There is no existing value set for the given key, // simply insert the key value set to the map, no need to update deleted // set nodes @@ -205,9 +224,31 @@ void KeyValueCache::UpdateKeyValueSet( } // end locking key } +void KeyValueCache::UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix) { + ScopeLatencyMetricsRecorder + latency_recorder(KVServerContextMap()->SafeMetric()); + if (auto prefix_max_time_node = + uint32_sets_max_cleanup_commit_time_map_.CGet(prefix); + prefix_max_time_node.is_present() && + logical_commit_time <= *prefix_max_time_node.value()) { + return; // Skip old updates. + } + auto cached_set_node = uint32_sets_map_.Get(key); + if (!cached_set_node.is_present()) { + auto result = uint32_sets_map_.PutIfAbsent(key, UInt32ValueSet()); + cached_set_node = std::move(result.first); + } + cached_set_node.value()->Add(value_set, logical_commit_time); +} -void KeyValueCache::DeleteKey(std::string_view key, int64_t logical_commit_time, - std::string_view prefix) { +void KeyValueCache::DeleteKey( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, int64_t logical_commit_time, + std::string_view prefix) { ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); absl::MutexLock lock(&mutex_); @@ -226,14 +267,14 @@ void KeyValueCache::DeleteKey(std::string_view key, int64_t logical_commit_time, map_.insert_or_assign( key, {.value = nullptr, .last_logical_commit_time = logical_commit_time}); - auto result = deleted_nodes_map_[prefix].emplace(logical_commit_time, key); + deleted_nodes_map_[prefix].emplace(logical_commit_time, key); } } -void KeyValueCache::DeleteValuesInSet(std::string_view key, - absl::Span value_set, - int64_t logical_commit_time, - std::string_view prefix) { +void KeyValueCache::DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix) { ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); @@ -243,7 +284,7 @@ void KeyValueCache::DeleteValuesInSet(std::string_view key, { absl::MutexLock lock_map(&set_map_mutex_); auto max_cleanup_logical_commit_time = - max_cleanup_logical_commit_time_map_for_set_cache_[prefix]; + set_cache_max_cleanup_logical_commit_time_[prefix]; if (logical_commit_time <= max_cleanup_logical_commit_time || value_set.empty()) { return; @@ -298,17 +339,55 @@ void KeyValueCache::DeleteValuesInSet(std::string_view key, } } -void KeyValueCache::RemoveDeletedKeys(int64_t logical_commit_time, - std::string_view prefix) { +void KeyValueCache::DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix) { + ScopeLatencyMetricsRecorder + latency_recorder(KVServerContextMap()->SafeMetric()); + if (auto prefix_max_time_node = + uint32_sets_max_cleanup_commit_time_map_.CGet(prefix); + prefix_max_time_node.is_present() && + logical_commit_time <= *prefix_max_time_node.value()) { + return; // Skip old deletes. + } + { + auto cached_set_node = uint32_sets_map_.Get(key); + if (!cached_set_node.is_present()) { + auto result = uint32_sets_map_.PutIfAbsent(key, UInt32ValueSet()); + cached_set_node = std::move(result.first); + } + cached_set_node.value()->Remove(value_set, logical_commit_time); + } + { + // Mark set as having deleted elements. + auto prefix_deleted_sets_node = deleted_uint32_sets_map_.Get(prefix); + if (!prefix_deleted_sets_node.is_present()) { + auto result = deleted_uint32_sets_map_.PutIfAbsent( + prefix, absl::btree_map>()); + prefix_deleted_sets_node = std::move(result.first); + } + auto* commit_time_sets = + &(*prefix_deleted_sets_node.value())[logical_commit_time]; + commit_time_sets->insert(std::string(key)); + } +} + +void KeyValueCache::RemoveDeletedKeys( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix) { ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); - CleanUpKeyValueMap(logical_commit_time, prefix); - CleanUpKeyValueSetMap(logical_commit_time, prefix); + CleanUpKeyValueMap(log_context, logical_commit_time, prefix); + CleanUpKeyValueSetMap(log_context, logical_commit_time, prefix); + CleanUpUInt32SetMap(log_context, logical_commit_time, prefix); } -void KeyValueCache::CleanUpKeyValueMap(int64_t logical_commit_time, - std::string_view prefix) { +void KeyValueCache::CleanUpKeyValueMap( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix) { ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); @@ -343,16 +422,16 @@ void KeyValueCache::CleanUpKeyValueMap(int64_t logical_commit_time, } } -void KeyValueCache::CleanUpKeyValueSetMap(int64_t logical_commit_time, - std::string_view prefix) { +void KeyValueCache::CleanUpKeyValueSetMap( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix) { ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); absl::MutexLock lock_set_map(&set_map_mutex_); - if (max_cleanup_logical_commit_time_map_for_set_cache_[prefix] < + if (set_cache_max_cleanup_logical_commit_time_[prefix] < logical_commit_time) { - max_cleanup_logical_commit_time_map_for_set_cache_[prefix] = - logical_commit_time; + set_cache_max_cleanup_logical_commit_time_[prefix] = logical_commit_time; } auto deleted_nodes_per_prefix = deleted_set_nodes_map_.find(prefix); if (deleted_nodes_per_prefix == deleted_set_nodes_map_.end()) { @@ -366,15 +445,17 @@ void KeyValueCache::CleanUpKeyValueSetMap(int64_t logical_commit_time, for (const auto& [key, values] : delete_itr->second) { if (auto key_itr = key_to_value_set_map_.find(key); key_itr != key_to_value_set_map_.end()) { - absl::MutexLock(&key_itr->second->first); - for (const auto& v_to_delete : values) { - auto existing_value_itr = key_itr->second->second.find(v_to_delete); - if (existing_value_itr != key_itr->second->second.end() && - existing_value_itr->second.is_deleted && - existing_value_itr->second.last_logical_commit_time <= - logical_commit_time) { - // Delete the existing value that is marked deleted from set - key_itr->second->second.erase(existing_value_itr); + { + absl::MutexLock key_lock(&key_itr->second->first); + for (const auto& v_to_delete : values) { + auto existing_value_itr = key_itr->second->second.find(v_to_delete); + if (existing_value_itr != key_itr->second->second.end() && + existing_value_itr->second.is_deleted && + existing_value_itr->second.last_logical_commit_time <= + logical_commit_time) { + // Delete the existing value that is marked deleted from set + key_itr->second->second.erase(existing_value_itr); + } } } if (key_itr->second->second.empty()) { @@ -392,6 +473,49 @@ void KeyValueCache::CleanUpKeyValueSetMap(int64_t logical_commit_time, } } +void KeyValueCache::CleanUpUInt32SetMap( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix) { + ScopeLatencyMetricsRecorder + latency_recorder(KVServerContextMap()->SafeMetric()); + { + if (auto max_cleanup_time_node = + uint32_sets_max_cleanup_commit_time_map_.PutIfAbsent( + prefix, logical_commit_time); + *max_cleanup_time_node.first.value() < logical_commit_time) { + *max_cleanup_time_node.first.value() = logical_commit_time; + } + } + absl::flat_hash_set cleanup_sets; + { + auto prefix_deleted_sets_node = deleted_uint32_sets_map_.Get(prefix); + if (!prefix_deleted_sets_node.is_present()) { + return; // nothing to cleanup for this prefix. + } + absl::flat_hash_set cleanup_commit_times; + for (const auto& [commit_time, deleted_sets] : + *prefix_deleted_sets_node.value()) { + if (commit_time > logical_commit_time) { + break; + } + cleanup_commit_times.insert(commit_time); + cleanup_sets.insert(deleted_sets.begin(), deleted_sets.end()); + } + for (auto commit_time : cleanup_commit_times) { + prefix_deleted_sets_node.value()->erase( + prefix_deleted_sets_node.value()->find(commit_time)); + } + } + { + for (const auto& set : cleanup_sets) { + if (auto set_node = uint32_sets_map_.Get(set); set_node.is_present()) { + set_node.value()->Cleanup(logical_commit_time); + } + } + } +} + void KeyValueCache::LogCacheAccessMetrics( const RequestContext& request_context, std::string_view cache_access_event) const { diff --git a/components/data_server/cache/key_value_cache.h b/components/data_server/cache/key_value_cache.h index aeae908c..61e08c91 100644 --- a/components/data_server/cache/key_value_cache.h +++ b/components/data_server/cache/key_value_cache.h @@ -17,21 +17,19 @@ #ifndef COMPONENTS_DATA_SERVER_CACHE_KEY_VALUE_CACHE_H_ #define COMPONENTS_DATA_SERVER_CACHE_KEY_VALUE_CACHE_H_ -#include #include #include #include #include -#include #include -#include #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "components/container/thread_safe_hash_map.h" #include "components/data_server/cache/cache.h" #include "components/data_server/cache/get_key_value_set_result.h" -#include "public/base_types.pb.h" +#include "components/data_server/cache/uint32_value_set.h" namespace kv_server { // In-memory datastore. @@ -48,36 +46,59 @@ class KeyValueCache : public Cache { const RequestContext& request_context, const absl::flat_hash_set& key_set) const override; + // Looks up and returns int32 value set result for the given key set. + std::unique_ptr GetUInt32ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const override; + // Inserts or updates the key with the new value for a given prefix - void UpdateKeyValue(std::string_view key, std::string_view value, - int64_t logical_commit_time, - std::string_view prefix = "") override; + void UpdateKeyValue( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, std::string_view value, int64_t logical_commit_time, + std::string_view prefix = "") override; // Inserts or updates values in the set for a given key and prefix, if a value // exists, updates its timestamp to the latest logical commit time. - void UpdateKeyValueSet(std::string_view key, - absl::Span input_value_set, - int64_t logical_commit_time, - std::string_view prefix = "") override; + void UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span input_value_set, + int64_t logical_commit_time, std::string_view prefix = "") override; + + // Inserts or updates values in the set for a given key and prefix, if a value + // exists, updates its timestamp to the latest logical commit time. + void UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") override; // Deletes a particular (key, value) pair for a given prefix. - void DeleteKey(std::string_view key, int64_t logical_commit_time, + void DeleteKey(privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, int64_t logical_commit_time, std::string_view prefix = "") override; // Deletes values in the set for a given key and prefix. The deletion, this // object still exist and is marked "deleted", in case there are late-arriving // updates to this value. - void DeleteValuesInSet(std::string_view key, - absl::Span value_set, - int64_t logical_commit_time, - std::string_view prefix = "") override; + void DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") override; + + // Deletes values in the set for a given key and prefix. The deletion, this + // object still exist and is marked "deleted", in case there are late-arriving + // updates to this value. + void DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") override; // Removes the values that were deleted before the specified // logical_commit_time for a given prefix. // TODO: b/267182790 -- Cache cleanup should be done periodically from a // background thread - void RemoveDeletedKeys(int64_t logical_commit_time, - std::string_view prefix = "") override; + void RemoveDeletedKeys( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix = "") override; static std::unique_ptr Create(); @@ -105,6 +126,7 @@ class KeyValueCache : public Cache { SetValueMeta(int64_t logical_commit_time, bool deleted) : last_logical_commit_time(logical_commit_time), is_deleted(deleted) {} }; + // mutex for key value map; mutable absl::Mutex mutex_; // mutex for key value set map; @@ -131,7 +153,7 @@ class KeyValueCache : public Cache { // guarded b mutex, if not, we may want to remove it and use one // max_cleanup_logical_commit_time in update/deletion for both maps absl::flat_hash_map - max_cleanup_logical_commit_time_map_for_set_cache_ + set_cache_max_cleanup_logical_commit_time_ ABSL_GUARDED_BY(set_map_mutex_); // Mapping from a key to its value map. The key in the inner map is the @@ -156,12 +178,28 @@ class KeyValueCache : public Cache { absl::flat_hash_map>>> deleted_set_nodes_map_ ABSL_GUARDED_BY(set_map_mutex_); + // Maps set key to its int32_t value set. + ThreadSafeHashMap uint32_sets_map_; + ThreadSafeHashMap + uint32_sets_max_cleanup_commit_time_map_; + ThreadSafeHashMap>> + deleted_uint32_sets_map_; + // Removes deleted keys from key-value map for a given prefix - void CleanUpKeyValueMap(int64_t logical_commit_time, std::string_view prefix); + void CleanUpKeyValueMap( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix); // Removes deleted key-values from key-value_set map for a given prefix - void CleanUpKeyValueSetMap(int64_t logical_commit_time, - std::string_view prefix); + void CleanUpKeyValueSetMap( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix); + + void CleanUpUInt32SetMap( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix); + // Logs cache access metrics for cache hit or miss counts. The cache access // event name is defined in server_definition.h file void LogCacheAccessMetrics(const RequestContext& request_context, diff --git a/components/data_server/cache/key_value_cache_test.cc b/components/data_server/cache/key_value_cache_test.cc index 12ee3d6f..be8620f2 100644 --- a/components/data_server/cache/key_value_cache_test.cc +++ b/components/data_server/cache/key_value_cache_test.cc @@ -51,6 +51,16 @@ class KeyValueCacheTestPeer { return c.map_; } + static const auto& ReadUint32Nodes(KeyValueCache& cache, + std::string_view prefix = "") { + return cache.uint32_sets_map_; + } + + static const auto& ReadDeletedUint32Nodes(KeyValueCache& cache, + std::string_view prefix = "") { + return cache.deleted_uint32_sets_map_; + } + static int GetDeletedSetNodesMapSize(const KeyValueCache& c, std::string prefix = "") { absl::MutexLock lock(&c.set_map_mutex_); @@ -89,8 +99,10 @@ class KeyValueCacheTestPeer { return iter->second->second.size(); } - static void CallCacheCleanup(KeyValueCache& c, int64_t logical_commit_time) { - c.RemoveDeletedKeys(logical_commit_time); + static void CallCacheCleanup( + privacy_sandbox::server_common::log::SafePathContext& log_context, + KeyValueCache& c, int64_t logical_commit_time) { + c.RemoveDeletedKeys(log_context, logical_commit_time); } }; @@ -98,23 +110,28 @@ namespace { using privacy_sandbox::server_common::TelemetryProvider; using testing::UnorderedElementsAre; +using testing::UnorderedElementsAreArray; + +class SafePathTestLogContext + : public privacy_sandbox::server_common::log::SafePathContext { + public: + SafePathTestLogContext() = default; +}; class CacheTest : public ::testing::Test { protected: CacheTest() { InitMetricsContextMap(); - scope_metrics_context_ = std::make_unique(); - request_context_ = - std::make_unique(*scope_metrics_context_); + request_context_ = std::make_shared(); } - RequestContext& GetRequestContext() { return *request_context_; } - std::unique_ptr scope_metrics_context_; - std::unique_ptr request_context_; + const RequestContext& GetRequestContext() { return *request_context_; } + std::shared_ptr request_context_; + SafePathTestLogContext safe_path_log_context_; }; TEST_F(CacheTest, RetrievesMatchingEntry) { std::unique_ptr cache = KeyValueCache::Create(); - cache->UpdateKeyValue("my_key", "my_value", 1); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 1); absl::flat_hash_set keys = {"my_key"}; absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), keys); @@ -126,9 +143,9 @@ TEST_F(CacheTest, RetrievesMatchingEntry) { TEST_F(CacheTest, GetWithMultipleKeysReturnsMatchingValues) { std::unique_ptr cache = KeyValueCache::Create(); - cache->UpdateKeyValue("key1", "value1", 1); - cache->UpdateKeyValue("key2", "value2", 2); - cache->UpdateKeyValue("key3", "value3", 3); + cache->UpdateKeyValue(safe_path_log_context_, "key1", "value1", 1); + cache->UpdateKeyValue(safe_path_log_context_, "key2", "value2", 2); + cache->UpdateKeyValue(safe_path_log_context_, "key3", "value3", 3); absl::flat_hash_set full_keys = {"key1", "key2"}; @@ -141,7 +158,7 @@ TEST_F(CacheTest, GetWithMultipleKeysReturnsMatchingValues) { TEST_F(CacheTest, GetAfterUpdateReturnsNewValue) { std::unique_ptr cache = KeyValueCache::Create(); - cache->UpdateKeyValue("my_key", "my_value", 1); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 1); absl::flat_hash_set keys = {"my_key"}; @@ -149,7 +166,7 @@ TEST_F(CacheTest, GetAfterUpdateReturnsNewValue) { cache->GetKeyValuePairs(GetRequestContext(), keys); EXPECT_THAT(kv_pairs, UnorderedElementsAre(KVPairEq("my_key", "my_value"))); - cache->UpdateKeyValue("my_key", "my_new_value", 2); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_new_value", 2); kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), keys); EXPECT_EQ(kv_pairs.size(), 1); @@ -159,8 +176,8 @@ TEST_F(CacheTest, GetAfterUpdateReturnsNewValue) { TEST_F(CacheTest, GetAfterUpdateDifferentKeyReturnsSameValue) { std::unique_ptr cache = KeyValueCache::Create(); - cache->UpdateKeyValue("my_key", "my_value", 1); - cache->UpdateKeyValue("new_key", "new_value", 2); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 1); + cache->UpdateKeyValue(safe_path_log_context_, "new_key", "new_value", 2); absl::flat_hash_set keys = {"my_key"}; @@ -180,7 +197,8 @@ TEST_F(CacheTest, GetForEmptyCacheReturnsEmptyList) { TEST_F(CacheTest, GetForCacheReturnsValueSet) { std::unique_ptr cache = KeyValueCache::Create(); std::vector values = {"v1", "v2"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); absl::flat_hash_set value_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) ->GetValueSet("my_key"); @@ -190,7 +208,8 @@ TEST_F(CacheTest, GetForCacheReturnsValueSet) { TEST_F(CacheTest, GetForCacheMissingKeyReturnsEmptySet) { std::unique_ptr cache = KeyValueCache::Create(); std::vector values = {"v1", "v2"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); auto get_key_value_set_result = cache->GetKeyValueSet(GetRequestContext(), {"missing_key", "my_key"}); EXPECT_EQ(get_key_value_set_result->GetValueSet("missing_key").size(), 0); @@ -200,8 +219,8 @@ TEST_F(CacheTest, GetForCacheMissingKeyReturnsEmptySet) { TEST_F(CacheTest, DeleteKeyTestRemovesKeyEntry) { std::unique_ptr cache = KeyValueCache::Create(); - cache->UpdateKeyValue("my_key", "my_value", 1); - cache->DeleteKey("my_key", 2); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 1); + cache->DeleteKey(safe_path_log_context_, "my_key", 2); absl::flat_hash_set full_keys = {"my_key"}; absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), full_keys); @@ -210,8 +229,8 @@ TEST_F(CacheTest, DeleteKeyTestRemovesKeyEntry) { TEST_F(CacheTest, DeleteKeyValueSetWrongkeyDoesNotRemoveEntry) { std::unique_ptr cache = KeyValueCache::Create(); - cache->UpdateKeyValue("my_key", "my_value", 1); - cache->DeleteKey("wrong_key", 1); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 1); + cache->DeleteKey(safe_path_log_context_, "wrong_key", 1); absl::flat_hash_set keys = {"my_key"}; absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), keys); @@ -222,8 +241,9 @@ TEST_F(CacheTest, DeleteKeyValueSetRemovesValueEntry) { std::unique_ptr cache = std::make_unique(); std::vector values = {"v1", "v2", "v3"}; std::vector values_to_delete = {"v1", "v2"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); - cache->DeleteValuesInSet("my_key", + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", absl::Span(values_to_delete), 2); absl::flat_hash_set value_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) @@ -249,8 +269,9 @@ TEST_F(CacheTest, DeleteKeyValueSetWrongKeyDoesNotRemoveKeyValueEntry) { std::unique_ptr cache = std::make_unique(); std::vector values = {"v1", "v2", "v3"}; std::vector values_to_delete = {"v1"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); - cache->DeleteValuesInSet("wrong_key", + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "wrong_key", absl::Span(values_to_delete), 2); std::unique_ptr result = cache->GetKeyValueSet(GetRequestContext(), {"my_key", "wrong_key"}); @@ -279,8 +300,9 @@ TEST_F(CacheTest, DeleteKeyValueSetWrongValueDoesNotRemoveEntry) { std::unique_ptr cache = std::make_unique(); std::vector values = {"v1", "v2", "v3"}; std::vector values_to_delete = {"v4"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); - cache->DeleteValuesInSet("my_key", + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", absl::Span(values_to_delete), 2); absl::flat_hash_set value_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) @@ -303,7 +325,7 @@ TEST_F(CacheTest, DeleteKeyValueSetWrongValueDoesNotRemoveEntry) { TEST_F(CacheTest, OutOfOrderUpdateAfterUpdateWorks) { std::unique_ptr cache = KeyValueCache::Create(); - cache->UpdateKeyValue("my_key", "my_value", 2); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 2); absl::flat_hash_set keys = {"my_key"}; @@ -311,7 +333,7 @@ TEST_F(CacheTest, OutOfOrderUpdateAfterUpdateWorks) { cache->GetKeyValuePairs(GetRequestContext(), keys); EXPECT_THAT(kv_pairs, UnorderedElementsAre(KVPairEq("my_key", "my_value"))); - cache->UpdateKeyValue("my_key", "my_new_value", 1); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_new_value", 1); kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), keys); EXPECT_EQ(kv_pairs.size(), 1); @@ -320,8 +342,8 @@ TEST_F(CacheTest, OutOfOrderUpdateAfterUpdateWorks) { TEST_F(CacheTest, DeleteKeyOutOfOrderDeleteAfterUpdateWorks) { std::unique_ptr cache = KeyValueCache::Create(); - cache->DeleteKey("my_key", 2); - cache->UpdateKeyValue("my_key", "my_value", 1); + cache->DeleteKey(safe_path_log_context_, "my_key", 2); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 1); absl::flat_hash_set full_keys = {"my_key"}; absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), full_keys); @@ -330,8 +352,8 @@ TEST_F(CacheTest, DeleteKeyOutOfOrderDeleteAfterUpdateWorks) { TEST_F(CacheTest, DeleteKeyOutOfOrderUpdateAfterDeleteWorks) { std::unique_ptr cache = KeyValueCache::Create(); - cache->UpdateKeyValue("my_key", "my_value", 2); - cache->DeleteKey("my_key", 1); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 2); + cache->DeleteKey(safe_path_log_context_, "my_key", 1); absl::flat_hash_set full_keys = {"my_key"}; absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), full_keys); @@ -341,8 +363,8 @@ TEST_F(CacheTest, DeleteKeyOutOfOrderUpdateAfterDeleteWorks) { TEST_F(CacheTest, DeleteKeyInOrderUpdateAfterDeleteWorks) { std::unique_ptr cache = KeyValueCache::Create(); - cache->DeleteKey("my_key", 1); - cache->UpdateKeyValue("my_key", "my_value", 2); + cache->DeleteKey(safe_path_log_context_, "my_key", 1); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 2); absl::flat_hash_set full_keys = {"my_key"}; absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), full_keys); @@ -352,8 +374,8 @@ TEST_F(CacheTest, DeleteKeyInOrderUpdateAfterDeleteWorks) { TEST_F(CacheTest, DeleteKeyInOrderDeleteAfterUpdateWorks) { std::unique_ptr cache = KeyValueCache::Create(); - cache->UpdateKeyValue("my_key", "my_value", 1); - cache->DeleteKey("my_key", 2); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 1); + cache->DeleteKey(safe_path_log_context_, "my_key", 2); absl::flat_hash_set full_keys = {"my_key"}; absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), full_keys); @@ -363,8 +385,10 @@ TEST_F(CacheTest, DeleteKeyInOrderDeleteAfterUpdateWorks) { TEST_F(CacheTest, UpdateSetTestUpdateAfterUpdateWithSameValue) { std::unique_ptr cache = std::make_unique(); std::vector values = {"v1"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); - cache->UpdateKeyValueSet("my_key", absl::Span(values), 2); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); absl::flat_hash_set value_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) ->GetValueSet("my_key"); @@ -379,10 +403,10 @@ TEST_F(CacheTest, UpdateSetTestUpdateAfterUpdateWithDifferentValue) { std::unique_ptr cache = std::make_unique(); std::vector first_value = {"v1"}; std::vector second_value = {"v2"}; - cache->UpdateKeyValueSet("my_key", absl::Span(first_value), - 1); - cache->UpdateKeyValueSet("my_key", absl::Span(second_value), - 2); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(first_value), 1); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(second_value), 2); absl::flat_hash_set value_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) ->GetValueSet("my_key"); @@ -400,8 +424,10 @@ TEST_F(CacheTest, UpdateSetTestUpdateAfterUpdateWithDifferentValue) { TEST_F(CacheTest, InOrderUpdateSetInsertAfterDeleteExpectInsert) { std::unique_ptr cache = std::make_unique(); std::vector values = {"v1"}; - cache->DeleteValuesInSet("my_key", absl::Span(values), 1); - cache->UpdateKeyValueSet("my_key", absl::Span(values), 2); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); absl::flat_hash_set value_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) ->GetValueSet("my_key"); @@ -415,8 +441,10 @@ TEST_F(CacheTest, InOrderUpdateSetInsertAfterDeleteExpectInsert) { TEST_F(CacheTest, InOrderUpdateSetDeleteAfterInsert) { std::unique_ptr cache = std::make_unique(); std::vector values = {"v1"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); - cache->DeleteValuesInSet("my_key", absl::Span(values), 2); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); absl::flat_hash_set value_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) ->GetValueSet("my_key"); @@ -430,8 +458,10 @@ TEST_F(CacheTest, InOrderUpdateSetDeleteAfterInsert) { TEST_F(CacheTest, OutOfOrderUpdateSetInsertAfterDeleteExpectNoInsert) { std::unique_ptr cache = std::make_unique(); std::vector values = {"v1"}; - cache->DeleteValuesInSet("my_key", absl::Span(values), 2); - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); absl::flat_hash_set value_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) ->GetValueSet("my_key"); @@ -445,8 +475,10 @@ TEST_F(CacheTest, OutOfOrderUpdateSetInsertAfterDeleteExpectNoInsert) { TEST_F(CacheTest, OutOfOrderUpdateSetDeleteAfterInsertExpectNoDelete) { std::unique_ptr cache = std::make_unique(); std::vector values = {"v1"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 2); - cache->DeleteValuesInSet("my_key", absl::Span(values), 1); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); absl::flat_hash_set value_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) ->GetValueSet("my_key"); @@ -459,7 +491,7 @@ TEST_F(CacheTest, OutOfOrderUpdateSetDeleteAfterInsertExpectNoDelete) { TEST_F(CacheTest, CleanupTimestampsInsertAKeyDoesntUpdateDeletedNodes) { std::unique_ptr cache = std::make_unique(); - cache->UpdateKeyValue("my_key", "my_value", 1); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 1); auto deleted_nodes = KeyValueCacheTestPeer::ReadDeletedNodes(*cache); EXPECT_EQ(deleted_nodes.size(), 0); @@ -467,10 +499,10 @@ TEST_F(CacheTest, CleanupTimestampsInsertAKeyDoesntUpdateDeletedNodes) { TEST_F(CacheTest, CleanupTimestampsRemoveDeletedKeysRemovesOldRecords) { std::unique_ptr cache = std::make_unique(); - cache->UpdateKeyValue("my_key", "my_value", 1); - cache->DeleteKey("my_key", 2); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 1); + cache->DeleteKey(safe_path_log_context_, "my_key", 2); - cache->RemoveDeletedKeys(3); + cache->RemoveDeletedKeys(safe_path_log_context_, 3); auto deleted_nodes = KeyValueCacheTestPeer::ReadDeletedNodes(*cache); EXPECT_EQ(deleted_nodes.size(), 0); @@ -481,10 +513,10 @@ TEST_F(CacheTest, CleanupTimestampsRemoveDeletedKeysRemovesOldRecords) { TEST_F(CacheTest, CleanupTimestampsRemoveDeletedKeysDoesntAffectNewRecords) { std::unique_ptr cache = std::make_unique(); - cache->UpdateKeyValue("my_key", "my_value", 5); - cache->DeleteKey("my_key", 6); + cache->UpdateKeyValue(safe_path_log_context_, "my_key", "my_value", 5); + cache->DeleteKey(safe_path_log_context_, "my_key", 6); - cache->RemoveDeletedKeys(2); + cache->RemoveDeletedKeys(safe_path_log_context_, 2); auto deleted_nodes = KeyValueCacheTestPeer::ReadDeletedNodes(*cache); EXPECT_EQ(deleted_nodes.size(), 1); @@ -496,18 +528,18 @@ TEST_F(CacheTest, CleanupTimestampsRemoveDeletedKeysDoesntAffectNewRecords) { TEST_F(CacheTest, CleanupRemoveDeletedKeysRemovesOldRecordsDoesntAffectNewRecords) { std::unique_ptr cache = std::make_unique(); - cache->UpdateKeyValue("my_key1", "my_value", 1); - cache->UpdateKeyValue("my_key2", "my_value", 2); - cache->UpdateKeyValue("my_key3", "my_value", 3); - cache->UpdateKeyValue("my_key4", "my_value", 4); - cache->UpdateKeyValue("my_key5", "my_value", 5); + cache->UpdateKeyValue(safe_path_log_context_, "my_key1", "my_value", 1); + cache->UpdateKeyValue(safe_path_log_context_, "my_key2", "my_value", 2); + cache->UpdateKeyValue(safe_path_log_context_, "my_key3", "my_value", 3); + cache->UpdateKeyValue(safe_path_log_context_, "my_key4", "my_value", 4); + cache->UpdateKeyValue(safe_path_log_context_, "my_key5", "my_value", 5); - cache->DeleteKey("my_key3", 8); - cache->DeleteKey("key_tombstone", 8); - cache->DeleteKey("my_key1", 6); - cache->DeleteKey("my_key2", 7); + cache->DeleteKey(safe_path_log_context_, "my_key3", 8); + cache->DeleteKey(safe_path_log_context_, "key_tombstone", 8); + cache->DeleteKey(safe_path_log_context_, "my_key1", 6); + cache->DeleteKey(safe_path_log_context_, "my_key2", 7); - cache->RemoveDeletedKeys(7); + cache->RemoveDeletedKeys(safe_path_log_context_, 7); auto deleted_nodes = KeyValueCacheTestPeer::ReadDeletedNodes(*cache); EXPECT_EQ(deleted_nodes.size(), 2); @@ -530,14 +562,14 @@ TEST_F(CacheTest, TEST_F(CacheTest, CleanupTimestampsCantInsertOldRecordsAfterCleanup) { std::unique_ptr cache = std::make_unique(); - cache->UpdateKeyValue("my_key1", "my_value", 10); - cache->DeleteKey("my_key1", 12); - cache->RemoveDeletedKeys(13); + cache->UpdateKeyValue(safe_path_log_context_, "my_key1", "my_value", 10); + cache->DeleteKey(safe_path_log_context_, "my_key1", 12); + cache->RemoveDeletedKeys(safe_path_log_context_, 13); auto deleted_nodes = KeyValueCacheTestPeer::ReadDeletedNodes(*cache); EXPECT_EQ(deleted_nodes.size(), 0); - cache->UpdateKeyValue("my_key1", "my_value", 10); + cache->UpdateKeyValue(safe_path_log_context_, "my_key1", "my_value", 10); absl::flat_hash_set keys = {"my_key1"}; @@ -549,7 +581,8 @@ TEST_F(CacheTest, CleanupTimestampsCantInsertOldRecordsAfterCleanup) { TEST_F(CacheTest, CleanupTimestampsInsertKeyValueSetDoesntUpdateDeletedNodes) { std::unique_ptr cache = std::make_unique(); std::vector values = {"my_value"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); int deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); EXPECT_EQ(deleted_nodes_map_size, 0); @@ -558,9 +591,10 @@ TEST_F(CacheTest, CleanupTimestampsInsertKeyValueSetDoesntUpdateDeletedNodes) { TEST_F(CacheTest, CleanupTimestampsDeleteKeyValueSetExpectUpdateDeletedNodes) { std::unique_ptr cache = std::make_unique(); std::vector values = {"my_value"}; - cache->DeleteValuesInSet("my_key", absl::Span(values), 1); - cache->DeleteValuesInSet("another_key", absl::Span(values), - 1); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "another_key", + absl::Span(values), 1); int deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); EXPECT_EQ(deleted_nodes_map_size, 1); @@ -577,13 +611,15 @@ TEST_F(CacheTest, CleanupTimestampsDeleteKeyValueSetExpectUpdateDeletedNodes) { TEST_F(CacheTest, CleanupTimestampsRemoveDeletedKeyValuesRemovesOldRecords) { std::unique_ptr cache = std::make_unique(); std::vector values = {"my_value"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); - cache->DeleteValuesInSet("my_key", absl::Span(values), 2); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); int deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); EXPECT_EQ(deleted_nodes_map_size, 1); - cache->RemoveDeletedKeys(3); + cache->RemoveDeletedKeys(safe_path_log_context_, 3); deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); EXPECT_EQ(deleted_nodes_map_size, 0); @@ -594,10 +630,12 @@ TEST_F(CacheTest, CleanupTimestampsRemoveDeletedKeyValuesDoesntAffectNewRecords) { std::unique_ptr cache = std::make_unique(); std::vector values = {"my_value"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 5); - cache->DeleteValuesInSet("my_key", absl::Span(values), 6); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 5); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 6); - cache->RemoveDeletedKeys(2); + cache->RemoveDeletedKeys(safe_path_log_context_, 2); int deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); @@ -614,19 +652,23 @@ TEST_F( std::unique_ptr cache = std::make_unique(); std::vector values = {"v1", "v2"}; std::vector values_to_delete = {"v1"}; - cache->UpdateKeyValueSet("my_key1", absl::Span(values), 1); - cache->UpdateKeyValueSet("my_key2", absl::Span(values), 2); - cache->UpdateKeyValueSet("my_key3", absl::Span(values), 3); - cache->UpdateKeyValueSet("my_key4", absl::Span(values), 4); - - cache->DeleteValuesInSet("my_key3", + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key1", + absl::Span(values), 1); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key2", + absl::Span(values), 2); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key3", + absl::Span(values), 3); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key4", + absl::Span(values), 4); + + cache->DeleteValuesInSet(safe_path_log_context_, "my_key3", absl::Span(values_to_delete), 4); - cache->DeleteValuesInSet("my_key1", + cache->DeleteValuesInSet(safe_path_log_context_, "my_key1", absl::Span(values_to_delete), 5); - cache->DeleteValuesInSet("my_key2", + cache->DeleteValuesInSet(safe_path_log_context_, "my_key2", absl::Span(values_to_delete), 6); - cache->RemoveDeletedKeys(5); + cache->RemoveDeletedKeys(safe_path_log_context_, 5); int deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); @@ -648,16 +690,19 @@ TEST_F( TEST_F(CacheTest, CleanupTimestampsSetCacheCantInsertOldRecordsAfterCleanup) { std::unique_ptr cache = std::make_unique(); std::vector values = {"my_value"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); - cache->DeleteValuesInSet("my_key", absl::Span(values), 2); - cache->RemoveDeletedKeys(3); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); + cache->RemoveDeletedKeys(safe_path_log_context_, 3); int deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); EXPECT_EQ(deleted_nodes_map_size, 0); EXPECT_EQ(KeyValueCacheTestPeer::GetCacheKeyValueSetMapSize(*cache), 0); - cache->UpdateKeyValueSet("my_key", absl::Span(values), 2); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); absl::flat_hash_set kv_set = cache->GetKeyValueSet(GetRequestContext(), {"my_key"}) @@ -668,9 +713,11 @@ TEST_F(CacheTest, CleanupTimestampsSetCacheCantInsertOldRecordsAfterCleanup) { TEST_F(CacheTest, CleanupTimestampsCantAddOldDeletedRecordsAfterCleanup) { std::unique_ptr cache = std::make_unique(); std::vector values = {"my_value"}; - cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); - cache->DeleteValuesInSet("my_key", absl::Span(values), 2); - cache->RemoveDeletedKeys(3); + cache->UpdateKeyValueSet(safe_path_log_context_, "my_key", + absl::Span(values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); + cache->RemoveDeletedKeys(safe_path_log_context_, 3); int deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); @@ -678,14 +725,16 @@ TEST_F(CacheTest, CleanupTimestampsCantAddOldDeletedRecordsAfterCleanup) { EXPECT_EQ(KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache), 0); // Old delete - cache->DeleteValuesInSet("my_key", absl::Span(values), 2); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 2); deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); EXPECT_EQ(deleted_nodes_map_size, 0); EXPECT_EQ(KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache), 0); // New delete - cache->DeleteValuesInSet("my_key", absl::Span(values), 4); + cache->DeleteValuesInSet(safe_path_log_context_, "my_key", + absl::Span(values), 4); deleted_nodes_map_size = KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); EXPECT_EQ(deleted_nodes_map_size, 1); @@ -706,12 +755,12 @@ TEST_F(CacheTest, ConcurrentGetAndGet) { absl::flat_hash_set keys_lookup_request = {"key1", "key2"}; std::vector values_for_key1 = {"v1"}; std::vector values_for_key2 = {"v2"}; - cache->UpdateKeyValueSet("key1", + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", absl::Span(values_for_key1), 1); - cache->UpdateKeyValueSet("key2", + cache->UpdateKeyValueSet(safe_path_log_context_, "key2", absl::Span(values_for_key2), 1); absl::Notification start; - auto request_context = GetRequestContext(); + auto& request_context = GetRequestContext(); auto lookup_fn = [&cache, &keys_lookup_request, &start, &request_context]() { start.WaitForNotification(); auto result = cache->GetKeyValueSet(request_context, keys_lookup_request); @@ -733,10 +782,10 @@ TEST_F(CacheTest, ConcurrentGetAndUpdateExpectNoUpdate) { auto cache = std::make_unique(); absl::flat_hash_set keys = {"key1"}; std::vector existing_values = {"v1"}; - cache->UpdateKeyValueSet("key1", + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", absl::Span(existing_values), 3); absl::Notification start; - auto request_context = GetRequestContext(); + auto& request_context = GetRequestContext(); auto lookup_fn = [&cache, &keys, &start, &request_context]() { start.WaitForNotification(); EXPECT_THAT( @@ -744,10 +793,10 @@ TEST_F(CacheTest, ConcurrentGetAndUpdateExpectNoUpdate) { UnorderedElementsAre("v1")); }; std::vector new_values = {"v1"}; - auto update_fn = [&cache, &new_values, &start]() { + auto update_fn = [&cache, &new_values, &start, this]() { start.WaitForNotification(); - cache->UpdateKeyValueSet("key1", absl::Span(new_values), - 1); + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", + absl::Span(new_values), 1); }; std::vector threads; for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); @@ -765,10 +814,10 @@ TEST_F(CacheTest, ConcurrentGetAndUpdateExpectUpdate) { auto cache = std::make_unique(); absl::flat_hash_set keys = {"key1", "key2"}; std::vector existing_values = {"v1"}; - cache->UpdateKeyValueSet("key1", + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", absl::Span(existing_values), 1); absl::Notification start; - auto request_context = GetRequestContext(); + auto& request_context = GetRequestContext(); auto lookup_fn = [&cache, &keys, &start, &request_context]() { start.WaitForNotification(); EXPECT_THAT( @@ -776,11 +825,12 @@ TEST_F(CacheTest, ConcurrentGetAndUpdateExpectUpdate) { UnorderedElementsAre("v1")); }; std::vector new_values_for_key2 = {"v2"}; - auto update_fn = [&cache, &new_values_for_key2, &start]() { + auto update_fn = [&cache, &new_values_for_key2, &start, this]() { // expect new value is inserted for key2 start.WaitForNotification(); - cache->UpdateKeyValueSet( - "key2", absl::Span(new_values_for_key2), 2); + cache->UpdateKeyValueSet(safe_path_log_context_, "key2", + absl::Span(new_values_for_key2), + 2); }; std::vector threads; for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); @@ -798,10 +848,10 @@ TEST_F(CacheTest, ConcurrentGetAndDeleteExpectNoDelete) { auto cache = std::make_unique(); absl::flat_hash_set keys = {"key1"}; std::vector existing_values = {"v1"}; - cache->UpdateKeyValueSet("key1", + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", absl::Span(existing_values), 3); absl::Notification start; - auto request_context = GetRequestContext(); + auto& request_context = GetRequestContext(); auto lookup_fn = [&cache, &keys, &start, &request_context]() { start.WaitForNotification(); EXPECT_THAT( @@ -809,10 +859,10 @@ TEST_F(CacheTest, ConcurrentGetAndDeleteExpectNoDelete) { UnorderedElementsAre("v1")); }; std::vector delete_values = {"v1"}; - auto delete_fn = [&cache, &delete_values, &start]() { + auto delete_fn = [&cache, &delete_values, &start, this]() { // expect no delete start.WaitForNotification(); - cache->DeleteValuesInSet("key1", + cache->DeleteValuesInSet(safe_path_log_context_, "key1", absl::Span(delete_values), 1); }; std::vector threads; @@ -831,14 +881,14 @@ TEST_F(CacheTest, ConcurrentGetAndCleanUp) { auto cache = std::make_unique(); absl::flat_hash_set keys = {"key1", "key2"}; std::vector existing_values = {"v1"}; - cache->UpdateKeyValueSet("key1", + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", absl::Span(existing_values), 3); - cache->UpdateKeyValueSet("key2", + cache->UpdateKeyValueSet(safe_path_log_context_, "key2", absl::Span(existing_values), 1); - cache->DeleteValuesInSet("key2", + cache->DeleteValuesInSet(safe_path_log_context_, "key2", absl::Span(existing_values), 2); absl::Notification start; - auto request_context = GetRequestContext(); + auto& request_context = GetRequestContext(); auto lookup_fn = [&cache, &keys, &start, &request_context]() { start.WaitForNotification(); EXPECT_THAT( @@ -849,10 +899,10 @@ TEST_F(CacheTest, ConcurrentGetAndCleanUp) { .size(), 0); }; - auto cleanup_fn = [&cache, &start]() { + auto cleanup_fn = [&cache, &start, this]() { // clean up old records start.WaitForNotification(); - KeyValueCacheTestPeer::CallCacheCleanup(*cache, 3); + KeyValueCacheTestPeer::CallCacheCleanup(safe_path_log_context_, *cache, 3); }; std::vector threads; for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); @@ -871,23 +921,23 @@ TEST_F(CacheTest, ConcurrentUpdateAndUpdateExpectUpdateBoth) { absl::flat_hash_set keys = {"key1", "key2"}; std::vector values_for_key1 = {"v1"}; absl::Notification start; - auto request_context = GetRequestContext(); - auto update_key1 = [&cache, &keys, &values_for_key1, &start, - &request_context]() { + auto& request_context = GetRequestContext(); + auto update_key1 = [&cache, &keys, &values_for_key1, &start, &request_context, + this]() { start.WaitForNotification(); // expect new value is inserted for key1 - cache->UpdateKeyValueSet("key1", + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", absl::Span(values_for_key1), 1); EXPECT_THAT( cache->GetKeyValueSet(request_context, keys)->GetValueSet("key1"), UnorderedElementsAre("v1")); }; std::vector values_for_key2 = {"v2"}; - auto update_key2 = [&cache, &keys, &values_for_key2, &start, - &request_context]() { + auto update_key2 = [&cache, &keys, &values_for_key2, &start, &request_context, + this]() { // expect new value is inserted for key2 start.WaitForNotification(); - cache->UpdateKeyValueSet("key2", + cache->UpdateKeyValueSet(safe_path_log_context_, "key2", absl::Span(values_for_key2), 2); EXPECT_THAT( cache->GetKeyValueSet(request_context, keys)->GetValueSet("key2"), @@ -910,12 +960,12 @@ TEST_F(CacheTest, ConcurrentUpdateAndDelete) { absl::flat_hash_set keys = {"key1", "key2"}; std::vector values_for_key1 = {"v1"}; absl::Notification start; - auto request_context = GetRequestContext(); - auto update_key1 = [&cache, &keys, &values_for_key1, &start, - &request_context]() { + auto& request_context = GetRequestContext(); + auto update_key1 = [&cache, &keys, &values_for_key1, &start, &request_context, + this]() { start.WaitForNotification(); // expect new value is inserted for key1 - cache->UpdateKeyValueSet("key1", + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", absl::Span(values_for_key1), 1); EXPECT_THAT( cache->GetKeyValueSet(request_context, keys)->GetValueSet("key1"), @@ -924,15 +974,17 @@ TEST_F(CacheTest, ConcurrentUpdateAndDelete) { // Update existing value for key2 std::vector existing_values_for_key2 = {"v1", "v2"}; cache->UpdateKeyValueSet( - "key2", absl::Span(existing_values_for_key2), 1); + safe_path_log_context_, "key2", + absl::Span(existing_values_for_key2), 1); std::vector values_to_delete_for_key2 = {"v1"}; auto delete_key2 = [&cache, &keys, &values_to_delete_for_key2, &start, - &request_context]() { + &request_context, this]() { start.WaitForNotification(); // expect value is deleted for key2 cache->DeleteValuesInSet( - "key2", absl::Span(values_to_delete_for_key2), 2); + safe_path_log_context_, "key2", + absl::Span(values_to_delete_for_key2), 2); EXPECT_THAT( cache->GetKeyValueSet(request_context, keys)->GetValueSet("key2"), UnorderedElementsAre("v2")); @@ -955,19 +1007,19 @@ TEST_F(CacheTest, ConcurrentUpdateAndCleanUp) { absl::flat_hash_set keys = {"key1"}; std::vector values_for_key1 = {"v1"}; absl::Notification start; - auto request_context = GetRequestContext(); - auto update_fn = [&cache, &keys, &values_for_key1, &start, - &request_context]() { + auto& request_context = GetRequestContext(); + auto update_fn = [&cache, &keys, &values_for_key1, &start, &request_context, + this]() { start.WaitForNotification(); - cache->UpdateKeyValueSet("key1", + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", absl::Span(values_for_key1), 2); EXPECT_THAT( cache->GetKeyValueSet(request_context, keys)->GetValueSet("key1"), UnorderedElementsAre("v1")); }; - auto cleanup_fn = [&cache, &start]() { + auto cleanup_fn = [&cache, &start, this]() { start.WaitForNotification(); - KeyValueCacheTestPeer::CallCacheCleanup(*cache, 1); + KeyValueCacheTestPeer::CallCacheCleanup(safe_path_log_context_, *cache, 1); }; std::vector threads; @@ -986,24 +1038,24 @@ TEST_F(CacheTest, ConcurrentDeleteAndCleanUp) { auto cache = std::make_unique(); absl::flat_hash_set keys = {"key1"}; std::vector values_for_key1 = {"v1"}; - cache->UpdateKeyValueSet("key1", + cache->UpdateKeyValueSet(safe_path_log_context_, "key1", absl::Span(values_for_key1), 1); absl::Notification start; - auto request_context = GetRequestContext(); - auto delete_fn = [&cache, &keys, &values_for_key1, &start, - &request_context]() { + auto& request_context = GetRequestContext(); + auto delete_fn = [&cache, &keys, &values_for_key1, &start, &request_context, + this]() { start.WaitForNotification(); // expect new value is deleted for key1 - cache->DeleteValuesInSet("key1", + cache->DeleteValuesInSet(safe_path_log_context_, "key1", absl::Span(values_for_key1), 2); EXPECT_EQ(cache->GetKeyValueSet(request_context, keys) ->GetValueSet("key1") .size(), 0); }; - auto cleanup_fn = [&cache, &start]() { + auto cleanup_fn = [&cache, &start, this]() { start.WaitForNotification(); - KeyValueCacheTestPeer::CallCacheCleanup(*cache, 1); + KeyValueCacheTestPeer::CallCacheCleanup(safe_path_log_context_, *cache, 1); }; std::vector threads; for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); @@ -1023,28 +1075,32 @@ TEST_F(CacheTest, ConcurrentGetUpdateDeleteCleanUp) { std::vector existing_values_for_key1 = {"v1"}; std::vector existing_values_for_key2 = {"v1"}; cache->UpdateKeyValueSet( - "key1", absl::Span(existing_values_for_key1), 1); + safe_path_log_context_, "key1", + absl::Span(existing_values_for_key1), 1); cache->UpdateKeyValueSet( - "key2", absl::Span(existing_values_for_key2), 1); + safe_path_log_context_, "key2", + absl::Span(existing_values_for_key2), 1); std::vector values_to_insert_for_key2 = {"v2"}; std::vector values_to_delete_for_key2 = {"v1"}; absl::Notification start; - auto insert_for_key2 = [&cache, &values_to_insert_for_key2, &start]() { + auto insert_for_key2 = [&cache, &values_to_insert_for_key2, &start, this]() { start.WaitForNotification(); cache->UpdateKeyValueSet( - "key2", absl::Span(values_to_insert_for_key2), 2); + safe_path_log_context_, "key2", + absl::Span(values_to_insert_for_key2), 2); }; - auto delete_for_key2 = [&cache, &values_to_delete_for_key2, &start]() { + auto delete_for_key2 = [&cache, &values_to_delete_for_key2, &start, this]() { start.WaitForNotification(); cache->DeleteValuesInSet( - "key2", absl::Span(values_to_delete_for_key2), 2); + safe_path_log_context_, "key2", + absl::Span(values_to_delete_for_key2), 2); }; - auto cleanup = [&cache, &start]() { + auto cleanup = [&cache, &start, this]() { start.WaitForNotification(); - KeyValueCacheTestPeer::CallCacheCleanup(*cache, 1); + KeyValueCacheTestPeer::CallCacheCleanup(safe_path_log_context_, *cache, 1); }; - auto request_context = GetRequestContext(); + auto& request_context = GetRequestContext(); auto lookup_for_key1 = [&cache, &keys, &start, &request_context]() { start.WaitForNotification(); EXPECT_THAT( @@ -1073,9 +1129,11 @@ TEST_F(CacheTest, MultiplePrefixKeyValueUpdates) { std::unique_ptr cache = KeyValueCache::Create(); // Call remove deleted keys for prefix1 to update the max delete cutoff // timestamp - cache->RemoveDeletedKeys(1, "prefix1"); - cache->UpdateKeyValue("prefix1-key", "value1", 2, "prefix1"); - cache->UpdateKeyValue("prefix2-key", "value2", 1, "prefix2"); + cache->RemoveDeletedKeys(safe_path_log_context_, 1, "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix1-key", "value1", 2, + "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix2-key", "value2", 1, + "prefix2"); absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), {"prefix1-key", "prefix2-key"}); @@ -1088,10 +1146,12 @@ TEST_F(CacheTest, MultiplePrefixKeyValueNoUpdateForAnother) { std::unique_ptr cache = KeyValueCache::Create(); // Call remove deleted keys for prefix1 to update the max delete cutoff // timestamp - cache->RemoveDeletedKeys(2, "prefix1"); + cache->RemoveDeletedKeys(safe_path_log_context_, 2, "prefix1"); // Expect no update for prefix1 - cache->UpdateKeyValue("prefix1-key", "value1", 1, "prefix1"); - cache->UpdateKeyValue("prefix2-key", "value2", 1, "prefix2"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix1-key", "value1", 1, + "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix2-key", "value2", 1, + "prefix2"); absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), {"prefix1-key", "prefix2-key"}); @@ -1103,11 +1163,13 @@ TEST_F(CacheTest, MultiplePrefixKeyValueNoDeleteForAnother) { std::unique_ptr cache = KeyValueCache::Create(); // Call remove deleted keys for prefix1 to update the max delete cutoff // timestamp - cache->RemoveDeletedKeys(2, "prefix1"); - cache->UpdateKeyValue("prefix1-key", "value1", 3, "prefix1"); + cache->RemoveDeletedKeys(safe_path_log_context_, 2, "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix1-key", "value1", 3, + "prefix1"); // Expect no deletion - cache->DeleteKey("prefix1-key", 1, "prefix1"); - cache->UpdateKeyValue("prefix2-key", "value2", 1, "prefix2"); + cache->DeleteKey(safe_path_log_context_, "prefix1-key", 1, "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix2-key", "value2", 1, + "prefix2"); absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), {"prefix1-key", "prefix2-key"}); @@ -1118,9 +1180,11 @@ TEST_F(CacheTest, MultiplePrefixKeyValueNoDeleteForAnother) { TEST_F(CacheTest, MultiplePrefixKeyValueDeletesAndUpdates) { std::unique_ptr cache = std::make_unique(); - cache->DeleteKey("prefix1-key", 2, "prefix1"); - cache->UpdateKeyValue("prefix1-key", "value1", 1, "prefix1"); - cache->UpdateKeyValue("prefix2-key", "value2", 1, "prefix2"); + cache->DeleteKey(safe_path_log_context_, "prefix1-key", 2, "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix1-key", "value1", 1, + "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix2-key", "value2", 1, + "prefix2"); absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), {"prefix1-key", "prefix2-key"}); @@ -1135,10 +1199,12 @@ TEST_F(CacheTest, MultiplePrefixKeyValueDeletesAndUpdates) { TEST_F(CacheTest, MultiplePrefixKeyValueUpdatesAndDeletes) { std::unique_ptr cache = std::make_unique(); - cache->UpdateKeyValue("prefix1-key", "value1", 2, "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix1-key", "value1", 2, + "prefix1"); // Expects no deletes - cache->DeleteKey("prefix1-key", 1, "prefix1"); - cache->UpdateKeyValue("prefix2-key", "value2", 1, "prefix2"); + cache->DeleteKey(safe_path_log_context_, "prefix1-key", 1, "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix2-key", "value2", 1, + "prefix2"); absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(GetRequestContext(), {"prefix1-key", "prefix2-key"}); @@ -1155,11 +1221,11 @@ TEST_F(CacheTest, MultiplePrefixKeyValueSetUpdates) { std::vector values2 = {"v3", "v4"}; // Call remove deleted keys for prefix1 to update the max delete cutoff // timestamp - cache->RemoveDeletedKeys(1, "prefix1"); - cache->UpdateKeyValueSet("prefix1-key", absl::Span(values1), - 2, "prefix1"); - cache->UpdateKeyValueSet("prefix2-key", absl::Span(values2), - 1, "prefix2"); + cache->RemoveDeletedKeys(safe_path_log_context_, 1, "prefix1"); + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix1-key", + absl::Span(values1), 2, "prefix1"); + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix2-key", + absl::Span(values2), 1, "prefix2"); auto get_value_set_result = cache->GetKeyValueSet( GetRequestContext(), {"prefix1-key", "prefix2-key"}); @@ -1175,11 +1241,11 @@ TEST_F(CacheTest, MultipleKeyValueSetNoUpdateForAnother) { std::vector values2 = {"v3", "v4"}; // Call remove deleted keys for prefix1 to update the max delete cutoff // timestamp - cache->RemoveDeletedKeys(2, "prefix1"); - cache->UpdateKeyValueSet("prefix1-key", absl::Span(values1), - 1, "prefix1"); - cache->UpdateKeyValueSet("prefix2-key", absl::Span(values2), - 1, "prefix2"); + cache->RemoveDeletedKeys(safe_path_log_context_, 2, "prefix1"); + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix1-key", + absl::Span(values1), 1, "prefix1"); + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix2-key", + absl::Span(values2), 1, "prefix2"); auto get_value_set_result = cache->GetKeyValueSet( GetRequestContext(), {"prefix1-key", "prefix2-key"}); EXPECT_EQ(get_value_set_result->GetValueSet("prefix1-key").size(), 0); @@ -1192,13 +1258,13 @@ TEST_F(CacheTest, MultiplePrefixKeyValueSetDeletesAndUpdates) { std::vector values1 = {"v1", "v2"}; std::vector values_to_delete = {"v1"}; std::vector values2 = {"v3", "v4"}; - cache->DeleteValuesInSet("prefix1-key", + cache->DeleteValuesInSet(safe_path_log_context_, "prefix1-key", absl::Span(values_to_delete), 2, "prefix1"); - cache->UpdateKeyValueSet("prefix1-key", absl::Span(values1), - 1, "prefix1"); - cache->UpdateKeyValueSet("prefix2-key", absl::Span(values2), - 1, "prefix2"); + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix1-key", + absl::Span(values1), 1, "prefix1"); + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix2-key", + absl::Span(values2), 1, "prefix2"); auto get_value_set_result = cache->GetKeyValueSet( GetRequestContext(), {"prefix1-key", "prefix2-key"}); EXPECT_THAT(get_value_set_result->GetValueSet("prefix1-key"), @@ -1217,12 +1283,12 @@ TEST_F(CacheTest, MultiplePrefixKeyValueSetUpdatesAndDeletes) { std::vector values_to_delete = {"v1"}; std::vector values2 = {"v3", "v4"}; - cache->UpdateKeyValueSet("prefix1-key", absl::Span(values1), - 2, "prefix1"); - cache->UpdateKeyValueSet("prefix2-key", absl::Span(values2), - 1, "prefix2"); + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix1-key", + absl::Span(values1), 2, "prefix1"); + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix2-key", + absl::Span(values2), 1, "prefix2"); // Expect no deletes - cache->DeleteValuesInSet("prefix1-key", + cache->DeleteValuesInSet(safe_path_log_context_, "prefix1-key", absl::Span(values_to_delete), 1, "prefix1"); auto get_value_set_result = cache->GetKeyValueSet( @@ -1239,15 +1305,17 @@ TEST_F(CacheTest, MultiplePrefixKeyValueSetUpdatesAndDeletes) { TEST_F(CacheTest, MultiplePrefixTimestampKeyValueCleanUps) { std::unique_ptr cache = std::make_unique(); - cache->UpdateKeyValue("prefix1-key", "value", 2, "prefix1"); - cache->DeleteKey("prefix1-key", 3, "prefix1"); - cache->UpdateKeyValue("prefix2-key", "value", 2, "prefix2"); - cache->DeleteKey("prefix2-key", 5, "prefix2"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix1-key", "value", 2, + "prefix1"); + cache->DeleteKey(safe_path_log_context_, "prefix1-key", 3, "prefix1"); + cache->UpdateKeyValue(safe_path_log_context_, "prefix2-key", "value", 2, + "prefix2"); + cache->DeleteKey(safe_path_log_context_, "prefix2-key", 5, "prefix2"); auto deleted_nodes_for_prefix1 = KeyValueCacheTestPeer::ReadDeletedNodes(*cache, "prefix1"); EXPECT_EQ(deleted_nodes_for_prefix1.size(), 1); - cache->RemoveDeletedKeys(4, "prefix1"); - cache->RemoveDeletedKeys(4, "prefix2"); + cache->RemoveDeletedKeys(safe_path_log_context_, 4, "prefix1"); + cache->RemoveDeletedKeys(safe_path_log_context_, 4, "prefix2"); deleted_nodes_for_prefix1 = KeyValueCacheTestPeer::ReadDeletedNodes(*cache, "prefix1"); EXPECT_EQ(deleted_nodes_for_prefix1.size(), 0); @@ -1259,22 +1327,127 @@ TEST_F(CacheTest, MultiplePrefixTimestampKeyValueSetCleanUps) { std::unique_ptr cache = std::make_unique(); std::vector values = {"v1", "v2"}; std::vector values_to_delete = {"v1"}; - cache->UpdateKeyValueSet("prefix1-key", absl::Span(values), - 2, "prefix1"); - cache->UpdateKeyValueSet("prefix2-key", absl::Span(values), - 2, "prefix2"); - cache->DeleteValuesInSet("prefix1-key", + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix1-key", + absl::Span(values), 2, "prefix1"); + cache->UpdateKeyValueSet(safe_path_log_context_, "prefix2-key", + absl::Span(values), 2, "prefix2"); + cache->DeleteValuesInSet(safe_path_log_context_, "prefix1-key", absl::Span(values_to_delete), 3, "prefix1"); - cache->DeleteValuesInSet("prefix2-key", + cache->DeleteValuesInSet(safe_path_log_context_, "prefix2-key", absl::Span(values_to_delete), 5, "prefix2"); - cache->RemoveDeletedKeys(4, "prefix1"); - cache->RemoveDeletedKeys(4, "prefix2"); + cache->RemoveDeletedKeys(safe_path_log_context_, 4, "prefix1"); + cache->RemoveDeletedKeys(safe_path_log_context_, 4, "prefix2"); EXPECT_EQ(KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache, "prefix1"), 0); EXPECT_EQ(KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache, "prefix2"), 1); } + +TEST_F(CacheTest, VerifyUpdatingUInt32Sets) { + auto cache = KeyValueCache::Create(); + auto& request_context = GetRequestContext(); + auto keys = absl::flat_hash_set({"set1", "set2"}); + { + auto result = cache->GetUInt32ValueSet(request_context, keys); + for (const auto& key : keys) { + auto* set = result->GetUInt32ValueSet(key); + EXPECT_EQ(set, nullptr); + } + } + { + auto set1_values = std::vector({1, 2, 3, 4, 5}); + cache->UpdateKeyValueSet(safe_path_log_context_, "set1", + absl::MakeSpan(set1_values), 1); + auto result = cache->GetUInt32ValueSet(request_context, keys); + auto* set = result->GetUInt32ValueSet("set1"); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), UnorderedElementsAreArray(set1_values)); + } + { + auto set2_values = std::vector({6, 7, 8, 9, 10}); + cache->UpdateKeyValueSet(safe_path_log_context_, "set2", + absl::MakeSpan(set2_values), 1); + auto result = cache->GetUInt32ValueSet(request_context, keys); + auto* set = result->GetUInt32ValueSet("set2"); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), UnorderedElementsAreArray(set2_values)); + } +} + +TEST_F(CacheTest, VerifyDeletingUInt32Sets) { + auto cache = KeyValueCache::Create(); + auto& request_context = GetRequestContext(); + auto keys = absl::flat_hash_set({"set1", "set2"}); + auto delete_values = std::vector({1, 2, 6, 7}); + { + auto set1_values = std::vector({1, 2, 3, 4, 5}); + cache->UpdateKeyValueSet(safe_path_log_context_, "set1", + absl::MakeSpan(set1_values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "set1", + absl::MakeSpan(delete_values), 2); + auto result = cache->GetUInt32ValueSet(request_context, keys); + auto* set = result->GetUInt32ValueSet("set1"); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), UnorderedElementsAre(3, 4, 5)); + } + { + auto set2_values = std::vector({6, 7, 8, 9, 10}); + cache->UpdateKeyValueSet(safe_path_log_context_, "set2", + absl::MakeSpan(set2_values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "set2", + absl::MakeSpan(delete_values), 2); + auto result = cache->GetUInt32ValueSet(request_context, keys); + auto* set = result->GetUInt32ValueSet("set2"); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), UnorderedElementsAre(8, 9, 10)); + } +} + +TEST_F(CacheTest, VerifyCleaningUpUInt32Sets) { + std::unique_ptr cache = std::make_unique(); + auto& request_context = GetRequestContext(); + auto keys = absl::flat_hash_set({"set1"}); + auto set1_values = std::vector({1, 2, 3, 4, 5}); + auto delete_values = std::vector({1, 2}); + { + cache->UpdateKeyValueSet(safe_path_log_context_, "set1", + absl::MakeSpan(set1_values), 1); + cache->DeleteValuesInSet(safe_path_log_context_, "set1", + absl::MakeSpan(delete_values), 2); + auto result = cache->GetUInt32ValueSet(request_context, keys); + auto* set = result->GetUInt32ValueSet("set1"); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), UnorderedElementsAre(3, 4, 5)); + EXPECT_THAT(set->GetRemovedValues(), + UnorderedElementsAreArray(delete_values)); + } + const auto& deleted_nodes = + KeyValueCacheTestPeer::ReadDeletedUint32Nodes(*cache); + { + auto prefix_nodes = deleted_nodes.CGet(""); + ASSERT_TRUE(prefix_nodes.is_present()); + auto iter = prefix_nodes.value()->find(2); + ASSERT_NE(iter, prefix_nodes.value()->end()); + EXPECT_THAT(iter->first, 2); + EXPECT_TRUE(iter->second.contains("set1")); + } + { + cache->RemoveDeletedKeys(safe_path_log_context_, 3); + auto prefix_nodes = deleted_nodes.CGet(""); + ASSERT_TRUE(prefix_nodes.is_present()); + auto iter = prefix_nodes.value()->find(2); + ASSERT_EQ(iter, prefix_nodes.value()->end()); + } + { + auto result = cache->GetUInt32ValueSet(request_context, keys); + auto* set = result->GetUInt32ValueSet("set1"); + ASSERT_TRUE(set != nullptr); + EXPECT_THAT(set->GetValues(), UnorderedElementsAre(3, 4, 5)); + EXPECT_TRUE(set->GetRemovedValues().empty()); + } +} + } // namespace } // namespace kv_server diff --git a/components/data_server/cache/mocks.h b/components/data_server/cache/mocks.h index 134c71fc..5661b6a7 100644 --- a/components/data_server/cache/mocks.h +++ b/components/data_server/cache/mocks.h @@ -17,10 +17,10 @@ #include #include -#include -#include +#include "components/container/thread_safe_hash_map.h" #include "components/data_server/cache/cache.h" +#include "components/data_server/cache/uint32_value_set.h" #include "gmock/gmock.h" namespace kv_server { @@ -33,29 +33,48 @@ MATCHER_P2(KVPairEq, key, value, "") { class MockCache : public Cache { public: MOCK_METHOD((absl::flat_hash_map), GetKeyValuePairs, - (const RequestContext& request_context, + (const RequestContext&, const absl::flat_hash_set&), (const, override)); MOCK_METHOD((std::unique_ptr), GetKeyValueSet, - (const RequestContext& request_context, + (const RequestContext&, + const absl::flat_hash_set&), + (const, override)); + MOCK_METHOD((std::unique_ptr), GetUInt32ValueSet, + (const RequestContext&, const absl::flat_hash_set&), (const, override)); MOCK_METHOD(void, UpdateKeyValue, - (std::string_view key, std::string_view value, int64_t ts, - std::string_view prefix), + (privacy_sandbox::server_common::log::PSLogContext&, + std::string_view, std::string_view, int64_t, std::string_view), + (override)); + MOCK_METHOD(void, UpdateKeyValueSet, + (privacy_sandbox::server_common::log::PSLogContext&, + std::string_view, absl::Span, int64_t, + std::string_view), (override)); MOCK_METHOD(void, UpdateKeyValueSet, - (std::string_view key, absl::Span value_set, - int64_t logical_commit_time, std::string_view prefix), + (privacy_sandbox::server_common::log::PSLogContext&, + std::string_view, absl::Span, int64_t, + std::string_view), + (override)); + MOCK_METHOD(void, DeleteValuesInSet, + (privacy_sandbox::server_common::log::PSLogContext&, + std::string_view, absl::Span, int64_t, + std::string_view), (override)); MOCK_METHOD(void, DeleteValuesInSet, - (std::string_view key, absl::Span value_set, - int64_t logical_commit_time, std::string_view prefix), + (privacy_sandbox::server_common::log::PSLogContext&, + std::string_view, absl::Span, int64_t, + std::string_view), (override)); MOCK_METHOD(void, DeleteKey, - (std::string_view key, int64_t ts, std::string_view prefix), + (privacy_sandbox::server_common::log::PSLogContext&, + std::string_view, int64_t, std::string_view), (override)); - MOCK_METHOD(void, RemoveDeletedKeys, (int64_t ts, std::string_view prefix), + MOCK_METHOD(void, RemoveDeletedKeys, + (privacy_sandbox::server_common::log::PSLogContext&, int64_t, + std::string_view), (override)); }; @@ -67,6 +86,13 @@ class MockGetKeyValueSetResult : public GetKeyValueSetResult { (std::string_view, absl::flat_hash_set, std::unique_ptr), (override)); + MOCK_METHOD((const UInt32ValueSet*), GetUInt32ValueSet, (std::string_view), + (const override)); + MOCK_METHOD( + void, AddUInt32ValueSet, + (std::string_view, + (ThreadSafeHashMap::ConstLockedNodePtr)), + (override)); }; } // namespace kv_server diff --git a/components/data_server/cache/noop_key_value_cache.h b/components/data_server/cache/noop_key_value_cache.h index 5fc2afd9..aa79f7e1 100644 --- a/components/data_server/cache/noop_key_value_cache.h +++ b/components/data_server/cache/noop_key_value_cache.h @@ -18,7 +18,6 @@ #include #include -#include #include "components/data_server/cache/cache.h" @@ -35,21 +34,37 @@ class NoOpKeyValueCache : public Cache { const absl::flat_hash_set& key_set) const override { return std::make_unique(); } - void UpdateKeyValue(std::string_view key, std::string_view value, - int64_t logical_commit_time, - std::string_view prefix) override {} - void UpdateKeyValueSet(std::string_view key, - absl::Span value_set, - int64_t logical_commit_time, - std::string_view prefix) override {} - void DeleteKey(std::string_view key, int64_t logical_commit_time, + std::unique_ptr GetUInt32ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const override { + return std::make_unique(); + } + void UpdateKeyValue( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, std::string_view value, int64_t logical_commit_time, + std::string_view prefix) override {} + void UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix) override {} + void UpdateKeyValueSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") override {} + void DeleteKey(privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, int64_t logical_commit_time, std::string_view prefix) override {} - void DeleteValuesInSet(std::string_view key, - absl::Span value_set, - int64_t logical_commit_time, - std::string_view prefix) override {} - void RemoveDeletedKeys(int64_t logical_commit_time, - std::string_view prefix) override {} + void DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix) override {} + void DeleteValuesInSet( + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::string_view key, absl::Span value_set, + int64_t logical_commit_time, std::string_view prefix = "") override {} + void RemoveDeletedKeys( + privacy_sandbox::server_common::log::PSLogContext& log_context, + int64_t logical_commit_time, std::string_view prefix) override {} static std::unique_ptr Create() { return std::make_unique(); } @@ -60,9 +75,17 @@ class NoOpKeyValueCache : public Cache { std::string_view key) const override { return {}; } + const UInt32ValueSet* GetUInt32ValueSet( + std::string_view key) const override { + return nullptr; + } void AddKeyValueSet( std::string_view key, absl::flat_hash_set value_set, std::unique_ptr key_lock) override {} + void AddUInt32ValueSet( + std::string_view key, + ThreadSafeHashMap::ConstLockedNodePtr + value_set_node) override {} }; }; diff --git a/components/data_server/cache/uint32_value_set.cc b/components/data_server/cache/uint32_value_set.cc new file mode 100644 index 00000000..9f6e02a4 --- /dev/null +++ b/components/data_server/cache/uint32_value_set.cc @@ -0,0 +1,100 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "components/data_server/cache/uint32_value_set.h" + +#include + +namespace kv_server { + +absl::flat_hash_set UInt32ValueSet::GetValues() const { + absl::flat_hash_set values; + values.reserve(values_bitset_.cardinality()); + for (const auto& [value, metadata] : values_metadata_) { + if (!metadata.is_deleted) { + values.insert(value); + } + } + return values; +} + +const roaring::Roaring& UInt32ValueSet::GetValuesBitSet() const { + return values_bitset_; +} + +absl::flat_hash_set UInt32ValueSet::GetRemovedValues() const { + absl::flat_hash_set removed_values; + for (const auto& [_, values] : deleted_values_) { + for (auto value : values) { + removed_values.insert(value); + } + } + return removed_values; +} + +void UInt32ValueSet::AddOrRemove(absl::Span values, + int64_t logical_commit_time, bool is_deleted) { + for (auto value : values) { + auto* metadata = &values_metadata_[value]; + if (metadata->logical_commit_time >= logical_commit_time) { + continue; + } + metadata->logical_commit_time = logical_commit_time; + metadata->is_deleted = is_deleted; + if (is_deleted) { + values_bitset_.remove(value); + deleted_values_[logical_commit_time].insert(value); + } else { + values_bitset_.add(value); + deleted_values_[logical_commit_time].erase(value); + } + } + values_bitset_.runOptimize(); +} + +void UInt32ValueSet::Add(absl::Span values, + int64_t logical_commit_time) { + AddOrRemove(values, logical_commit_time, /*is_deleted=*/false); +} + +void UInt32ValueSet::Remove(absl::Span values, + int64_t logical_commit_time) { + AddOrRemove(values, logical_commit_time, /*is_deleted=*/true); +} + +void UInt32ValueSet::Cleanup(int64_t cutoff_logical_commit_time) { + for (const auto& [logical_commit_time, values] : deleted_values_) { + if (logical_commit_time > cutoff_logical_commit_time) { + break; + } + for (auto value : values) { + values_metadata_.erase(value); + } + } + deleted_values_.erase( + deleted_values_.begin(), + deleted_values_.upper_bound(cutoff_logical_commit_time)); +} + +absl::flat_hash_set BitSetToUint32Set( + const roaring::Roaring& bitset) { + auto num_values = bitset.cardinality(); + auto data = std::make_unique(num_values); + bitset.toUint32Array(data.get()); + return absl::flat_hash_set(data.get(), data.get() + num_values); +} + +} // namespace kv_server diff --git a/components/data_server/cache/uint32_value_set.h b/components/data_server/cache/uint32_value_set.h new file mode 100644 index 00000000..49b1844a --- /dev/null +++ b/components/data_server/cache/uint32_value_set.h @@ -0,0 +1,75 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_SERVER_CACHE_UINT32_VALUE_SET_H_ +#define COMPONENTS_DATA_SERVER_CACHE_UINT32_VALUE_SET_H_ + +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" + +#include "roaring.hh" + +namespace kv_server { + +// Stores a set of `uint32_t` values associated with a `logical_commit_time`. +// The `logical_commit_time` is used to support out of order set mutations, +// i.e., calling `Remove({1, 2, 3}, 5)` and then `Add({1, 2, 3}, 3)` will result +// in an empty set. +// +// The values in the set are also projected to `roaring::Roaring` bitset which +// can be used for efficient set operations such as union, intersection, .e.t.c. +class UInt32ValueSet { + public: + // Returns values not marked as removed from the set. + absl::flat_hash_set GetValues() const; + // Returns values not marked as removed from the set as a bitset. + const roaring::Roaring& GetValuesBitSet() const; + // Returns values marked as removed from the set. + absl::flat_hash_set GetRemovedValues() const; + + // Adds values associated with `logical_commit_time` to the set. If a value + // with the same or greater `logical_commit_time` already exists in the set, + // then this is a noop. + void Add(absl::Span values, int64_t logical_commit_time); + // Marks values associated with `logical_commit_time` as removed from the set. + // If a value with the same or greater `logical_commit_time` already exists in + // the set, then this is a noop. + void Remove(absl::Span values, int64_t logical_commit_time); + // Cleans up space occupied by values (including value metadata) matching the + // condition `logical_commit_time` <= `cutoff_logical_commit_time` and are + // marked as removed. + void Cleanup(int64_t cutoff_logical_commit_time); + + private: + struct ValueMetadata { + int64_t logical_commit_time; + bool is_deleted; + }; + + void AddOrRemove(absl::Span values, int64_t logical_commit_time, + bool is_deleted); + + roaring::Roaring values_bitset_; + absl::flat_hash_map values_metadata_; + absl::btree_map> deleted_values_; +}; + +absl::flat_hash_set BitSetToUint32Set(const roaring::Roaring& bitset); + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_CACHE_UINT32_VALUE_SET_H_ diff --git a/components/data_server/cache/uint32_value_set_test.cc b/components/data_server/cache/uint32_value_set_test.cc new file mode 100644 index 00000000..7af32601 --- /dev/null +++ b/components/data_server/cache/uint32_value_set_test.cc @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/cache/uint32_value_set.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +using testing::UnorderedElementsAre; + +TEST(UInt32ValueSet, VerifyAddingValues) { + UInt32ValueSet value_set; + auto values = std::vector{1, 2, 3, 4, 5}; + value_set.Add(absl::MakeSpan(values), 1); + EXPECT_THAT(value_set.GetValues(), UnorderedElementsAre(1, 2, 3, 4, 5)); + EXPECT_EQ(value_set.GetValuesBitSet(), roaring::Roaring({1, 2, 3, 4, 5})); +} + +TEST(UInt32ValueSet, VerifyBitSetToUint32Set) { + roaring::Roaring bitset({1, 2, 3, 4, 5}); + EXPECT_THAT(BitSetToUint32Set(bitset), UnorderedElementsAre(1, 2, 3, 4, 5)); +} + +TEST(UInt32ValueSet, VerifyRemovingValues) { + UInt32ValueSet value_set; + auto values = std::vector{1, 2, 3, 4, 5}; + value_set.Add(absl::MakeSpan(values), 1); + values.erase(values.begin()); + values.erase(values.begin()); + // Should not remove anything because logical_commit_time did not change. + value_set.Remove(absl::MakeSpan(values), 1); + EXPECT_THAT(value_set.GetValues(), UnorderedElementsAre(1, 2, 3, 4, 5)); + EXPECT_EQ(value_set.GetValuesBitSet(), roaring::Roaring({1, 2, 3, 4, 5})); + // Should now remove: {3, 4, 5}. + value_set.Remove(absl::MakeSpan(values), 2); + EXPECT_THAT(value_set.GetValues(), UnorderedElementsAre(1, 2)); + EXPECT_EQ(value_set.GetValuesBitSet(), roaring::Roaring({1, 2})); +} + +TEST(UInt32ValueSet, VerifyCleaningUpValues) { + UInt32ValueSet value_set; + auto values = std::vector{1, 2, 3, 4, 5}; + value_set.Add(absl::MakeSpan(values), 1); + values.erase(values.begin()); + values.erase(values.begin()); + value_set.Remove(absl::MakeSpan(values), 2); + EXPECT_THAT(value_set.GetRemovedValues(), UnorderedElementsAre(3, 4, 5)); + value_set.Cleanup(1); // Does nothing. + EXPECT_THAT(value_set.GetRemovedValues(), UnorderedElementsAre(3, 4, 5)); + value_set.Cleanup(2); + EXPECT_TRUE(value_set.GetRemovedValues().empty()); +} + +} // namespace +} // namespace kv_server diff --git a/components/data_server/data_loading/data_orchestrator.cc b/components/data_server/data_loading/data_orchestrator.cc index f787b309..1ca9d43b 100644 --- a/components/data_server/data_loading/data_orchestrator.cc +++ b/components/data_server/data_loading/data_orchestrator.cc @@ -76,18 +76,26 @@ void LogDataLoadingMetrics(std::string_view source, data_loading_stats.total_dropped_records)}})); } -absl::Status ApplyUpdateMutation(std::string_view prefix, - const KeyValueMutationRecord& record, - Cache& cache) { +absl::Status ApplyUpdateMutation( + std::string_view prefix, const KeyValueMutationRecord& record, Cache& cache, + privacy_sandbox::server_common::log::PSLogContext& log_context) { if (record.value_type() == Value::StringValue) { - cache.UpdateKeyValue(record.key()->string_view(), + cache.UpdateKeyValue(log_context, record.key()->string_view(), GetRecordValue(record), record.logical_commit_time(), prefix); return absl::OkStatus(); } if (record.value_type() == Value::StringSet) { auto values = GetRecordValue>(record); - cache.UpdateKeyValueSet(record.key()->string_view(), absl::MakeSpan(values), + cache.UpdateKeyValueSet(log_context, record.key()->string_view(), + absl::MakeSpan(values), + record.logical_commit_time(), prefix); + return absl::OkStatus(); + } + if (record.value_type() == Value::UInt32Set) { + auto values = GetRecordValue>(record); + cache.UpdateKeyValueSet(log_context, record.key()->string_view(), + absl::MakeSpan(values), record.logical_commit_time(), prefix); return absl::OkStatus(); } @@ -96,17 +104,25 @@ absl::Status ApplyUpdateMutation(std::string_view prefix, " has unsupported value type: ", record.value_type())); } -absl::Status ApplyDeleteMutation(std::string_view prefix, - const KeyValueMutationRecord& record, - Cache& cache) { +absl::Status ApplyDeleteMutation( + std::string_view prefix, const KeyValueMutationRecord& record, Cache& cache, + privacy_sandbox::server_common::log::PSLogContext& log_context) { if (record.value_type() == Value::StringValue) { - cache.DeleteKey(record.key()->string_view(), record.logical_commit_time(), - prefix); + cache.DeleteKey(log_context, record.key()->string_view(), + record.logical_commit_time(), prefix); return absl::OkStatus(); } if (record.value_type() == Value::StringSet) { auto values = GetRecordValue>(record); - cache.DeleteValuesInSet(record.key()->string_view(), absl::MakeSpan(values), + cache.DeleteValuesInSet(log_context, record.key()->string_view(), + absl::MakeSpan(values), + record.logical_commit_time(), prefix); + return absl::OkStatus(); + } + if (record.value_type() == Value::UInt32Set) { + auto values = GetRecordValue>(record); + cache.DeleteValuesInSet(log_context, record.key()->string_view(), + absl::MakeSpan(values), record.logical_commit_time(), prefix); return absl::OkStatus(); } @@ -115,10 +131,11 @@ absl::Status ApplyDeleteMutation(std::string_view prefix, " has unsupported value type: ", record.value_type())); } -bool ShouldProcessRecord(const KeyValueMutationRecord& record, - int64_t num_shards, int64_t server_shard_num, - const KeySharder& key_sharder, - DataLoadingStats& data_loading_stats) { +bool ShouldProcessRecord( + const KeyValueMutationRecord& record, int64_t num_shards, + int64_t server_shard_num, const KeySharder& key_sharder, + DataLoadingStats& data_loading_stats, + privacy_sandbox::server_common::log::PSLogContext& log_context) { if (num_shards <= 1) { return true; } @@ -139,10 +156,11 @@ bool ShouldProcessRecord(const KeyValueMutationRecord& record, absl::Status ApplyKeyValueMutationToCache( std::string_view prefix, const KeyValueMutationRecord& record, Cache& cache, - int64_t& max_timestamp, DataLoadingStats& data_loading_stats) { + int64_t& max_timestamp, DataLoadingStats& data_loading_stats, + privacy_sandbox::server_common::log::PSLogContext& log_context) { switch (record.mutation_type()) { case KeyValueMutationType::Update: { - if (auto status = ApplyUpdateMutation(prefix, record, cache); + if (auto status = ApplyUpdateMutation(prefix, record, cache, log_context); !status.ok()) { return status; } @@ -151,7 +169,7 @@ absl::Status ApplyKeyValueMutationToCache( break; } case KeyValueMutationType::Delete: { - if (auto status = ApplyDeleteMutation(prefix, record, cache); + if (auto status = ApplyDeleteMutation(prefix, record, cache, log_context); !status.ok()) { return status; } @@ -171,35 +189,43 @@ absl::StatusOr LoadCacheWithData( std::string_view data_source, std::string_view prefix, StreamRecordReader& record_reader, Cache& cache, int64_t& max_timestamp, const int32_t server_shard_num, const int32_t num_shards, - UdfClient& udf_client, const KeySharder& key_sharder) { + UdfClient& udf_client, const KeySharder& key_sharder, + privacy_sandbox::server_common::log::PSLogContext& log_context) { DataLoadingStats data_loading_stats; - const auto process_data_record_fn = - [prefix, &cache, &max_timestamp, &data_loading_stats, server_shard_num, - num_shards, &udf_client, &key_sharder](const DataRecord& data_record) { - if (data_record.record_type() == Record::KeyValueMutationRecord) { - const auto* record = data_record.record_as_KeyValueMutationRecord(); - if (!ShouldProcessRecord(*record, num_shards, server_shard_num, - key_sharder, data_loading_stats)) { - // NOTE: currently upstream logic retries on non-ok status - // this will get us in a loop - return absl::OkStatus(); - } - return ApplyKeyValueMutationToCache( - prefix, *record, cache, max_timestamp, data_loading_stats); - } else if (data_record.record_type() == - Record::UserDefinedFunctionsConfig) { - const auto* udf_config = - data_record.record_as_UserDefinedFunctionsConfig(); - VLOG(3) << "Setting UDF code snippet for version: " - << udf_config->version(); - return udf_client.SetCodeObject(CodeConfig{ + const auto process_data_record_fn = [prefix, &cache, &max_timestamp, + &data_loading_stats, server_shard_num, + num_shards, &udf_client, &key_sharder, + &log_context]( + const DataRecord& data_record) { + if (data_record.record_type() == Record::KeyValueMutationRecord) { + const auto* record = data_record.record_as_KeyValueMutationRecord(); + if (!ShouldProcessRecord(*record, num_shards, server_shard_num, + key_sharder, data_loading_stats, log_context)) { + // NOTE: currently upstream logic retries on non-ok status + // this will get us in a loop + return absl::OkStatus(); + } + return ApplyKeyValueMutationToCache(prefix, *record, cache, max_timestamp, + data_loading_stats, log_context); + } else if (data_record.record_type() == + Record::UserDefinedFunctionsConfig) { + const auto* udf_config = + data_record.record_as_UserDefinedFunctionsConfig(); + PS_VLOG(3, log_context) + << "Setting UDF code snippet for version: " << udf_config->version() + << ", handler: " << udf_config->handler_name()->str() + << ", code length: " << udf_config->code_snippet()->str().size(); + return udf_client.SetCodeObject( + CodeConfig{ .js = udf_config->code_snippet()->str(), .udf_handler_name = udf_config->handler_name()->str(), .logical_commit_time = udf_config->logical_commit_time(), - .version = udf_config->version()}); - } - return absl::InvalidArgumentError("Received unsupported record."); - }; + .version = udf_config->version(), + }, + log_context); + } + return absl::InvalidArgumentError("Received unsupported record."); + }; // TODO(b/314302953): ReadStreamRecords will skip over individual records that // have errors. We should pass the file name to the function so that it will // appear in error logs. @@ -215,7 +241,7 @@ absl::StatusOr LoadCacheWithData( absl::StatusOr LoadCacheWithDataFromFile( const BlobStorageClient::DataLocation& location, const DataOrchestrator::Options& options) { - LOG(INFO) << "Loading " << location; + PS_LOG(INFO, options.log_context) << "Loading " << location; int64_t max_timestamp = 0; auto& cache = options.cache; auto record_reader = @@ -228,10 +254,10 @@ absl::StatusOr LoadCacheWithDataFromFile( _ << "Blob " << location); if (metadata.has_sharding_metadata() && metadata.sharding_metadata().shard_num() != options.shard_num) { - LOG(INFO) << "Blob " << location << " belongs to shard num " - << metadata.sharding_metadata().shard_num() - << " but server shard num is " << options.shard_num - << " Skipping it."; + PS_LOG(INFO, options.log_context) + << "Blob " << location << " belongs to shard num " + << metadata.sharding_metadata().shard_num() + << " but server shard num is " << options.shard_num << " Skipping it."; return DataLoadingStats{ .total_updated_records = 0, .total_deleted_records = 0, @@ -246,9 +272,10 @@ absl::StatusOr LoadCacheWithDataFromFile( auto data_loading_stats, LoadCacheWithData(file_name, location.prefix, *record_reader, cache, max_timestamp, options.shard_num, options.num_shards, - options.udf_client, options.key_sharder), + options.udf_client, options.key_sharder, + options.log_context), _ << "Blob: " << location); - cache.RemoveDeletedKeys(max_timestamp, location.prefix); + cache.RemoveDeletedKeys(options.log_context, max_timestamp, location.prefix); return data_loading_stats; } @@ -281,16 +308,18 @@ class DataOrchestratorImpl : public DataOrchestrator { absl::MutexLock l(&mu_); stop_ = true; } - LOG(INFO) << "Sent cancel signal to data loader thread"; - LOG(INFO) << "Stopping loading new data from " << options_.data_bucket; + PS_LOG(INFO, options_.log_context) + << "Sent cancel signal to data loader thread"; + PS_LOG(INFO, options_.log_context) + << "Stopping loading new data from " << options_.data_bucket; if (options_.delta_notifier.IsRunning()) { if (const auto s = options_.delta_notifier.Stop(); !s.ok()) { - LOG(ERROR) << "Failed to stop notify: " << s; + PS_LOG(ERROR, options_.log_context) << "Failed to stop notify: " << s; } } - LOG(INFO) << "Delta notifier stopped"; + PS_LOG(INFO, options_.log_context) << "Delta notifier stopped"; data_loader_thread_->join(); - LOG(INFO) << "Stopped loading new data"; + PS_LOG(INFO, options_.log_context) << "Stopped loading new data"; } static absl::StatusOr> Init( @@ -311,14 +340,16 @@ class DataOrchestratorImpl : public DataOrchestrator { if (!maybe_filenames.ok()) { return maybe_filenames.status(); } - LOG(INFO) << "Initializing cache with " << maybe_filenames->size() - << " delta files from " << location; + PS_LOG(INFO, options.log_context) + << "Initializing cache with " << maybe_filenames->size() + << " delta files from " << location; for (auto&& basename : std::move(*maybe_filenames)) { auto blob = BlobStorageClient::DataLocation{ .bucket = options.data_bucket, .prefix = prefix, .key = basename}; if (!IsDeltaFilename(blob.key)) { - LOG(WARNING) << "Saw a file " << blob - << " not in delta file format. Skipping it."; + PS_LOG(WARNING, options.log_context) + << "Saw a file " << blob + << " not in delta file format. Skipping it."; continue; } (*ending_delta_files)[prefix] = blob.key; @@ -326,7 +357,7 @@ class DataOrchestratorImpl : public DataOrchestrator { !s.ok()) { return s.status(); } - LOG(INFO) << "Done loading " << blob; + PS_LOG(INFO, options.log_context) << "Done loading " << blob; } } return ending_delta_files; @@ -336,7 +367,8 @@ class DataOrchestratorImpl : public DataOrchestrator { if (data_loader_thread_) { return absl::OkStatus(); } - LOG(INFO) << "Transitioning to state ContinuouslyLoadNewData"; + PS_LOG(INFO, options_.log_context) + << "Transitioning to state ContinuouslyLoadNewData"; auto prefix_last_basenames = prefix_last_basenames_; absl::Status status = options_.delta_notifier.Start( options_.change_notifier, {.bucket = options_.data_bucket}, @@ -350,13 +382,13 @@ class DataOrchestratorImpl : public DataOrchestrator { absl::bind_front(&DataOrchestratorImpl::ProcessNewFiles, this)); return options_.realtime_thread_pool_manager.Start( - [this, &cache = options_.cache, + [this, &cache = options_.cache, &log_context = options_.log_context, &delta_stream_reader_factory = options_.delta_stream_reader_factory]( const std::string& message_body) { return LoadCacheWithHighPriorityUpdates( kDefaultDataSourceForRealtimeUpdates, kDefaultPrefixForRealTimeUpdates, delta_stream_reader_factory, - message_body, cache); + message_body, cache, log_context); }); } @@ -370,7 +402,8 @@ class DataOrchestratorImpl : public DataOrchestrator { // On failure, puts the file back to the end of the queue and retry at a // later point. void ProcessNewFiles() { - LOG(INFO) << "Thread for new file processing started"; + PS_LOG(INFO, options_.log_context) + << "Thread for new file processing started"; absl::Condition has_new_event(this, &DataOrchestratorImpl::HasNewEventToProcess); while (true) { @@ -378,21 +411,23 @@ class DataOrchestratorImpl : public DataOrchestrator { { absl::MutexLock l(&mu_, has_new_event); if (stop_) { - LOG(INFO) << "Thread for new file processing stopped"; + PS_LOG(INFO, options_.log_context) + << "Thread for new file processing stopped"; return; } basename = std::move(unprocessed_basenames_.back()); unprocessed_basenames_.pop_back(); } - LOG(INFO) << "Loading " << basename; + PS_LOG(INFO, options_.log_context) << "Loading " << basename; auto blob = ParseBlobName(basename); if (!IsDeltaFilename(blob.key)) { - LOG(WARNING) << "Received file with invalid name: " << basename; + PS_LOG(WARNING, options_.log_context) + << "Received file with invalid name: " << basename; continue; } if (!options_.blob_prefix_allowlist.Contains(blob.prefix)) { - LOG(WARNING) << "Received file with prefix not allowlisted: " - << basename; + PS_LOG(WARNING, options_.log_context) + << "Received file with prefix not allowlisted: " << basename; continue; } RetryUntilOk( @@ -405,7 +440,8 @@ class DataOrchestratorImpl : public DataOrchestrator { .key = blob.key}, options_); }, - "LoadNewFile", LogStatusSafeMetricsFn()); + "LoadNewFile", LogStatusSafeMetricsFn(), + options_.log_context); } } @@ -413,7 +449,8 @@ class DataOrchestratorImpl : public DataOrchestrator { void EnqueueNewFilesToProcess(const std::string& basename) { absl::MutexLock l(&mu_); unprocessed_basenames_.push_front(basename); - LOG(INFO) << "queued " << basename << " for loading"; + PS_LOG(INFO, options_.log_context) + << "queued " << basename << " for loading"; // TODO: block if the queue is too large: consumption is too slow. } @@ -425,8 +462,8 @@ class DataOrchestratorImpl : public DataOrchestrator { for (const auto& prefix : options.blob_prefix_allowlist.Prefixes()) { auto location = BlobStorageClient::DataLocation{ .bucket = options.data_bucket, .prefix = prefix}; - LOG(INFO) << "Initializing cache with snapshot file(s) from: " - << location; + PS_LOG(INFO, options.log_context) + << "Initializing cache with snapshot file(s) from: " << location; PS_ASSIGN_OR_RETURN( auto snapshot_group, FindMostRecentFileGroup( @@ -435,7 +472,8 @@ class DataOrchestratorImpl : public DataOrchestrator { .status = FileGroup::FileStatus::kComplete}, options.blob_client)); if (!snapshot_group.has_value()) { - LOG(INFO) << "No snapshot files found in: " << location; + PS_LOG(INFO, options.log_context) + << "No snapshot files found in: " << location; continue; } for (const auto& snapshot : snapshot_group->Filenames()) { @@ -450,13 +488,15 @@ class DataOrchestratorImpl : public DataOrchestrator { PS_ASSIGN_OR_RETURN(auto metadata, record_reader->GetKVFileMetadata()); if (metadata.has_sharding_metadata() && metadata.sharding_metadata().shard_num() != options.shard_num) { - LOG(INFO) << "Snapshot " << snapshot_blob << " belongs to shard num " - << metadata.sharding_metadata().shard_num() - << " but server shard num is " << options.shard_num - << ". Skipping it."; + PS_LOG(INFO, options.log_context) + << "Snapshot " << snapshot_blob << " belongs to shard num " + << metadata.sharding_metadata().shard_num() + << " but server shard num is " << options.shard_num + << ". Skipping it."; continue; } - LOG(INFO) << "Loading snapshot file: " << snapshot_blob; + PS_LOG(INFO, options.log_context) + << "Loading snapshot file: " << snapshot_blob; PS_ASSIGN_OR_RETURN( auto stats, TraceLoadCacheWithDataFromFile(snapshot_blob, options)); if (auto iter = ending_delta_files.find(prefix); @@ -464,7 +504,8 @@ class DataOrchestratorImpl : public DataOrchestrator { metadata.snapshot().ending_delta_file() > iter->second) { ending_delta_files[prefix] = metadata.snapshot().ending_delta_file(); } - LOG(INFO) << "Done loading snapshot file: " << snapshot_blob; + PS_LOG(INFO, options.log_context) + << "Done loading snapshot file: " << snapshot_blob; } } return ending_delta_files; @@ -473,14 +514,15 @@ class DataOrchestratorImpl : public DataOrchestrator { absl::StatusOr LoadCacheWithHighPriorityUpdates( std::string_view data_source, std::string_view prefix, StreamRecordReaderFactory& delta_stream_reader_factory, - const std::string& record_string, Cache& cache) { + const std::string& record_string, Cache& cache, + privacy_sandbox::server_common::log::PSLogContext& log_context) { std::istringstream is(record_string); int64_t max_timestamp = 0; auto record_reader = delta_stream_reader_factory.CreateReader(is); return LoadCacheWithData(data_source, prefix, *record_reader, cache, max_timestamp, options_.shard_num, options_.num_shards, options_.udf_client, - options_.key_sharder); + options_.key_sharder, log_context); } const Options options_; diff --git a/components/data_server/data_loading/data_orchestrator.h b/components/data_server/data_loading/data_orchestrator.h index d9a17c51..c2e09f76 100644 --- a/components/data_server/data_loading/data_orchestrator.h +++ b/components/data_server/data_loading/data_orchestrator.h @@ -35,7 +35,6 @@ #include "public/sharding/key_sharder.h" namespace kv_server { - // Coordinate data loading. // // This class is intended to be used in a single thread. @@ -58,6 +57,7 @@ class DataOrchestrator { const int32_t num_shards = 1; const KeySharder key_sharder; BlobPrefixAllowlist blob_prefix_allowlist; + privacy_sandbox::server_common::log::PSLogContext& log_context; }; // Creates initial state. Scans the bucket and initializes the cache with data diff --git a/components/data_server/data_loading/data_orchestrator_test.cc b/components/data_server/data_loading/data_orchestrator_test.cc index 913e6324..7864a91b 100644 --- a/components/data_server/data_loading/data_orchestrator_test.cc +++ b/components/data_server/data_loading/data_orchestrator_test.cc @@ -107,7 +107,8 @@ class DataOrchestratorTest : public ::testing::Test { .realtime_thread_pool_manager = realtime_thread_pool_manager_, .key_sharder = kv_server::KeySharder(kv_server::ShardingFunction{/*seed=*/""}), - .blob_prefix_allowlist = kv_server::BlobPrefixAllowlist("")}) {} + .blob_prefix_allowlist = kv_server::BlobPrefixAllowlist(""), + .log_context = log_context_}) {} MockBlobStorageClient blob_client_; MockDeltaFileNotifier notifier_; @@ -117,6 +118,7 @@ class DataOrchestratorTest : public ::testing::Test { MockCache cache_; MockRealtimeThreadPoolManager realtime_thread_pool_manager_; DataOrchestrator::Options options_; + privacy_sandbox::server_common::log::NoOpContext log_context_; }; TEST_F(DataOrchestratorTest, InitCacheListRetriesOnFailure) { @@ -357,9 +359,9 @@ TEST_F(DataOrchestratorTest, InitCacheSuccess) { .WillOnce(Return(ByMove(std::move(update_reader)))) .WillOnce(Return(ByMove(std::move(delete_reader)))); - EXPECT_CALL(cache_, UpdateKeyValue("bar", "bar value", 3, _)).Times(1); - EXPECT_CALL(cache_, DeleteKey("bar", 3, _)).Times(1); - EXPECT_CALL(cache_, RemoveDeletedKeys(3, _)).Times(2); + EXPECT_CALL(cache_, UpdateKeyValue(_, "bar", "bar value", 3, _)).Times(1); + EXPECT_CALL(cache_, DeleteKey(_, "bar", 3, _)).Times(1); + EXPECT_CALL(cache_, RemoveDeletedKeys(_, 3, _)).Times(2); auto maybe_orchestrator = DataOrchestrator::TryCreate(options_); ASSERT_TRUE(maybe_orchestrator.ok()); @@ -411,7 +413,8 @@ TEST_F(DataOrchestratorTest, UpdateUdfCodeSuccess) { EXPECT_CALL(udf_client_, SetCodeObject(CodeConfig{.js = "function hello(){}", .udf_handler_name = "hello", - .logical_commit_time = 1})) + .logical_commit_time = 1}, + _)) .WillOnce(Return(absl::OkStatus())); auto maybe_orchestrator = DataOrchestrator::TryCreate(options_); ASSERT_TRUE(maybe_orchestrator.ok()); @@ -463,7 +466,8 @@ TEST_F(DataOrchestratorTest, UpdateUdfCodeFails_OrchestratorContinues) { EXPECT_CALL(udf_client_, SetCodeObject(CodeConfig{.js = "function hello(){}", .udf_handler_name = "hello", - .logical_commit_time = 1})) + .logical_commit_time = 1}, + _)) .WillOnce(Return(absl::UnknownError("Some error."))); auto maybe_orchestrator = DataOrchestrator::TryCreate(options_); ASSERT_TRUE(maybe_orchestrator.ok()); @@ -540,9 +544,9 @@ TEST_F(DataOrchestratorTest, StartLoading) { .WillOnce(Return(ByMove(std::move(update_reader)))) .WillOnce(Return(ByMove(std::move(delete_reader)))); - EXPECT_CALL(cache_, UpdateKeyValue("bar", "bar value", 3, _)).Times(1); - EXPECT_CALL(cache_, DeleteKey("bar", 3, _)).Times(1); - EXPECT_CALL(cache_, RemoveDeletedKeys(3, _)).Times(2); + EXPECT_CALL(cache_, UpdateKeyValue(_, "bar", "bar value", 3, _)).Times(1); + EXPECT_CALL(cache_, DeleteKey(_, "bar", 3, _)).Times(1); + EXPECT_CALL(cache_, RemoveDeletedKeys(_, 3, _)).Times(2); EXPECT_TRUE(orchestrator->Start().ok()); LOG(INFO) << "Created ContinuouslyLoadNewData"; @@ -617,9 +621,9 @@ TEST_F(DataOrchestratorTest, InitCacheShardedSuccessSkipRecord) { .WillOnce(Return(ByMove(std::move(update_reader)))) .WillOnce(Return(ByMove(std::move(delete_reader)))); - EXPECT_CALL(strict_cache, RemoveDeletedKeys(0, _)).Times(1); - EXPECT_CALL(strict_cache, DeleteKey("shard2", 3, _)).Times(1); - EXPECT_CALL(strict_cache, RemoveDeletedKeys(3, _)).Times(1); + EXPECT_CALL(strict_cache, RemoveDeletedKeys(_, 0, _)).Times(1); + EXPECT_CALL(strict_cache, DeleteKey(_, "shard2", 3, _)).Times(1); + EXPECT_CALL(strict_cache, RemoveDeletedKeys(_, 3, _)).Times(1); auto sharded_options = DataOrchestrator::Options{ .data_bucket = GetTestLocation().bucket, @@ -634,7 +638,8 @@ TEST_F(DataOrchestratorTest, InitCacheShardedSuccessSkipRecord) { .num_shards = 2, .key_sharder = kv_server::KeySharder(kv_server::ShardingFunction{/*seed=*/""}), - .blob_prefix_allowlist = BlobPrefixAllowlist("")}; + .blob_prefix_allowlist = BlobPrefixAllowlist(""), + .log_context = log_context_}; auto maybe_orchestrator = DataOrchestrator::TryCreate(sharded_options); ASSERT_TRUE(maybe_orchestrator.ok()); diff --git a/components/data_server/request_handler/BUILD.bazel b/components/data_server/request_handler/BUILD.bazel index 292869da..66a2f1cb 100644 --- a/components/data_server/request_handler/BUILD.bazel +++ b/components/data_server/request_handler/BUILD.bazel @@ -72,6 +72,8 @@ cc_library( ], deps = [ ":compression", + ":framing_utils", + ":get_values_v2_status", ":ohttp_server_encryptor", "//components/data_server/cache", "//components/telemetry:server_definition", @@ -87,11 +89,37 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", + "@google_privacysandbox_servers_common//src/communication:encoding_utils", "@google_privacysandbox_servers_common//src/telemetry", "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", ], ) +cc_library( + name = "framing_utils", + srcs = [ + "framing_utils.cc", + ], + hdrs = [ + "framing_utils.h", + ], + deps = [ + "@com_google_absl//absl/numeric:bits", + ], +) + +cc_test( + name = "framing_utils_test", + size = "small", + srcs = [ + "framing_utils_test.cc", + ], + deps = [ + ":framing_utils", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "compression", srcs = [ @@ -147,11 +175,13 @@ cc_test( "//components/udf:udf_client", "//public/query/v2:get_values_v2_cc_grpc", "//public/test_util:proto_matcher", + "//public/test_util:request_example", "@com_github_google_quiche//quiche:binary_http_unstable_api", "@com_github_google_quiche//quiche:oblivious_http_unstable_api", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", + "@google_privacysandbox_servers_common//src/communication:encoding_utils", "@google_privacysandbox_servers_common//src/encryption/key_fetcher:fake_key_fetcher_manager", "@nlohmann_json//:lib", ], @@ -247,11 +277,11 @@ cc_library( hdrs = [ "ohttp_client_encryptor.h", ], - copts = select({ - "//:aws_platform": ["-DCLOUD_PLATFORM_AWS=1"], - "//:gcp_platform": ["-DCLOUD_PLATFORM_GCP=1"], - "//conditions:default": [], - }), + visibility = [ + "//components/data_server:__subpackages__", + "//components/internal_server:__subpackages__", + "//components/tools:__subpackages__", + ], deps = [ "//public:constants", "@com_github_google_quiche//quiche:oblivious_http_unstable_api", @@ -291,3 +321,25 @@ cc_test( "@google_privacysandbox_servers_common//src/encryption/key_fetcher:fake_key_fetcher_manager", ], ) + +cc_library( + name = "get_values_v2_status", + srcs = select({ + "//:nonprod_mode": [ + "get_values_v2_status_nonprod.cc", + ], + "//:prod_mode": [ + "get_values_v2_status.cc", + ], + "//conditions:default": [ + "get_values_v2_status.cc", + ], + }), + hdrs = ["get_values_v2_status.h"], + deps = [ + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status:statusor", + "@google_privacysandbox_servers_common//src/util/status_macro:status_macros", + ], +) diff --git a/components/data_server/request_handler/framing_utils.cc b/components/data_server/request_handler/framing_utils.cc new file mode 100644 index 00000000..7db740bc --- /dev/null +++ b/components/data_server/request_handler/framing_utils.cc @@ -0,0 +1,44 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "components/data_server/request_handler/framing_utils.h" + +#include + +#include "absl/numeric/bits.h" + +namespace kv_server { + +// 1 byte for version + compression details. +constexpr int kVersionCompressionSize = 1; + +// 4-bytes specifying the size of the actual payload. +constexpr int kPayloadLength = 4; + +// Minimum size of the returned response in bytes. +// TODO: b/348613920 - Move framing utils to the common repo, and as part of +// that figure out if this needs to be inline with B&A. +inline constexpr size_t kMinResultBytes = 0; + +// Gets size of the complete payload including the preamble expected by +// android, which is: 1 byte (containing version, compression details), 4 bytes +// indicating the length of the actual encoded response and any other padding +// required to make the complete payload a power of 2. +size_t GetEncodedDataSize(size_t encapsulated_payload_size) { + size_t total_payload_size = + kVersionCompressionSize + kPayloadLength + encapsulated_payload_size; + // Ensure that the payload size is a power of 2. + return std::max(absl::bit_ceil(total_payload_size), kMinResultBytes); +} + +} // namespace kv_server diff --git a/components/data_server/request_handler/framing_utils.h b/components/data_server/request_handler/framing_utils.h new file mode 100644 index 00000000..5b146c2a --- /dev/null +++ b/components/data_server/request_handler/framing_utils.h @@ -0,0 +1,33 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_FRAMING_UTILS_H_ +#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_FRAMING_UTILS_H_ + +#include + +// TODO: b/348613920 - Move framing utils to the common repo +namespace kv_server { + +// Gets size of the complete payload including the preamble expected by +// client, which is: 1 byte (containing version, compression details), 4 bytes +// indicating the length of the actual encoded response and any other padding +// required to make the complete payload a power of 2. +size_t GetEncodedDataSize(size_t encapsulated_payload_size); + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_FRAMING_UTILS_H_ diff --git a/components/data_server/request_handler/framing_utils_test.cc b/components/data_server/request_handler/framing_utils_test.cc new file mode 100644 index 00000000..254b6c1f --- /dev/null +++ b/components/data_server/request_handler/framing_utils_test.cc @@ -0,0 +1,32 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/framing_utils.h" + +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +TEST(FramingUtilsTest, EncodedDataSizeMatchesTheSpec) { + EXPECT_EQ(GetEncodedDataSize(0), 8); + EXPECT_EQ(GetEncodedDataSize(1), 8); + EXPECT_EQ(GetEncodedDataSize(3), 8); + EXPECT_EQ(GetEncodedDataSize(4), 16); + EXPECT_EQ(GetEncodedDataSize(100), 128); + EXPECT_EQ(GetEncodedDataSize(1000), 1024); +} + +} // namespace +} // namespace kv_server diff --git a/components/data_server/request_handler/get_values_adapter.cc b/components/data_server/request_handler/get_values_adapter.cc index 0c86b5c9..72972d0d 100644 --- a/components/data_server/request_handler/get_values_adapter.cc +++ b/components/data_server/request_handler/get_values_adapter.cc @@ -22,11 +22,14 @@ #include #include "absl/log/log.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" #include "components/data_server/request_handler/v2_response_data.pb.h" #include "google/protobuf/util/json_util.h" #include "public/api_schema.pb.h" #include "public/applications/pa/api_overlay.pb.h" #include "public/applications/pa/response_utils.h" +#include "public/constants.h" #include "src/util/status_macro/status_macros.h" namespace kv_server { @@ -38,6 +41,7 @@ using google::protobuf::util::JsonStringToMessage; constexpr char kKeysTag[] = "keys"; constexpr char kRenderUrlsTag[] = "renderUrls"; +constexpr char kInterestGroupNamesTag[] = "interestGroupNames"; constexpr char kAdComponentRenderUrlsTag[] = "adComponentRenderUrls"; constexpr char kKvInternalTag[] = "kvInternal"; constexpr char kCustomTag[] = "custom"; @@ -51,7 +55,10 @@ UDFArgument BuildArgument(const RepeatedPtrField& keys, arg.mutable_tags()->add_values()->set_string_value(namespace_tag); auto* key_list = arg.mutable_data()->mutable_list_value(); for (const auto& key : keys) { - key_list->add_values()->set_string_value(key); + for (absl::string_view individual_key : + absl::StrSplit(key, kQueryArgDelimiter)) { + key_list->add_values()->set_string_value(individual_key); + } } return arg; } @@ -65,6 +72,10 @@ v2::GetValuesRequest BuildV2Request(const v1::GetValuesRequest& v1_request) { if (v1_request.keys_size() > 0) { *partition->add_arguments() = BuildArgument(v1_request.keys(), kKeysTag); } + if (v1_request.interest_group_names_size() > 0) { + *partition->add_arguments() = BuildArgument( + v1_request.interest_group_names(), kInterestGroupNamesTag); + } if (v1_request.render_urls_size() > 0) { *partition->add_arguments() = BuildArgument(v1_request.render_urls(), kRenderUrlsTag); @@ -140,6 +151,10 @@ void ProcessKeyGroupOutput(application_pa::KeyGroupOutput key_group_output, if (tag_namespace_status_or.value() == kKeysTag) { ProcessKeyValues(std::move(key_group_output), *v1_response.mutable_keys()); } + if (tag_namespace_status_or.value() == kInterestGroupNamesTag) { + ProcessKeyValues(std::move(key_group_output), + *v1_response.mutable_per_interest_group_data()); + } if (tag_namespace_status_or.value() == kRenderUrlsTag) { ProcessKeyValues(std::move(key_group_output), *v1_response.mutable_render_urls()); @@ -188,17 +203,29 @@ class GetValuesAdapterImpl : public GetValuesAdapter { explicit GetValuesAdapterImpl(std::unique_ptr v2_handler) : v2_handler_(std::move(v2_handler)) {} - grpc::Status CallV2Handler(const v1::GetValuesRequest& v1_request, + grpc::Status CallV2Handler(RequestContextFactory& request_context_factory, + const v1::GetValuesRequest& v1_request, v1::GetValuesResponse& v1_response) const { + privacy_sandbox::server_common::Stopwatch stopwatch; v2::GetValuesRequest v2_request = BuildV2Request(v1_request); - VLOG(7) << "Converting V1 request " << v1_request.DebugString() - << " to v2 request " << v2_request.DebugString(); + PS_VLOG(7, request_context_factory.Get().GetPSLogContext()) + << "Converting V1 request " << v1_request.DebugString() + << " to v2 request " << v2_request.DebugString(); v2::GetValuesResponse v2_response; - if (auto status = v2_handler_->GetValues(v2_request, &v2_response); + ExecutionMetadata execution_metadata; + if (auto status = + v2_handler_->GetValues(request_context_factory, v2_request, + &v2_response, execution_metadata); !status.ok()) { return status; } - VLOG(7) << "Received v2 response: " << v2_response.DebugString(); + int duration_ms = + static_cast(absl::ToInt64Milliseconds(stopwatch.GetElapsedTime())); + LogIfError(KVServerContextMap() + ->SafeMetric() + .LogHistogram(duration_ms)); + PS_VLOG(7, request_context_factory.Get().GetPSLogContext()) + << "Received v2 response: " << v2_response.DebugString(); return privacy_sandbox::server_common::FromAbslStatus( ConvertToV1Response(v2_response, v1_response)); } diff --git a/components/data_server/request_handler/get_values_adapter.h b/components/data_server/request_handler/get_values_adapter.h index aa4343ec..14ca6132 100644 --- a/components/data_server/request_handler/get_values_adapter.h +++ b/components/data_server/request_handler/get_values_adapter.h @@ -33,6 +33,7 @@ class GetValuesAdapter { // Calls the V2 GetValues Handler for a V1 GetValuesRequest. Converts between // V1 and V2 request/responses. virtual grpc::Status CallV2Handler( + RequestContextFactory& request_context_factory, const v1::GetValuesRequest& v1_request, v1::GetValuesResponse& v1_response) const = 0; diff --git a/components/data_server/request_handler/get_values_adapter_test.cc b/components/data_server/request_handler/get_values_adapter_test.cc index 15d77b47..71022f5e 100644 --- a/components/data_server/request_handler/get_values_adapter_test.cc +++ b/components/data_server/request_handler/get_values_adapter_test.cc @@ -54,6 +54,7 @@ class GetValuesAdapterTest : public ::testing::Test { mock_udf_client_, fake_key_fetcher_manager_); get_values_adapter_ = GetValuesAdapter::Create(std::move(v2_handler_)); InitMetricsContextMap(); + request_context_factory_ = std::make_unique(); } privacy_sandbox::server_common::FakeKeyFetcherManager @@ -61,6 +62,7 @@ class GetValuesAdapterTest : public ::testing::Test { std::unique_ptr get_values_adapter_; std::unique_ptr v2_handler_; MockUdfClient mock_udf_client_; + std::unique_ptr request_context_factory_; }; TEST_F(GetValuesAdapterTest, EmptyRequestReturnsEmptyResponse) { @@ -68,14 +70,15 @@ TEST_F(GetValuesAdapterTest, EmptyRequestReturnsEmptyResponse) { TextFormat::ParseFromString(kEmptyMetadata, &udf_metadata); nlohmann::json output = nlohmann::json::parse(R"({"keyGroupOutputs": {}})"); - EXPECT_CALL( - mock_udf_client_, - ExecuteCode(testing::_, EqualsProto(udf_metadata), testing::IsEmpty())) + EXPECT_CALL(mock_udf_client_, + ExecuteCode(testing::_, EqualsProto(udf_metadata), + testing::IsEmpty(), testing::_)) .WillOnce(Return(output.dump())); v1::GetValuesRequest v1_request; v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); @@ -132,7 +135,7 @@ data { &key_group_outputs); EXPECT_CALL(mock_udf_client_, ExecuteCode(testing::_, EqualsProto(udf_metadata), - testing::ElementsAre(EqualsProto(arg)))) + testing::ElementsAre(EqualsProto(arg)), testing::_)) .WillOnce(Return( application_pa::KeyGroupOutputsToJson(key_group_outputs).value())); @@ -140,7 +143,92 @@ data { v1_request.add_keys("key1"); v1_request.add_keys("key2"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString( + R"pb( + keys { + key: "key1" + value { + value { + string_value: "value1" + } + } + } + keys { + key: "key2" + value { + value { + string_value: "value2" + } + } + } + })pb", + &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, V1RequestSeparatesTwoKeysReturnsOk) { + UDFExecutionMetadata udf_metadata; + TextFormat::ParseFromString(kEmptyMetadata, &udf_metadata); + UDFArgument arg; + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "keys" + } +} +data { + list_value { + values { + string_value: "key1" + } + values { + string_value: "key2" + } + } +})", + &arg); + application_pa::KeyGroupOutputs key_group_outputs; + TextFormat::ParseFromString(R"( + key_group_outputs: { + tags: "custom" + tags: "keys" + key_values: { + key: "key1" + value: { + value: { + string_value: "value1" + } + } + } + key_values: { + key: "key2" + value: { + value: { + string_value: "value2" + } + } + } + } +)", + &key_group_outputs); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(testing::_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg)), testing::_)) + .WillOnce(Return( + application_pa::KeyGroupOutputsToJson(key_group_outputs).value())); + + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1,key2"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString( @@ -235,7 +323,8 @@ data { EXPECT_CALL( mock_udf_client_, ExecuteCode(testing::_, EqualsProto(udf_metadata), - testing::ElementsAre(EqualsProto(arg1), EqualsProto(arg2)))) + testing::ElementsAre(EqualsProto(arg1), EqualsProto(arg2)), + testing::_)) .WillOnce(Return( application_pa::KeyGroupOutputsToJson(key_group_outputs).value())); @@ -243,7 +332,8 @@ data { v1_request.add_render_urls("key1"); v1_request.add_ad_component_render_urls("key2"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb( @@ -263,13 +353,14 @@ TEST_F(GetValuesAdapterTest, V2ResponseIsNullReturnsError) { nlohmann::json output = R"({ "keyGroupOutpus": [] })"_json; - EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _)) + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _, _)) .WillOnce(Return(output.dump())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_FALSE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); @@ -284,13 +375,14 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputWithEmptyKVsReturnsOk) { }], "udfOutputApiVersion": 1 })"_json; - EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _)) + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _, _)) .WillOnce(Return(output.dump())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); @@ -305,13 +397,14 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputWithInvalidNamespaceTagIsIgnored) { }], "udfOutputApiVersion": 1 })"_json; - EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _)) + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _, _)) .WillOnce(Return(output.dump())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); @@ -326,13 +419,14 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputWithNoCustomTagIsIgnored) { }], "udfOutputApiVersion": 1 })"_json; - EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _)) + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _, _)) .WillOnce(Return(output.dump())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); @@ -347,13 +441,14 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputWithNoNamespaceTagIsIgnored) { }], "udfOutputApiVersion": 1 })"_json; - EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _)) + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _, _)) .WillOnce(Return(output.dump())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb()pb", &v1_expected); @@ -373,13 +468,14 @@ TEST_F(GetValuesAdapterTest, }], "udfOutputApiVersion": 1 })"_json; - EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _)) + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _, _)) .WillOnce(Return(output.dump())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString(R"pb( @@ -407,13 +503,14 @@ TEST_F(GetValuesAdapterTest, KeyGroupOutputHasDifferentValueTypesReturnsOk) { }], "udfOutputApiVersion": 1 })"_json; - EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _)) + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _, _)) .WillOnce(Return(output.dump())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString( @@ -487,13 +584,14 @@ TEST_F(GetValuesAdapterTest, ValueWithStatusSuccess) { }], "udfOutputApiVersion": 1 })"_json; - EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _)) + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, _, _)) .WillOnce(Return(output.dump())); v1::GetValuesRequest v1_request; v1_request.add_keys("key1"); v1::GetValuesResponse v1_response; - auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); EXPECT_TRUE(status.ok()); v1::GetValuesResponse v1_expected; TextFormat::ParseFromString( @@ -526,5 +624,98 @@ TEST_F(GetValuesAdapterTest, ValueWithStatusSuccess) { EXPECT_THAT(v1_response, EqualsProto(v1_expected)); } +TEST_F(GetValuesAdapterTest, V1RequestWithInterestGroupNamesReturnsOk) { + UDFExecutionMetadata udf_metadata; + TextFormat::ParseFromString(kEmptyMetadata, &udf_metadata); + UDFArgument arg; + TextFormat::ParseFromString(R"( +tags { + values { + string_value: "custom" + } + values { + string_value: "interestGroupNames" + } +} +data { + list_value { + values { + string_value: "interestGroup1" + } + values { + string_value: "interestGroup2" + } + } +})", + &arg); + application_pa::KeyGroupOutputs key_group_outputs; + TextFormat::ParseFromString(R"( + key_group_outputs: { + tags: "custom" + tags: "interestGroupNames" + key_values: { + key: "interestGroup1" + value: { + value: { + string_value: "value1" + } + } + } + key_values: { + key: "interestGroup2" + value: { + value: { + string_value: "{\"priorityVector\":{\"signal1\":1}}" + } + } + } + } +)", + &key_group_outputs); + EXPECT_CALL(mock_udf_client_, + ExecuteCode(testing::_, EqualsProto(udf_metadata), + testing::ElementsAre(EqualsProto(arg)), testing::_)) + .WillOnce(Return( + application_pa::KeyGroupOutputsToJson(key_group_outputs).value())); + + v1::GetValuesRequest v1_request; + v1_request.add_interest_group_names("interestGroup1"); + v1_request.add_interest_group_names("interestGroup2"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(*request_context_factory_, + v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString( + R"pb( + per_interest_group_data { + key: "interestGroup1" + value { value { string_value: "value1" } } + } + per_interest_group_data { + key: "interestGroup2" + value { + value { + struct_value { + fields { + key: "priorityVector" + value { + struct_value { + fields { + key: "signal1" + value { number_value: 1 } + } + + } + } + } + } + } + } + })pb", + &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + } // namespace } // namespace kv_server diff --git a/components/data_server/request_handler/get_values_handler.cc b/components/data_server/request_handler/get_values_handler.cc index f8cce490..d866fc64 100644 --- a/components/data_server/request_handler/get_values_handler.cc +++ b/components/data_server/request_handler/get_values_handler.cc @@ -91,31 +91,43 @@ void ProcessKeys( } // namespace -grpc::Status GetValuesHandler::GetValues(const RequestContext& request_context, - const GetValuesRequest& request, - GetValuesResponse* response) const { +grpc::Status GetValuesHandler::GetValues( + RequestContextFactory& request_context_factory, + const GetValuesRequest& request, GetValuesResponse* response) const { + const auto& request_context = request_context_factory.Get(); if (use_v2_) { - VLOG(5) << "Using V2 adapter for " << request.DebugString(); - return adapter_.CallV2Handler(request, *response); + PS_VLOG(5, request_context.GetPSLogContext()) + << "Using V2 adapter for " << request.DebugString(); + return adapter_.CallV2Handler(request_context_factory, request, *response); } if (!request.kv_internal().empty()) { - VLOG(5) << "Processing kv_internal for " << request.DebugString(); + PS_VLOG(5, request_context.GetPSLogContext()) + << "Processing kv_internal for " << request.DebugString(); ProcessKeys(request_context, request.kv_internal(), cache_, *response->mutable_kv_internal(), add_missing_keys_v1_); } if (!request.keys().empty()) { - VLOG(5) << "Processing keys for " << request.DebugString(); + PS_VLOG(5, request_context.GetPSLogContext()) + << "Processing keys for " << request.DebugString(); ProcessKeys(request_context, request.keys(), cache_, *response->mutable_keys(), add_missing_keys_v1_); } + if (!request.interest_group_names().empty()) { + PS_VLOG(5, request_context.GetPSLogContext()) + << "Processing interest_group_names for " << request.DebugString(); + ProcessKeys(request_context, request.interest_group_names(), cache_, + *response->mutable_per_interest_group_data(), + add_missing_keys_v1_); + } if (!request.render_urls().empty()) { - VLOG(5) << "Processing render_urls for " << request.DebugString(); + PS_VLOG(5, request_context.GetPSLogContext()) + << "Processing render_urls for " << request.DebugString(); ProcessKeys(request_context, request.render_urls(), cache_, *response->mutable_render_urls(), add_missing_keys_v1_); } if (!request.ad_component_render_urls().empty()) { - VLOG(5) << "Processing ad_component_render_urls for " - << request.DebugString(); + PS_VLOG(5, request_context.GetPSLogContext()) + << "Processing ad_component_render_urls for " << request.DebugString(); ProcessKeys(request_context, request.ad_component_render_urls(), cache_, *response->mutable_ad_component_render_urls(), add_missing_keys_v1_); diff --git a/components/data_server/request_handler/get_values_handler.h b/components/data_server/request_handler/get_values_handler.h index ef2be6cf..7c312922 100644 --- a/components/data_server/request_handler/get_values_handler.h +++ b/components/data_server/request_handler/get_values_handler.h @@ -40,7 +40,7 @@ class GetValuesHandler { add_missing_keys_v1_(add_missing_keys_v1) {} // TODO: Implement hostname, ad/render url lookups. - grpc::Status GetValues(const RequestContext& request_context, + grpc::Status GetValues(RequestContextFactory& request_context_factory, const v1::GetValuesRequest& request, v1::GetValuesResponse* response) const; diff --git a/components/data_server/request_handler/get_values_handler_test.cc b/components/data_server/request_handler/get_values_handler_test.cc index 76f0d40a..87e52d1c 100644 --- a/components/data_server/request_handler/get_values_handler_test.cc +++ b/components/data_server/request_handler/get_values_handler_test.cc @@ -47,15 +47,17 @@ class GetValuesHandlerTest : public ::testing::Test { protected: GetValuesHandlerTest() { InitMetricsContextMap(); - scope_metrics_context_ = std::make_unique(); - request_context_ = - std::make_unique(*scope_metrics_context_); + request_context_factory_ = std::make_unique(); + request_context_factory_->UpdateLogContext( + privacy_sandbox::server_common::LogContext(), + privacy_sandbox::server_common::ConsentedDebugConfiguration()); } MockCache mock_cache_; MockGetValuesAdapter mock_get_values_adapter_; - RequestContext& GetRequestContext() { return *request_context_; } - std::unique_ptr scope_metrics_context_; - std::unique_ptr request_context_; + RequestContextFactory& GetRequestContextFactory() { + return *request_context_factory_; + } + std::unique_ptr request_context_factory_; }; TEST_F(GetValuesHandlerTest, ReturnsExistingKeyTwice) { @@ -69,7 +71,7 @@ TEST_F(GetValuesHandlerTest, ReturnsExistingKeyTwice) { GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, /*use_v2=*/false); const auto result = - handler.GetValues(GetRequestContext(), request, &response); + handler.GetValues(GetRequestContextFactory(), request, &response); ASSERT_TRUE(result.ok()) << "code: " << result.error_code() << ", msg: " << result.error_message(); @@ -82,7 +84,8 @@ TEST_F(GetValuesHandlerTest, ReturnsExistingKeyTwice) { &expected); EXPECT_THAT(response, EqualsProto(expected)); - ASSERT_TRUE(handler.GetValues(GetRequestContext(), request, &response).ok()); + ASSERT_TRUE( + handler.GetValues(GetRequestContextFactory(), request, &response).ok()); EXPECT_THAT(response, EqualsProto(expected)); } @@ -97,7 +100,8 @@ TEST_F(GetValuesHandlerTest, RepeatedKeys) { GetValuesResponse response; GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, /*use_v2=*/false); - ASSERT_TRUE(handler.GetValues(GetRequestContext(), request, &response).ok()); + ASSERT_TRUE( + handler.GetValues(GetRequestContextFactory(), request, &response).ok()); GetValuesResponse expected; TextFormat::ParseFromString( @@ -130,7 +134,8 @@ TEST_F(GetValuesHandlerTest, RepeatedKeysSkipEmpty) { GetValuesResponse response; GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, /*use_v2=*/false, /*add_missing_keys_v1=*/false); - ASSERT_TRUE(handler.GetValues(GetRequestContext(), request, &response).ok()); + ASSERT_TRUE( + handler.GetValues(GetRequestContextFactory(), request, &response).ok()); GetValuesResponse expected; TextFormat::ParseFromString( @@ -156,7 +161,8 @@ TEST_F(GetValuesHandlerTest, ReturnsMultipleExistingKeysSameNamespace) { GetValuesResponse response; GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, /*use_v2=*/false); - ASSERT_TRUE(handler.GetValues(GetRequestContext(), request, &response).ok()); + ASSERT_TRUE( + handler.GetValues(GetRequestContextFactory(), request, &response).ok()); GetValuesResponse expected; TextFormat::ParseFromString(R"pb( @@ -187,7 +193,8 @@ TEST_F(GetValuesHandlerTest, ReturnsMultipleExistingKeysDifferentNamespace) { GetValuesResponse response; GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, /*use_v2=*/false); - ASSERT_TRUE(handler.GetValues(GetRequestContext(), request, &response).ok()); + ASSERT_TRUE( + handler.GetValues(GetRequestContextFactory(), request, &response).ok()); GetValuesResponse expected; TextFormat::ParseFromString(R"pb(render_urls { @@ -291,7 +298,8 @@ TEST_F(GetValuesHandlerTest, TestResponseOnDifferentValueFormats) { GetValuesResponse response; GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, /*use_v2=*/false); - ASSERT_TRUE(handler.GetValues(GetRequestContext(), request, &response).ok()); + ASSERT_TRUE( + handler.GetValues(GetRequestContextFactory(), request, &response).ok()); GetValuesResponse expected_from_pb; TextFormat::ParseFromString(response_pb_string, &expected_from_pb); EXPECT_THAT(response, EqualsProto(expected_from_pb)); @@ -310,18 +318,48 @@ TEST_F(GetValuesHandlerTest, CallsV2Adapter) { } })pb", &adapter_response); - EXPECT_CALL(mock_get_values_adapter_, CallV2Handler(_, _)) + EXPECT_CALL(mock_get_values_adapter_, CallV2Handler(_, _, _)) .WillOnce( - DoAll(SetArgReferee<1>(adapter_response), Return(grpc::Status::OK))); + DoAll(SetArgReferee<2>(adapter_response), Return(grpc::Status::OK))); GetValuesRequest request; request.add_keys("key1"); GetValuesResponse response; GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, /*use_v2=*/true); - ASSERT_TRUE(handler.GetValues(GetRequestContext(), request, &response).ok()); + ASSERT_TRUE( + handler.GetValues(GetRequestContextFactory(), request, &response).ok()); EXPECT_THAT(response, EqualsProto(adapter_response)); } +TEST_F(GetValuesHandlerTest, ReturnsPerInterestGroupData) { + EXPECT_CALL(mock_cache_, GetKeyValuePairs(_, UnorderedElementsAre("my_key"))) + .Times(2) + .WillRepeatedly(Return(absl::flat_hash_map{ + {"my_key", "my_value"}})); + GetValuesRequest request; + request.add_interest_group_names("my_key"); + GetValuesResponse response; + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + /*use_v2=*/false); + const auto result = + handler.GetValues(GetRequestContextFactory(), request, &response); + ASSERT_TRUE(result.ok()) << "code: " << result.error_code() + << ", msg: " << result.error_message(); + + GetValuesResponse expected; + TextFormat::ParseFromString( + R"pb(per_interest_group_data { + key: "my_key" + value { value { string_value: "my_value" } } + })pb", + &expected); + EXPECT_THAT(response, EqualsProto(expected)); + + ASSERT_TRUE( + handler.GetValues(GetRequestContextFactory(), request, &response).ok()); + EXPECT_THAT(response, EqualsProto(expected)); +} + } // namespace } // namespace kv_server diff --git a/components/data_server/request_handler/get_values_v2_handler.cc b/components/data_server/request_handler/get_values_v2_handler.cc index 06d3b5c7..2268a038 100644 --- a/components/data_server/request_handler/get_values_v2_handler.cc +++ b/components/data_server/request_handler/get_values_v2_handler.cc @@ -24,6 +24,8 @@ #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "components/data_server/request_handler/framing_utils.h" +#include "components/data_server/request_handler/get_values_v2_status.h" #include "components/data_server/request_handler/ohttp_server_encryptor.h" #include "components/telemetry/server_definition.h" #include "google/protobuf/util/json_util.h" @@ -34,6 +36,7 @@ #include "quiche/binary_http/binary_http_message.h" #include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" #include "quiche/oblivious_http/oblivious_http_gateway.h" +#include "src/communication/encoding_utils.h" #include "src/telemetry/telemetry.h" #include "src/util/status_macro/status_macros.h" @@ -67,15 +70,20 @@ CompressionGroupConcatenator::CompressionType GetResponseCompressionType( } // namespace grpc::Status GetValuesV2Handler::GetValuesHttp( - const GetValuesHttpRequest& request, - google::api::HttpBody* response) const { + RequestContextFactory& request_context_factory, + const std::multimap& headers, + const GetValuesHttpRequest& request, google::api::HttpBody* response, + ExecutionMetadata& execution_metadata) const { return FromAbslStatus( - GetValuesHttp(request.raw_body().data(), *response->mutable_data())); + GetValuesHttp(request_context_factory, request.raw_body().data(), + *response->mutable_data(), execution_metadata, + GetContentType(headers, ContentType::kJson))); } -absl::Status GetValuesV2Handler::GetValuesHttp(std::string_view request, - std::string& response, - ContentType content_type) const { +absl::Status GetValuesV2Handler::GetValuesHttp( + RequestContextFactory& request_context_factory, std::string_view request, + std::string& response, ExecutionMetadata& execution_metadata, + ContentType content_type) const { v2::GetValuesRequest request_proto; if (content_type == ContentType::kJson) { PS_RETURN_IF_ERROR( @@ -84,31 +92,39 @@ absl::Status GetValuesV2Handler::GetValuesHttp(std::string_view request, if (!request_proto.ParseFromString(request)) { auto error_message = "Cannot parse request as a valid serilized proto object."; - VLOG(4) << error_message; + PS_VLOG(4, request_context_factory.Get().GetPSLogContext()) + << error_message; return absl::InvalidArgumentError(error_message); } } - VLOG(9) << "Converted the http request to proto: " - << request_proto.DebugString(); + PS_VLOG(9) << "Converted the http request to proto: " + << request_proto.DebugString(); v2::GetValuesResponse response_proto; - PS_RETURN_IF_ERROR(GetValues(request_proto, &response_proto)); + PS_RETURN_IF_ERROR( + + GetValues(request_context_factory, request_proto, &response_proto, + execution_metadata)); if (content_type == ContentType::kJson) { return MessageToJsonString(response_proto, &response); } // content_type == proto if (!response_proto.SerializeToString(&response)) { auto error_message = "Cannot serialize the response as a proto."; - VLOG(4) << error_message; + PS_VLOG(4, request_context_factory.Get().GetPSLogContext()) + << error_message; return absl::InvalidArgumentError(error_message); } return absl::OkStatus(); } grpc::Status GetValuesV2Handler::BinaryHttpGetValues( + RequestContextFactory& request_context_factory, const v2::BinaryHttpGetValuesRequest& bhttp_request, - google::api::HttpBody* response) const { - return FromAbslStatus(BinaryHttpGetValues(bhttp_request.raw_body().data(), - *response->mutable_data())); + google::api::HttpBody* response, + ExecutionMetadata& execution_metadata) const { + return FromAbslStatus(BinaryHttpGetValues( + request_context_factory, bhttp_request.raw_body().data(), + *response->mutable_data(), execution_metadata)); } GetValuesV2Handler::ContentType GetValuesV2Handler::GetContentType( @@ -123,18 +139,46 @@ GetValuesV2Handler::ContentType GetValuesV2Handler::GetContentType( return ContentType::kJson; } +GetValuesV2Handler::ContentType GetValuesV2Handler::GetContentType( + const std::multimap& headers, + ContentType default_content_type) const { + for (const auto& [header_name, header_value] : headers) { + if (absl::AsciiStrToLower(std::string_view( + header_name.data(), header_name.size())) == kKVContentTypeHeader) { + if (absl::AsciiStrToLower( + std::string_view(header_value.data(), header_value.size())) == + kContentEncodingBhttpHeaderValue) { + return ContentType::kBhttp; + } else if (absl::AsciiStrToLower(std::string_view(header_value.data(), + header_value.size())) == + kContentEncodingProtoHeaderValue) { + return ContentType::kProto; + } else if (absl::AsciiStrToLower(std::string_view(header_value.data(), + header_value.size())) == + kContentEncodingJsonHeaderValue) { + return ContentType::kJson; + } + } + } + return default_content_type; +} + absl::StatusOr GetValuesV2Handler::BuildSuccessfulGetValuesBhttpResponse( - std::string_view bhttp_request_body) const { - VLOG(9) << "Handling the binary http layer"; + RequestContextFactory& request_context_factory, + std::string_view bhttp_request_body, + ExecutionMetadata& execution_metadata) const { + PS_VLOG(9) << "Handling the binary http layer"; PS_ASSIGN_OR_RETURN(quiche::BinaryHttpRequest deserialized_req, quiche::BinaryHttpRequest::Create(bhttp_request_body), _ << "Failed to deserialize binary http request"); - VLOG(3) << "BinaryHttpGetValues request: " << deserialized_req.DebugString(); + PS_VLOG(3) << "BinaryHttpGetValues request: " + << deserialized_req.DebugString(); std::string response; auto content_type = GetContentType(deserialized_req); - PS_RETURN_IF_ERROR( - GetValuesHttp(deserialized_req.body(), response, content_type)); + PS_RETURN_IF_ERROR(GetValuesHttp(request_context_factory, + deserialized_req.body(), response, + execution_metadata, content_type)); quiche::BinaryHttpResponse bhttp_response(200); if (content_type == ContentType::kProto) { bhttp_response.AddHeaderField({ @@ -147,12 +191,15 @@ GetValuesV2Handler::BuildSuccessfulGetValuesBhttpResponse( } absl::Status GetValuesV2Handler::BinaryHttpGetValues( - std::string_view bhttp_request_body, std::string& response) const { + RequestContextFactory& request_context_factory, + std::string_view bhttp_request_body, std::string& response, + ExecutionMetadata& execution_metadata) const { static quiche::BinaryHttpResponse const* kDefaultBhttpResponse = new quiche::BinaryHttpResponse(500); const quiche::BinaryHttpResponse* bhttp_response = kDefaultBhttpResponse; absl::StatusOr maybe_successful_bhttp_response = - BuildSuccessfulGetValuesBhttpResponse(bhttp_request_body); + BuildSuccessfulGetValuesBhttpResponse( + request_context_factory, bhttp_request_body, execution_metadata); if (maybe_successful_bhttp_response.ok()) { bhttp_response = &(maybe_successful_bhttp_response.value()); } @@ -160,27 +207,53 @@ absl::Status GetValuesV2Handler::BinaryHttpGetValues( bhttp_response->Serialize()); response = std::move(serialized_bhttp_response); - VLOG(9) << "BinaryHttpGetValues finished successfully"; + PS_VLOG(9) << "BinaryHttpGetValues finished successfully"; return absl::OkStatus(); } grpc::Status GetValuesV2Handler::ObliviousGetValues( + RequestContextFactory& request_context_factory, + const std::multimap& headers, const ObliviousGetValuesRequest& oblivious_request, - google::api::HttpBody* oblivious_response) const { - VLOG(9) << "Received ObliviousGetValues request. "; + google::api::HttpBody* oblivious_response, + ExecutionMetadata& execution_metadata) const { + PS_VLOG(9) << "Received ObliviousGetValues request. "; OhttpServerEncryptor encryptor(key_fetcher_manager_); auto maybe_plain_text = - encryptor.DecryptRequest(oblivious_request.raw_body().data()); + encryptor.DecryptRequest(oblivious_request.raw_body().data(), + request_context_factory.Get().GetPSLogContext()); if (!maybe_plain_text.ok()) { return FromAbslStatus(maybe_plain_text.status()); } - // Now process the binary http request std::string response; - if (const auto s = BinaryHttpGetValues(*maybe_plain_text, response); - !s.ok()) { - return FromAbslStatus(s); + auto content_type = GetContentType(headers, ContentType::kBhttp); + if (content_type == ContentType::kBhttp) { + // Now process the binary http request + if (const auto s = + BinaryHttpGetValues(request_context_factory, *maybe_plain_text, + response, execution_metadata); + !s.ok()) { + return FromAbslStatus(s); + } + } else { + if (const auto s = + GetValuesHttp(request_context_factory, *maybe_plain_text, response, + execution_metadata, content_type); + !s.ok()) { + return FromAbslStatus(s); + } + } + auto encoded_data_size = GetEncodedDataSize(response.size()); + auto maybe_padded_response = + privacy_sandbox::server_common::EncodeResponsePayload( + privacy_sandbox::server_common::CompressionType::kUncompressed, + std::move(response), encoded_data_size); + if (!maybe_padded_response.ok()) { + return FromAbslStatus(maybe_padded_response.status()); } - auto encrypted_response = encryptor.EncryptResponse(std::move(response)); + auto encrypted_response = encryptor.EncryptResponse( + std::move(*maybe_padded_response), + request_context_factory.Get().GetPSLogContext()); if (!encrypted_response.ok()) { return grpc::Status(grpc::StatusCode::INTERNAL, absl::StrCat(encrypted_response.status().code(), " : ", @@ -191,38 +264,45 @@ grpc::Status GetValuesV2Handler::ObliviousGetValues( return grpc::Status::OK; } -void GetValuesV2Handler::ProcessOnePartition( - RequestContext request_context, +absl::Status GetValuesV2Handler::ProcessOnePartition( + const RequestContextFactory& request_context_factory, const google::protobuf::Struct& req_metadata, const v2::RequestPartition& req_partition, - v2::ResponsePartition& resp_partition) const { + v2::ResponsePartition& resp_partition, + ExecutionMetadata& execution_metadata) const { resp_partition.set_id(req_partition.id()); UDFExecutionMetadata udf_metadata; *udf_metadata.mutable_request_metadata() = req_metadata; + const auto maybe_output_string = udf_client_.ExecuteCode( - std::move(request_context), std::move(udf_metadata), - req_partition.arguments()); + std::move(request_context_factory), std::move(udf_metadata), + req_partition.arguments(), execution_metadata); if (!maybe_output_string.ok()) { resp_partition.mutable_status()->set_code( static_cast(maybe_output_string.status().code())); resp_partition.mutable_status()->set_message( maybe_output_string.status().message()); - } else { - VLOG(5) << "UDF output: " << maybe_output_string.value(); - resp_partition.set_string_output(std::move(maybe_output_string).value()); + return maybe_output_string.status(); } + PS_VLOG(5, request_context_factory.Get().GetPSLogContext()) + << "UDF output: " << maybe_output_string.value(); + resp_partition.set_string_output(std::move(maybe_output_string).value()); + return absl::OkStatus(); } grpc::Status GetValuesV2Handler::GetValues( - const v2::GetValuesRequest& request, - v2::GetValuesResponse* response) const { - auto scope_metrics_context = std::make_unique(); - RequestContext request_context(*scope_metrics_context); + RequestContextFactory& request_context_factory, + const v2::GetValuesRequest& request, v2::GetValuesResponse* response, + ExecutionMetadata& execution_metadata) const { + PS_VLOG(9) << "Update log context " << request.log_context() << ";" + << request.consented_debug_config(); + request_context_factory.UpdateLogContext(request.log_context(), + request.consented_debug_config()); if (request.partitions().size() == 1) { - ProcessOnePartition(std::move(request_context), request.metadata(), - request.partitions(0), - *response->mutable_single_partition()); - return grpc::Status::OK; + const auto partition_status = ProcessOnePartition( + request_context_factory, request.metadata(), request.partitions(0), + *response->mutable_single_partition(), execution_metadata); + return GetExternalStatusForV2(partition_status); } if (request.partitions().empty()) { return grpc::Status(StatusCode::INTERNAL, diff --git a/components/data_server/request_handler/get_values_v2_handler.h b/components/data_server/request_handler/get_values_v2_handler.h index 52f1c0d6..51f198d7 100644 --- a/components/data_server/request_handler/get_values_v2_handler.h +++ b/components/data_server/request_handler/get_values_v2_handler.h @@ -17,6 +17,8 @@ #ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_GET_VALUES_V2_HANDLER_H_ #define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_GET_VALUES_V2_HANDLER_H_ +#include +#include #include #include #include @@ -39,12 +41,19 @@ namespace kv_server { // Content Type Header Name. Can be set for bhttp request to proto or json // values below. inline constexpr std::string_view kContentTypeHeader = "content-type"; +// Header in clear text http request/response that indicates which format is +// used by the payload. The more common "Content-Type" header is not used +// because most importantly that has CORS implications, and in addition, may not +// be forwarded by Envoy to gRPC. +inline constexpr std::string_view kKVContentTypeHeader = "kv-content-type"; // Protobuf Content Type Header Value. inline constexpr std::string_view kContentEncodingProtoHeaderValue = "application/protobuf"; // Json Content Type Header Value. inline constexpr std::string_view kContentEncodingJsonHeaderValue = "application/json"; +inline constexpr std::string_view kContentEncodingBhttpHeaderValue = + "message/bhttp"; // Handles the request family of *GetValues. // See the Service proto definition for details. @@ -63,15 +72,22 @@ class GetValuesV2Handler { std::move(create_compression_group_concatenator)), key_fetcher_manager_(key_fetcher_manager) {} - grpc::Status GetValuesHttp(const v2::GetValuesHttpRequest& request, - google::api::HttpBody* response) const; + grpc::Status GetValuesHttp( + RequestContextFactory& request_context_factory, + const std::multimap& headers, + const v2::GetValuesHttpRequest& request, google::api::HttpBody* response, + ExecutionMetadata& execution_metadata) const; - grpc::Status GetValues(const v2::GetValuesRequest& request, - v2::GetValuesResponse* response) const; + grpc::Status GetValues(RequestContextFactory& request_context_factory, + const v2::GetValuesRequest& request, + v2::GetValuesResponse* response, + ExecutionMetadata& execution_metadata) const; grpc::Status BinaryHttpGetValues( + RequestContextFactory& request_context_factory, const v2::BinaryHttpGetValuesRequest& request, - google::api::HttpBody* response) const; + google::api::HttpBody* response, + ExecutionMetadata& execution_metadata) const; // Supports requests encrypted with a fixed key for debugging/demoing. // X25519 Secret key (priv key). @@ -84,21 +100,31 @@ class GetValuesV2Handler { // HPKE Configuration must be: // KEM: DHKEM(X25519, HKDF-SHA256) 0x0020 // KDF: HKDF-SHA256 0x0001 - // AEAD: AES-128-GCM 0X0001 + // AEAD: AES-256-GCM 0X0002 // (https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md#encryption) - grpc::Status ObliviousGetValues(const v2::ObliviousGetValuesRequest& request, - google::api::HttpBody* response) const; + grpc::Status ObliviousGetValues( + RequestContextFactory& request_context_factory, + const std::multimap& headers, + const v2::ObliviousGetValuesRequest& request, + google::api::HttpBody* response, + ExecutionMetadata& execution_metadata) const; private: enum class ContentType { kJson = 0, kProto, + kBhttp, }; ContentType GetContentType( const quiche::BinaryHttpRequest& deserialized_req) const; + ContentType GetContentType( + const std::multimap& headers, + ContentType default_content_type) const; + absl::Status GetValuesHttp( - std::string_view request, std::string& json_response, + RequestContextFactory& request_context_factory, std::string_view request, + std::string& json_response, ExecutionMetadata& execution_metadata, ContentType content_type = ContentType::kJson) const; // On success, returns a BinaryHttpResponse with a successful response. The @@ -107,19 +133,25 @@ class GetValuesV2Handler { // this function fails, the final grpc code may still be ok. absl::StatusOr BuildSuccessfulGetValuesBhttpResponse( - std::string_view bhttp_request_body) const; + RequestContextFactory& request_context_factory, + std::string_view bhttp_request_body, + ExecutionMetadata& execution_metadata) const; // Returns error only if the response cannot be serialized into Binary HTTP // response. For all other failures, the error status will be inside the // Binary HTTP message. - absl::Status BinaryHttpGetValues(std::string_view bhttp_request_body, - std::string& response) const; + absl::Status BinaryHttpGetValues( + RequestContextFactory& request_context_factory, + std::string_view bhttp_request_body, std::string& response, + ExecutionMetadata& execution_metadata) const; // Invokes UDF to process one partition. - void ProcessOnePartition(RequestContext request_context, - const google::protobuf::Struct& req_metadata, - const v2::RequestPartition& req_partition, - v2::ResponsePartition& resp_partition) const; + absl::Status ProcessOnePartition( + const RequestContextFactory& request_context_factory, + const google::protobuf::Struct& req_metadata, + const v2::RequestPartition& req_partition, + v2::ResponsePartition& resp_partition, + ExecutionMetadata& execution_metadata) const; const UdfClient& udf_client_; std::function diff --git a/components/data_server/request_handler/get_values_v2_handler_test.cc b/components/data_server/request_handler/get_values_v2_handler_test.cc index 9c7376f7..f9fd95a3 100644 --- a/components/data_server/request_handler/get_values_v2_handler_test.cc +++ b/components/data_server/request_handler/get_values_v2_handler_test.cc @@ -30,9 +30,11 @@ #include "nlohmann/json.hpp" #include "public/constants.h" #include "public/test_util/proto_matcher.h" +#include "public/test_util/request_example.h" #include "quiche/binary_http/binary_http_message.h" #include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" #include "quiche/oblivious_http/oblivious_http_client.h" +#include "src/communication/encoding_utils.h" #include "src/encryption/key_fetcher/fake_key_fetcher_manager.h" namespace kv_server { @@ -57,13 +59,19 @@ enum class ProtocolType { struct TestingParameters { ProtocolType protocol_type; const std::string_view content_type; + const std::string_view core_request_body; + const bool is_consented; }; class GetValuesHandlerTest : public ::testing::Test, public ::testing::WithParamInterface { protected: - void SetUp() override { InitMetricsContextMap(); } + void SetUp() override { + privacy_sandbox::server_common::log::ServerToken( + kExampleConsentedDebugToken); + InitMetricsContextMap(); + } template bool IsUsing() { auto param = GetParam(); @@ -75,6 +83,16 @@ class GetValuesHandlerTest return param.content_type == kContentEncodingProtoHeaderValue; } + bool IsRequestExpectConsented() { + auto param = GetParam(); + return param.is_consented; + } + + std::string GetTestRequestBody() { + auto param = GetParam(); + return std::string(param.core_request_body); + } + class PlainRequest { public: explicit PlainRequest(std::string plain_request_body) @@ -126,18 +144,38 @@ class GetValuesHandlerTest class BHTTPResponse { public: google::api::HttpBody& RawResponse() { return response_; } - int16_t ResponseCode() const { + int16_t ResponseCode(bool is_using_bhttp) const { + std::string response; + if (is_using_bhttp) { + response = response_.data(); + } else { + auto deframed_req = + privacy_sandbox::server_common::DecodeRequestPayload( + response_.data()); + EXPECT_TRUE(deframed_req.ok()) << deframed_req.status(); + response = deframed_req->compressed_data; + } const absl::StatusOr maybe_res_bhttp_layer = - quiche::BinaryHttpResponse::Create(response_.data()); + quiche::BinaryHttpResponse::Create(response); EXPECT_TRUE(maybe_res_bhttp_layer.ok()) << "quiche::BinaryHttpResponse::Create failed: " << maybe_res_bhttp_layer.status(); return maybe_res_bhttp_layer->status_code(); } - std::string Unwrap(bool is_protobuf_content) const { + std::string Unwrap(bool is_protobuf_content, bool is_using_bhttp) const { + std::string response; + if (is_using_bhttp) { + response = response_.data(); + } else { + auto deframed_req = + privacy_sandbox::server_common::DecodeRequestPayload( + response_.data()); + EXPECT_TRUE(deframed_req.ok()) << deframed_req.status(); + response = deframed_req->compressed_data; + } const absl::StatusOr maybe_res_bhttp_layer = - quiche::BinaryHttpResponse::Create(response_.data()); + quiche::BinaryHttpResponse::Create(response); EXPECT_TRUE(maybe_res_bhttp_layer.ok()) << "quiche::BinaryHttpResponse::Create failed: " << maybe_res_bhttp_layer.status(); @@ -233,43 +271,51 @@ class GetValuesHandlerTest // For Non-plain protocols, test request and response data are converted // to/from the corresponding request/responses. - grpc::Status GetValuesBasedOnProtocol(std::string request_body, - google::api::HttpBody* response, - int16_t* bhttp_response_code, - GetValuesV2Handler* handler) { + grpc::Status GetValuesBasedOnProtocol( + RequestContextFactory& request_context_factory, std::string request_body, + google::api::HttpBody* response, int16_t* bhttp_response_code, + GetValuesV2Handler* handler) { PlainRequest plain_request(std::move(request_body)); + ExecutionMetadata execution_metadata; + std::multimap headers = { + {"kv-content-type", "application/json"}}; if (IsUsing()) { *bhttp_response_code = 200; - return handler->GetValuesHttp(plain_request.Build(), response); + return handler->GetValuesHttp(request_context_factory, headers, + plain_request.Build(), response, + execution_metadata); } BHTTPRequest bhttp_request(std::move(plain_request), IsProtobufContent()); BHTTPResponse bresponse; if (IsUsing()) { - if (const auto s = handler->BinaryHttpGetValues(bhttp_request.Build(), - &bresponse.RawResponse()); + if (const auto s = handler->BinaryHttpGetValues( + request_context_factory, bhttp_request.Build(), + &bresponse.RawResponse(), execution_metadata); !s.ok()) { LOG(ERROR) << "BinaryHttpGetValues failed: " << s.error_message(); return s; } - *bhttp_response_code = bresponse.ResponseCode(); + *bhttp_response_code = bresponse.ResponseCode(true); } else if (IsUsing()) { OHTTPRequest ohttp_request(std::move(bhttp_request)); // get ObliviousGetValuesRequest, OHTTPResponseUnwrapper auto [request, response_unwrapper] = ohttp_request.Build(); if (const auto s = handler->ObliviousGetValues( - request, &response_unwrapper.RawResponse()); + request_context_factory, {{"kv-content-type", "message/bhttp"}}, + request, &response_unwrapper.RawResponse(), execution_metadata); !s.ok()) { LOG(ERROR) << "ObliviousGetValues failed: " << s.error_message(); return s; } bresponse = response_unwrapper.Unwrap(); - *bhttp_response_code = bresponse.ResponseCode(); + *bhttp_response_code = bresponse.ResponseCode(false); } - response->set_data(bresponse.Unwrap(IsProtobufContent())); + response->set_data(bresponse.Unwrap(IsProtobufContent(), + IsUsing())); return grpc::Status::OK; } @@ -284,22 +330,97 @@ INSTANTIATE_TEST_SUITE_P( TestingParameters{ .protocol_type = ProtocolType::kPlain, .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server::kExampleV2RequestInJson, + .is_consented = false, + }, + TestingParameters{ + .protocol_type = ProtocolType::kPlain, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server::kExampleConsentedV2RequestInJson, + .is_consented = true, + }, + TestingParameters{ + .protocol_type = ProtocolType::kPlain, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = + kv_server::kExampleConsentedV2RequestWithLogContextInJson, + .is_consented = true, }, TestingParameters{ .protocol_type = ProtocolType::kBinaryHttp, .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server::kExampleV2RequestInJson, + .is_consented = false, + }, + TestingParameters{ + .protocol_type = ProtocolType::kBinaryHttp, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server::kExampleConsentedV2RequestInJson, + .is_consented = true, + }, + TestingParameters{ + .protocol_type = ProtocolType::kBinaryHttp, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = + kv_server::kExampleConsentedV2RequestWithLogContextInJson, + .is_consented = true, }, TestingParameters{ .protocol_type = ProtocolType::kObliviousHttp, .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server::kExampleV2RequestInJson, + .is_consented = false, + }, + TestingParameters{ + .protocol_type = ProtocolType::kObliviousHttp, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = kv_server::kExampleConsentedV2RequestInJson, + .is_consented = true, + }, + TestingParameters{ + .protocol_type = ProtocolType::kObliviousHttp, + .content_type = kContentEncodingJsonHeaderValue, + .core_request_body = + kv_server::kExampleConsentedV2RequestWithLogContextInJson, + .is_consented = true, }, TestingParameters{ .protocol_type = ProtocolType::kBinaryHttp, .content_type = kContentEncodingProtoHeaderValue, + .core_request_body = kv_server::kExampleV2RequestInJson, + .is_consented = false, + }, + TestingParameters{ + .protocol_type = ProtocolType::kBinaryHttp, + .content_type = kContentEncodingProtoHeaderValue, + .core_request_body = kv_server::kExampleConsentedV2RequestInJson, + .is_consented = true, + }, + TestingParameters{ + .protocol_type = ProtocolType::kBinaryHttp, + .content_type = kContentEncodingProtoHeaderValue, + .core_request_body = + kv_server::kExampleConsentedV2RequestWithLogContextInJson, + .is_consented = true, }, TestingParameters{ .protocol_type = ProtocolType::kObliviousHttp, .content_type = kContentEncodingProtoHeaderValue, + .core_request_body = kv_server::kExampleV2RequestInJson, + .is_consented = false, + }, + TestingParameters{ + .protocol_type = ProtocolType::kObliviousHttp, + .content_type = kContentEncodingProtoHeaderValue, + .core_request_body = kv_server::kExampleConsentedV2RequestInJson, + .is_consented = true, + }, + TestingParameters{ + .protocol_type = ProtocolType::kObliviousHttp, + .content_type = kContentEncodingProtoHeaderValue, + .core_request_body = + kv_server::kExampleConsentedV2RequestWithLogContextInJson, + .is_consented = true, })); TEST_P(GetValuesHandlerTest, Success) { @@ -377,43 +498,11 @@ data { EXPECT_CALL( mock_udf_client_, ExecuteCode(_, EqualsProto(udf_metadata), - testing::ElementsAre(EqualsProto(arg1), EqualsProto(arg2)))) + testing::ElementsAre(EqualsProto(arg1), EqualsProto(arg2)), + _)) .WillOnce(Return(output.dump())); - std::string core_request_body = R"( -{ - "metadata": { - "hostname": "example.com" - }, - "partitions": [ - { - "id": 0, - "compressionGroupId": 0, - "arguments": [ - { - "tags": [ - "structured", - "groupNames" - ], - "data": [ - "hello" - ] - }, - { - "tags": [ - "custom", - "keys" - ], - "data": [ - "key1" - ] - } - ] - } - ] -} - )"; - + std::string core_request_body = GetTestRequestBody(); google::api::HttpBody response; GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); int16_t bhttp_response_code = 0; @@ -423,10 +512,13 @@ data { &request_proto) .ok()); ASSERT_TRUE(request_proto.SerializeToString(&core_request_body)); + EXPECT_EQ(request_proto.consented_debug_config().is_consented(), + IsRequestExpectConsented()); } - - const auto result = GetValuesBasedOnProtocol(core_request_body, &response, - &bhttp_response_code, &handler); + auto request_context_factory = std::make_unique(); + const auto result = + GetValuesBasedOnProtocol(*request_context_factory, core_request_body, + &response, &bhttp_response_code, &handler); ASSERT_EQ(bhttp_response_code, 200); ASSERT_TRUE(result.ok()) << "code: " << result.error_code() << ", msg: " << result.error_message(); @@ -462,8 +554,10 @@ TEST_P(GetValuesHandlerTest, NoPartition) { .ok()); ASSERT_TRUE(request_proto.SerializeToString(&core_request_body)); } - const auto result = GetValuesBasedOnProtocol(core_request_body, &response, - &bhttp_response_code, &handler); + auto request_context_factory = std::make_unique(); + const auto result = + GetValuesBasedOnProtocol(*request_context_factory, core_request_body, + &response, &bhttp_response_code, &handler); if (IsUsing()) { ASSERT_FALSE(result.ok()); EXPECT_EQ(result.error_code(), grpc::StatusCode::INTERNAL); @@ -474,7 +568,7 @@ TEST_P(GetValuesHandlerTest, NoPartition) { } TEST_P(GetValuesHandlerTest, UdfFailureForOnePartition) { - EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, testing::IsEmpty())) + EXPECT_CALL(mock_udf_client_, ExecuteCode(_, _, testing::IsEmpty(), _)) .WillOnce(Return(absl::InternalError("UDF execution error"))); std::string core_request_body = R"( @@ -498,9 +592,10 @@ TEST_P(GetValuesHandlerTest, UdfFailureForOnePartition) { .ok()); ASSERT_TRUE(request_proto.SerializeToString(&core_request_body)); } - - const auto result = GetValuesBasedOnProtocol(core_request_body, &response, - &bhttp_response_code, &handler); + auto request_context_factory = std::make_unique(); + const auto result = + GetValuesBasedOnProtocol(*request_context_factory, core_request_body, + &response, &bhttp_response_code, &handler); ASSERT_EQ(bhttp_response_code, 200); ASSERT_TRUE(result.ok()) << "code: " << result.error_code() << ", msg: " << result.error_message(); @@ -523,6 +618,7 @@ TEST_P(GetValuesHandlerTest, UdfFailureForOnePartition) { TEST_F(GetValuesHandlerTest, PureGRPCTest) { v2::GetValuesRequest req; + ExecutionMetadata execution_metadata; TextFormat::ParseFromString( R"pb(partitions { id: 9 @@ -530,13 +626,16 @@ TEST_F(GetValuesHandlerTest, PureGRPCTest) { })pb", &req); GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); - EXPECT_CALL(mock_udf_client_, - ExecuteCode(_, _, - testing::ElementsAre( - EqualsProto(req.partitions(0).arguments(0))))) + EXPECT_CALL( + mock_udf_client_, + ExecuteCode( + _, _, + testing::ElementsAre(EqualsProto(req.partitions(0).arguments(0))), _)) .WillOnce(Return("ECHO")); v2::GetValuesResponse resp; - const auto result = handler.GetValues(req, &resp); + auto request_context_factory = std::make_unique(); + const auto result = handler.GetValues(*request_context_factory, req, &resp, + execution_metadata); ASSERT_TRUE(result.ok()) << "code: " << result.error_code() << ", msg: " << result.error_message(); @@ -548,6 +647,7 @@ TEST_F(GetValuesHandlerTest, PureGRPCTest) { TEST_F(GetValuesHandlerTest, PureGRPCTestFailure) { v2::GetValuesRequest req; + ExecutionMetadata execution_metadata; TextFormat::ParseFromString( R"pb(partitions { id: 9 @@ -555,13 +655,16 @@ TEST_F(GetValuesHandlerTest, PureGRPCTestFailure) { })pb", &req); GetValuesV2Handler handler(mock_udf_client_, fake_key_fetcher_manager_); - EXPECT_CALL(mock_udf_client_, - ExecuteCode(_, _, - testing::ElementsAre( - EqualsProto(req.partitions(0).arguments(0))))) + EXPECT_CALL( + mock_udf_client_, + ExecuteCode( + _, _, + testing::ElementsAre(EqualsProto(req.partitions(0).arguments(0))), _)) .WillOnce(Return(absl::InternalError("UDF execution error"))); v2::GetValuesResponse resp; - const auto result = handler.GetValues(req, &resp); + auto request_context_factory = std::make_unique(); + const auto result = handler.GetValues(*request_context_factory, req, &resp, + execution_metadata); ASSERT_TRUE(result.ok()) << "code: " << result.error_code() << ", msg: " << result.error_message(); diff --git a/components/data_server/request_handler/get_values_v2_status.cc b/components/data_server/request_handler/get_values_v2_status.cc new file mode 100644 index 00000000..458f92c8 --- /dev/null +++ b/components/data_server/request_handler/get_values_v2_status.cc @@ -0,0 +1,24 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/get_values_v2_status.h" + +namespace kv_server { + +grpc::Status GetExternalStatusForV2(const absl::Status& status) { + // Return OK status regardless of UDF outcome + return grpc::Status::OK; +} + +} // namespace kv_server diff --git a/components/data_server/request_handler/get_values_v2_status.h b/components/data_server/request_handler/get_values_v2_status.h new file mode 100644 index 00000000..992fc1da --- /dev/null +++ b/components/data_server/request_handler/get_values_v2_status.h @@ -0,0 +1,29 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_GET_VALUES_V2_STATUS_H_ +#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_GET_VALUES_V2_STATUS_H_ + +#include "absl/log/log.h" +#include "grpcpp/grpcpp.h" + +namespace kv_server { + +grpc::Status GetExternalStatusForV2(const absl::Status& status); + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_GET_VALUES_V2_STATUS_H_ diff --git a/components/data_server/request_handler/get_values_v2_status_nonprod.cc b/components/data_server/request_handler/get_values_v2_status_nonprod.cc new file mode 100644 index 00000000..9d08c5f1 --- /dev/null +++ b/components/data_server/request_handler/get_values_v2_status_nonprod.cc @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/flags/flag.h" +#include "src/util/status_macro/status_macros.h" + +ABSL_FLAG(bool, propagate_v2_error_status, false, + "Whether to propagate an error status to V2. This flag is only " + "available in nonprod mode."); + +namespace kv_server { + +using privacy_sandbox::server_common::FromAbslStatus; + +grpc::Status GetExternalStatusForV2(const absl::Status& status) { + if (absl::GetFlag(FLAGS_propagate_v2_error_status)) { + return FromAbslStatus(status); + } + return grpc::Status::OK; +} + +} // namespace kv_server diff --git a/components/data_server/request_handler/mocks.h b/components/data_server/request_handler/mocks.h index 1b23035f..8e30abe4 100644 --- a/components/data_server/request_handler/mocks.h +++ b/components/data_server/request_handler/mocks.h @@ -27,7 +27,8 @@ namespace kv_server { class MockGetValuesAdapter : public GetValuesAdapter { public: MOCK_METHOD((grpc::Status), CallV2Handler, - (const v1::GetValuesRequest& v1_request, + (RequestContextFactory & request_context_factory, + const v1::GetValuesRequest& v1_request, v1::GetValuesResponse& v1_response), (const, override)); }; diff --git a/components/data_server/request_handler/ohttp_client_encryptor.cc b/components/data_server/request_handler/ohttp_client_encryptor.cc index 81f33d0c..54cf955e 100644 --- a/components/data_server/request_handler/ohttp_client_encryptor.cc +++ b/components/data_server/request_handler/ohttp_client_encryptor.cc @@ -37,16 +37,9 @@ absl::StatusOr StringToUint8(absl::string_view str) { } // namespace absl::StatusOr OhttpClientEncryptor::EncryptRequest( - std::string payload) { - auto key = key_fetcher_manager_.GetPublicKey(cloud_platform_); - if (!key.ok()) { - const std::string error = - absl::StrCat("Could not get public key to use for HPKE encryption: ", - key.status().message()); - LOG(ERROR) << error; - return absl::InternalError(error); - } - auto key_id = StringToUint8(key->key_id()); + std::string payload, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + auto key_id = StringToUint8(public_key_.key_id()); if (!key_id.ok()) { return key_id.status(); } @@ -55,12 +48,13 @@ absl::StatusOr OhttpClientEncryptor::EncryptRequest( if (!maybe_config.ok()) { return absl::InternalError(std::string(maybe_config.status().message())); } - std::string public_key; - VLOG(9) << "Encrypting with public key id: " << key->key_id() - << " uint8 key id " << *key_id << "public key " << key->public_key(); - absl::Base64Unescape(key->public_key(), &public_key); + std::string public_key_string; + PS_VLOG(9, log_context) << "Encrypting with public key id: " + << public_key_.key_id() << " uint8 key id " << *key_id + << "public key " << public_key_.public_key(); + absl::Base64Unescape(public_key_.public_key(), &public_key_string); auto http_client_maybe = - quiche::ObliviousHttpClient::Create(public_key, *maybe_config); + quiche::ObliviousHttpClient::Create(public_key_string, *maybe_config); if (!http_client_maybe.ok()) { return absl::InternalError( std::string(http_client_maybe.status().message())); @@ -78,7 +72,8 @@ absl::StatusOr OhttpClientEncryptor::EncryptRequest( } absl::StatusOr OhttpClientEncryptor::DecryptResponse( - std::string encrypted_payload) { + std::string encrypted_payload, + privacy_sandbox::server_common::log::PSLogContext& log_context) { if (!http_client_.has_value() || !http_request_context_.has_value()) { return absl::InternalError( "Emtpy `http_client_` or `http_request_context_`. You should call " diff --git a/components/data_server/request_handler/ohttp_client_encryptor.h b/components/data_server/request_handler/ohttp_client_encryptor.h index 1e023516..8e9fcf81 100644 --- a/components/data_server/request_handler/ohttp_client_encryptor.h +++ b/components/data_server/request_handler/ohttp_client_encryptor.h @@ -31,28 +31,26 @@ namespace kv_server { class OhttpClientEncryptor { public: explicit OhttpClientEncryptor( - privacy_sandbox::server_common::KeyFetcherManagerInterface& - key_fetcher_manager) - : key_fetcher_manager_(key_fetcher_manager) { -#if defined(CLOUD_PLATFORM_AWS) - cloud_platform_ = privacy_sandbox::server_common::CloudPlatform::kAws; -#elif defined(CLOUD_PLATFORM_GCP) - cloud_platform_ = privacy_sandbox::server_common::CloudPlatform::kGcp; -#endif - } + google::cmrt::sdk::public_key_service::v1::PublicKey& public_key) + : public_key_(public_key) {} // Encrypts ougoing request. - absl::StatusOr EncryptRequest(std::string payload); + absl::StatusOr EncryptRequest( + std::string payload, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); // Decrypts incoming reponse. Since OHTTP is stateful, this method should be // called after EncryptRequest. - absl::StatusOr DecryptResponse(std::string encrypted_payload); + absl::StatusOr DecryptResponse( + std::string encrypted_payload, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); private: - ::privacy_sandbox::server_common::CloudPlatform cloud_platform_ = - ::privacy_sandbox::server_common::CloudPlatform::kLocal; std::optional http_client_; std::optional http_request_context_; - privacy_sandbox::server_common::KeyFetcherManagerInterface& - key_fetcher_manager_; + google::cmrt::sdk::public_key_service::v1::PublicKey& public_key_; }; } // namespace kv_server diff --git a/components/data_server/request_handler/ohttp_encryptor_test.cc b/components/data_server/request_handler/ohttp_encryptor_test.cc index 73cdd15f..6cad5850 100644 --- a/components/data_server/request_handler/ohttp_encryptor_test.cc +++ b/components/data_server/request_handler/ohttp_encryptor_test.cc @@ -23,11 +23,14 @@ namespace kv_server { namespace { +using privacy_sandbox::server_common::CloudPlatform::kLocal; + TEST(OhttpEncryptorTest, FullCircleSuccess) { const std::string kTestRequest = "request to encrypt"; privacy_sandbox::server_common::FakeKeyFetcherManager fake_key_fetcher_manager; - OhttpClientEncryptor client_encryptor(fake_key_fetcher_manager); + auto public_key = fake_key_fetcher_manager.GetPublicKey(kLocal); + OhttpClientEncryptor client_encryptor(public_key.value()); OhttpServerEncryptor server_encryptor(fake_key_fetcher_manager); auto request_encrypted_status = client_encryptor.EncryptRequest(kTestRequest); ASSERT_TRUE(request_encrypted_status.ok()); @@ -58,7 +61,8 @@ TEST(OhttpEncryptorTest, ClientDecryptFails) { privacy_sandbox::server_common::FakeKeyFetcherManager fake_key_fetcher_manager; const std::string kTestRequest = "request to encrypt"; - OhttpClientEncryptor client_encryptor(fake_key_fetcher_manager); + auto public_key = fake_key_fetcher_manager.GetPublicKey(kLocal); + OhttpClientEncryptor client_encryptor(public_key.value()); auto request_encrypted_status = client_encryptor.EncryptRequest(kTestRequest); ASSERT_TRUE(request_encrypted_status.ok()); auto response_decrypted_status = client_encryptor.DecryptResponse("garbage"); @@ -83,7 +87,8 @@ TEST(OhttpEncryptorTest, ClientDecryptResponseFails) { privacy_sandbox::server_common::FakeKeyFetcherManager fake_key_fetcher_manager; const std::string kTestRequest = "request to decrypt"; - OhttpClientEncryptor client_encryptor(fake_key_fetcher_manager); + auto public_key = fake_key_fetcher_manager.GetPublicKey(kLocal); + OhttpClientEncryptor client_encryptor(public_key.value()); auto request_encrypted_status = client_encryptor.DecryptResponse(kTestRequest); ASSERT_FALSE(request_encrypted_status.ok()); diff --git a/components/data_server/request_handler/ohttp_server_encryptor.cc b/components/data_server/request_handler/ohttp_server_encryptor.cc index 8280e845..a8893a4c 100644 --- a/components/data_server/request_handler/ohttp_server_encryptor.cc +++ b/components/data_server/request_handler/ohttp_server_encryptor.cc @@ -21,7 +21,8 @@ namespace kv_server { absl::StatusOr OhttpServerEncryptor::DecryptRequest( - absl::string_view encrypted_payload) { + absl::string_view encrypted_payload, + privacy_sandbox::server_common::log::PSLogContext& log_context) { const absl::StatusOr maybe_req_key_id = quiche::ObliviousHttpHeaderKeyConfig:: ParseKeyIdFromObliviousHttpRequestPayload(encrypted_payload); @@ -37,12 +38,13 @@ absl::StatusOr OhttpServerEncryptor::DecryptRequest( } auto private_key_id = std::to_string(*maybe_req_key_id); - VLOG(9) << "Decrypting for the public key id: " << private_key_id; + PS_VLOG(9, log_context) << "Decrypting for the public key id: " + << private_key_id; auto private_key = key_fetcher_manager_.GetPrivateKey(private_key_id); if (!private_key.has_value()) { const std::string error = absl::StrCat( "Unable to retrieve private key for key ID: ", *maybe_req_key_id); - LOG(ERROR) << error; + PS_LOG(ERROR, log_context) << error; return absl::InternalError(error); } @@ -62,7 +64,8 @@ absl::StatusOr OhttpServerEncryptor::DecryptRequest( } absl::StatusOr OhttpServerEncryptor::EncryptResponse( - std::string payload) { + std::string payload, + privacy_sandbox::server_common::log::PSLogContext& log_context) { if (!ohttp_gateway_.has_value() || !decrypted_request_.has_value()) { return absl::InternalError( "Emtpy `ohttp_gateway_` or `decrypted_request_`. You should call " diff --git a/components/data_server/request_handler/ohttp_server_encryptor.h b/components/data_server/request_handler/ohttp_server_encryptor.h index f888284f..644e89c3 100644 --- a/components/data_server/request_handler/ohttp_server_encryptor.h +++ b/components/data_server/request_handler/ohttp_server_encryptor.h @@ -39,10 +39,17 @@ class OhttpServerEncryptor { // lifetime is tied to that object, which lifetime is in turn tied to the // instance of OhttpEncryptor. absl::StatusOr DecryptRequest( - absl::string_view encrypted_payload); + absl::string_view encrypted_payload, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); // Encrypts outgoing response. Since OHTTP is stateful, this method should be // called after DecryptRequest. - absl::StatusOr EncryptResponse(std::string payload); + absl::StatusOr EncryptResponse( + std::string payload, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); private: std::optional ohttp_gateway_; diff --git a/components/data_server/server/BUILD.bazel b/components/data_server/server/BUILD.bazel index c1cb1072..fc2a103a 100644 --- a/components/data_server/server/BUILD.bazel +++ b/components/data_server/server/BUILD.bazel @@ -41,6 +41,7 @@ cc_library( hdrs = ["parameter_fetcher.h"], visibility = [ "//components/sharding:__subpackages__", + "//components/tools:__subpackages__", "//production/packaging:__subpackages__", ], deps = [ @@ -85,6 +86,7 @@ cc_library( "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", ], ) @@ -130,6 +132,7 @@ cc_library( ":lifecycle_heartbeat", ":parameter_fetcher", ":server_initializer", + ":server_log_init", "//components/cloud_config:instance_client", "//components/cloud_config:parameter_client", "//components/data/blob_storage:blob_storage_client", @@ -156,6 +159,7 @@ cc_library( "//components/udf/hooks:get_values_hook", "//components/util:periodic_closure", "//components/util:platform_initializer", + "//components/util:safe_path_log_context", "//components/util:version_linkstamp", "//public:base_types_cc_proto", "//public:constants", @@ -194,6 +198,7 @@ cc_test( "@com_google_absl//absl/flags:parse", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", + "@io_opentelemetry_cpp//exporters/ostream:ostream_log_record_exporter", ], ) @@ -208,6 +213,7 @@ cc_binary( deps = [ ":key_fetcher_factory", ":server_lib", + ":server_log_init", "//components/sharding:shard_manager", "//components/util:version_linkstamp", "@com_google_absl//absl/debugging:failure_signal_handler", @@ -285,6 +291,8 @@ cc_library( "//:gcp_platform": [ "key_fetcher_utils_gcp.h", ], + "//conditions:default": [], + }) + select({ "//:nonprod_mode": [ "nonprod_key_fetcher_factory_cloud.h", ], @@ -292,6 +300,7 @@ cc_library( }) + [ "key_fetcher_factory.h", ], + visibility = ["//components/tools:__subpackages__"], deps = select({ "//:gcp_platform": [ @@ -344,3 +353,35 @@ cc_library( "@google_privacysandbox_servers_common//src/encryption/key_fetcher:key_fetcher_manager", ], ) + +cc_library( + name = "server_log_init", + srcs = select({ + "//:nonprod_mode": [ + "nonprod_server_log_init.cc", + ], + "//conditions:default": [ + "prod_server_log_init.cc", + ], + }) + + select({ + "//:local_instance": [ + "local_server_log_init.cc", + ], + "//conditions:default": [ + "cloud_server_log_init.cc", + ], + }), + hdrs = ["server_log_init.h"], + deps = + select({ + "//:nonprod_mode": [ + "@google_privacysandbox_servers_common//src/logger:request_context_logger", + ], + "//conditions:default": [], + }) + [ + ":parameter_fetcher", + "//components/cloud_config:parameter_client", + "@com_google_absl//absl/log:initialize", + ], +) diff --git a/components/data_server/server/cloud_server_log_init.cc b/components/data_server/server/cloud_server_log_init.cc new file mode 100644 index 00000000..fb00928c --- /dev/null +++ b/components/data_server/server/cloud_server_log_init.cc @@ -0,0 +1,47 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/log/initialize.h" +#include "components/data_server/server/parameter_fetcher.h" +#include "components/data_server/server/server_log_init.h" + +namespace kv_server { + +namespace { + +constexpr absl::string_view kUseExternalMetricsCollectorEndpointSuffix = + "use-external-metrics-collector-endpoint"; +constexpr absl::string_view kMetricsCollectorEndpointSuffix = + "metrics-collector-endpoint"; + +} // namespace + +absl::optional GetMetricsCollectorEndPoint( + const ParameterClient& parameter_client, const std::string& environment) { + absl::optional metrics_collection_endpoint; + ParameterFetcher parameter_fetcher(environment, parameter_client); + auto should_connect_to_external_metrics_collector = + parameter_fetcher.GetBoolParameter( + kUseExternalMetricsCollectorEndpointSuffix); + if (should_connect_to_external_metrics_collector) { + std::string metrics_collector_endpoint_value = + parameter_fetcher.GetParameter(kMetricsCollectorEndpointSuffix); + LOG(INFO) << "Retrieved " << kMetricsCollectorEndpointSuffix + << " parameter: " << metrics_collector_endpoint_value; + metrics_collection_endpoint = std::move(metrics_collector_endpoint_value); + } + return metrics_collection_endpoint; +} + +} // namespace kv_server diff --git a/components/data_server/server/key_fetcher_factory.h b/components/data_server/server/key_fetcher_factory.h index e5f61f73..228a45a5 100644 --- a/components/data_server/server/key_fetcher_factory.h +++ b/components/data_server/server/key_fetcher_factory.h @@ -38,7 +38,10 @@ class KeyFetcherFactory { privacy_sandbox::server_common::KeyFetcherManagerInterface> CreateKeyFetcherManager(const ParameterFetcher& parameter_fetcher) const = 0; // Constructs a KeyFetcherFactory. - static std::unique_ptr Create(); + static std::unique_ptr Create( + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; // Constructs CloudKeyFetcherManager. CloudKeyFetcherManager has common logic @@ -46,12 +49,16 @@ class KeyFetcherFactory { // should be used. class CloudKeyFetcherFactory : public KeyFetcherFactory { public: + explicit CloudKeyFetcherFactory( + privacy_sandbox::server_common::log::PSLogContext& log_context) + : log_context_(log_context) {} // Creates KeyFetcherManager. std::unique_ptr CreateKeyFetcherManager( const ParameterFetcher& parameter_fetcher) const override; protected: + privacy_sandbox::server_common::log::PSLogContext& log_context_; virtual google::scp::cpio::PrivateKeyVendingEndpoint GetPrimaryKeyFetchingEndpoint( const ParameterFetcher& parameter_fetcher) const; diff --git a/components/data_server/server/key_fetcher_factory_aws.cc b/components/data_server/server/key_fetcher_factory_aws.cc index 687172d7..4ac11b92 100644 --- a/components/data_server/server/key_fetcher_factory_aws.cc +++ b/components/data_server/server/key_fetcher_factory_aws.cc @@ -15,8 +15,9 @@ #include "components/data_server/server/key_fetcher_factory.h" namespace kv_server { -std::unique_ptr KeyFetcherFactory::Create() { - return std::make_unique(); +std::unique_ptr KeyFetcherFactory::Create( + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(log_context); } } // namespace kv_server diff --git a/components/data_server/server/key_fetcher_factory_cloud.cc b/components/data_server/server/key_fetcher_factory_cloud.cc index 810b17ed..83117dba 100644 --- a/components/data_server/server/key_fetcher_factory_cloud.cc +++ b/components/data_server/server/key_fetcher_factory_cloud.cc @@ -54,7 +54,7 @@ constexpr std::string_view constexpr std::string_view kPrimaryCoordinatorRegionParameterSuffix = "primary-coordinator-region"; constexpr std::string_view - kSecondaryCoordinatoPrivateKeyEndpointParameterSuffix = + kSecondaryCoordinatorPrivateKeyEndpointParameterSuffix = "secondary-coordinator-private-key-endpoint"; constexpr std::string_view kSecondaryCoordinatorRegionParameterSuffix = "secondary-coordinator-region"; @@ -65,6 +65,7 @@ constexpr absl::Duration kPrivateKeyCacheTtl = absl::Hours(24 * 45); // 45 days constexpr absl::Duration kKeyRefreshFlowRunFrequency = absl::Hours(3); PrivateKeyVendingEndpoint GetKeyFetchingEndpoint( + privacy_sandbox::server_common::log::PSLogContext& log_context, const ParameterFetcher& parameter_fetcher, std::string_view account_identity_prefix, std::string_view private_key_endpoint_prefix, @@ -72,14 +73,14 @@ PrivateKeyVendingEndpoint GetKeyFetchingEndpoint( PrivateKeyVendingEndpoint endpoint; endpoint.account_identity = parameter_fetcher.GetParameter(account_identity_prefix); - LOG(INFO) << "Retrieved " << account_identity_prefix - << " parameter: " << endpoint.account_identity; + PS_LOG(INFO, log_context) << "Retrieved " << account_identity_prefix + << " parameter: " << endpoint.account_identity; endpoint.private_key_vending_service_endpoint = parameter_fetcher.GetParameter(private_key_endpoint_prefix); - LOG(INFO) << "Service endpoint: " - << endpoint.private_key_vending_service_endpoint; + PS_LOG(INFO, log_context) + << "Service endpoint: " << endpoint.private_key_vending_service_endpoint; endpoint.service_region = parameter_fetcher.GetParameter(region_prefix); - LOG(INFO) << "Region: " << endpoint.service_region; + PS_LOG(INFO, log_context) << "Region: " << endpoint.service_region; return endpoint; } } // namespace @@ -89,19 +90,21 @@ CloudKeyFetcherFactory::CreateKeyFetcherManager( const ParameterFetcher& parameter_fetcher) const { if (!parameter_fetcher.GetBoolParameter( kUseRealCoordinatorsParameterSuffix)) { - LOG(INFO) << "Not using real coordinators. Using hardcoded unsafe public " - "and private keys"; + PS_LOG(INFO, log_context_) + << "Not using real coordinators. Using hardcoded unsafe public " + "and private keys"; return std::make_unique(); } std::vector endpoints = GetPublicKeyFetchingEndpoint(parameter_fetcher); std::unique_ptr public_key_fetcher = - PublicKeyFetcherFactory::Create({{GetCloudPlatform(), endpoints}}); + PublicKeyFetcherFactory::Create({{GetCloudPlatform(), endpoints}}, + log_context_); auto primary = GetPrimaryKeyFetchingEndpoint(parameter_fetcher); auto secondary = GetSecondaryKeyFetchingEndpoint(parameter_fetcher); std::unique_ptr private_key_fetcher = PrivateKeyFetcherFactory::Create(primary, {secondary}, - kPrivateKeyCacheTtl); + kPrivateKeyCacheTtl, log_context_); auto event_engine = std::make_unique( grpc_event_engine::experimental::GetDefaultEventEngine()); std::unique_ptr manager = @@ -116,8 +119,8 @@ CloudKeyFetcherFactory::CreateKeyFetcherManager( std::vector CloudKeyFetcherFactory::GetPublicKeyFetchingEndpoint( const ParameterFetcher& parameter_fetcher) const { auto publicKeyEndpointParameter = absl::GetFlag(FLAGS_public_key_endpoint); - LOG(INFO) << "Retrieved public_key_endpoint parameter: " - << publicKeyEndpointParameter; + PS_LOG(INFO, log_context_) << "Retrieved public_key_endpoint parameter: " + << publicKeyEndpointParameter; std::vector endpoints = {publicKeyEndpointParameter}; return endpoints; } @@ -125,7 +128,8 @@ std::vector CloudKeyFetcherFactory::GetPublicKeyFetchingEndpoint( PrivateKeyVendingEndpoint CloudKeyFetcherFactory::GetPrimaryKeyFetchingEndpoint( const ParameterFetcher& parameter_fetcher) const { return GetKeyFetchingEndpoint( - parameter_fetcher, kPrimaryCoordinatorAccountIdentityParameterSuffix, + log_context_, parameter_fetcher, + kPrimaryCoordinatorAccountIdentityParameterSuffix, kPrimaryCoordinatorPrivateKeyEndpointParameterSuffix, kPrimaryCoordinatorRegionParameterSuffix); } @@ -134,8 +138,9 @@ PrivateKeyVendingEndpoint CloudKeyFetcherFactory::GetSecondaryKeyFetchingEndpoint( const ParameterFetcher& parameter_fetcher) const { return GetKeyFetchingEndpoint( - parameter_fetcher, kSecondaryCoordinatorAccountIdentityParameterSuffix, - kPrimaryCoordinatorAccountIdentityParameterSuffix, + log_context_, parameter_fetcher, + kSecondaryCoordinatorAccountIdentityParameterSuffix, + kSecondaryCoordinatorPrivateKeyEndpointParameterSuffix, kSecondaryCoordinatorRegionParameterSuffix); } diff --git a/components/data_server/server/key_fetcher_factory_gcp.cc b/components/data_server/server/key_fetcher_factory_gcp.cc index d4fe2ada..afdabab1 100644 --- a/components/data_server/server/key_fetcher_factory_gcp.cc +++ b/components/data_server/server/key_fetcher_factory_gcp.cc @@ -22,12 +22,18 @@ using ::google::scp::cpio::PrivateKeyVendingEndpoint; using ::privacy_sandbox::server_common::CloudPlatform; class KeyFetcherFactoryGcp : public CloudKeyFetcherFactory { + public: + explicit KeyFetcherFactoryGcp( + privacy_sandbox::server_common::log::PSLogContext& log_context) + : CloudKeyFetcherFactory(log_context) {} + + protected: PrivateKeyVendingEndpoint GetPrimaryKeyFetchingEndpoint( const ParameterFetcher& parameter_fetcher) const override { PrivateKeyVendingEndpoint endpoint = CloudKeyFetcherFactory::GetPrimaryKeyFetchingEndpoint( parameter_fetcher); - UpdatePrimaryGcpEndpoint(endpoint, parameter_fetcher); + UpdatePrimaryGcpEndpoint(endpoint, parameter_fetcher, log_context_); return endpoint; } @@ -36,7 +42,7 @@ class KeyFetcherFactoryGcp : public CloudKeyFetcherFactory { PrivateKeyVendingEndpoint endpoint = CloudKeyFetcherFactory::GetSecondaryKeyFetchingEndpoint( parameter_fetcher); - UpdateSecondaryGcpEndpoint(endpoint, parameter_fetcher); + UpdateSecondaryGcpEndpoint(endpoint, parameter_fetcher, log_context_); return endpoint; } @@ -46,7 +52,8 @@ class KeyFetcherFactoryGcp : public CloudKeyFetcherFactory { }; } // namespace -std::unique_ptr KeyFetcherFactory::Create() { - return std::make_unique(); +std::unique_ptr KeyFetcherFactory::Create( + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(log_context); } } // namespace kv_server diff --git a/components/data_server/server/key_fetcher_factory_local.cc b/components/data_server/server/key_fetcher_factory_local.cc index 9c1545b0..dcf68525 100644 --- a/components/data_server/server/key_fetcher_factory_local.cc +++ b/components/data_server/server/key_fetcher_factory_local.cc @@ -40,7 +40,8 @@ class LocalKeyFetcherFactory : public KeyFetcherFactory { } }; -std::unique_ptr KeyFetcherFactory::Create() { +std::unique_ptr KeyFetcherFactory::Create( + privacy_sandbox::server_common::log::PSLogContext& log_context) { return std::make_unique(); } diff --git a/components/data_server/server/key_fetcher_utils_gcp.cc b/components/data_server/server/key_fetcher_utils_gcp.cc index 8fedbfa8..3f20cd7c 100644 --- a/components/data_server/server/key_fetcher_utils_gcp.cc +++ b/components/data_server/server/key_fetcher_utils_gcp.cc @@ -20,31 +20,38 @@ namespace kv_server { -void SetGcpSpecificParameters(PrivateKeyVendingEndpoint& endpoint, - const ParameterFetcher& parameter_fetcher, - const std::string_view cloudfunction_prefix, - const std::string_view wip_provider) { +void SetGcpSpecificParameters( + PrivateKeyVendingEndpoint& endpoint, + const ParameterFetcher& parameter_fetcher, + const std::string_view cloudfunction_prefix, + const std::string_view wip_provider, + privacy_sandbox::server_common::log::PSLogContext& log_context) { endpoint.gcp_private_key_vending_service_cloudfunction_url = parameter_fetcher.GetParameter(cloudfunction_prefix); - LOG(INFO) << "Retrieved " << cloudfunction_prefix << " parameter: " - << endpoint.gcp_private_key_vending_service_cloudfunction_url; + PS_LOG(INFO, log_context) + << "Retrieved " << cloudfunction_prefix << " parameter: " + << endpoint.gcp_private_key_vending_service_cloudfunction_url; endpoint.gcp_wip_provider = parameter_fetcher.GetParameter(wip_provider); - LOG(INFO) << "Retrieved " << wip_provider - << " parameter: " << endpoint.gcp_wip_provider; + PS_LOG(INFO, log_context) << "Retrieved " << wip_provider + << " parameter: " << endpoint.gcp_wip_provider; } -void UpdatePrimaryGcpEndpoint(PrivateKeyVendingEndpoint& endpoint, - const ParameterFetcher& parameter_fetcher) { - SetGcpSpecificParameters(endpoint, parameter_fetcher, - kPrimaryKeyServiceCloudFunctionUrlSuffix, - kPrimaryWorkloadIdentityPoolProviderSuffix); +void UpdatePrimaryGcpEndpoint( + PrivateKeyVendingEndpoint& endpoint, + const ParameterFetcher& parameter_fetcher, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + SetGcpSpecificParameters( + endpoint, parameter_fetcher, kPrimaryKeyServiceCloudFunctionUrlSuffix, + kPrimaryWorkloadIdentityPoolProviderSuffix, log_context); } -void UpdateSecondaryGcpEndpoint(PrivateKeyVendingEndpoint& endpoint, - const ParameterFetcher& parameter_fetcher) { - SetGcpSpecificParameters(endpoint, parameter_fetcher, - kSecondaryKeyServiceCloudFunctionUrlSuffix, - kSecondaryWorkloadIdentityPoolProviderSuffix); +void UpdateSecondaryGcpEndpoint( + PrivateKeyVendingEndpoint& endpoint, + const ParameterFetcher& parameter_fetcher, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + SetGcpSpecificParameters( + endpoint, parameter_fetcher, kSecondaryKeyServiceCloudFunctionUrlSuffix, + kSecondaryWorkloadIdentityPoolProviderSuffix, log_context); } } // namespace kv_server diff --git a/components/data_server/server/key_fetcher_utils_gcp.h b/components/data_server/server/key_fetcher_utils_gcp.h index 1b508aae..9d53c6b3 100644 --- a/components/data_server/server/key_fetcher_utils_gcp.h +++ b/components/data_server/server/key_fetcher_utils_gcp.h @@ -33,11 +33,15 @@ constexpr std::string_view kSecondaryKeyServiceCloudFunctionUrlSuffix = constexpr std::string_view kSecondaryWorkloadIdentityPoolProviderSuffix = "secondary-workload-identity-pool-provider"; -void UpdatePrimaryGcpEndpoint(PrivateKeyVendingEndpoint& endpoint, - const ParameterFetcher& parameter_fetcher); - -void UpdateSecondaryGcpEndpoint(PrivateKeyVendingEndpoint& endpoint, - const ParameterFetcher& parameter_fetcher); +void UpdatePrimaryGcpEndpoint( + PrivateKeyVendingEndpoint& endpoint, + const ParameterFetcher& parameter_fetcher, + privacy_sandbox::server_common::log::PSLogContext& log_context); + +void UpdateSecondaryGcpEndpoint( + PrivateKeyVendingEndpoint& endpoint, + const ParameterFetcher& parameter_fetcher, + privacy_sandbox::server_common::log::PSLogContext& log_context); } // namespace kv_server diff --git a/components/data_server/server/key_value_service_impl.cc b/components/data_server/server/key_value_service_impl.cc index 35fc0e08..2b080bcf 100644 --- a/components/data_server/server/key_value_service_impl.cc +++ b/components/data_server/server/key_value_service_impl.cc @@ -30,16 +30,32 @@ using v1::GetValuesRequest; using v1::GetValuesResponse; using v1::KeyValueService; +namespace { +// The V1 request API should have no concept of consented debugging, default +// all V1 requests to consented requests +privacy_sandbox::server_common::ConsentedDebugConfiguration +GetDefaultConsentedDebugConfigForV1Request() { + privacy_sandbox::server_common::ConsentedDebugConfiguration config; + config.set_is_consented(true); + config.set_token(privacy_sandbox::server_common::log::ServerToken()); + return config; +} +} // namespace + grpc::ServerUnaryReactor* KeyValueServiceImpl::GetValues( CallbackServerContext* context, const GetValuesRequest* request, GetValuesResponse* response) { - auto request_received_time = absl::Now(); - auto scope_metrics_context = std::make_unique(); - RequestContext request_context(*scope_metrics_context); - grpc::Status status = handler_.GetValues(request_context, *request, response); + privacy_sandbox::server_common::Stopwatch stopwatch; + std::unique_ptr request_context_factory = + std::make_unique(); + request_context_factory->UpdateLogContext( + privacy_sandbox::server_common::LogContext(), + GetDefaultConsentedDebugConfigForV1Request()); + grpc::Status status = + handler_.GetValues(*request_context_factory, *request, response); auto* reactor = context->DefaultReactor(); reactor->Finish(status); - LogRequestCommonSafeMetrics(request, response, status, request_received_time); + LogV1RequestCommonSafeMetrics(request, response, status, stopwatch); return reactor; } diff --git a/components/data_server/server/key_value_service_v2_impl.cc b/components/data_server/server/key_value_service_v2_impl.cc index 36932962..e426a917 100644 --- a/components/data_server/server/key_value_service_v2_impl.cc +++ b/components/data_server/server/key_value_service_v2_impl.cc @@ -28,19 +28,40 @@ using v2::GetValuesHttpRequest; using v2::KeyValueService; template -using HandlerFunctionT = grpc::Status (GetValuesV2Handler::*)(const RequestT&, - ResponseT*) const; +using HandlerFunctionT = grpc::Status (GetValuesV2Handler::*)( + RequestContextFactory&, const RequestT&, ResponseT*, + ExecutionMetadata& execution_metadata) const; + +inline void LogTotalExecutionWithoutCustomCodeMetric( + const privacy_sandbox::server_common::Stopwatch& stopwatch, + std::optional custom_code_total_execution_time_micros, + RequestContextFactory& request_context_factory) { + auto duration_micros = absl::ToDoubleMicroseconds(stopwatch.GetElapsedTime()); + if (custom_code_total_execution_time_micros.has_value()) { + duration_micros -= *custom_code_total_execution_time_micros; + } + UdfRequestMetricsContext& metrics_context = + request_context_factory.Get().GetUdfRequestMetricsContext(); + LogIfError(metrics_context.LogHistogram( + (duration_micros))); +} template grpc::ServerUnaryReactor* HandleRequest( + RequestContextFactory& request_context_factory, CallbackServerContext* context, const RequestT* request, ResponseT* response, const GetValuesV2Handler& handler, HandlerFunctionT handler_function) { - auto request_received_time = absl::Now(); - grpc::Status status = (handler.*handler_function)(*request, response); + privacy_sandbox::server_common::Stopwatch stopwatch; + ExecutionMetadata execution_metadata; + grpc::Status status = (handler.*handler_function)( + request_context_factory, *request, response, execution_metadata); auto* reactor = context->DefaultReactor(); reactor->Finish(status); - LogRequestCommonSafeMetrics(request, response, status, request_received_time); + LogRequestCommonSafeMetrics(request, response, status, stopwatch); + LogTotalExecutionWithoutCustomCodeMetric( + stopwatch, execution_metadata.custom_code_total_execution_time_micros, + request_context_factory); return reactor; } @@ -49,30 +70,54 @@ grpc::ServerUnaryReactor* HandleRequest( grpc::ServerUnaryReactor* KeyValueServiceV2Impl::GetValuesHttp( CallbackServerContext* context, const GetValuesHttpRequest* request, google::api::HttpBody* response) { - return HandleRequest(context, request, response, handler_, - &GetValuesV2Handler::GetValuesHttp); + privacy_sandbox::server_common::Stopwatch stopwatch; + auto request_context_factory = std::make_unique(); + ExecutionMetadata execution_metadata; + grpc::Status status = handler_.GetValuesHttp( + *request_context_factory, context->client_metadata(), *request, response, + execution_metadata); + auto* reactor = context->DefaultReactor(); + reactor->Finish(status); + LogRequestCommonSafeMetrics(request, response, status, stopwatch); + LogTotalExecutionWithoutCustomCodeMetric( + stopwatch, execution_metadata.custom_code_total_execution_time_micros, + *request_context_factory); + return reactor; } grpc::ServerUnaryReactor* KeyValueServiceV2Impl::GetValues( grpc::CallbackServerContext* context, const v2::GetValuesRequest* request, v2::GetValuesResponse* response) { - return HandleRequest(context, request, response, handler_, - &GetValuesV2Handler::GetValues); + auto request_context_factory = std::make_unique(); + return HandleRequest(*request_context_factory, context, request, response, + handler_, &GetValuesV2Handler::GetValues); } grpc::ServerUnaryReactor* KeyValueServiceV2Impl::BinaryHttpGetValues( CallbackServerContext* context, const v2::BinaryHttpGetValuesRequest* request, google::api::HttpBody* response) { - return HandleRequest(context, request, response, handler_, - &GetValuesV2Handler::BinaryHttpGetValues); + auto request_context_factory = std::make_unique(); + return HandleRequest(*request_context_factory, context, request, response, + handler_, &GetValuesV2Handler::BinaryHttpGetValues); } grpc::ServerUnaryReactor* KeyValueServiceV2Impl::ObliviousGetValues( CallbackServerContext* context, const v2::ObliviousGetValuesRequest* request, google::api::HttpBody* response) { - return HandleRequest(context, request, response, handler_, - &GetValuesV2Handler::ObliviousGetValues); + privacy_sandbox::server_common::Stopwatch stopwatch; + auto request_context_factory = std::make_unique(); + ExecutionMetadata execution_metadata; + grpc::Status status = handler_.ObliviousGetValues( + *request_context_factory, context->client_metadata(), *request, response, + execution_metadata); + auto* reactor = context->DefaultReactor(); + reactor->Finish(status); + LogRequestCommonSafeMetrics(request, response, status, stopwatch); + LogTotalExecutionWithoutCustomCodeMetric( + stopwatch, execution_metadata.custom_code_total_execution_time_micros, + *request_context_factory); + return reactor; } } // namespace kv_server diff --git a/components/data_server/server/lifecycle_heartbeat.cc b/components/data_server/server/lifecycle_heartbeat.cc index 593faf80..6c72ff61 100644 --- a/components/data_server/server/lifecycle_heartbeat.cc +++ b/components/data_server/server/lifecycle_heartbeat.cc @@ -29,9 +29,13 @@ constexpr absl::Duration kLifecycleHeartbeatFrequency = absl::Seconds(30); class LifecycleHeartbeatImpl : public LifecycleHeartbeat { public: - explicit LifecycleHeartbeatImpl(std::unique_ptr heartbeat, - InstanceClient& instance_client) - : heartbeat_(std::move(heartbeat)), instance_client_(instance_client) {} + explicit LifecycleHeartbeatImpl( + std::unique_ptr heartbeat, + InstanceClient& instance_client, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : heartbeat_(std::move(heartbeat)), + instance_client_(instance_client), + log_context_(log_context) {} ~LifecycleHeartbeatImpl() { if (is_running_) { @@ -46,15 +50,16 @@ class LifecycleHeartbeatImpl : public LifecycleHeartbeat { } launch_hook_name_ = parameter_fetcher.GetParameter(kLaunchHookParameterSuffix); - LOG(INFO) << "Retrieved " << kLaunchHookParameterSuffix - << " parameter: " << launch_hook_name_; + PS_LOG(INFO, log_context_) << "Retrieved " << kLaunchHookParameterSuffix + << " parameter: " << launch_hook_name_; absl::Status status = heartbeat_->StartDelayed(kLifecycleHeartbeatFrequency, [this] { if (const absl::Status status = instance_client_.RecordLifecycleHeartbeat(launch_hook_name_); !status.ok()) { - LOG(WARNING) << "Failed to record lifecycle heartbeat: " << status; + PS_LOG(WARNING, log_context_) + << "Failed to record lifecycle heartbeat: " << status; } }); if (status.ok()) { @@ -72,9 +77,10 @@ class LifecycleHeartbeatImpl : public LifecycleHeartbeat { [this] { return instance_client_.CompleteLifecycle(launch_hook_name_); }, - "CompleteLifecycle", - LogStatusSafeMetricsFn()); - LOG(INFO) << "Completed lifecycle hook " << launch_hook_name_; + "CompleteLifecycle", LogStatusSafeMetricsFn(), + log_context_); + PS_LOG(INFO, log_context_) + << "Completed lifecycle hook " << launch_hook_name_; } private: @@ -82,20 +88,22 @@ class LifecycleHeartbeatImpl : public LifecycleHeartbeat { InstanceClient& instance_client_; std::string launch_hook_name_; bool is_running_ = false; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace std::unique_ptr LifecycleHeartbeat::Create( - std::unique_ptr heartbeat, - InstanceClient& instance_client) { + std::unique_ptr heartbeat, InstanceClient& instance_client, + privacy_sandbox::server_common::log::PSLogContext& log_context) { return std::make_unique(std::move(heartbeat), - instance_client); + instance_client, log_context); } std::unique_ptr LifecycleHeartbeat::Create( - InstanceClient& instance_client) { + InstanceClient& instance_client, + privacy_sandbox::server_common::log::PSLogContext& log_context) { return std::make_unique(PeriodicClosure::Create(), - instance_client); + instance_client, log_context); } } // namespace kv_server diff --git a/components/data_server/server/lifecycle_heartbeat.h b/components/data_server/server/lifecycle_heartbeat.h index c6c54512..6f00b24c 100644 --- a/components/data_server/server/lifecycle_heartbeat.h +++ b/components/data_server/server/lifecycle_heartbeat.h @@ -23,6 +23,7 @@ #include "components/cloud_config/instance_client.h" #include "components/data_server/server/parameter_fetcher.h" #include "components/util/periodic_closure.h" +#include "src/logger/request_context_logger.h" namespace kv_server { @@ -34,12 +35,18 @@ class LifecycleHeartbeat { virtual void Finish() = 0; static std::unique_ptr Create( - InstanceClient& instance_client); + InstanceClient& instance_client, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); // For testing static std::unique_ptr Create( std::unique_ptr heartbeat, - InstanceClient& instance_client); + InstanceClient& instance_client, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server diff --git a/components/data_server/server/local_server_log_init.cc b/components/data_server/server/local_server_log_init.cc new file mode 100644 index 00000000..dbd1dd82 --- /dev/null +++ b/components/data_server/server/local_server_log_init.cc @@ -0,0 +1,26 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/log/initialize.h" +#include "components/data_server/server/server_log_init.h" + +namespace kv_server { + +absl::optional GetMetricsCollectorEndPoint( + const ParameterClient& parameter_client, const std::string& environment) { + absl::optional metrics_collection_endpoint; + return metrics_collection_endpoint; +} + +} // namespace kv_server diff --git a/components/data_server/server/main.cc b/components/data_server/server/main.cc index 15caa8aa..f6be2725 100644 --- a/components/data_server/server/main.cc +++ b/components/data_server/server/main.cc @@ -22,6 +22,7 @@ #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "components/data_server/server/server.h" +#include "components/data_server/server/server_log_init.h" #include "components/util/build_info.h" #include "src/util/rlimit_core_config.h" @@ -46,7 +47,7 @@ int main(int argc, char** argv) { absl::InstallFailureSignalHandler(options); } - absl::InitializeLog(); + kv_server::InitLog(); absl::SetProgramUsageMessage(absl::StrCat( "FLEDGE Key/Value Server. Sample usage:\n", argv[0], " --port=50051")); absl::ParseCommandLine(argc, argv); diff --git a/components/data_server/server/mocks.h b/components/data_server/server/mocks.h index 8734fce2..78a5922a 100644 --- a/components/data_server/server/mocks.h +++ b/components/data_server/server/mocks.h @@ -42,6 +42,9 @@ class MockInstanceClient : public InstanceClient { (DescribeInstanceGroupInput & input), (override)); MOCK_METHOD(absl::StatusOr>, DescribeInstances, (const absl::flat_hash_set&), (override)); + MOCK_METHOD(void, UpdateLogContext, + (privacy_sandbox::server_common::log::PSLogContext & log_context), + (override)); }; class MockParameterClient : public ParameterClient { @@ -54,6 +57,9 @@ class MockParameterClient : public ParameterClient { (std::string_view parameter_name), (const, override)); MOCK_METHOD(absl::StatusOr, GetBoolParameter, (std::string_view parameter_name), (const, override)); + MOCK_METHOD(void, UpdateLogContext, + (privacy_sandbox::server_common::log::PSLogContext & log_context), + (override)); }; class MockParameterFetcher : public ParameterFetcher { diff --git a/components/data_server/server/nonprod_key_fetcher_factory_aws.cc b/components/data_server/server/nonprod_key_fetcher_factory_aws.cc index d50eb289..de12ebdc 100644 --- a/components/data_server/server/nonprod_key_fetcher_factory_aws.cc +++ b/components/data_server/server/nonprod_key_fetcher_factory_aws.cc @@ -15,7 +15,8 @@ #include "components/data_server/server/nonprod_key_fetcher_factory_cloud.h" namespace kv_server { -std::unique_ptr KeyFetcherFactory::Create() { - return std::make_unique(); +std::unique_ptr KeyFetcherFactory::Create( + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(log_context); } } // namespace kv_server diff --git a/components/data_server/server/nonprod_key_fetcher_factory_cloud.cc b/components/data_server/server/nonprod_key_fetcher_factory_cloud.cc index a877ec57..c537b916 100644 --- a/components/data_server/server/nonprod_key_fetcher_factory_cloud.cc +++ b/components/data_server/server/nonprod_key_fetcher_factory_cloud.cc @@ -40,8 +40,8 @@ NonprodCloudKeyFetcherFactory::GetPublicKeyFetchingEndpoint( const ParameterFetcher& parameter_fetcher) const { auto publicKeyEndpointParameter = parameter_fetcher.GetParameter(kPublicKeyEndpointParameterSuffix); - LOG(INFO) << "Retrieved public_key_endpoint parameter: " - << publicKeyEndpointParameter; + PS_LOG(INFO, log_context_) << "Retrieved public_key_endpoint parameter: " + << publicKeyEndpointParameter; std::vector endpoints = {publicKeyEndpointParameter}; return endpoints; } diff --git a/components/data_server/server/nonprod_key_fetcher_factory_cloud.h b/components/data_server/server/nonprod_key_fetcher_factory_cloud.h index 8f9bb4a5..b05451fd 100644 --- a/components/data_server/server/nonprod_key_fetcher_factory_cloud.h +++ b/components/data_server/server/nonprod_key_fetcher_factory_cloud.h @@ -31,15 +31,14 @@ namespace kv_server { // endpoint. This is a security risk for produciton build. Which is why this // implementation is only allowed in `nonprod` build mode. class NonprodCloudKeyFetcherFactory : public CloudKeyFetcherFactory { + public: + explicit NonprodCloudKeyFetcherFactory( + privacy_sandbox::server_common::log::PSLogContext& log_context) + : CloudKeyFetcherFactory(log_context) {} + protected: - google::scp::cpio::PrivateKeyVendingEndpoint GetPrimaryKeyFetchingEndpoint( - const ParameterFetcher& parameter_fetcher) const override; - google::scp::cpio::PrivateKeyVendingEndpoint GetSecondaryKeyFetchingEndpoint( - const ParameterFetcher& parameter_fetcher) const override; std::vector GetPublicKeyFetchingEndpoint( const ParameterFetcher& parameter_fetcher) const override; - ::privacy_sandbox::server_common::CloudPlatform GetCloudPlatform() - const override; }; } // namespace kv_server diff --git a/components/data_server/server/nonprod_key_fetcher_factory_gcp.cc b/components/data_server/server/nonprod_key_fetcher_factory_gcp.cc index 47880038..66200cff 100644 --- a/components/data_server/server/nonprod_key_fetcher_factory_gcp.cc +++ b/components/data_server/server/nonprod_key_fetcher_factory_gcp.cc @@ -23,21 +23,27 @@ using ::google::scp::cpio::PrivateKeyVendingEndpoint; using ::privacy_sandbox::server_common::CloudPlatform; class KeyFetcherFactoryGcpNonProd : public NonprodCloudKeyFetcherFactory { + public: + explicit KeyFetcherFactoryGcpNonProd( + privacy_sandbox::server_common::log::PSLogContext& log_context) + : NonprodCloudKeyFetcherFactory(log_context) {} + + protected: PrivateKeyVendingEndpoint GetPrimaryKeyFetchingEndpoint( const ParameterFetcher& parameter_fetcher) const override { PrivateKeyVendingEndpoint endpoint = - NonprodCloudKeyFetcherFactory::GetPrimaryKeyFetchingEndpoint( + CloudKeyFetcherFactory::GetPrimaryKeyFetchingEndpoint( parameter_fetcher); - UpdatePrimaryGcpEndpoint(endpoint, parameter_fetcher); + UpdatePrimaryGcpEndpoint(endpoint, parameter_fetcher, log_context_); return endpoint; } PrivateKeyVendingEndpoint GetSecondaryKeyFetchingEndpoint( const ParameterFetcher& parameter_fetcher) const override { PrivateKeyVendingEndpoint endpoint = - NonprodCloudKeyFetcherFactory::GetSecondaryKeyFetchingEndpoint( + CloudKeyFetcherFactory::GetSecondaryKeyFetchingEndpoint( parameter_fetcher); - UpdateSecondaryGcpEndpoint(endpoint, parameter_fetcher); + UpdateSecondaryGcpEndpoint(endpoint, parameter_fetcher, log_context_); return endpoint; } @@ -47,7 +53,8 @@ class KeyFetcherFactoryGcpNonProd : public NonprodCloudKeyFetcherFactory { }; } // namespace -std::unique_ptr KeyFetcherFactory::Create() { - return std::make_unique(); +std::unique_ptr KeyFetcherFactory::Create( + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique(log_context); } } // namespace kv_server diff --git a/components/data_server/server/nonprod_server_log_init.cc b/components/data_server/server/nonprod_server_log_init.cc new file mode 100644 index 00000000..a0fc25db --- /dev/null +++ b/components/data_server/server/nonprod_server_log_init.cc @@ -0,0 +1,28 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/log/initialize.h" +#include "components/data_server/server/server_log_init.h" +#include "src/logger/request_context_logger.h" + +namespace kv_server { + +void InitLog() { + absl::InitializeLog(); + // Turn on all otel logging for nonprod build regardless of consented or not + // This line must be called before the first PS_LOG/PS_VLOG line. + privacy_sandbox::server_common::log::AlwaysLogOtel(true); +} + +} // namespace kv_server diff --git a/components/data_server/server/parameter_fetcher.cc b/components/data_server/server/parameter_fetcher.cc index b8e36b55..79f895b1 100644 --- a/components/data_server/server/parameter_fetcher.cc +++ b/components/data_server/server/parameter_fetcher.cc @@ -27,10 +27,12 @@ namespace kv_server { ParameterFetcher::ParameterFetcher( std::string environment, const ParameterClient& parameter_client, absl::AnyInvocable - metrics_callback) + metrics_callback, + privacy_sandbox::server_common::log::PSLogContext& log_context) : environment_(std::move(environment)), parameter_client_(parameter_client), - metrics_callback_(std::move(metrics_callback)) {} + metrics_callback_(std::move(metrics_callback)), + log_context_(log_context) {} std::string ParameterFetcher::GetParameter( std::string_view parameter_suffix, @@ -40,7 +42,7 @@ std::string ParameterFetcher::GetParameter( [this, ¶m_name, &default_value] { return parameter_client_.GetParameter(param_name, default_value); }, - "GetParameter", metrics_callback_, {{"param", param_name}}); + "GetParameter", metrics_callback_, log_context_, {{"param", param_name}}); } int32_t ParameterFetcher::GetInt32Parameter( @@ -50,7 +52,7 @@ int32_t ParameterFetcher::GetInt32Parameter( [this, ¶m_name] { return parameter_client_.GetInt32Parameter(param_name); }, - "GetParameter", metrics_callback_, {{"param", param_name}}); + "GetParameter", metrics_callback_, log_context_, {{"param", param_name}}); } bool ParameterFetcher::GetBoolParameter( @@ -60,7 +62,7 @@ bool ParameterFetcher::GetBoolParameter( [this, ¶m_name] { return parameter_client_.GetBoolParameter(param_name); }, - "GetParameter", metrics_callback_, {{"param", param_name}}); + "GetParameter", metrics_callback_, log_context_, {{"param", param_name}}); } std::string ParameterFetcher::GetParamName( diff --git a/components/data_server/server/parameter_fetcher.h b/components/data_server/server/parameter_fetcher.h index bc3a344d..0b368a23 100644 --- a/components/data_server/server/parameter_fetcher.h +++ b/components/data_server/server/parameter_fetcher.h @@ -38,7 +38,10 @@ class ParameterFetcher { ParameterFetcher( std::string environment, const ParameterClient& parameter_client, absl::AnyInvocable - metric_callback = LogMetricsNoOpCallback()); + metric_callback = LogMetricsNoOpCallback(), + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); virtual ~ParameterFetcher() = default; @@ -60,6 +63,9 @@ class ParameterFetcher { virtual NotifierMetadata GetRealtimeNotifierMetadata(int32_t num_shards, int32_t shard_num) const; + protected: + privacy_sandbox::server_common::log::PSLogContext& log_context_; + private: std::string GetParamName(std::string_view parameter_suffix) const; diff --git a/components/data_server/server/parameter_fetcher_aws.cc b/components/data_server/server/parameter_fetcher_aws.cc index 6d4f0e55..5c3b17e5 100644 --- a/components/data_server/server/parameter_fetcher_aws.cc +++ b/components/data_server/server/parameter_fetcher_aws.cc @@ -39,9 +39,12 @@ constexpr std::string_view kS3ClientMaxRangeBytesParameterSuffix = NotifierMetadata ParameterFetcher::GetBlobStorageNotifierMetadata() const { std::string bucket_sns_arn = GetParameter(kDataLoadingFileChannelBucketSNSParameterSuffix); - LOG(INFO) << "Retrieved " << kDataLoadingFileChannelBucketSNSParameterSuffix - << " parameter: " << bucket_sns_arn; - return AwsNotifierMetadata{"BlobNotifier_", std::move(bucket_sns_arn)}; + PS_LOG(INFO, log_context_) + << "Retrieved " << kDataLoadingFileChannelBucketSNSParameterSuffix + << " parameter: " << bucket_sns_arn; + return AwsNotifierMetadata{.queue_prefix = "BlobNotifier_", + .sns_arn = std::move(bucket_sns_arn), + .environment = environment_}; } BlobStorageClient::ClientOptions ParameterFetcher::GetBlobStorageClientOptions() @@ -49,12 +52,14 @@ BlobStorageClient::ClientOptions ParameterFetcher::GetBlobStorageClientOptions() BlobStorageClient::ClientOptions client_options; client_options.max_connections = GetInt32Parameter(kS3ClientMaxConnectionsParameterSuffix); - LOG(INFO) << "Retrieved " << kS3ClientMaxConnectionsParameterSuffix - << " parameter: " << client_options.max_connections; + PS_LOG(INFO, log_context_) + << "Retrieved " << kS3ClientMaxConnectionsParameterSuffix + << " parameter: " << client_options.max_connections; client_options.max_range_bytes = GetInt32Parameter(kS3ClientMaxRangeBytesParameterSuffix); - LOG(INFO) << "Retrieved " << kS3ClientMaxRangeBytesParameterSuffix - << " parameter: " << client_options.max_range_bytes; + PS_LOG(INFO, log_context_) + << "Retrieved " << kS3ClientMaxRangeBytesParameterSuffix + << " parameter: " << client_options.max_range_bytes; return client_options; } @@ -62,10 +67,12 @@ NotifierMetadata ParameterFetcher::GetRealtimeNotifierMetadata( int32_t num_shards, int32_t shard_num) const { std::string realtime_sns_arn = GetParameter(kDataLoadingRealtimeChannelSNSParameterSuffix); - LOG(INFO) << "Retrieved " << kDataLoadingRealtimeChannelSNSParameterSuffix - << " parameter: " << realtime_sns_arn; + PS_LOG(INFO, log_context_) + << "Retrieved " << kDataLoadingRealtimeChannelSNSParameterSuffix + << " parameter: " << realtime_sns_arn; return AwsNotifierMetadata{"QueueNotifier_", std::move(realtime_sns_arn), - .num_shards = num_shards, .shard_num = shard_num}; + .num_shards = num_shards, .shard_num = shard_num, + .environment = environment_}; } } // namespace kv_server diff --git a/components/data_server/server/parameter_fetcher_gcp.cc b/components/data_server/server/parameter_fetcher_gcp.cc index 5bcc9fbf..159d7856 100644 --- a/components/data_server/server/parameter_fetcher_gcp.cc +++ b/components/data_server/server/parameter_fetcher_gcp.cc @@ -39,19 +39,23 @@ BlobStorageClient::ClientOptions ParameterFetcher::GetBlobStorageClientOptions() NotifierMetadata ParameterFetcher::GetRealtimeNotifierMetadata( int32_t num_shards, int32_t shard_num) const { std::string environment = GetParameter(kEnvironment); - LOG(INFO) << "Retrieved " << kEnvironment << " parameter: " << environment; + PS_LOG(INFO, log_context_) + << "Retrieved " << kEnvironment << " parameter: " << environment; auto realtime_thread_numbers = GetInt32Parameter(kRealtimeUpdaterThreadNumberParameterSuffix); - LOG(INFO) << "Retrieved " << kRealtimeUpdaterThreadNumberParameterSuffix - << " parameter: " << realtime_thread_numbers; + PS_LOG(INFO, log_context_) + << "Retrieved " << kRealtimeUpdaterThreadNumberParameterSuffix + << " parameter: " << realtime_thread_numbers; std::string topic_id = absl::StrFormat("kv-server-%s-realtime-pubsub", environment); std::string project_id = GetParameter(kProjectId); - LOG(INFO) << "Retrieved " << kProjectId << " parameter: " << project_id; + PS_LOG(INFO, log_context_) + << "Retrieved " << kProjectId << " parameter: " << project_id; return GcpNotifierMetadata{ .queue_prefix = "QueueNotifier_", .project_id = project_id, .topic_id = topic_id, + .environment = environment, .num_threads = realtime_thread_numbers, .num_shards = num_shards, .shard_num = shard_num, diff --git a/components/data_server/server/parameter_fetcher_local.cc b/components/data_server/server/parameter_fetcher_local.cc index b22e1e1c..083ade02 100644 --- a/components/data_server/server/parameter_fetcher_local.cc +++ b/components/data_server/server/parameter_fetcher_local.cc @@ -24,8 +24,8 @@ constexpr std::string_view kRealtimeDirectoryToWatch = "realtime-directory"; NotifierMetadata ParameterFetcher::GetBlobStorageNotifierMetadata() const { std::string directory = GetParameter(kLocalDirectoryToWatch); - LOG(INFO) << "Retrieved " << kLocalDirectoryToWatch - << " parameter: " << directory; + PS_LOG(INFO, log_context_) + << "Retrieved " << kLocalDirectoryToWatch << " parameter: " << directory; return LocalNotifierMetadata{.local_directory = std::move(directory)}; } @@ -37,8 +37,8 @@ BlobStorageClient::ClientOptions ParameterFetcher::GetBlobStorageClientOptions() NotifierMetadata ParameterFetcher::GetRealtimeNotifierMetadata( int32_t num_shards, int32_t shard_num) const { std::string directory = GetParameter(kRealtimeDirectoryToWatch); - LOG(INFO) << "Retrieved " << kRealtimeDirectoryToWatch - << " parameter: " << directory; + PS_LOG(INFO, log_context_) << "Retrieved " << kRealtimeDirectoryToWatch + << " parameter: " << directory; return LocalNotifierMetadata{.local_directory = std::move(directory)}; } diff --git a/components/data_server/server/prod_server_log_init.cc b/components/data_server/server/prod_server_log_init.cc new file mode 100644 index 00000000..d6f583c3 --- /dev/null +++ b/components/data_server/server/prod_server_log_init.cc @@ -0,0 +1,22 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/log/initialize.h" +#include "components/data_server/server/server_log_init.h" + +namespace kv_server { + +void InitLog() { absl::InitializeLog(); } + +} // namespace kv_server diff --git a/components/data_server/server/server.cc b/components/data_server/server/server.cc index 2b88b059..69aaeee9 100644 --- a/components/data_server/server/server.cc +++ b/components/data_server/server/server.cc @@ -31,6 +31,7 @@ #include "components/data_server/server/key_fetcher_factory.h" #include "components/data_server/server/key_value_service_impl.h" #include "components/data_server/server/key_value_service_v2_impl.h" +#include "components/data_server/server/server_log_init.h" #include "components/errors/retry.h" #include "components/internal_server/constants.h" #include "components/internal_server/local_lookup.h" @@ -62,22 +63,18 @@ ABSL_FLAG(uint16_t, port, 50051, namespace kv_server { namespace { -using privacy_sandbox::server_common::ConfigureMetrics; using privacy_sandbox::server_common::ConfigurePrivateMetrics; using privacy_sandbox::server_common::ConfigureTracer; using privacy_sandbox::server_common::GetTracer; using privacy_sandbox::server_common::InitTelemetry; using privacy_sandbox::server_common::TelemetryProvider; +using privacy_sandbox::server_common::log::PSLogContext; using privacy_sandbox::server_common::telemetry::BuildDependentConfig; // TODO: Use config cpio client to get this from the environment constexpr absl::string_view kDataBucketParameterSuffix = "data-bucket-id"; constexpr absl::string_view kBackupPollFrequencySecsParameterSuffix = "backup-poll-frequency-secs"; -constexpr absl::string_view kUseExternalMetricsCollectorEndpointSuffix = - "use-external-metrics-collector-endpoint"; -constexpr absl::string_view kMetricsCollectorEndpointSuffix = - "metrics-collector-endpoint"; constexpr absl::string_view kMetricsExportIntervalMillisParameterSuffix = "metrics-export-interval-millis"; constexpr absl::string_view kMetricsExportTimeoutMillisParameterSuffix = @@ -94,6 +91,8 @@ constexpr absl::string_view kLoggingVerbosityLevelParameterSuffix = "logging-verbosity-level"; constexpr absl::string_view kUdfTimeoutMillisParameterSuffix = "udf-timeout-millis"; +constexpr absl::string_view kUdfUpdateTimeoutMillisParameterSuffix = + "udf-update-timeout-millis"; constexpr absl::string_view kUdfMinLogLevelParameterSuffix = "udf-min-log-level"; constexpr absl::string_view kUseShardingKeyRegexParameterSuffix = @@ -110,6 +109,8 @@ constexpr absl::string_view kEnableOtelLoggerParameterSuffix = constexpr std::string_view kDataLoadingBlobPrefixAllowlistSuffix = "data-loading-blob-prefix-allowlist"; constexpr std::string_view kTelemetryConfigSuffix = "telemetry-config"; +constexpr std::string_view kConsentedDebugTokenSuffix = "consented-debug-token"; +constexpr std::string_view kEnableConsentedLogSuffix = "enable-consented-log"; opentelemetry::sdk::metrics::PeriodicExportingMetricReaderOptions GetMetricsOptions(const ParameterClient& parameter_client, @@ -142,6 +143,7 @@ void CheckMetricsCollectorEndPointConnection( RetryUntilOk( [channel]() { if (channel->GetState(true) != GRPC_CHANNEL_READY) { + LOG(INFO) << "Metrics collector endpoint is not ready. Will retry."; return absl::UnavailableError("metrics collector is not connected"); } return absl::OkStatus(); @@ -149,23 +151,6 @@ void CheckMetricsCollectorEndPointConnection( "Checking connection to metrics collector", LogMetricsNoOpCallback()); } -absl::optional GetMetricsCollectorEndPoint( - const ParameterClient& parameter_client, const std::string& environment) { - absl::optional metrics_collection_endpoint; - ParameterFetcher parameter_fetcher(environment, parameter_client); - auto should_connect_to_external_metrics_collector = - parameter_fetcher.GetBoolParameter( - kUseExternalMetricsCollectorEndpointSuffix); - if (should_connect_to_external_metrics_collector) { - std::string metrics_collector_endpoint_value = - parameter_fetcher.GetParameter(kMetricsCollectorEndpointSuffix); - LOG(INFO) << "Retrieved " << kMetricsCollectorEndpointSuffix - << " parameter: " << metrics_collector_endpoint_value; - metrics_collection_endpoint = std::move(metrics_collector_endpoint_value); - } - return std::move(metrics_collection_endpoint); -} - privacy_sandbox::server_common::telemetry::TelemetryConfig GetServerTelemetryConfig(const ParameterClient& parameter_client, const std::string& environment) { @@ -182,11 +167,12 @@ GetServerTelemetryConfig(const ParameterClient& parameter_client, } BlobPrefixAllowlist GetBlobPrefixAllowlist( - const ParameterFetcher& parameter_fetcher) { + const ParameterFetcher& parameter_fetcher, PSLogContext& log_context) { const auto prefix_allowlist = parameter_fetcher.GetParameter( kDataLoadingBlobPrefixAllowlistSuffix, /*default_value=*/""); - LOG(INFO) << "Retrieved " << kDataLoadingBlobPrefixAllowlistSuffix - << " parameter: " << prefix_allowlist; + PS_LOG(INFO, log_context) + << "Retrieved " << kDataLoadingBlobPrefixAllowlistSuffix + << " parameter: " << prefix_allowlist; return BlobPrefixAllowlist(prefix_allowlist); } @@ -197,7 +183,8 @@ Server::Server() GetValuesHook::Create(GetValuesHook::OutputType::kString)), binary_get_values_hook_( GetValuesHook::Create(GetValuesHook::OutputType::kBinary)), - run_query_hook_(RunQueryHook::Create()) {} + run_set_query_int_hook_(RunSetQueryIntHook::Create()), + run_set_query_string_hook_(RunSetQueryStringHook::Create()) {} // Because the cache relies on telemetry, this function needs to be // called right after telemetry has been initialized but before anything that @@ -205,28 +192,35 @@ Server::Server() void Server::InitializeKeyValueCache() { cache_ = KeyValueCache::Create(); cache_->UpdateKeyValue( - "hi", + server_safe_log_context_, "hi", "Hello, world! If you are seeing this, it means you can " "query me successfully", /*logical_commit_time = */ 1); } -void Server::InitOtelLogger( - ::opentelemetry::sdk::resource::Resource server_info, - absl::optional collector_endpoint, - const ParameterFetcher& parameter_fetcher) { - const bool enable_otel_logger = - parameter_fetcher.GetBoolParameter(kEnableOtelLoggerParameterSuffix); - LOG(INFO) << "Retrieved " << kEnableOtelLoggerParameterSuffix - << " parameter: " << enable_otel_logger; - if (!enable_otel_logger) { - return; - } - log_provider_ = privacy_sandbox::server_common::ConfigurePrivateLogger( - server_info, collector_endpoint); - open_telemetry_sink_ = std::make_unique( - log_provider_->GetLogger(kServiceName.data())); - absl::AddLogSink(open_telemetry_sink_.get()); +void Server::InitLogger(::opentelemetry::sdk::resource::Resource server_info, + absl::optional collector_endpoint, + const ParameterFetcher& parameter_fetcher) { + // updating verbosity level flag as early as we can, as it affects all logging + // downstream. + auto verbosity_level = parameter_fetcher.GetInt32Parameter( + kLoggingVerbosityLevelParameterSuffix); + absl::SetGlobalVLogLevel(verbosity_level); + privacy_sandbox::server_common::log::PS_VLOG_IS_ON(0, verbosity_level); + static auto* log_provider = + privacy_sandbox::server_common::ConfigurePrivateLogger(server_info, + collector_endpoint) + .release(); + privacy_sandbox::server_common::log::logger_private = + log_provider->GetLogger(kServiceName.data()).get(); + parameter_client_->UpdateLogContext(server_safe_log_context_); + instance_client_->UpdateLogContext(server_safe_log_context_); + if (const bool enable_consented_log = + parameter_fetcher.GetBoolParameter(kEnableConsentedLogSuffix); + enable_consented_log) { + privacy_sandbox::server_common::log::ServerToken( + parameter_fetcher.GetParameter(kConsentedDebugTokenSuffix, "")); + } } void Server::InitializeTelemetry(const ParameterClient& parameter_client, @@ -234,8 +228,16 @@ void Server::InitializeTelemetry(const ParameterClient& parameter_client, std::string instance_id = RetryUntilOk( [&instance_client]() { return instance_client.GetInstanceId(); }, "GetInstanceId", LogMetricsNoOpCallback()); - - InitTelemetry(std::string(kServiceName), std::string(BuildVersion())); + ParameterFetcher parameter_fetcher(environment_, parameter_client); + const bool enable_otel_logger = + parameter_fetcher.GetBoolParameter(kEnableOtelLoggerParameterSuffix); + LOG(INFO) << "Retrieved " << kEnableOtelLoggerParameterSuffix + << " parameter: " << enable_otel_logger; + BuildDependentConfig telemetry_config( + GetServerTelemetryConfig(parameter_client, environment_)); + InitTelemetry(std::string(kServiceName), std::string(BuildVersion()), + telemetry_config.TraceAllowed(), + telemetry_config.MetricAllowed(), enable_otel_logger); auto metrics_options = GetMetricsOptions(parameter_client, environment_); auto metrics_collector_endpoint = GetMetricsCollectorEndPoint(parameter_client, environment_); @@ -243,8 +245,6 @@ void Server::InitializeTelemetry(const ParameterClient& parameter_client, CheckMetricsCollectorEndPointConnection(metrics_collector_endpoint.value()); } LOG(INFO) << "Done retrieving metrics collector endpoint"; - BuildDependentConfig telemetry_config( - GetServerTelemetryConfig(parameter_client, environment_)); auto* context_map = KVServerContextMap( telemetry_config, ConfigurePrivateMetrics( @@ -259,24 +259,17 @@ void Server::InitializeTelemetry(const ParameterClient& parameter_client, CreateKVAttributes(instance_id, std::to_string(shard_num_), environment_), metrics_options, metrics_collector_endpoint)); - - // TODO(b/300137699): Deprecate ConfigureMetrics once all metrics are migrated - // to new telemetry API - ConfigureMetrics( - CreateKVAttributes(instance_id, std::to_string(shard_num_), environment_), - metrics_options, metrics_collector_endpoint); ConfigureTracer( CreateKVAttributes(instance_id, std::to_string(shard_num_), environment_), metrics_collector_endpoint); - ParameterFetcher parameter_fetcher(environment_, parameter_client); - InitOtelLogger(CreateKVAttributes(std::move(instance_id), - std::to_string(shard_num_), environment_), - metrics_collector_endpoint, parameter_fetcher); - LOG(INFO) << "Done init telemetry"; + InitLogger(CreateKVAttributes(std::move(instance_id), + std::to_string(shard_num_), environment_), + metrics_collector_endpoint, parameter_fetcher); + PS_LOG(INFO, server_safe_log_context_) << "Done init telemetry"; } absl::Status Server::CreateDefaultInstancesIfNecessaryAndGetEnvironment( - std::unique_ptr parameter_client, + std::unique_ptr parameter_client, std::unique_ptr instance_client, std::unique_ptr udf_client) { parameter_client_ = parameter_client == nullptr ? ParameterClient::Create() @@ -287,19 +280,35 @@ absl::Status Server::CreateDefaultInstancesIfNecessaryAndGetEnvironment( [this]() { return instance_client_->GetEnvironmentTag(); }, "GetEnvironment", LogMetricsNoOpCallback()); LOG(INFO) << "Retrieved environment: " << environment_; - ParameterFetcher parameter_fetcher(environment_, *parameter_client_); + const auto shard_num_status = instance_client_->GetShardNumTag(); + if (!shard_num_status.ok()) { + return shard_num_status.status(); + } + if (!absl::SimpleAtoi(*shard_num_status, &shard_num_)) { + std::string error = + absl::StrFormat("Failed converting shard id parameter: %s to int32.", + *shard_num_status); + LOG(ERROR) << error; + return absl::InvalidArgumentError(error); + } + LOG(INFO) << "Retrieved shard num: " << shard_num_; + + InitializeTelemetry(*parameter_client_, *instance_client_); + + ParameterFetcher parameter_fetcher( + environment_, *parameter_client_, + std::move(LogStatusSafeMetricsFn()), + server_safe_log_context_); int32_t number_of_workers = parameter_fetcher.GetInt32Parameter(kUdfNumWorkersParameterSuffix); int32_t udf_timeout_ms = parameter_fetcher.GetInt32Parameter(kUdfTimeoutMillisParameterSuffix); + int32_t udf_update_timeout_ms = parameter_fetcher.GetInt32Parameter( + kUdfUpdateTimeoutMillisParameterSuffix); int32_t udf_min_log_level = parameter_fetcher.GetInt32Parameter(kUdfMinLogLevelParameterSuffix); - // updating verbosity level flag as early as we can, as it affects all logging - // downstream. - absl::SetGlobalVLogLevel(parameter_fetcher.GetInt32Parameter( - kLoggingVerbosityLevelParameterSuffix)); if (udf_client != nullptr) { udf_client_ = std::move(udf_client); return absl::OkStatus(); @@ -309,24 +318,26 @@ absl::Status Server::CreateDefaultInstancesIfNecessaryAndGetEnvironment( // can be removed and we can own the unique ptr to the hooks. absl::StatusOr> udf_client_or_status = UdfClient::Create( - std::move(config_builder - .RegisterStringGetValuesHook(*string_get_values_hook_) - .RegisterBinaryGetValuesHook(*binary_get_values_hook_) - .RegisterRunQueryHook(*run_query_hook_) - .RegisterLoggingFunction() - .SetNumberOfWorkers(number_of_workers) - .Config()), - absl::Milliseconds(udf_timeout_ms), udf_min_log_level); + std::move( + config_builder + .RegisterStringGetValuesHook(*string_get_values_hook_) + .RegisterBinaryGetValuesHook(*binary_get_values_hook_) + .RegisterRunSetQueryIntHook(*run_set_query_int_hook_) + .RegisterRunSetQueryStringHook(*run_set_query_string_hook_) + .RegisterLoggingFunction() + .SetNumberOfWorkers(number_of_workers) + .Config()), + absl::Milliseconds(udf_timeout_ms), + absl::Milliseconds(udf_update_timeout_ms), udf_min_log_level); if (udf_client_or_status.ok()) { udf_client_ = std::move(*udf_client_or_status); } return udf_client_or_status.status(); } -absl::Status Server::Init( - std::unique_ptr parameter_client, - std::unique_ptr instance_client, - std::unique_ptr udf_client) { +absl::Status Server::Init(std::unique_ptr parameter_client, + std::unique_ptr instance_client, + std::unique_ptr udf_client) { { absl::Status status = CreateDefaultInstancesIfNecessaryAndGetEnvironment( std::move(parameter_client), std::move(instance_client), @@ -341,18 +352,21 @@ absl::Status Server::Init( return InitOnceInstancesAreCreated(); } -KeySharder GetKeySharder(const ParameterFetcher& parameter_fetcher) { +KeySharder GetKeySharder(const ParameterFetcher& parameter_fetcher, + PSLogContext& log_context) { const bool use_sharding_key_regex = parameter_fetcher.GetBoolParameter(kUseShardingKeyRegexParameterSuffix); - LOG(INFO) << "Retrieved " << kUseShardingKeyRegexParameterSuffix - << " parameter: " << use_sharding_key_regex; + PS_LOG(INFO, log_context) + << "Retrieved " << kUseShardingKeyRegexParameterSuffix + << " parameter: " << use_sharding_key_regex; ShardingFunction func(/*seed=*/""); std::optional shard_key_regex; if (use_sharding_key_regex) { std::string sharding_key_regex_value = parameter_fetcher.GetParameter(kShardingKeyRegexParameterSuffix); - LOG(INFO) << "Retrieved " << kShardingKeyRegexParameterSuffix - << " parameter: " << sharding_key_regex_value; + PS_LOG(INFO, log_context) + << "Retrieved " << kShardingKeyRegexParameterSuffix + << " parameter: " << sharding_key_regex_value; // https://en.cppreference.com/w/cpp/regex/syntax_option_type // optimize -- "Instructs the regular expression engine to make matching // faster, with the potential cost of making construction slower. For @@ -365,52 +379,44 @@ KeySharder GetKeySharder(const ParameterFetcher& parameter_fetcher) { } absl::Status Server::InitOnceInstancesAreCreated() { - const auto shard_num_status = instance_client_->GetShardNumTag(); - if (!shard_num_status.ok()) { - return shard_num_status.status(); - } - if (!absl::SimpleAtoi(*shard_num_status, &shard_num_)) { - std::string error = - absl::StrFormat("Failed converting shard id parameter: %s to int32.", - *shard_num_status); - LOG(ERROR) << error; - return absl::InvalidArgumentError(error); - } - LOG(INFO) << "Retrieved shard num: " << shard_num_; - InitializeTelemetry(*parameter_client_, *instance_client_); InitializeKeyValueCache(); auto span = GetTracer()->StartSpan("InitServer"); auto scope = opentelemetry::trace::Scope(span); - LOG(INFO) << "Creating lifecycle heartbeat..."; + PS_LOG(INFO, server_safe_log_context_) << "Creating lifecycle heartbeat..."; std::unique_ptr lifecycle_heartbeat = - LifecycleHeartbeat::Create(*instance_client_); + LifecycleHeartbeat::Create(*instance_client_, server_safe_log_context_); ParameterFetcher parameter_fetcher( environment_, *parameter_client_, - std::move(LogStatusSafeMetricsFn())); + std::move(LogStatusSafeMetricsFn()), + server_safe_log_context_); if (absl::Status status = lifecycle_heartbeat->Start(parameter_fetcher); status != absl::OkStatus()) { return status; } + PS_LOG(INFO, server_safe_log_context_) << "Setting default UDF."; if (absl::Status status = SetDefaultUdfCodeObject(); !status.ok()) { + PS_LOG(ERROR, server_safe_log_context_) << status; return absl::InternalError( "Error setting default UDF. Please contact Google to fix the default " "UDF or retry starting the server."); } num_shards_ = parameter_fetcher.GetInt32Parameter(kNumShardsParameterSuffix); - LOG(INFO) << "Retrieved " << kNumShardsParameterSuffix - << " parameter: " << num_shards_; + PS_LOG(INFO, server_safe_log_context_) + << "Retrieved " << kNumShardsParameterSuffix + << " parameter: " << num_shards_; blob_client_ = CreateBlobClient(parameter_fetcher); delta_stream_reader_factory_ = CreateStreamRecordReaderFactory(parameter_fetcher); notifier_ = CreateDeltaFileNotifier(parameter_fetcher); - auto factory = KeyFetcherFactory::Create(); + auto factory = KeyFetcherFactory::Create(server_safe_log_context_); key_fetcher_manager_ = factory->CreateKeyFetcherManager(parameter_fetcher); CreateGrpcServices(parameter_fetcher); auto metadata = parameter_fetcher.GetBlobStorageNotifierMetadata(); - auto message_service_status = MessageService::Create(metadata); + auto message_service_status = + MessageService::Create(metadata, server_safe_log_context_); if (!message_service_status.ok()) { return message_service_status.status(); } @@ -419,14 +425,15 @@ absl::Status Server::InitOnceInstancesAreCreated() { grpc_server_ = CreateAndStartGrpcServer(); local_lookup_ = CreateLocalLookup(*cache_); - auto key_sharder = GetKeySharder(parameter_fetcher); + auto key_sharder = GetKeySharder(parameter_fetcher, server_safe_log_context_); auto server_initializer = GetServerInitializer( num_shards_, *key_fetcher_manager_, *local_lookup_, environment_, - shard_num_, *instance_client_, *cache_, parameter_fetcher, key_sharder); + shard_num_, *instance_client_, *cache_, parameter_fetcher, key_sharder, + server_safe_log_context_); remote_lookup_ = server_initializer->CreateAndStartRemoteLookupServer(); { - auto status_or_notifier = - BlobStorageChangeNotifier::Create(std::move(metadata)); + auto status_or_notifier = BlobStorageChangeNotifier::Create( + std::move(metadata), server_safe_log_context_); if (!status_or_notifier.ok()) { // The ChangeNotifier is required to read delta files, if it's not // available that's a critical error and so return immediately. @@ -436,8 +443,8 @@ absl::Status Server::InitOnceInstancesAreCreated() { } auto realtime_notifier_metadata = parameter_fetcher.GetRealtimeNotifierMetadata(num_shards_, shard_num_); - auto realtime_message_service_status = - MessageService::Create(realtime_notifier_metadata); + auto realtime_message_service_status = MessageService::Create( + realtime_notifier_metadata, server_safe_log_context_); if (!realtime_message_service_status.ok()) { return realtime_message_service_status.status(); } @@ -445,10 +452,12 @@ absl::Status Server::InitOnceInstancesAreCreated() { SetQueueManager(realtime_notifier_metadata, message_service_realtime_.get()); uint32_t realtime_thread_numbers = parameter_fetcher.GetInt32Parameter( kRealtimeUpdaterThreadNumberParameterSuffix); - LOG(INFO) << "Retrieved " << kRealtimeUpdaterThreadNumberParameterSuffix - << " parameter: " << realtime_thread_numbers; + PS_LOG(INFO, server_safe_log_context_) + << "Retrieved " << kRealtimeUpdaterThreadNumberParameterSuffix + << " parameter: " << realtime_thread_numbers; auto maybe_realtime_thread_pool_manager = RealtimeThreadPoolManager::Create( - realtime_notifier_metadata, realtime_thread_numbers); + realtime_notifier_metadata, realtime_thread_numbers, {}, + server_safe_log_context_); if (!maybe_realtime_thread_pool_manager.ok()) { return maybe_realtime_thread_pool_manager.status(); } @@ -457,7 +466,8 @@ absl::Status Server::InitOnceInstancesAreCreated() { data_orchestrator_ = CreateDataOrchestrator(parameter_fetcher, key_sharder); TraceRetryUntilOk([this] { return data_orchestrator_->Start(); }, "StartDataOrchestrator", - LogStatusSafeMetricsFn()); + LogStatusSafeMetricsFn(), + server_safe_log_context_); if (num_shards_ > 1) { // At this point the server is healthy and the initialization is over. // The only missing piece is having a shard map, which is dependent on @@ -466,7 +476,8 @@ absl::Status Server::InitOnceInstancesAreCreated() { lifecycle_heartbeat->Finish(); } auto maybe_shard_state = server_initializer->InitializeUdfHooks( - *string_get_values_hook_, *binary_get_values_hook_, *run_query_hook_); + *string_get_values_hook_, *binary_get_values_hook_, + *run_set_query_string_hook_, *run_set_query_int_hook_); if (!maybe_shard_state.ok()) { return maybe_shard_state.status(); } @@ -495,7 +506,8 @@ absl::Status Server::MaybeShutdownNotifiers() { } void Server::GracefulShutdown(absl::Duration timeout) { - LOG(INFO) << "Graceful gRPC server shutdown requested, timeout: " << timeout; + PS_LOG(INFO, server_safe_log_context_) + << "Graceful gRPC server shutdown requested, timeout: " << timeout; if (internal_lookup_server_) { internal_lookup_server_->Shutdown(); } @@ -505,12 +517,14 @@ void Server::GracefulShutdown(absl::Duration timeout) { if (grpc_server_) { grpc_server_->Shutdown(absl::ToChronoTime(absl::Now() + timeout)); } else { - LOG(WARNING) << "Server was not started, cannot shut down."; + PS_LOG(WARNING, server_safe_log_context_) + << "Server was not started, cannot shut down."; } if (udf_client_) { const absl::Status status = udf_client_->Stop(); if (!status.ok()) { - LOG(ERROR) << "Failed to stop UDF client: " << status; + PS_LOG(ERROR, server_safe_log_context_) + << "Failed to stop UDF client: " << status; } } if (shard_manager_state_.cluster_mappings_manager && @@ -518,17 +532,20 @@ void Server::GracefulShutdown(absl::Duration timeout) { const absl::Status status = shard_manager_state_.cluster_mappings_manager->Stop(); if (!status.ok()) { - LOG(ERROR) << "Failed to stop cluster mappings manager: " << status; + PS_LOG(ERROR, server_safe_log_context_) + << "Failed to stop cluster mappings manager: " << status; } } const absl::Status status = MaybeShutdownNotifiers(); if (!status.ok()) { - LOG(ERROR) << "Failed to shutdown notifiers. Got status " << status; + PS_LOG(ERROR, server_safe_log_context_) + << "Failed to shutdown notifiers. Got status " << status; } } void Server::ForceShutdown() { - LOG(WARNING) << "Immediate gRPC server shutdown requested"; + PS_LOG(WARNING, server_safe_log_context_) + << "Immediate gRPC server shutdown requested"; if (internal_lookup_server_) { internal_lookup_server_->Shutdown(); } @@ -538,16 +555,19 @@ void Server::ForceShutdown() { if (grpc_server_) { grpc_server_->Shutdown(); } else { - LOG(WARNING) << "Server was not started, cannot shut down."; + PS_LOG(WARNING, server_safe_log_context_) + << "Server was not started, cannot shut down."; } const absl::Status status = MaybeShutdownNotifiers(); if (!status.ok()) { - LOG(ERROR) << "Failed to shutdown notifiers. Got status " << status; + PS_LOG(ERROR, server_safe_log_context_) + << "Failed to shutdown notifiers. Got status " << status; } if (udf_client_) { const absl::Status status = udf_client_->Stop(); if (!status.ok()) { - LOG(ERROR) << "Failed to stop UDF client: " << status; + PS_LOG(ERROR, server_safe_log_context_) + << "Failed to stop UDF client: " << status; } } if (shard_manager_state_.cluster_mappings_manager && @@ -555,7 +575,8 @@ void Server::ForceShutdown() { const absl::Status status = shard_manager_state_.cluster_mappings_manager->Stop(); if (!status.ok()) { - LOG(ERROR) << "Failed to stop cluster mappings manager: " << status; + PS_LOG(ERROR, server_safe_log_context_) + << "Failed to stop cluster mappings manager: " << status; } } } @@ -567,7 +588,7 @@ std::unique_ptr Server::CreateBlobClient( std::unique_ptr blob_storage_client_factory = BlobStorageClientFactory::Create(); return blob_storage_client_factory->CreateBlobStorageClient( - std::move(client_options)); + std::move(client_options), server_safe_log_context_); } std::unique_ptr @@ -582,11 +603,13 @@ Server::CreateStreamRecordReaderFactory( if (file_format == kFileFormats[static_cast(FileFormat::kAvro)]) { AvroConcurrentStreamRecordReader::Options options; options.num_worker_threads = data_loading_num_threads; + options.log_context = server_safe_log_context_; return std::make_unique(options); } else if (file_format == kFileFormats[static_cast(FileFormat::kRiegeli)]) { ConcurrentStreamRecordReader::Options options; options.num_worker_threads = data_loading_num_threads; + options.log_context = server_safe_log_context_; return std::make_unique(options); } } @@ -595,8 +618,9 @@ std::unique_ptr Server::CreateDataOrchestrator( const ParameterFetcher& parameter_fetcher, KeySharder key_sharder) { const std::string data_bucket = parameter_fetcher.GetParameter(kDataBucketParameterSuffix); - LOG(INFO) << "Retrieved " << kDataBucketParameterSuffix - << " parameter: " << data_bucket; + PS_LOG(INFO, server_safe_log_context_) + << "Retrieved " << kDataBucketParameterSuffix + << " parameter: " << data_bucket; auto metrics_callback = LogStatusSafeMetricsFn(); return TraceRetryUntilOk( @@ -613,17 +637,20 @@ std::unique_ptr Server::CreateDataOrchestrator( .shard_num = shard_num_, .num_shards = num_shards_, .key_sharder = std::move(key_sharder), - .blob_prefix_allowlist = GetBlobPrefixAllowlist(parameter_fetcher), + .blob_prefix_allowlist = GetBlobPrefixAllowlist( + parameter_fetcher, server_safe_log_context_), + .log_context = server_safe_log_context_, }); }, - "CreateDataOrchestrator", metrics_callback); + "CreateDataOrchestrator", metrics_callback, server_safe_log_context_); } void Server::CreateGrpcServices(const ParameterFetcher& parameter_fetcher) { const bool use_v2 = parameter_fetcher.GetBoolParameter(kRouteV1ToV2Suffix); const bool add_missing_keys_v1 = parameter_fetcher.GetBoolParameter(kAddMissingKeysV1Suffix); - LOG(INFO) << "Retrieved " << kRouteV1ToV2Suffix << " parameter: " << use_v2; + PS_LOG(INFO, server_safe_log_context_) + << "Retrieved " << kRouteV1ToV2Suffix << " parameter: " << use_v2; get_values_adapter_ = GetValuesAdapter::Create(std::make_unique( *udf_client_, *key_fetcher_manager_)); @@ -657,7 +684,8 @@ std::unique_ptr Server::CreateAndStartGrpcServer() { builder.RegisterService(service.get()); } // Finally assemble the server. - LOG(INFO) << "Server listening on " << server_address << std::endl; + PS_LOG(INFO, server_safe_log_context_) + << "Server listening on " << server_address << std::endl; auto server = builder.BuildAndStart(); server->GetHealthCheckService()->SetServingStatus( std::string(kAutoscalerHealthcheck), true); @@ -671,7 +699,8 @@ absl::Status Server::SetDefaultUdfCodeObject() { CodeConfig{.js = kDefaultUdfCodeSnippet, .udf_handler_name = kDefaultUdfHandlerName, .logical_commit_time = kDefaultLogicalCommitTime, - .version = kDefaultVersion}); + .version = kDefaultVersion}, + server_safe_log_context_); return status; } @@ -679,12 +708,14 @@ std::unique_ptr Server::CreateDeltaFileNotifier( const ParameterFetcher& parameter_fetcher) { uint32_t backup_poll_frequency_secs = parameter_fetcher.GetInt32Parameter( kBackupPollFrequencySecsParameterSuffix); - LOG(INFO) << "Retrieved " << kBackupPollFrequencySecsParameterSuffix - << " parameter: " << backup_poll_frequency_secs; - - return DeltaFileNotifier::Create(*blob_client_, - absl::Seconds(backup_poll_frequency_secs), - GetBlobPrefixAllowlist(parameter_fetcher)); + PS_LOG(INFO, server_safe_log_context_) + << "Retrieved " << kBackupPollFrequencySecsParameterSuffix + << " parameter: " << backup_poll_frequency_secs; + + return DeltaFileNotifier::Create( + *blob_client_, absl::Seconds(backup_poll_frequency_secs), + GetBlobPrefixAllowlist(parameter_fetcher, server_safe_log_context_), + server_safe_log_context_); } } // namespace kv_server diff --git a/components/data_server/server/server.h b/components/data_server/server/server.h index 494a2dd4..68e1dda5 100644 --- a/components/data_server/server/server.h +++ b/components/data_server/server/server.h @@ -42,6 +42,7 @@ #include "components/udf/hooks/run_query_hook.h" #include "components/udf/udf_client.h" #include "components/util/platform_initializer.h" +#include "components/util/safe_path_log_context.h" #include "grpcpp/grpcpp.h" #include "public/base_types.pb.h" #include "public/query/get_values.grpc.pb.h" @@ -56,10 +57,9 @@ class Server { // Arguments that are nullptr will be created, they may be passed in for // unit testing purposes. - absl::Status Init( - std::unique_ptr parameter_client = nullptr, - std::unique_ptr instance_client = nullptr, - std::unique_ptr udf_client = nullptr); + absl::Status Init(std::unique_ptr parameter_client = nullptr, + std::unique_ptr instance_client = nullptr, + std::unique_ptr udf_client = nullptr); // Wait for the server to shut down. Note that some other thread must be // responsible for shutting down the server for this call to ever return. @@ -73,7 +73,7 @@ class Server { private: // If objects were not passed in for unit testing purposes then create them. absl::Status CreateDefaultInstancesIfNecessaryAndGetEnvironment( - std::unique_ptr parameter_client, + std::unique_ptr parameter_client, std::unique_ptr instance_client, std::unique_ptr udf_client); @@ -103,14 +103,14 @@ class Server { void InitializeTelemetry(const ParameterClient& parameter_client, InstanceClient& instance_client); absl::Status CreateShardManager(); - void InitOtelLogger(::opentelemetry::sdk::resource::Resource server_info, - absl::optional collector_endpoint, - const ParameterFetcher& parameter_fetcher); + void InitLogger(::opentelemetry::sdk::resource::Resource server_info, + absl::optional collector_endpoint, + const ParameterFetcher& parameter_fetcher); // This must be first, otherwise the AWS SDK will crash when it's called: PlatformInitializer platform_initializer_; - std::unique_ptr parameter_client_; + std::unique_ptr parameter_client_; std::unique_ptr instance_client_; std::string environment_; std::vector> grpc_services_; @@ -119,7 +119,8 @@ class Server { std::unique_ptr get_values_adapter_; std::unique_ptr string_get_values_hook_; std::unique_ptr binary_get_values_hook_; - std::unique_ptr run_query_hook_; + std::unique_ptr run_set_query_int_hook_; + std::unique_ptr run_set_query_string_hook_; // BlobStorageClient must outlive DeltaFileNotifier std::unique_ptr blob_client_; @@ -154,8 +155,8 @@ class Server { std::unique_ptr key_fetcher_manager_; - std::unique_ptr log_provider_; std::unique_ptr open_telemetry_sink_; + KVServerSafeLogContext server_safe_log_context_; }; } // namespace kv_server diff --git a/components/data_server/server/server_initializer.cc b/components/data_server/server/server_initializer.cc index 4499017f..ce8ae4a0 100644 --- a/components/data_server/server/server_initializer.cc +++ b/components/data_server/server/server_initializer.cc @@ -32,19 +32,27 @@ using privacy_sandbox::server_common::KeyFetcherManagerInterface; absl::Status InitializeUdfHooksInternal( std::function()> get_lookup, GetValuesHook& string_get_values_hook, - GetValuesHook& binary_get_values_hook, RunQueryHook& run_query_hook) { - VLOG(9) << "Finishing getValues init"; + GetValuesHook& binary_get_values_hook, + RunSetQueryStringHook& run_query_hook, + RunSetQueryIntHook& run_set_query_int_hook, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + PS_VLOG(9, log_context) << "Finishing getValues init"; string_get_values_hook.FinishInit(get_lookup()); - VLOG(9) << "Finishing getValuesBinary init"; + PS_VLOG(9, log_context) << "Finishing getValuesBinary init"; binary_get_values_hook.FinishInit(get_lookup()); - VLOG(9) << "Finishing runQuery init"; + PS_VLOG(9, log_context) << "Finishing runQuery init"; run_query_hook.FinishInit(get_lookup()); + PS_VLOG(9, log_context) << "Finishing runSetQueryInt init"; + run_set_query_int_hook.FinishInit(get_lookup()); return absl::OkStatus(); } class NonshardedServerInitializer : public ServerInitializer { public: - explicit NonshardedServerInitializer(Cache& cache) : cache_(cache) {} + explicit NonshardedServerInitializer( + Cache& cache, + privacy_sandbox::server_common::log::PSLogContext& log_context) + : cache_(cache), log_context_(log_context) {} RemoteLookup CreateAndStartRemoteLookupServer() override { RemoteLookup remote_lookup; @@ -54,19 +62,22 @@ class NonshardedServerInitializer : public ServerInitializer { absl::StatusOr InitializeUdfHooks( GetValuesHook& string_get_values_hook, GetValuesHook& binary_get_values_hook, - RunQueryHook& run_query_hook) override { + RunSetQueryStringHook& run_query_hook, + RunSetQueryIntHook& run_set_query_int_hook) override { ShardManagerState shard_manager_state; auto lookup_supplier = [&cache = cache_]() { return CreateLocalLookup(cache); }; InitializeUdfHooksInternal(std::move(lookup_supplier), string_get_values_hook, binary_get_values_hook, - run_query_hook); + run_query_hook, run_set_query_int_hook, + log_context_); return shard_manager_state; } private: Cache& cache_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; class ShardedServerInitializer : public ServerInitializer { @@ -75,7 +86,8 @@ class ShardedServerInitializer : public ServerInitializer { KeyFetcherManagerInterface& key_fetcher_manager, Lookup& local_lookup, std::string environment, int32_t num_shards, int32_t current_shard_num, InstanceClient& instance_client, ParameterFetcher& parameter_fetcher, - KeySharder key_sharder) + KeySharder key_sharder, + privacy_sandbox::server_common::log::PSLogContext& log_context) : key_fetcher_manager_(key_fetcher_manager), local_lookup_(local_lookup), environment_(environment), @@ -83,7 +95,8 @@ class ShardedServerInitializer : public ServerInitializer { current_shard_num_(current_shard_num), instance_client_(instance_client), parameter_fetcher_(parameter_fetcher), - key_sharder_(std::move(key_sharder)) {} + key_sharder_(std::move(key_sharder)), + log_context_(log_context) {} RemoteLookup CreateAndStartRemoteLookupServer() override { RemoteLookup remote_lookup; @@ -96,17 +109,18 @@ class ShardedServerInitializer : public ServerInitializer { remoteLookupServerAddress, grpc::InsecureServerCredentials()); remote_lookup_server_builder.RegisterService( remote_lookup.remote_lookup_service.get()); - LOG(INFO) << "Remote lookup server listening on " - << remoteLookupServerAddress; + PS_LOG(INFO, log_context_) + << "Remote lookup server listening on " << remoteLookupServerAddress; remote_lookup.remote_lookup_server = remote_lookup_server_builder.BuildAndStart(); - return std::move(remote_lookup); + return remote_lookup; } absl::StatusOr InitializeUdfHooks( GetValuesHook& string_get_values_hook, GetValuesHook& binary_get_values_hook, - RunQueryHook& run_query_hook) override { + RunSetQueryStringHook& run_set_query_string_hook, + RunSetQueryIntHook& run_set_query_int_hook) override { auto maybe_shard_state = CreateShardManager(); if (!maybe_shard_state.ok()) { return maybe_shard_state.status(); @@ -121,22 +135,24 @@ class ShardedServerInitializer : public ServerInitializer { }; InitializeUdfHooksInternal(std::move(lookup_supplier), string_get_values_hook, binary_get_values_hook, - run_query_hook); + run_set_query_string_hook, + run_set_query_int_hook, log_context_); return std::move(*maybe_shard_state); } private: absl::StatusOr CreateShardManager() { ShardManagerState shard_manager_state; - VLOG(10) << "Creating shard manager"; + PS_VLOG(10, log_context_) << "Creating shard manager"; shard_manager_state.cluster_mappings_manager = ClusterMappingsManager::Create(environment_, num_shards_, - instance_client_, parameter_fetcher_); + instance_client_, parameter_fetcher_, + log_context_); shard_manager_state.shard_manager = TraceRetryUntilOk( [&cluster_mappings_manager = *shard_manager_state.cluster_mappings_manager, - &num_shards = num_shards_, - &key_fetcher_manager = key_fetcher_manager_] { + &num_shards = num_shards_, &key_fetcher_manager = key_fetcher_manager_, + &log_context = log_context_] { // It might be that the cluster mappings that are passed don't pass // validation. E.g. a particular cluster might not have any // replicas @@ -145,9 +161,10 @@ class ShardedServerInitializer : public ServerInitializer { // at that point in time might have new replicas spun up. return ShardManager::Create( num_shards, key_fetcher_manager, - cluster_mappings_manager.GetClusterMappings()); + cluster_mappings_manager.GetClusterMappings(), log_context); }, - "GetShardManager", LogStatusSafeMetricsFn()); + "GetShardManager", LogStatusSafeMetricsFn(), + log_context_); auto start_status = shard_manager_state.cluster_mappings_manager->Start( *shard_manager_state.shard_manager); if (!start_status.ok()) { @@ -163,6 +180,7 @@ class ShardedServerInitializer : public ServerInitializer { InstanceClient& instance_client_; ParameterFetcher& parameter_fetcher_; KeySharder key_sharder_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; } // namespace @@ -171,15 +189,16 @@ std::unique_ptr GetServerInitializer( int64_t num_shards, KeyFetcherManagerInterface& key_fetcher_manager, Lookup& local_lookup, std::string environment, int32_t current_shard_num, InstanceClient& instance_client, Cache& cache, - ParameterFetcher& parameter_fetcher, KeySharder key_sharder) { + ParameterFetcher& parameter_fetcher, KeySharder key_sharder, + privacy_sandbox::server_common::log::PSLogContext& log_context) { CHECK_GT(num_shards, 0) << "num_shards must be greater than 0"; if (num_shards == 1) { - return std::make_unique(cache); + return std::make_unique(cache, log_context); } return std::make_unique( key_fetcher_manager, local_lookup, environment, num_shards, current_shard_num, instance_client, parameter_fetcher, - std::move(key_sharder)); + std::move(key_sharder), log_context); } } // namespace kv_server diff --git a/components/data_server/server/server_initializer.h b/components/data_server/server/server_initializer.h index c490c377..5e30bfa7 100644 --- a/components/data_server/server/server_initializer.h +++ b/components/data_server/server/server_initializer.h @@ -54,7 +54,9 @@ class ServerInitializer { virtual RemoteLookup CreateAndStartRemoteLookupServer() = 0; virtual absl::StatusOr InitializeUdfHooks( GetValuesHook& string_get_values_hook, - GetValuesHook& binary_get_values_hook, RunQueryHook& run_query_hook) = 0; + GetValuesHook& binary_get_values_hook, + RunSetQueryStringHook& run_set_query_string_hook, + RunSetQueryIntHook& run_set_query_int_hook) = 0; }; std::unique_ptr GetServerInitializer( @@ -63,7 +65,10 @@ std::unique_ptr GetServerInitializer( key_fetcher_manager, Lookup& local_lookup, std::string environment, int32_t current_shard_num, InstanceClient& instance_client, Cache& cache, - ParameterFetcher& parameter_fetcher, KeySharder key_sharder); + ParameterFetcher& parameter_fetcher, KeySharder key_sharder, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); } // namespace kv_server #endif // COMPONENTS_DATA_SERVER_SERVER_INITIALIZER_H_ diff --git a/components/data_server/server/server_local_test.cc b/components/data_server/server/server_local_test.cc index d6ae6cd8..9fa33f5b 100644 --- a/components/data_server/server/server_local_test.cc +++ b/components/data_server/server/server_local_test.cc @@ -21,6 +21,9 @@ #include "components/udf/mocks.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "opentelemetry/exporters/ostream/log_record_exporter.h" +#include "opentelemetry/sdk/logs/logger_provider_factory.h" +#include "opentelemetry/sdk/logs/simple_log_record_processor_factory.h" #include "opentelemetry/sdk/resource/resource.h" namespace kv_server { @@ -31,9 +34,6 @@ using privacy_sandbox::server_common::ConfigureMetrics; using testing::_; void RegisterRequiredTelemetryExpectations(MockParameterClient& client) { - EXPECT_CALL(client, GetBoolParameter("kv-server-environment-use-external-" - "metrics-collector-endpoint")) - .WillOnce(::testing::Return(false)); EXPECT_CALL( client, GetInt32Parameter("kv-server-environment-metrics-export-interval-millis")) @@ -56,12 +56,15 @@ void RegisterRequiredTelemetryExpectations(MockParameterClient& client) { EXPECT_CALL(client, GetBoolParameter("kv-server-environment-enable-otel-logger")) .WillOnce(::testing::Return(false)); + EXPECT_CALL(client, + GetBoolParameter("kv-server-environment-enable-consented-log")) + .WillOnce(::testing::Return(false)); EXPECT_CALL(client, GetParameter("kv-server-environment-telemetry-config", testing::Eq(std::nullopt))) .WillOnce(::testing::Return("mode: EXPERIMENT")); } -void InitializeMetrics() { +void InitializeTelemetry() { opentelemetry::sdk::metrics::PeriodicExportingMetricReaderOptions metrics_options; // The defaults for these values are 30 and 60s and we don't want to wait that @@ -69,24 +72,36 @@ void InitializeMetrics() { metrics_options.export_interval_millis = std::chrono::milliseconds(200); metrics_options.export_timeout_millis = std::chrono::milliseconds(100); ConfigureMetrics(Resource::Create({}), metrics_options); + static auto* logger_provider = + opentelemetry::sdk::logs::LoggerProviderFactory::Create( + opentelemetry::sdk::logs::SimpleLogRecordProcessorFactory::Create( + std::make_unique< + opentelemetry::exporter::logs::OStreamLogRecordExporter>( + std::cout))) + .release(); + privacy_sandbox::server_common::log::logger_private = + logger_provider->GetLogger("test").get(); } -TEST(ServerLocalTest, WaitWithoutStart) { - InitializeMetrics(); +class ServerLocalTest : public ::testing::Test { + protected: + void SetUp() override { InitializeTelemetry(); } +}; + +TEST_F(ServerLocalTest, WaitWithoutStart) { kv_server::Server server; // This should be a no-op as the server was never started: server.Wait(); } -TEST(ServerLocalTest, ShutdownWithoutStart) { - InitializeMetrics(); +TEST_F(ServerLocalTest, ShutdownWithoutStart) { kv_server::Server server; // These should be a no-op as the server was never started: server.GracefulShutdown(absl::Seconds(1)); server.ForceShutdown(); } -TEST(ServerLocalTest, InitFailsWithNoDeltaDirectory) { +TEST_F(ServerLocalTest, InitFailsWithNoDeltaDirectory) { auto instance_client = std::make_unique(); auto parameter_client = std::make_unique(); RegisterRequiredTelemetryExpectations(*parameter_client); @@ -115,6 +130,10 @@ TEST(ServerLocalTest, InitFailsWithNoDeltaDirectory) { EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-udf-timeout-millis")) .WillOnce(::testing::Return(5000)); + EXPECT_CALL( + *parameter_client, + GetInt32Parameter("kv-server-environment-udf-update-timeout-millis")) + .WillOnce(::testing::Return(5000)); EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-udf-min-log-level")) .WillOnce(::testing::Return(0)); @@ -143,7 +162,7 @@ TEST(ServerLocalTest, InitFailsWithNoDeltaDirectory) { EXPECT_FALSE(status.ok()); } -TEST(ServerLocalTest, InitPassesWithDeltaDirectoryAndRealtimeDirectory) { +TEST_F(ServerLocalTest, InitPassesWithDeltaDirectoryAndRealtimeDirectory) { auto instance_client = std::make_unique(); auto parameter_client = std::make_unique(); RegisterRequiredTelemetryExpectations(*parameter_client); @@ -184,6 +203,10 @@ TEST(ServerLocalTest, InitPassesWithDeltaDirectoryAndRealtimeDirectory) { EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-udf-timeout-millis")) .WillOnce(::testing::Return(5000)); + EXPECT_CALL( + *parameter_client, + GetInt32Parameter("kv-server-environment-udf-update-timeout-millis")) + .WillOnce(::testing::Return(5000)); EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-udf-min-log-level")) .WillOnce(::testing::Return(0)); @@ -200,7 +223,7 @@ TEST(ServerLocalTest, InitPassesWithDeltaDirectoryAndRealtimeDirectory) { EXPECT_CALL(*parameter_client, GetBoolParameter("kv-server-environment-use-sharding-key-regex")) .WillOnce(::testing::Return(false)); - EXPECT_CALL(*mock_udf_client, SetCodeObject(_)) + EXPECT_CALL(*mock_udf_client, SetCodeObject(_, _)) .WillOnce(testing::Return(absl::OkStatus())); EXPECT_CALL( *parameter_client, @@ -215,7 +238,7 @@ TEST(ServerLocalTest, InitPassesWithDeltaDirectoryAndRealtimeDirectory) { EXPECT_TRUE(status.ok()); } -TEST(ServerLocalTest, GracefulServerShutdown) { +TEST_F(ServerLocalTest, GracefulServerShutdown) { auto instance_client = std::make_unique(); auto parameter_client = std::make_unique(); RegisterRequiredTelemetryExpectations(*parameter_client); @@ -259,6 +282,10 @@ TEST(ServerLocalTest, GracefulServerShutdown) { EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-udf-timeout-millis")) .WillOnce(::testing::Return(5000)); + EXPECT_CALL( + *parameter_client, + GetInt32Parameter("kv-server-environment-udf-update-timeout-millis")) + .WillOnce(::testing::Return(5000)); EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-udf-min-log-level")) .WillOnce(::testing::Return(0)); @@ -271,7 +298,7 @@ TEST(ServerLocalTest, GracefulServerShutdown) { EXPECT_CALL(*parameter_client, GetBoolParameter("kv-server-environment-use-sharding-key-regex")) .WillOnce(::testing::Return(false)); - EXPECT_CALL(*mock_udf_client, SetCodeObject(_)) + EXPECT_CALL(*mock_udf_client, SetCodeObject(_, _)) .WillOnce(testing::Return(absl::OkStatus())); EXPECT_CALL( *parameter_client, @@ -289,7 +316,7 @@ TEST(ServerLocalTest, GracefulServerShutdown) { server_thread.join(); } -TEST(ServerLocalTest, ForceServerShutdown) { +TEST_F(ServerLocalTest, ForceServerShutdown) { auto instance_client = std::make_unique(); auto parameter_client = std::make_unique(); RegisterRequiredTelemetryExpectations(*parameter_client); @@ -330,6 +357,10 @@ TEST(ServerLocalTest, ForceServerShutdown) { EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-udf-timeout-millis")) .WillOnce(::testing::Return(5000)); + EXPECT_CALL( + *parameter_client, + GetInt32Parameter("kv-server-environment-udf-update-timeout-millis")) + .WillOnce(::testing::Return(5000)); EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-udf-min-log-level")) .WillOnce(::testing::Return(0)); @@ -346,7 +377,7 @@ TEST(ServerLocalTest, ForceServerShutdown) { EXPECT_CALL(*parameter_client, GetBoolParameter("kv-server-environment-use-sharding-key-regex")) .WillOnce(::testing::Return(false)); - EXPECT_CALL(*mock_udf_client, SetCodeObject(_)) + EXPECT_CALL(*mock_udf_client, SetCodeObject(_, _)) .WillOnce(testing::Return(absl::OkStatus())); EXPECT_CALL( *parameter_client, diff --git a/components/data_server/server/server_log_init.h b/components/data_server/server/server_log_init.h new file mode 100644 index 00000000..dd27ea85 --- /dev/null +++ b/components/data_server/server/server_log_init.h @@ -0,0 +1,33 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_SERVER_SERVER_SERVER_LOG_INIT_H_H_ +#define COMPONENTS_DATA_SERVER_SERVER_SERVER_LOG_INIT_H_H_ + +#include +#include + +#include "components/cloud_config/parameter_client.h" + +namespace kv_server { + +void InitLog(); +absl::optional GetMetricsCollectorEndPoint( + const ParameterClient& parameter_client, const std::string& environment); + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_SERVER_SERVER_LOG_INIT_H_H_ diff --git a/components/envoy_proxy/envoy.yaml b/components/envoy_proxy/envoy.yaml index 1f38b634..8b41d3c8 100644 --- a/components/envoy_proxy/envoy.yaml +++ b/components/envoy_proxy/envoy.yaml @@ -46,6 +46,10 @@ static_resources: retry_policy: retry_on: 5xx num_retries: 10 + request_headers_to_add: + - header: + key: "kv-content-type" + value: "%DYNAMIC_METADATA(tkv:content-type)%" response_headers_to_add: - header: key: 'Ad-Auction-Allowed' @@ -56,6 +60,23 @@ static_resources: value: '2' append: false http_filters: + # Pass the content-type to the grpc server. By default the header is overwritten to "application/grpc" + # This extra filter must be used because the content-type must be copied to a metadata prior to grpc transcoding + # By the time the Router filter is invoked, the original content-type is already changed. + - name: envoy.filters.http.header_to_metadata + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.http.header_to_metadata.v3.Config + request_rules: + - header: "content-type" + on_header_present: + metadata_namespace: tkv + key: content-type + type: STRING + on_header_missing: + metadata_namespace: tkv + key: content-type + value: 'unknown' + type: STRING - name: envoy.filters.http.grpc_stats typed_config: "@type": type.googleapis.com/envoy.extensions.filters.http.grpc_stats.v3.FilterConfig @@ -69,8 +90,6 @@ static_resources: - kv_server.v1.KeyValueService - kv_server.v2.KeyValueService - grpc.health.v1.Health - ignored_query_parameters: - - "interestGroupNames" print_options: add_whitespace: true always_print_primitive_fields: true diff --git a/components/errors/BUILD.bazel b/components/errors/BUILD.bazel index 29aaccd8..a44be19a 100644 --- a/components/errors/BUILD.bazel +++ b/components/errors/BUILD.bazel @@ -90,6 +90,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", "@google_privacysandbox_servers_common//src/telemetry:tracing", ], ) diff --git a/components/errors/retry.cc b/components/errors/retry.cc index 1c92fc3a..64eec247 100644 --- a/components/errors/retry.cc +++ b/components/errors/retry.cc @@ -34,16 +34,18 @@ absl::Duration ExponentialBackoffForRetry(uint32_t retries) { return std::min(backoff, kMaxRetryInterval); } -void TraceRetryUntilOk(std::function func, - std::string task_name, - const absl::AnyInvocable& metrics_callback) { +void TraceRetryUntilOk( + std::function func, std::string task_name, + const absl::AnyInvocable& + metrics_callback, + privacy_sandbox::server_common::log::PSLogContext& log_context) { auto span = GetTracer()->StartSpan("RetryUntilOk - " + task_name); auto scope = opentelemetry::trace::Scope(span); auto wrapped = [func = std::move(func), task_name]() { return TraceWithStatus(std::move(func), task_name); }; - RetryUntilOk(std::move(wrapped), std::move(task_name), metrics_callback); + RetryUntilOk(std::move(wrapped), std::move(task_name), metrics_callback, + log_context); } } // namespace kv_server diff --git a/components/errors/retry.h b/components/errors/retry.h index 8f033af0..732da84b 100644 --- a/components/errors/retry.h +++ b/components/errors/retry.h @@ -24,6 +24,7 @@ #include "absl/time/time.h" #include "components/telemetry/server_definition.h" #include "components/util/sleepfor.h" +#include "src/logger/request_context_logger.h" #include "src/telemetry/tracing.h" namespace kv_server { @@ -42,15 +43,18 @@ class RetryableWithMax { // If max_attempts <= 0, will retry until OK. // `metrics_callback` is optional. - RetryableWithMax(Func&& f, std::string task_name, int max_attempts, - const absl::AnyInvocable& metrics_callback, - const SleepFor& sleep_for) + RetryableWithMax( + Func&& f, std::string task_name, int max_attempts, + const absl::AnyInvocable& + metrics_callback, + const SleepFor& sleep_for, + privacy_sandbox::server_common::log::PSLogContext& log_context) : func_(std::forward(f)), task_name_(std::move(task_name)), max_attempts_(max_attempts <= 0 ? kUnlimitedRetry : max_attempts), metrics_callback_(metrics_callback), - sleep_for_(sleep_for) {} + sleep_for_(sleep_for), + log_context_(log_context) {} absl::Status ToStatus(absl::Status& result) { return result; } @@ -69,8 +73,9 @@ class RetryableWithMax { if (result.ok()) { return result; } else { - LOG(WARNING) << task_name_ << " failed with " << ToStatus(result) - << " for Attempt " << i; + PS_LOG(WARNING, log_context_) + << task_name_ << " failed with " << ToStatus(result) + << " for Attempt " << i; } const absl::Duration backoff = ExponentialBackoffForRetry(i); if (!sleep_for_.Duration(backoff)) { @@ -87,6 +92,7 @@ class RetryableWithMax { const absl::AnyInvocable& metrics_callback_; const SleepFor& sleep_for_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; // Retries functors that return an absl::StatusOr until they are `ok`. @@ -97,10 +103,13 @@ typename std::invoke_result_t>::value_type RetryUntilOk( Func&& f, std::string task_name, const absl::AnyInvocable& metrics_callback, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext), const UnstoppableSleepFor& sleep_for = UnstoppableSleepFor()) { return RetryableWithMax(std::forward(f), std::move(task_name), RetryableWithMax::kUnlimitedRetry, - metrics_callback, sleep_for)() + metrics_callback, sleep_for, log_context)() .value(); } // Same as above `RetryUntilOk`, wrapped in an `opentelemetry::trace::Span`. @@ -112,6 +121,9 @@ TraceRetryUntilOk( Func&& func, std::string task_name, const absl::AnyInvocable& metrics_callback, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext), std::vector attributes = {}) { auto span = privacy_sandbox::server_common::GetTracer()->StartSpan( @@ -122,7 +134,7 @@ TraceRetryUntilOk( return TraceWithStatusOr(std::move(func), task_name, std::move(attributes)); }; return RetryUntilOk(std::move(wrapped), std::move(task_name), - metrics_callback); + metrics_callback, log_context); } // Retries functors that return an absl::Status until they are `ok`. @@ -131,20 +143,26 @@ inline void RetryUntilOk( std::function func, std::string task_name, const absl::AnyInvocable& metrics_callback, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext), const UnstoppableSleepFor& sleep_for = UnstoppableSleepFor()) { RetryableWithMax(std::move(func), std::move(task_name), RetryableWithMax::kUnlimitedRetry, - metrics_callback, sleep_for)() + metrics_callback, sleep_for, log_context)() .IgnoreError(); } // Starts and `opentelemetry::trace::Span` and Calls `RetryUntilOk`. // Each individual retry of `func` is also traced. // `metrics_callback` is optional. -void TraceRetryUntilOk(std::function func, - std::string task_name, - const absl::AnyInvocable& metrics_callback); +void TraceRetryUntilOk( + std::function func, std::string task_name, + const absl::AnyInvocable& + metrics_callback, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); // Retries functors that return an absl::StatusOr until they are `ok` or // max_attempts is reached. Retry starts at 1. @@ -154,9 +172,13 @@ typename std::invoke_result_t> RetryWithMax( Func&& f, std::string task_name, int max_attempts, const absl::AnyInvocable& metrics_callback, - const SleepFor& sleep_for) { + const SleepFor& sleep_for, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)) { return RetryableWithMax(std::forward(f), std::move(task_name), - max_attempts, metrics_callback, sleep_for)(); + max_attempts, metrics_callback, sleep_for, + log_context)(); } } // namespace kv_server diff --git a/components/errors/retry_test.cc b/components/errors/retry_test.cc index d5349bdb..2d5992c6 100644 --- a/components/errors/retry_test.cc +++ b/components/errors/retry_test.cc @@ -29,7 +29,12 @@ namespace { using ::testing::Return; -TEST(RetryTest, RetryUntilOk) { +class RetryTest : public ::testing::Test { + protected: + privacy_sandbox::server_common::log::NoOpContext log_context_; +}; + +TEST_F(RetryTest, RetryUntilOk) { testing::MockFunction()> func; EXPECT_CALL(func, Call) .Times(2) @@ -43,13 +48,14 @@ TEST(RetryTest, RetryUntilOk) { status_count_metric_callback = [](const absl::Status&, int) { // no-op }; - absl::StatusOr v = RetryUntilOk(func.AsStdFunction(), "TestFunc", - status_count_metric_callback, sleep_for); + absl::StatusOr v = + RetryUntilOk(func.AsStdFunction(), "TestFunc", + status_count_metric_callback, log_context_, sleep_for); EXPECT_TRUE(v.ok()); EXPECT_EQ(v.value(), 1); } -TEST(RetryTest, RetryUnilOkStatusOnly) { +TEST_F(RetryTest, RetryUnilOkStatusOnly) { testing::MockFunction func; EXPECT_CALL(func, Call) .Times(2) @@ -64,10 +70,10 @@ TEST(RetryTest, RetryUnilOkStatusOnly) { // no-op }; RetryUntilOk(func.AsStdFunction(), "TestFunc", status_count_metric_callback, - sleep_for); + log_context_, sleep_for); } -TEST(RetryTest, RetryWithMaxFailsWhenExceedingMax) { +TEST_F(RetryTest, RetryWithMaxFailsWhenExceedingMax) { testing::MockFunction()> func; EXPECT_CALL(func, Call).Times(2).WillRepeatedly([] { return absl::InvalidArgumentError("whatever"); @@ -84,13 +90,14 @@ TEST(RetryTest, RetryWithMaxFailsWhenExceedingMax) { status_count_metric_callback = [](const absl::Status&, int) { // no-op }; - absl::StatusOr v = RetryWithMax(func.AsStdFunction(), "TestFunc", 2, - status_count_metric_callback, sleep_for); + absl::StatusOr v = + RetryWithMax(func.AsStdFunction(), "TestFunc", 2, + status_count_metric_callback, sleep_for, log_context_); EXPECT_FALSE(v.ok()); EXPECT_EQ(v.status(), absl::InvalidArgumentError("whatever")); } -TEST(RetryTest, RetryWithMaxSucceedsOnMax) { +TEST_F(RetryTest, RetryWithMaxSucceedsOnMax) { testing::MockFunction()> func; EXPECT_CALL(func, Call) .Times(2) @@ -105,13 +112,14 @@ TEST(RetryTest, RetryWithMaxSucceedsOnMax) { status_count_metric_callback = [](const absl::Status&, int) { // no-op }; - absl::StatusOr v = RetryWithMax(func.AsStdFunction(), "TestFunc", 2, - status_count_metric_callback, sleep_for); + absl::StatusOr v = + RetryWithMax(func.AsStdFunction(), "TestFunc", 2, + status_count_metric_callback, sleep_for, log_context_); EXPECT_TRUE(v.ok()); EXPECT_EQ(v.value(), 1); } -TEST(RetryTest, RetryWithMaxSucceedsEarly) { +TEST_F(RetryTest, RetryWithMaxSucceedsEarly) { testing::MockFunction()> func; EXPECT_CALL(func, Call) .Times(2) @@ -125,8 +133,9 @@ TEST(RetryTest, RetryWithMaxSucceedsEarly) { status_count_metric_callback = [](const absl::Status&, int) { // no-op }; - absl::StatusOr v = RetryWithMax(func.AsStdFunction(), "TestFunc", 300, - status_count_metric_callback, sleep_for); + absl::StatusOr v = + RetryWithMax(func.AsStdFunction(), "TestFunc", 300, + status_count_metric_callback, sleep_for, log_context_); EXPECT_TRUE(v.ok()); EXPECT_EQ(v.value(), 1); } diff --git a/components/internal_server/BUILD.bazel b/components/internal_server/BUILD.bazel index ae82b089..44eb3fd2 100644 --- a/components/internal_server/BUILD.bazel +++ b/components/internal_server/BUILD.bazel @@ -38,6 +38,7 @@ cc_library( "//components/query:driver", "//components/query:scanner", "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/functional:any_invocable", "@com_google_protobuf//:protobuf", "@google_privacysandbox_servers_common//src/telemetry", ], @@ -139,8 +140,11 @@ cc_library( ":internal_lookup_cc_proto", ":lookup", "//components/data_server/cache", + "//components/data_server/cache:uint32_value_set", + "//components/errors:error_tag", "//components/query:driver", "//components/query:scanner", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", ], @@ -156,6 +160,7 @@ cc_library( ":internal_lookup_cc_proto", ":local_lookup", ":remote_lookup_client_impl", + "//components/data_server/cache:uint32_value_set", "//components/query:driver", "//components/query:scanner", "//components/sharding:shard_manager", @@ -194,6 +199,11 @@ cc_library( hdrs = [ "remote_lookup_client.h", ], + copts = select({ + "//:aws_platform": ["-DCLOUD_PLATFORM_AWS=1"], + "//:gcp_platform": ["-DCLOUD_PLATFORM_GCP=1"], + "//conditions:default": [], + }), deps = [ ":constants", ":internal_lookup_cc_grpc", diff --git a/components/internal_server/local_lookup.cc b/components/internal_server/local_lookup.cc index 45522e49..b9e535a1 100644 --- a/components/internal_server/local_lookup.cc +++ b/components/internal_server/local_lookup.cc @@ -18,10 +18,11 @@ #include #include #include -#include -#include "absl/log/log.h" +#include "absl/functional/any_invocable.h" #include "components/data_server/cache/cache.h" +#include "components/data_server/cache/uint32_value_set.h" +#include "components/errors/error_tag.h" #include "components/internal_server/lookup.h" #include "components/internal_server/lookup.pb.h" #include "components/query/driver.h" @@ -30,6 +31,8 @@ namespace kv_server { namespace { +enum class ErrorTag : int { kProcessValueSetKeys = 1 }; + class LocalLookup : public Lookup { public: explicit LocalLookup(const Cache& cache) : cache_(cache) {} @@ -43,12 +46,54 @@ class LocalLookup : public Lookup { absl::StatusOr GetKeyValueSet( const RequestContext& request_context, const absl::flat_hash_set& key_set) const override { - return ProcessKeysetKeys(request_context, key_set); + return ProcessValueSetKeys(request_context, key_set, + SingleLookupResult::kKeysetValues); + } + + absl::StatusOr GetUInt32ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const override { + return ProcessValueSetKeys(request_context, key_set, + SingleLookupResult::kUintsetValues); } absl::StatusOr RunQuery( const RequestContext& request_context, std::string query) const override { - return ProcessQuery(request_context, query); + return ProcessQuery>>( + request_context, std::move(query), + [](const RequestContext& request_context, const Driver& driver, + const Cache& cache) { + auto get_key_value_set_result = cache.GetKeyValueSet( + request_context, driver.GetRootNode()->Keys()); + return driver.EvaluateQuery>( + [&get_key_value_set_result](std::string_view key) { + return get_key_value_set_result->GetValueSet(key); + }); + }); + } + + absl::StatusOr RunSetQueryInt( + const RequestContext& request_context, std::string query) const override { + return ProcessQuery>>( + request_context, std::move(query), + [](const RequestContext& request_context, const Driver& driver, + const Cache& cache) + -> absl::StatusOr> { + auto get_key_value_set_result = cache.GetUInt32ValueSet( + request_context, driver.GetRootNode()->Keys()); + auto query_eval_result = driver.EvaluateQuery( + [&get_key_value_set_result](std::string_view key) { + auto set = get_key_value_set_result->GetUInt32ValueSet(key); + return set == nullptr ? roaring::Roaring() + : set->GetValuesBitSet(); + }); + if (!query_eval_result.ok()) { + return query_eval_result.status(); + } + return BitSetToUint32Set(*query_eval_result); + }); } private: @@ -79,9 +124,10 @@ class LocalLookup : public Lookup { return response; } - absl::StatusOr ProcessKeysetKeys( + absl::StatusOr ProcessValueSetKeys( const RequestContext& request_context, - const absl::flat_hash_set& key_set) const { + const absl::flat_hash_set& key_set, + SingleLookupResult::SingleLookupResultCase set_type) const { ScopeLatencyMetricsRecorder latency_recorder(request_context.GetInternalLookupMetricsContext()); @@ -89,59 +135,84 @@ class LocalLookup : public Lookup { if (key_set.empty()) { return response; } - auto key_value_set_result = cache_.GetKeyValueSet(request_context, key_set); + std::unique_ptr key_value_set_result; + if (set_type == SingleLookupResult::kKeysetValues) { + key_value_set_result = cache_.GetKeyValueSet(request_context, key_set); + } else if (set_type == SingleLookupResult::kUintsetValues) { + key_value_set_result = cache_.GetUInt32ValueSet(request_context, key_set); + } else { + return StatusWithErrorTag(absl::InvalidArgumentError(absl::StrCat( + "Unsupported set type: ", set_type)), + __FILE__, ErrorTag::kProcessValueSetKeys); + } for (const auto& key : key_set) { SingleLookupResult result; - const auto value_set = key_value_set_result->GetValueSet(key); - if (value_set.empty()) { + bool is_empty_value_set = false; + if (set_type == SingleLookupResult::kKeysetValues) { + if (const auto value_set = key_value_set_result->GetValueSet(key); + !value_set.empty()) { + auto* keyset_values = result.mutable_keyset_values(); + keyset_values->mutable_values()->Reserve(value_set.size()); + keyset_values->mutable_values()->Add(value_set.begin(), + value_set.end()); + } else { + is_empty_value_set = true; + } + } + if (set_type == SingleLookupResult::kUintsetValues) { + if (const auto value_set = key_value_set_result->GetUInt32ValueSet(key); + value_set != nullptr && !value_set->GetValues().empty()) { + auto uint32_values = value_set->GetValues(); + auto* result_values = result.mutable_uintset_values(); + result_values->mutable_values()->Reserve(uint32_values.size()); + result_values->mutable_values()->Add(uint32_values.begin(), + uint32_values.end()); + } else { + is_empty_value_set = true; + } + } + if (is_empty_value_set) { auto status = result.mutable_status(); status->set_code(static_cast(absl::StatusCode::kNotFound)); status->set_message(absl::StrCat("Key not found: ", key)); - } else { - auto keyset_values = result.mutable_keyset_values(); - keyset_values->mutable_values()->Add(value_set.begin(), - value_set.end()); } (*response.mutable_kv_pairs())[key] = std::move(result); } return response; } - absl::StatusOr ProcessQuery( - const RequestContext& request_context, std::string query) const { + template + absl::StatusOr ProcessQuery( + const RequestContext& request_context, std::string query, + absl::AnyInvocable + query_eval_fn) const { ScopeLatencyMetricsRecorder latency_recorder(request_context.GetInternalLookupMetricsContext()); if (query.empty()) return absl::OkStatus(); - std::unique_ptr get_key_value_set_result; - kv_server::Driver driver([&get_key_value_set_result](std::string_view key) { - return get_key_value_set_result->GetValueSet(key); - }); - - std::istringstream stream(query); + kv_server::Driver driver; + std::istringstream stream(std::move(query)); kv_server::Scanner scanner(stream); kv_server::Parser parse(driver, scanner); - int parse_result = parse(); - if (parse_result) { + if (int parse_result = parse(); parse_result) { LogInternalLookupRequestErrorMetric( request_context.GetInternalLookupMetricsContext(), kLocalRunQueryParsingFailure); return absl::InvalidArgumentError("Parsing failure."); } - get_key_value_set_result = - cache_.GetKeyValueSet(request_context, driver.GetRootNode()->Keys()); - - auto result = driver.GetResult(); + auto result = query_eval_fn(request_context, driver, cache_); if (!result.ok()) { LogInternalLookupRequestErrorMetric( request_context.GetInternalLookupMetricsContext(), kLocalRunQueryFailure); return result.status(); } - InternalRunQueryResponse response; + ResponseType response; response.mutable_elements()->Assign(result->begin(), result->end()); return response; } + const Cache& cache_; }; diff --git a/components/internal_server/local_lookup_test.cc b/components/internal_server/local_lookup_test.cc index bddb646b..0f8c8e97 100644 --- a/components/internal_server/local_lookup_test.cc +++ b/components/internal_server/local_lookup_test.cc @@ -37,13 +37,13 @@ class LocalLookupTest : public ::testing::Test { protected: LocalLookupTest() { InitMetricsContextMap(); - scope_metrics_context_ = std::make_unique(); - request_context_ = - std::make_unique(*scope_metrics_context_); + request_context_ = std::make_unique(); + request_context_->UpdateLogContext( + privacy_sandbox::server_common::LogContext(), + privacy_sandbox::server_common::ConsentedDebugConfiguration()); } RequestContext& GetRequestContext() { return *request_context_; } - std::unique_ptr scope_metrics_context_; - std::unique_ptr request_context_; + std::shared_ptr request_context_; MockCache mock_cache_; }; @@ -143,6 +143,46 @@ TEST_F(LocalLookupTest, GetKeyValueSets_KeysFound_Success) { testing::UnorderedElementsAreArray(expected_resulting_set)); } +TEST_F(LocalLookupTest, GetUInt32ValueSets_KeysFound_Success) { + auto values = std::vector({1000, 1001}); + UInt32ValueSet value_set; + value_set.Add(absl::MakeSpan(values), 1); + auto mock_get_key_value_set_result = + std::make_unique(); + EXPECT_CALL(*mock_get_key_value_set_result, GetUInt32ValueSet("key1")) + .WillOnce(Return(&value_set)); + EXPECT_CALL(mock_cache_, GetUInt32ValueSet(_, _)) + .WillOnce(Return(std::move(mock_get_key_value_set_result))); + auto local_lookup = CreateLocalLookup(mock_cache_); + auto response = + local_lookup->GetUInt32ValueSet(GetRequestContext(), {"key1"}); + ASSERT_TRUE(response.ok()); + EXPECT_THAT(response.value().kv_pairs().at("key1").uintset_values().values(), + testing::UnorderedElementsAreArray(values)); +} + +TEST_F(LocalLookupTest, GetUInt32ValueSets_SetEmpty_Success) { + auto mock_get_key_value_set_result = + std::make_unique(); + EXPECT_CALL(*mock_get_key_value_set_result, GetUInt32ValueSet("key1")) + .WillOnce(Return(nullptr)); + EXPECT_CALL(mock_cache_, GetUInt32ValueSet(_, _)) + .WillOnce(Return(std::move(mock_get_key_value_set_result))); + auto local_lookup = CreateLocalLookup(mock_cache_); + auto response = + local_lookup->GetUInt32ValueSet(GetRequestContext(), {"key1"}); + ASSERT_TRUE(response.ok()); + InternalLookupResponse expected; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key1" + value { status { code: 5 message: "Key not found: key1" } } + } + )pb", + &expected); + EXPECT_THAT(response.value(), EqualsProto(expected)); +} + TEST_F(LocalLookupTest, GetKeyValueSets_SetEmpty_Success) { auto mock_get_key_value_set_result = std::make_unique(); @@ -206,6 +246,33 @@ TEST_F(LocalLookupTest, RunQuery_ParsingError_Error) { EXPECT_EQ(response.status().code(), absl::StatusCode::kInvalidArgument); } +TEST_F(LocalLookupTest, Verify_RunSetQueryInt_Success) { + std::string query = "A"; + UInt32ValueSet value_set; + auto values = std::vector({10, 20, 30, 40, 50}); + value_set.Add(absl::MakeSpan(values), 1); + auto mock_get_key_value_set_result = + std::make_unique(); + EXPECT_CALL(*mock_get_key_value_set_result, GetUInt32ValueSet("A")) + .WillOnce(Return(&value_set)); + EXPECT_CALL(mock_cache_, + GetUInt32ValueSet(_, absl::flat_hash_set{"A"})) + .WillOnce(Return(std::move(mock_get_key_value_set_result))); + auto local_lookup = CreateLocalLookup(mock_cache_); + auto response = local_lookup->RunSetQueryInt(GetRequestContext(), query); + ASSERT_TRUE(response.ok()) << response.status(); + EXPECT_THAT(response.value().elements(), + testing::UnorderedElementsAreArray(values.begin(), values.end())); +} + +TEST_F(LocalLookupTest, Verify_RunSetQueryInt_ParsingError_Error) { + std::string query = "someset|("; + auto local_lookup = CreateLocalLookup(mock_cache_); + auto response = local_lookup->RunSetQueryInt(GetRequestContext(), query); + EXPECT_FALSE(response.ok()); + EXPECT_EQ(response.status().code(), absl::StatusCode::kInvalidArgument); +} + } // namespace } // namespace kv_server diff --git a/components/internal_server/lookup.h b/components/internal_server/lookup.h index dd3356e7..363e1b7b 100644 --- a/components/internal_server/lookup.h +++ b/components/internal_server/lookup.h @@ -41,8 +41,15 @@ class Lookup { const RequestContext& request_context, const absl::flat_hash_set& key_set) const = 0; + virtual absl::StatusOr GetUInt32ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const = 0; + virtual absl::StatusOr RunQuery( const RequestContext& request_context, std::string query) const = 0; + + virtual absl::StatusOr RunSetQueryInt( + const RequestContext& request_context, std::string query) const = 0; }; } // namespace kv_server diff --git a/components/internal_server/lookup.proto b/components/internal_server/lookup.proto index 11d5c0ee..a31b26a5 100644 --- a/components/internal_server/lookup.proto +++ b/components/internal_server/lookup.proto @@ -31,6 +31,10 @@ service InternalLookupService { // Endpoint for running a query on the server's internal datastore. Should // only be used within TEEs. rpc InternalRunQuery(InternalRunQueryRequest) returns (InternalRunQueryResponse) {} + + // Endpoint for running a set query over unsigned int sets in the server's + // internal datastore. Should only be used within TEEs. + rpc InternalRunSetQueryInt(InternalRunSetQueryIntRequest) returns (InternalRunSetQueryIntResponse) {} } // Lookup request for internal datastore. @@ -85,6 +89,7 @@ message SingleLookupResult { string value = 1; google.rpc.Status status = 2; KeysetValues keyset_values = 3; + UInt32SetValues uintset_values = 4; } } @@ -93,6 +98,11 @@ message KeysetValues { repeated string values = 1; } +// UInt32 set values +message UInt32SetValues { + repeated uint32 values = 1; +} + // Run Query request. message InternalRunQueryRequest { // Query to run. @@ -108,3 +118,18 @@ message InternalRunQueryResponse { // Set of elements returned. repeated string elements = 1; } + +// Run Query request. +message InternalRunSetQueryIntRequest { + // Query to run. + optional string query = 1; + // Context useful for logging and tracing requests + privacy_sandbox.server_common.LogContext log_context = 2; + // Consented debugging configuration + privacy_sandbox.server_common.ConsentedDebugConfiguration consented_debug_config = 3; +} + +// Response for running a set query using sets of unsigned ints as input. +message InternalRunSetQueryIntResponse { + repeated uint32 elements = 1; +} diff --git a/components/internal_server/lookup_server_impl.cc b/components/internal_server/lookup_server_impl.cc index ed9b5715..f4f6575c 100644 --- a/components/internal_server/lookup_server_impl.cc +++ b/components/internal_server/lookup_server_impl.cc @@ -14,32 +14,27 @@ #include "components/internal_server/lookup_server_impl.h" -#include -#include -#include #include #include -#include +#include "absl/functional/any_invocable.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "components/data_server/request_handler/ohttp_server_encryptor.h" -#include "components/internal_server/lookup.grpc.pb.h" #include "components/internal_server/lookup.h" #include "components/internal_server/string_padder.h" #include "google/protobuf/message.h" #include "grpcpp/grpcpp.h" namespace kv_server { -using google::protobuf::RepeatedPtrField; +using google::protobuf::RepeatedPtrField; using grpc::StatusCode; grpc::Status LookupServiceImpl::ToInternalGrpcStatus( - const RequestContext& request_context, const absl::Status& status, + InternalLookupMetricsContext& metrics_context, const absl::Status& status, std::string_view error_code) const { - LogInternalLookupRequestErrorMetric( - request_context.GetInternalLookupMetricsContext(), error_code); + LogInternalLookupRequestErrorMetric(metrics_context, error_code); return grpc::Status(StatusCode::INTERNAL, absl::StrCat(status.code(), " : ", status.message())); } @@ -76,8 +71,7 @@ void LookupServiceImpl::ProcessKeysetKeys( grpc::Status LookupServiceImpl::InternalLookup( grpc::ServerContext* context, const InternalLookupRequest* request, InternalLookupResponse* response) { - auto scope_metrics_context = std::make_unique(); - RequestContext request_context(*scope_metrics_context); + RequestContext request_context; if (context->IsCancelled()) { return grpc::Status(grpc::StatusCode::CANCELLED, "Deadline exceeded or client cancelled, abandoning."); @@ -90,8 +84,7 @@ grpc::Status LookupServiceImpl::SecureLookup( grpc::ServerContext* context, const SecureLookupRequest* secure_lookup_request, SecureLookupResponse* secure_response) { - auto scope_metrics_context = std::make_unique(); - RequestContext request_context(*scope_metrics_context); + RequestContext request_context; LogIfError(request_context.GetInternalLookupMetricsContext() .AccumulateMetric(1)); ScopeLatencyMetricsRecorderohttp_request()); + encryptor.DecryptRequest(secure_lookup_request->ohttp_request(), + request_context.GetPSLogContext()); if (!padded_serialized_request_maybe.ok()) { - return ToInternalGrpcStatus(request_context, - padded_serialized_request_maybe.status(), - kRequestDecryptionFailure); + return ToInternalGrpcStatus( + request_context.GetInternalLookupMetricsContext(), + padded_serialized_request_maybe.status(), kRequestDecryptionFailure); } - VLOG(9) << "SecureLookup decrypted"; + PS_VLOG(9, request_context.GetPSLogContext()) << "SecureLookup decrypted"; auto serialized_request_maybe = kv_server::Unpad(*padded_serialized_request_maybe); if (!serialized_request_maybe.ok()) { - return ToInternalGrpcStatus(request_context, - serialized_request_maybe.status(), - kRequestUnpaddingError); + return ToInternalGrpcStatus( + request_context.GetInternalLookupMetricsContext(), + serialized_request_maybe.status(), kRequestUnpaddingError); } - VLOG(9) << "SecureLookup unpadded"; + PS_VLOG(9, request_context.GetPSLogContext()) << "SecureLookup unpadded"; InternalLookupRequest request; if (!request.ParseFromString(*serialized_request_maybe)) { return grpc::Status(grpc::StatusCode::INTERNAL, "Failed parsing incoming request"); } - + request_context.UpdateLogContext(request.log_context(), + request.consented_debug_config()); auto payload_to_encrypt = GetPayload(request_context, request.lookup_sets(), request.keys()); if (payload_to_encrypt.empty()) { @@ -135,12 +130,12 @@ grpc::Status LookupServiceImpl::SecureLookup( // to pad responses, so this branch will never be hit. return grpc::Status::OK; } - auto encrypted_response_payload = - encryptor.EncryptResponse(payload_to_encrypt); + auto encrypted_response_payload = encryptor.EncryptResponse( + payload_to_encrypt, request_context.GetPSLogContext()); if (!encrypted_response_payload.ok()) { - return ToInternalGrpcStatus(request_context, - encrypted_response_payload.status(), - kResponseEncryptionFailure); + return ToInternalGrpcStatus( + request_context.GetInternalLookupMetricsContext(), + encrypted_response_payload.status(), kResponseEncryptionFailure); } secure_response->set_ohttp_response(*encrypted_response_payload); return grpc::Status::OK; @@ -161,20 +156,23 @@ std::string LookupServiceImpl::GetPayload( grpc::Status LookupServiceImpl::InternalRunQuery( grpc::ServerContext* context, const InternalRunQueryRequest* request, InternalRunQueryResponse* response) { - auto scope_metrics_context = std::make_unique(); - RequestContext request_context(*scope_metrics_context); - if (context->IsCancelled()) { - return grpc::Status(grpc::StatusCode::CANCELLED, - "Deadline exceeded or client cancelled, abandoning."); - } - const auto process_result = - lookup_.RunQuery(request_context, request->query()); - if (!process_result.ok()) { - return ToInternalGrpcStatus(request_context, process_result.status(), - kInternalRunQueryRequestFailure); - } - *response = *std::move(process_result); - return grpc::Status::OK; + return RunSetQuery( + context, request, response, + [this](const RequestContext& request_context, std::string query) { + return lookup_.RunQuery(request_context, query); + }); +} + +grpc::Status LookupServiceImpl::InternalRunSetQueryInt( + grpc::ServerContext* context, + const kv_server::InternalRunSetQueryIntRequest* request, + kv_server::InternalRunSetQueryIntResponse* response) { + return RunSetQuery( + context, request, response, + [this](const RequestContext& request_context, std::string query) { + return lookup_.RunSetQueryInt(request_context, query); + }); } } // namespace kv_server diff --git a/components/internal_server/lookup_server_impl.h b/components/internal_server/lookup_server_impl.h index 83fb9dfc..ac3040d0 100644 --- a/components/internal_server/lookup_server_impl.h +++ b/components/internal_server/lookup_server_impl.h @@ -18,13 +18,13 @@ #define COMPONENTS_INTERNAL_SERVER_LOOKUP_SERVER_IMPL_H_ #include +#include #include "components/internal_server/lookup.grpc.pb.h" #include "components/internal_server/lookup.h" #include "components/util/request_context.h" #include "grpcpp/grpcpp.h" #include "src/encryption/key_fetcher/interface/key_fetcher_manager_interface.h" -#include "src/telemetry/telemetry.h" namespace kv_server { // Implements the internal lookup service for the data store. @@ -52,6 +52,11 @@ class LookupServiceImpl final const kv_server::InternalRunQueryRequest* request, kv_server::InternalRunQueryResponse* response) override; + grpc::Status InternalRunSetQueryInt( + grpc::ServerContext* context, + const kv_server::InternalRunSetQueryIntRequest* request, + kv_server::InternalRunSetQueryIntResponse* response) override; + private: std::string GetPayload( const RequestContext& request_context, const bool lookup_sets, @@ -63,9 +68,32 @@ class LookupServiceImpl final const RequestContext& request_context, const google::protobuf::RepeatedPtrField& keys, InternalLookupResponse& response) const; - grpc::Status ToInternalGrpcStatus(const RequestContext& request_context, - const absl::Status& status, - std::string_view error_code) const; + grpc::Status ToInternalGrpcStatus( + InternalLookupMetricsContext& metrics_context, const absl::Status& status, + std::string_view error_code) const; + + template + grpc::Status RunSetQuery(grpc::ServerContext* context, + const RequestType* request, ResponseType* response, + absl::AnyInvocable( + const RequestContext&, std::string)> + run_set_query_fn) { + RequestContext request_context; + if (context->IsCancelled()) { + return grpc::Status(grpc::StatusCode::CANCELLED, + "Deadline exceeded or client cancelled, abandoning."); + } + const auto process_result = + run_set_query_fn(request_context, request->query()); + if (!process_result.ok()) { + return ToInternalGrpcStatus( + request_context.GetInternalLookupMetricsContext(), + process_result.status(), kInternalRunQueryRequestFailure); + } + *response = std::move(*process_result); + return grpc::Status::OK; + } + const Lookup& lookup_; privacy_sandbox::server_common::KeyFetcherManagerInterface& key_fetcher_manager_; diff --git a/components/internal_server/lookup_server_impl_test.cc b/components/internal_server/lookup_server_impl_test.cc index be2118fd..277d13e5 100644 --- a/components/internal_server/lookup_server_impl_test.cc +++ b/components/internal_server/lookup_server_impl_test.cc @@ -16,14 +16,8 @@ #include "components/internal_server/lookup_server_impl.h" #include -#include -#include -#include -#include "components/data_server/cache/key_value_cache.h" -#include "components/data_server/cache/mocks.h" #include "components/internal_server/mocks.h" -#include "components/internal_server/string_padder.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "grpcpp/grpcpp.h" @@ -37,7 +31,6 @@ namespace { using google::protobuf::TextFormat; using testing::_; using testing::Return; -using testing::ReturnRef; class LookupServiceImplTest : public ::testing::Test { protected: @@ -142,6 +135,34 @@ TEST_F(LookupServiceImplTest, SecureLookupFailure) { EXPECT_EQ(status.error_code(), grpc::StatusCode::INTERNAL); } +TEST_F(LookupServiceImplTest, InternalRunSetQueryInt_Success) { + InternalRunSetQueryIntRequest request; + request.set_query("someset"); + InternalRunSetQueryIntResponse expected; + expected.add_elements(1000); + expected.add_elements(1001); + expected.add_elements(1002); + EXPECT_CALL(mock_lookup_, RunSetQueryInt(_, _)).WillOnce(Return(expected)); + InternalRunSetQueryIntResponse response; + grpc::ClientContext context; + grpc::Status status = + stub_->InternalRunSetQueryInt(&context, request, &response); + auto results = response.elements(); + EXPECT_THAT(results, testing::UnorderedElementsAreArray({1000, 1001, 1002})); +} + +TEST_F(LookupServiceImplTest, InternalRunSetQueryInt_LookupError_Failure) { + InternalRunSetQueryIntRequest request; + request.set_query("fail|||||now"); + EXPECT_CALL(mock_lookup_, RunSetQueryInt(_, _)) + .WillOnce(Return(absl::UnknownError("Some error"))); + InternalRunSetQueryIntResponse response; + grpc::ClientContext context; + grpc::Status status = + stub_->InternalRunSetQueryInt(&context, request, &response); + EXPECT_EQ(status.error_code(), grpc::StatusCode::INTERNAL); +} + } // namespace } // namespace kv_server diff --git a/components/internal_server/mocks.h b/components/internal_server/mocks.h index 1d049b15..c72f5487 100644 --- a/components/internal_server/mocks.h +++ b/components/internal_server/mocks.h @@ -49,8 +49,14 @@ class MockLookup : public Lookup { (const RequestContext&, const absl::flat_hash_set&), (const, override)); + MOCK_METHOD(absl::StatusOr, GetUInt32ValueSet, + (const RequestContext&, + const absl::flat_hash_set&), + (const, override)); MOCK_METHOD(absl::StatusOr, RunQuery, (const RequestContext&, std::string query), (const, override)); + MOCK_METHOD(absl::StatusOr, RunSetQueryInt, + (const RequestContext&, std::string query), (const, override)); }; } // namespace kv_server diff --git a/components/internal_server/remote_lookup_client_impl.cc b/components/internal_server/remote_lookup_client_impl.cc index 8640f39d..a599daff 100644 --- a/components/internal_server/remote_lookup_client_impl.cc +++ b/components/internal_server/remote_lookup_client_impl.cc @@ -56,9 +56,19 @@ class RemoteLookupClientImpl : public RemoteLookupClient { ScopeLatencyMetricsRecorder latency_recorder(request_context.GetUdfRequestMetricsContext()); - OhttpClientEncryptor encryptor(key_fetcher_manager_); + auto maybe_public_key = + key_fetcher_manager_.GetPublicKey(GetCloudPlatform()); + if (!maybe_public_key.ok()) { + const std::string error = + absl::StrCat("Could not get public key to use for HPKE encryption:", + maybe_public_key.status().message()); + PS_LOG(ERROR, request_context.GetPSLogContext()) << error; + return absl::InternalError(error); + } + OhttpClientEncryptor encryptor(maybe_public_key.value()); auto encrypted_padded_serialized_request_maybe = - encryptor.EncryptRequest(Pad(serialized_message, padding_length)); + encryptor.EncryptRequest(Pad(serialized_message, padding_length), + request_context.GetPSLogContext()); if (!encrypted_padded_serialized_request_maybe.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), kRemoteRequestEncryptionFailure); @@ -74,7 +84,8 @@ class RemoteLookupClientImpl : public RemoteLookupClient { if (!status.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), kRemoteSecureLookupFailure); - LOG(ERROR) << status.error_code() << ": " << status.error_message(); + PS_LOG(ERROR, request_context.GetPSLogContext()) + << status.error_code() << ": " << status.error_message(); return absl::Status((absl::StatusCode)status.error_code(), status.error_message()); } @@ -85,7 +96,8 @@ class RemoteLookupClientImpl : public RemoteLookupClient { return response; } auto decrypted_response_maybe = - encryptor.DecryptResponse(std::move(secure_response.ohttp_response())); + encryptor.DecryptResponse(std::move(secure_response.ohttp_response()), + request_context.GetPSLogContext()); if (!decrypted_response_maybe.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), kResponseEncryptionFailure); @@ -100,6 +112,14 @@ class RemoteLookupClientImpl : public RemoteLookupClient { std::string_view GetIpAddress() const override { return ip_address_; } private: + privacy_sandbox::server_common::CloudPlatform GetCloudPlatform() const { +#if defined(CLOUD_PLATFORM_AWS) + return privacy_sandbox::server_common::CloudPlatform::kAws; +#elif defined(CLOUD_PLATFORM_GCP) + return privacy_sandbox::server_common::CloudPlatform::kGcp; +#endif + return privacy_sandbox::server_common::CloudPlatform::kLocal; + } const std::string ip_address_; std::unique_ptr stub_; privacy_sandbox::server_common::KeyFetcherManagerInterface& diff --git a/components/internal_server/remote_lookup_client_impl_test.cc b/components/internal_server/remote_lookup_client_impl_test.cc index 56f9b7fe..d95320f3 100644 --- a/components/internal_server/remote_lookup_client_impl_test.cc +++ b/components/internal_server/remote_lookup_client_impl_test.cc @@ -43,9 +43,10 @@ class RemoteLookupClientImplTest : public ::testing::Test { server_->InProcessChannel(grpc::ChannelArguments())), fake_key_fetcher_manager_); InitMetricsContextMap(); - scope_metrics_context_ = std::make_unique(); - request_context_ = - std::make_unique(*scope_metrics_context_); + request_context_ = std::make_unique(); + request_context_->UpdateLogContext( + privacy_sandbox::server_common::LogContext(), + privacy_sandbox::server_common::ConsentedDebugConfiguration()); } ~RemoteLookupClientImplTest() { @@ -59,8 +60,7 @@ class RemoteLookupClientImplTest : public ::testing::Test { std::unique_ptr lookup_service_; std::unique_ptr server_; std::unique_ptr remote_lookup_client_; - std::unique_ptr scope_metrics_context_; - std::unique_ptr request_context_; + std::shared_ptr request_context_; }; TEST_F(RemoteLookupClientImplTest, EncryptedPaddedSuccessfulCall) { diff --git a/components/internal_server/sharded_lookup.cc b/components/internal_server/sharded_lookup.cc index b102cdca..2e98bf7a 100644 --- a/components/internal_server/sharded_lookup.cc +++ b/components/internal_server/sharded_lookup.cc @@ -22,7 +22,7 @@ #include #include "absl/log/check.h" -#include "absl/log/log.h" +#include "components/data_server/cache/uint32_value_set.h" #include "components/internal_server/lookup.h" #include "components/internal_server/lookup.pb.h" #include "components/internal_server/remote_lookup_client.h" @@ -30,13 +30,10 @@ #include "components/query/scanner.h" #include "components/sharding/shard_manager.h" #include "components/util/request_context.h" -#include "pir/hashing/sha256_hash_family.h" namespace kv_server { namespace { -using google::protobuf::RepeatedPtrField; - void UpdateResponse( const std::vector& key_list, ::google::protobuf::Map& @@ -56,7 +53,8 @@ void UpdateResponse( } void SetRequestFailed(const std::vector& key_list, - InternalLookupResponse& response) { + InternalLookupResponse& response, + const RequestContext& request_context) { SingleLookupResult result; auto status = result.mutable_status(); status->set_code(static_cast(absl::StatusCode::kInternal)); @@ -64,7 +62,8 @@ void SetRequestFailed(const std::vector& key_list, for (const auto& key : key_list) { (*response.mutable_kv_pairs())[key] = result; } - LOG(ERROR) << "Sharded lookup failed:" << response.DebugString(); + PS_LOG(ERROR, request_context.GetPSLogContext()) + << "Sharded lookup failed:" << response.DebugString(); } class ShardedLookup : public Lookup { @@ -101,39 +100,13 @@ class ShardedLookup : public Lookup { absl::StatusOr GetKeyValueSet( const RequestContext& request_context, const absl::flat_hash_set& keys) const override { - ScopeLatencyMetricsRecorder - latency_recorder(request_context.GetUdfRequestMetricsContext()); - InternalLookupResponse response; - if (keys.empty()) { - return response; - } - absl::flat_hash_map> key_sets; - auto get_key_value_set_result_maybe = - GetShardedKeyValueSet(request_context, keys); - if (!get_key_value_set_result_maybe.ok()) { - LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedGetKeyValueSetKeySetRetrievalFailure); - return get_key_value_set_result_maybe.status(); - } - key_sets = *std::move(get_key_value_set_result_maybe); + return GetKeyValueSets(request_context, keys); + } - for (const auto& key : keys) { - SingleLookupResult result; - const auto key_iter = key_sets.find(key); - if (key_iter == key_sets.end()) { - auto status = result.mutable_status(); - status->set_code(static_cast(absl::StatusCode::kNotFound)); - LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedGetKeyValueSetKeySetNotFound); - } else { - auto keyset_values = result.mutable_keyset_values(); - keyset_values->mutable_values()->Add(key_iter->second.begin(), - key_iter->second.end()); - } - (*response.mutable_kv_pairs())[key] = std::move(result); - } - return response; + absl::StatusOr GetUInt32ValueSet( + const RequestContext& request_context, + const absl::flat_hash_set& key_set) const override { + return GetKeyValueSets(request_context, key_set); } absl::StatusOr RunQuery( @@ -147,52 +120,51 @@ class ShardedLookup : public Lookup { kShardedRunQueryEmptyQuery); return response; } - - absl::flat_hash_map> keysets; - kv_server::Driver driver([&keysets, this, - &request_context](std::string_view key) { - const auto key_iter = keysets.find(key); - if (key_iter == keysets.end()) { - VLOG(8) << "Driver can't find " << key << "key_set. Returning empty."; - LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedRunQueryMissingKeySet); - absl::flat_hash_set set; - return set; - } else { - absl::flat_hash_set set(key_iter->second.begin(), - key_iter->second.end()); - return set; - } - }); - std::istringstream stream(query); - kv_server::Scanner scanner(stream); - kv_server::Parser parse(driver, scanner); - int parse_result = parse(); - if (parse_result) { + auto result = + RunSetQuery, std::string>( + request_context, query); + if (!result.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedRunQueryParsingFailure); - return absl::InvalidArgumentError("Parsing failure."); + kShardedRunQueryFailure); + return result.status(); } - auto get_key_value_set_result_maybe = - GetShardedKeyValueSet(request_context, driver.GetRootNode()->Keys()); - if (!get_key_value_set_result_maybe.ok()) { + PS_VLOG(8, request_context.GetPSLogContext()) + << "Driver results for query " << query; + for (const auto& value : *result) { + PS_VLOG(8, request_context.GetPSLogContext()) + << "Value: " << value << "\n"; + } + response.mutable_elements()->Assign(result->begin(), result->end()); + return response; + } + + absl::StatusOr RunSetQueryInt( + const RequestContext& request_context, std::string query) const override { + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetUdfRequestMetricsContext()); + InternalRunSetQueryIntResponse response; + if (query.empty()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), - kShardedRunQueryKeySetRetrievalFailure); - return get_key_value_set_result_maybe.status(); + kShardedRunQueryEmptyQuery); + return response; } - keysets = std::move(*get_key_value_set_result_maybe); - auto result = driver.GetResult(); + auto result = + RunSetQuery(request_context, query); if (!result.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), kShardedRunQueryFailure); return result.status(); } - VLOG(8) << "Driver results for query " << query; + PS_VLOG(8, request_context.GetPSLogContext()) + << "Driver results for query " << query; for (const auto& value : *result) { - VLOG(8) << "Value: " << value << "\n"; + PS_VLOG(8, request_context.GetPSLogContext()) + << "Value: " << value << "\n"; } - - response.mutable_elements()->Assign(result->begin(), result->end()); + auto uint32_set = BitSetToUint32Set(*result); + response.mutable_elements()->Reserve(uint32_set.size()); + response.mutable_elements()->Assign(uint32_set.begin(), uint32_set.end()); return response; } @@ -210,27 +182,34 @@ class ShardedLookup : public Lookup { }; std::vector BucketKeys( + const RequestContext& request_context, const absl::flat_hash_set& keys) const { ShardLookupInput sli; std::vector lookup_inputs(num_shards_, sli); for (const auto key : keys) { auto sharding_result = key_sharder_.GetShardNumForKey(key, num_shards_); - VLOG(9) << "key: " << key - << ", shard number: " << sharding_result.shard_num - << ", sharding_key (if regex is present): " - << sharding_result.sharding_key; + PS_VLOG(9, request_context.GetPSLogContext()) + << "key: " << key << ", shard number: " << sharding_result.shard_num + << ", sharding_key (if regex is present): " + << sharding_result.sharding_key; lookup_inputs[sharding_result.shard_num].keys.emplace_back(key); } return lookup_inputs; } - void SerializeShardedRequests(std::vector& lookup_inputs, + void SerializeShardedRequests(const RequestContext& request_context, + std::vector& lookup_inputs, bool lookup_sets) const { for (auto& lookup_input : lookup_inputs) { InternalLookupRequest request; request.mutable_keys()->Assign(lookup_input.keys.begin(), lookup_input.keys.end()); request.set_lookup_sets(lookup_sets); + *request.mutable_consented_debug_config() = + request_context.GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + request_context.GetRequestLogContext().GetLogContext(); lookup_input.serialized_request = request.SerializeAsString(); } } @@ -248,10 +227,11 @@ class ShardedLookup : public Lookup { } std::vector ShardKeys( + const RequestContext& request_context, const absl::flat_hash_set& keys, bool lookup_sets) const { - auto lookup_inputs = BucketKeys(keys); - SerializeShardedRequests(lookup_inputs, lookup_sets); + auto lookup_inputs = BucketKeys(request_context, keys); + SerializeShardedRequests(request_context, lookup_inputs, lookup_sets); ComputePadding(lookup_inputs); return lookup_inputs; } @@ -295,31 +275,26 @@ class ShardedLookup : public Lookup { return responses; } - absl::StatusOr GetLocalValues( + // Local lookups will go away once we split the server into UDF and Data + // servers. + template + absl::StatusOr GetLocalLookupResponse( const RequestContext& request_context, const std::vector& key_list) const { - InternalLookupResponse response; + if (key_list.empty()) { + return InternalLookupResponse(); + } absl::flat_hash_set keys(key_list.begin(), key_list.end()); - return local_lookup_.GetKeyValues(request_context, keys); - } - - absl::StatusOr GetLocalKeyValuesSet( - const RequestContext& request_context, - const std::vector& key_list) const { - if (key_list.empty()) { - InternalLookupResponse response; - return response; + if constexpr (result_type == SingleLookupResult::kValue) { + return local_lookup_.GetKeyValues(request_context, keys); + } + if constexpr (result_type == SingleLookupResult::kKeysetValues) { + return local_lookup_.GetKeyValueSet(request_context, keys); + } + if constexpr (result_type == SingleLookupResult::kUintsetValues) { + return local_lookup_.GetUInt32ValueSet(request_context, keys); } - - // We have this conversion, because of the inconsistency how we look up - // keys in Cache -- GetKeyValuePairs vs GetKeyValueSet. GetKeyValuePairs - // should be refactored to flat_hash_set, and then this can be fixed. - // Additionally, this whole local branch will go away once we have a - // a sepration between UDF and Data servers. - absl::flat_hash_set key_list_set(key_list.begin(), - key_list.end()); - return local_lookup_.GetKeyValueSet(request_context, key_list_set); } absl::StatusOr ProcessShardedKeys( @@ -329,13 +304,14 @@ class ShardedLookup : public Lookup { if (keys.empty()) { return response; } - const auto shard_lookup_inputs = ShardKeys(keys, false); - auto responses = - GetLookupFutures(request_context, shard_lookup_inputs, - [this, &request_context]( - const std::vector& key_list) { - return GetLocalValues(request_context, key_list); - }); + const auto shard_lookup_inputs = ShardKeys(request_context, keys, false); + auto responses = GetLookupFutures( + request_context, shard_lookup_inputs, + [this, + &request_context](const std::vector& key_list) { + return GetLocalLookupResponse( + request_context, key_list); + }); if (!responses.ok()) { return responses.status(); } @@ -347,7 +323,7 @@ class ShardedLookup : public Lookup { // mark all keys as internal failure LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), kShardedKeyValueRequestFailure); - SetRequestFailed(shard_lookup_input.keys, response); + SetRequestFailed(shard_lookup_input.keys, response, request_context); continue; } auto kv_pairs = result->mutable_kv_pairs(); @@ -356,48 +332,68 @@ class ShardedLookup : public Lookup { return response; } + template void CollectKeySets( const RequestContext& request_context, - absl::flat_hash_map>& + absl::flat_hash_map>& key_sets, InternalLookupResponse& keysets_lookup_response) const { for (auto& [key, keyset_lookup_result] : (*(keysets_lookup_response.mutable_kv_pairs()))) { - switch (keyset_lookup_result.single_lookup_result_case()) { - case SingleLookupResult::kStatusFieldNumber: - // this means it wasn't found, no need to insert an empty set. - break; - case SingleLookupResult::kKeysetValuesFieldNumber: - absl::flat_hash_set value_set; + absl::flat_hash_set value_set; + if constexpr (std::is_same_v) { + if (keyset_lookup_result.single_lookup_result_case() == + SingleLookupResult::kKeysetValues) { for (auto& v : keyset_lookup_result.keyset_values().values()) { - VLOG(8) << "keyset name: " << key << " value: " << v; + PS_VLOG(8, request_context.GetPSLogContext()) + << "keyset name: " << key << " value: " << v; value_set.emplace(std::move(v)); } - auto [_, inserted] = - key_sets.insert_or_assign(key, std::move(value_set)); - if (!inserted) { - LogUdfRequestErrorMetric( - request_context.GetUdfRequestMetricsContext(), - kShardedKeyCollisionOnKeySetCollection); - LOG(ERROR) << "Key collision, when collecting results from shards: " - << key; + } + } + if constexpr (std::is_same_v) { + if (keyset_lookup_result.single_lookup_result_case() == + SingleLookupResult::kUintsetValues) { + for (auto& v : keyset_lookup_result.uintset_values().values()) { + PS_VLOG(8, request_context.GetPSLogContext()) + << "keyset name: " << key << " value: " << v; + value_set.emplace(std::move(v)); } - break; + } + } + if (!value_set.empty()) { + if (auto [_, inserted] = + key_sets.insert_or_assign(key, std::move(value_set)); + !inserted) { + LogUdfRequestErrorMetric( + request_context.GetUdfRequestMetricsContext(), + kShardedKeyCollisionOnKeySetCollection); + PS_LOG(ERROR, request_context.GetPSLogContext()) + << "Key collision, when collecting results from shards: " << key; + } } } } + template absl::StatusOr< - absl::flat_hash_map>> + absl::flat_hash_map>> GetShardedKeyValueSet( const RequestContext& request_context, const absl::flat_hash_set& key_set) const { - const auto shard_lookup_inputs = ShardKeys(key_set, true); + const auto shard_lookup_inputs = ShardKeys(request_context, key_set, true); auto responses = GetLookupFutures( request_context, shard_lookup_inputs, [this, &request_context](const std::vector& key_list) { - return GetLocalKeyValuesSet(request_context, key_list); + if constexpr (std::is_same_v) { + return GetLocalLookupResponse( + request_context, key_list); + } + if constexpr (std::is_same_v) { + return GetLocalLookupResponse( + request_context, key_list); + } }); if (!responses.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), @@ -405,9 +401,9 @@ class ShardedLookup : public Lookup { return responses.status(); } // process responses - absl::flat_hash_map> key_sets; + absl::flat_hash_map> + key_sets; for (int shard_num = 0; shard_num < num_shards_; shard_num++) { - auto& shard_lookup_input = shard_lookup_inputs[shard_num]; auto result = (*responses)[shard_num].get(); if (!result.ok()) { LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), @@ -419,6 +415,101 @@ class ShardedLookup : public Lookup { return key_sets; } + template + absl::StatusOr GetKeyValueSets( + const RequestContext& request_context, + const absl::flat_hash_set& keys) const { + ScopeLatencyMetricsRecorder + latency_recorder(request_context.GetUdfRequestMetricsContext()); + InternalLookupResponse response; + if (keys.empty()) { + return response; + } + absl::flat_hash_map> + key_sets; + auto get_key_value_set_result_maybe = + GetShardedKeyValueSet(request_context, keys); + if (!get_key_value_set_result_maybe.ok()) { + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedGetKeyValueSetKeySetRetrievalFailure); + return get_key_value_set_result_maybe.status(); + } + key_sets = *std::move(get_key_value_set_result_maybe); + for (const auto& key : keys) { + SingleLookupResult result; + if (const auto key_iter = key_sets.find(key); + key_iter == key_sets.end()) { + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedGetKeyValueSetKeySetNotFound); + } else { + if constexpr (std::is_same_v) { + auto* keyset_values = result.mutable_keyset_values(); + keyset_values->mutable_values()->Reserve(key_iter->second.size()); + keyset_values->mutable_values()->Add(key_iter->second.begin(), + key_iter->second.end()); + } + if constexpr (std::is_same_v) { + auto* uint32set_values = result.mutable_uintset_values(); + uint32set_values->mutable_values()->Reserve(key_iter->second.size()); + uint32set_values->mutable_values()->Add(key_iter->second.begin(), + key_iter->second.end()); + } + } + (*response.mutable_kv_pairs())[key] = std::move(result); + } + return response; + } + + template + absl::StatusOr RunSetQuery(const RequestContext& request_context, + std::string query) const { + kv_server::Driver driver; + std::istringstream stream(query); + kv_server::Scanner scanner(stream); + kv_server::Parser parse(driver, scanner); + int parse_result = parse(); + if (parse_result) { + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedRunQueryParsingFailure); + return absl::InvalidArgumentError("Parsing failure."); + } + auto get_key_value_set_result_maybe = GetShardedKeyValueSet( + request_context, driver.GetRootNode()->Keys()); + if (!get_key_value_set_result_maybe.ok()) { + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedRunQueryKeySetRetrievalFailure); + return get_key_value_set_result_maybe.status(); + } + auto keysets = std::move(*get_key_value_set_result_maybe); + return driver.EvaluateQuery([&keysets, &request_context]( + std::string_view key) { + const auto key_iter = keysets.find(key); + if (key_iter == keysets.end()) { + PS_VLOG(8, request_context.GetPSLogContext()) + << "Driver can't find " << key << "key_set. Returning empty."; + LogUdfRequestErrorMetric(request_context.GetUdfRequestMetricsContext(), + kShardedRunQueryMissingKeySet); + return SetType(); + } + if constexpr (std::is_same_v>) { + return absl::flat_hash_set(key_iter->second.begin(), + key_iter->second.end()); + } + if constexpr (std::is_same_v) { + roaring::Roaring bitset; + for (const auto& element : key_iter->second) { + bitset.add(element); + } + bitset.runOptimize(); + return bitset; + } + }); + } + const Lookup& local_lookup_; const int32_t num_shards_; const int32_t current_shard_num_; diff --git a/components/internal_server/sharded_lookup_test.cc b/components/internal_server/sharded_lookup_test.cc index a1ddda53..81021255 100644 --- a/components/internal_server/sharded_lookup_test.cc +++ b/components/internal_server/sharded_lookup_test.cc @@ -19,7 +19,6 @@ #include #include -#include "components/data_server/cache/mocks.h" #include "components/internal_server/mocks.h" #include "components/sharding/mocks.h" #include "gmock/gmock.h" @@ -33,19 +32,19 @@ namespace { using google::protobuf::TextFormat; using testing::_; using testing::Return; -using testing::ReturnRef; class ShardedLookupTest : public ::testing::Test { protected: ShardedLookupTest() { InitMetricsContextMap(); - scope_metrics_context_ = std::make_unique(); - request_context_ = - std::make_unique(*scope_metrics_context_); + request_context_ = std::make_unique(); + request_context_->UpdateLogContext( + privacy_sandbox::server_common::LogContext(), + privacy_sandbox::server_common::ConsentedDebugConfiguration()); + request_context_ = std::make_shared(); } RequestContext& GetRequestContext() { return *request_context_; } - std::unique_ptr scope_metrics_context_; - std::unique_ptr request_context_; + std::shared_ptr request_context_; int32_t num_shards_ = 2; int32_t shard_num_ = 0; @@ -53,6 +52,64 @@ class ShardedLookupTest : public ::testing::Test { KeySharder key_sharder_ = KeySharder(ShardingFunction{/*seed=*/""}); }; +TEST_F(ShardedLookupTest, VerifyCorrectnessOfSerializedRequest) { + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + privacy_sandbox::server_common::ConsentedDebugConfiguration + consented_debug_config; + consented_debug_config.set_is_consented(true); + consented_debug_config.set_token("test_token"); + privacy_sandbox::server_common::LogContext log_context; + log_context.set_adtech_debug_id("debug_id"); + log_context.set_generation_id("generation_id"); + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), + [&consented_debug_config, &log_context](const std::string& ip) { + if (ip != "1") { + return std::make_unique(); + } + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1", "key5"}; + EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, _, 0)) + .WillOnce([=](const RequestContext& request_context, + const std::string_view serialized_message, + const int32_t padding_length) { + InternalLookupRequest request; + EXPECT_TRUE(request.ParseFromString(serialized_message)); + auto request_keys = std::vector( + request.keys().begin(), request.keys().end()); + EXPECT_THAT(request.keys(), + testing::UnorderedElementsAreArray(key_list_remote)); + EXPECT_THAT(request.log_context(), EqualsProto(log_context)); + EXPECT_THAT(request.consented_debug_config(), + EqualsProto(consented_debug_config)); + InternalLookupResponse resp; + SingleLookupResult result; + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + (*resp.mutable_kv_pairs())["key1"] = result; + return resp; + }); + + return mock_remote_lookup_client_1; + }); + + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto request_log_context = + std::make_unique(log_context, consented_debug_config); + auto request_context = std::make_unique(); + request_context->UpdateLogContext(log_context, consented_debug_config); + EXPECT_TRUE( + sharded_lookup->GetKeyValues(*request_context, {"key1", "key4", "key5"}) + .ok()); +} + TEST_F(ShardedLookupTest, GetKeyValues_Success) { InternalLookupResponse local_lookup_response; TextFormat::ParseFromString(R"pb(kv_pairs { @@ -70,7 +127,7 @@ TEST_F(ShardedLookupTest, GetKeyValues_Success) { } auto shard_manager = ShardManager::Create( num_shards_, std::move(cluster_mappings), - std::make_unique(), [](const std::string& ip) { + std::make_unique(), [this](const std::string& ip) { if (ip != "1") { return std::make_unique(); } @@ -81,6 +138,12 @@ TEST_F(ShardedLookupTest, GetKeyValues_Success) { InternalLookupRequest request; request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, serialized_request, 0)) @@ -134,7 +197,7 @@ TEST_F(ShardedLookupTest, GetKeyValues_KeyMissing_ReturnsStatus) { auto shard_manager = ShardManager::Create( num_shards_, std::move(cluster_mappings), - std::make_unique(), [](const std::string& ip) { + std::make_unique(), [this](const std::string& ip) { if (ip != "1") { return std::make_unique(); } @@ -144,6 +207,12 @@ TEST_F(ShardedLookupTest, GetKeyValues_KeyMissing_ReturnsStatus) { InternalLookupRequest request; request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, _, 0)) @@ -233,7 +302,7 @@ TEST_F(ShardedLookupTest, GetKeyValues_FailedDownstreamRequest) { } auto shard_manager = ShardManager::Create( num_shards_, std::move(cluster_mappings), - std::make_unique(), [](const std::string& ip) { + std::make_unique(), [this](const std::string& ip) { if (ip != "1") { return std::make_unique(); } @@ -243,6 +312,12 @@ TEST_F(ShardedLookupTest, GetKeyValues_FailedDownstreamRequest) { InternalLookupRequest request; request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, serialized_request, 0)) @@ -289,7 +364,7 @@ TEST_F(ShardedLookupTest, GetKeyValues_ReturnsKeysFromCachePadding) { keys.insert("longkey1"); keys.insert("randomkey3"); - int total_length = 22; + int total_length = 26; std::vector key_list = {"key4", "verylongkey2"}; InternalLookupResponse local_lookup_response; @@ -320,7 +395,7 @@ TEST_F(ShardedLookupTest, GetKeyValues_ReturnsKeysFromCachePadding) { auto shard_manager = ShardManager::Create( num_shards, std::move(cluster_mappings), std::make_unique(), - [total_length](const std::string& ip) { + [total_length, this](const std::string& ip) { if (ip == "1") { auto mock_remote_lookup_client_1 = std::make_unique(); @@ -329,6 +404,12 @@ TEST_F(ShardedLookupTest, GetKeyValues_ReturnsKeysFromCachePadding) { InternalLookupRequest request; request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, _, _)) .WillOnce([total_length, key_list_remote]( @@ -365,6 +446,12 @@ TEST_F(ShardedLookupTest, GetKeyValues_ReturnsKeysFromCachePadding) { InternalLookupRequest request; request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, serialized_request, _)) @@ -385,6 +472,12 @@ TEST_F(ShardedLookupTest, GetKeyValues_ReturnsKeysFromCachePadding) { InternalLookupRequest request; request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, _, _)) .WillOnce([=](const RequestContext& request_context, @@ -478,7 +571,7 @@ TEST_F(ShardedLookupTest, GetKeyValueSets_KeysFound_Success) { auto shard_manager = ShardManager::Create( num_shards_, std::move(cluster_mappings), - std::make_unique(), [](const std::string& ip) { + std::make_unique(), [this](const std::string& ip) { if (ip != "1") { return std::make_unique(); } @@ -489,6 +582,12 @@ TEST_F(ShardedLookupTest, GetKeyValueSets_KeysFound_Success) { InternalLookupRequest request; request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); request.set_lookup_sets(true); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, _, 0)) @@ -548,7 +647,7 @@ TEST_F(ShardedLookupTest, GetKeyValueSets_KeysMissing_ReturnsStatus) { auto shard_manager = ShardManager::Create( num_shards_, std::move(cluster_mappings), - std::make_unique(), [](const std::string& ip) { + std::make_unique(), [this](const std::string& ip) { if (ip != "1") { return std::make_unique(); } @@ -560,6 +659,12 @@ TEST_F(ShardedLookupTest, GetKeyValueSets_KeysMissing_ReturnsStatus) { request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); request.set_lookup_sets(true); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, _, 0)) .WillOnce([=](const RequestContext& request_context, @@ -610,6 +715,175 @@ TEST_F(ShardedLookupTest, GetKeyValueSets_KeysMissing_ReturnsStatus) { EXPECT_THAT(response.value(), EqualsProto(expected)); } +TEST_F(ShardedLookupTest, GetUInt32ValueSets_KeysFound_Success) { + InternalLookupResponse local_lookup_response; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key4" + value { uintset_values { values: 1000 } } + } + )pb", + &local_lookup_response); + EXPECT_CALL(mock_local_lookup_, GetUInt32ValueSet(_, _)) + .WillOnce(Return(local_lookup_response)); + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [this](const std::string& ip) { + if (ip != "1") { + return std::make_unique(); + } + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); + request.set_lookup_sets(true); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, _, 0)) + .WillOnce([&]() { + InternalLookupResponse resp; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key1" + value { uintset_values { values: 2000 } } + } + )pb", + &resp); + return resp; + }); + return mock_remote_lookup_client_1; + }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = + sharded_lookup->GetUInt32ValueSet(GetRequestContext(), {"key1", "key4"}); + ASSERT_TRUE(response.ok()); + InternalLookupResponse expected; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key1" + value { uintset_values { values: 2000 } } + } + kv_pairs { + key: "key4" + value { uintset_values { values: 1000 } } + } + )pb", + &expected); + EXPECT_THAT(response.value(), EqualsProto(expected)); +} + +TEST_F(ShardedLookupTest, GetUInt32ValueSets_KeysMissing_ReturnsStatus) { + InternalLookupResponse local_lookup_response; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key4" + value { uintset_values { values: 1000 } } + } + )pb", + &local_lookup_response); + EXPECT_CALL(mock_local_lookup_, GetUInt32ValueSet(_, _)) + .WillOnce(Return(local_lookup_response)); + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [this](const std::string& ip) { + if (ip != "1") { + return std::make_unique(); + } + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1", "key5"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + request.set_lookup_sets(true); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, _, 0)) + .WillOnce([=](const RequestContext& request_context, + const std::string_view serialized_message, + const int32_t padding_length) { + InternalLookupRequest request; + EXPECT_TRUE(request.ParseFromString(serialized_message)); + auto request_keys = std::vector( + request.keys().begin(), request.keys().end()); + EXPECT_THAT(request.keys(), + testing::UnorderedElementsAreArray(key_list_remote)); + InternalLookupResponse resp; + SingleLookupResult result; + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + (*resp.mutable_kv_pairs())["key1"] = result; + return resp; + }); + return mock_remote_lookup_client_1; + }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = sharded_lookup->GetUInt32ValueSet(GetRequestContext(), + {"key1", "key4", "key5"}); + ASSERT_TRUE(response.ok()); + InternalLookupResponse expected; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key1" + value { status: { code: 5, message: "" } } + } + kv_pairs { + key: "key4" + value { uintset_values { values: 1000 } } + } + kv_pairs { + key: "key5" + value { status: { code: 5, message: "" } } + } + )pb", + &expected); + EXPECT_THAT(response.value(), EqualsProto(expected)); +} + +TEST_F(ShardedLookupTest, GetUInt32ValueSet_EmptyRequest_ReturnsEmptyResponse) { + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [](const std::string& ip) { + return std::make_unique(); + }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = sharded_lookup->GetUInt32ValueSet(GetRequestContext(), {}); + EXPECT_TRUE(response.ok()); + + InternalLookupResponse expected; + EXPECT_THAT(response.value(), EqualsProto(expected)); +} + TEST_F(ShardedLookupTest, GetKeyValueSet_EmptyRequest_ReturnsEmptyResponse) { std::vector> cluster_mappings; for (int i = 0; i < 2; i++) { @@ -648,7 +922,7 @@ TEST_F(ShardedLookupTest, GetKeyValueSet_FailedDownstreamRequest) { } auto shard_manager = ShardManager::Create( num_shards_, std::move(cluster_mappings), - std::make_unique(), [](const std::string& ip) { + std::make_unique(), [this](const std::string& ip) { if (ip != "1") { return std::make_unique(); } @@ -659,6 +933,12 @@ TEST_F(ShardedLookupTest, GetKeyValueSet_FailedDownstreamRequest) { request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); request.set_lookup_sets(true); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, serialized_request, 0)) @@ -694,7 +974,7 @@ TEST_F(ShardedLookupTest, RunQuery_Success) { } auto shard_manager = ShardManager::Create( num_shards_, std::move(cluster_mappings), - std::make_unique(), [](const std::string& ip) { + std::make_unique(), [this](const std::string& ip) { if (ip != "1") { return std::make_unique(); } @@ -706,6 +986,12 @@ TEST_F(ShardedLookupTest, RunQuery_Success) { request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); request.set_lookup_sets(true); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, serialized_request, 0)) @@ -752,7 +1038,7 @@ TEST_F(ShardedLookupTest, RunQuery_MissingKeySet_IgnoresMissingSet_Success) { } auto shard_manager = ShardManager::Create( num_shards_, std::move(cluster_mappings), - std::make_unique(), [](const std::string& ip) { + std::make_unique(), [this](const std::string& ip) { if (ip != "1") { return std::make_unique(); } @@ -764,6 +1050,12 @@ TEST_F(ShardedLookupTest, RunQuery_MissingKeySet_IgnoresMissingSet_Success) { request.mutable_keys()->Assign(key_list_remote.begin(), key_list_remote.end()); request.set_lookup_sets(true); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); const std::string serialized_request = request.SerializeAsString(); EXPECT_CALL(*mock_remote_lookup_client_1, GetValues(_, serialized_request, 0)) @@ -861,6 +1153,113 @@ TEST_F(ShardedLookupTest, RunQuery_EmptyRequest_EmptyResponse) { EXPECT_TRUE(response.value().elements().empty()); } +TEST_F(ShardedLookupTest, RunSetQueryInt_Success) { + InternalLookupResponse local_lookup_response; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key4" + value { uintset_values { values: 1000 } } + } + )pb", + &local_lookup_response); + EXPECT_CALL(mock_local_lookup_, GetUInt32ValueSet(_, _)) + .WillOnce(Return(local_lookup_response)); + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [this](const std::string& ip) { + if (ip != "1") { + return std::make_unique(); + } + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + request.set_lookup_sets(true); + *request.mutable_consented_debug_config() = + GetRequestContext() + .GetRequestLogContext() + .GetConsentedDebugConfiguration(); + *request.mutable_log_context() = + GetRequestContext().GetRequestLogContext().GetLogContext(); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, + GetValues(_, serialized_request, 0)) + .WillOnce([&]() { + InternalLookupResponse resp; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key1" + value { uintset_values { values: 2000 } } + } + )pb", + &resp); + return resp; + }); + return mock_remote_lookup_client_1; + }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = + sharded_lookup->RunSetQueryInt(GetRequestContext(), "key1|key4"); + EXPECT_TRUE(response.ok()); + EXPECT_THAT(response.value().elements(), + testing::UnorderedElementsAreArray({1000, 2000})); +} + +TEST_F(ShardedLookupTest, RunSetQueryInt_ShardedLookupFails_Error) { + InternalLookupResponse local_lookup_response; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key4" + value { uintset_values { values: 1000 } } + } + )pb", + &local_lookup_response); + EXPECT_CALL(mock_local_lookup_, GetUInt32ValueSet(_, _)) + .WillOnce(Return(local_lookup_response)); + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = + ShardManager::Create(num_shards_, std::move(cluster_mappings), + std::make_unique(), + [](const std::string& ip) { return nullptr; }); + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = + sharded_lookup->RunSetQueryInt(GetRequestContext(), "key1|key4"); + EXPECT_FALSE(response.ok()); + EXPECT_THAT(response.status().code(), absl::StatusCode::kInternal); +} + +TEST_F(ShardedLookupTest, RunSetQueryInt_EmptyRequest_EmptyResponse) { + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [](const std::string& ip) { + return std::make_unique(); + }); + + auto sharded_lookup = + CreateShardedLookup(mock_local_lookup_, num_shards_, shard_num_, + *(*shard_manager), key_sharder_); + auto response = sharded_lookup->RunSetQueryInt(GetRequestContext(), ""); + EXPECT_TRUE(response.ok()); + EXPECT_TRUE(response.value().elements().empty()); +} + } // namespace } // namespace kv_server diff --git a/components/query/BUILD.bazel b/components/query/BUILD.bazel index a5c4e2e4..211ea2bc 100644 --- a/components/query/BUILD.bazel +++ b/components/query/BUILD.bazel @@ -23,12 +23,27 @@ package(default_visibility = [ cc_library( name = "sets", srcs = [ + "sets.cc", ], hdrs = [ "sets.h", ], deps = [ "@com_google_absl//absl/container:flat_hash_set", + "@roaring_bitmap//:c_roaring", + ], +) + +cc_test( + name = "sets_test", + size = "small", + srcs = [ + "sets_test.cc", + ], + deps = [ + ":sets", + "@com_google_googletest//:gtest_main", + "@roaring_bitmap//:c_roaring", ], ) @@ -58,6 +73,7 @@ cc_test( ":ast", "@com_google_absl//absl/container:flat_hash_map", "@com_google_googletest//:gtest_main", + "@roaring_bitmap//:c_roaring", ], ) @@ -90,6 +106,7 @@ cc_test( ":parser", ":scanner", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:bind_front", "@com_google_googletest//:gtest_main", ], ) diff --git a/components/query/ast.cc b/components/query/ast.cc index 1eea45d5..a1d6e55d 100644 --- a/components/query/ast.cc +++ b/components/query/ast.cc @@ -20,16 +20,10 @@ #include #include "absl/container/flat_hash_set.h" -#include "components/query/sets.h" namespace kv_server { -namespace { -// Traverses the binary tree starting at root. -// Returns a vector of `Node`s in post order. -// This is represents the infix input as postfix. -// Postfix can then be more easily evaluated. -std::vector PostOrderTraversal(const Node* root) { +std::vector ComputePostfixOrder(const Node* root) { std::vector result; std::vector stack; stack.push_back(root); @@ -48,41 +42,9 @@ std::vector PostOrderTraversal(const Node* root) { return result; } -} // namespace - -void ASTStackVisitor::Visit(const OpNode& node, std::vector& stack) { - KVSetView right = std::move(stack.back()); - stack.pop_back(); - KVSetView left = std::move(stack.back()); - stack.pop_back(); - stack.emplace_back(node.Op(std::move(left), std::move(right))); -} - -void ASTStackVisitor::Visit(const ValueNode& node, - std::vector& stack) { - stack.emplace_back(node.Lookup()); -} - -KVSetView Compute(const std::vector& postorder) { - std::vector stack; - ASTStackVisitor visitor; - // Apply the operations on the postorder stack - for (const auto* node : postorder) { - node->Accept(visitor, stack); - } - return stack.back(); -} - -KVSetView Eval(const Node& node) { - std::vector postorder = PostOrderTraversal(&node); - return Compute(postorder); -} - -void OpNode::Accept(ASTStackVisitor& visitor, - std::vector& stack) const { - visitor.Visit(*this, stack); +std::string ValueNode::Accept(ASTStringVisitor& visitor) const { + return visitor.Visit(*this); } - std::string UnionNode::Accept(ASTStringVisitor& visitor) const { return visitor.Visit(*this); } @@ -93,6 +55,13 @@ std::string IntersectionNode::Accept(ASTStringVisitor& visitor) const { return visitor.Visit(*this); } +void ValueNode::Accept(ASTVisitor& visitor) const { visitor.Visit(*this); } +void UnionNode::Accept(ASTVisitor& visitor) const { visitor.Visit(*this); } +void IntersectionNode::Accept(ASTVisitor& visitor) const { + visitor.Visit(*this); +} +void DifferenceNode::Accept(ASTVisitor& visitor) const { visitor.Visit(*this); } + absl::flat_hash_set OpNode::Keys() const { std::vector nodes; absl::flat_hash_set key_set; @@ -118,21 +87,6 @@ absl::flat_hash_set OpNode::Keys() const { return key_set; } -ValueNode::ValueNode( - absl::AnyInvocable lookup_fn, - std::string key) - : lookup_fn_(absl::bind_front(std::move(lookup_fn), key)), - key_(std::move(key)) {} - -void ValueNode::Accept(ASTStackVisitor& visitor, - std::vector& stack) const { - visitor.Visit(*this, stack); -} - -std::string ValueNode::Accept(ASTStringVisitor& visitor) const { - return visitor.Visit(*this); -} - absl::flat_hash_set ValueNode::Keys() const { // Return a set containing a view into this instances, `key_`. // Be sure that the reference is not to any temp string. @@ -141,6 +95,4 @@ absl::flat_hash_set ValueNode::Keys() const { }; } -KVSetView ValueNode::Lookup() const { return lookup_fn_(); } - } // namespace kv_server diff --git a/components/query/ast.h b/components/query/ast.h index 50b0d3df..b01c0f00 100644 --- a/components/query/ast.h +++ b/components/query/ast.h @@ -16,6 +16,7 @@ #ifndef COMPONENTS_QUERY_AST_H_ #define COMPONENTS_QUERY_AST_H_ + #include #include #include @@ -24,17 +25,16 @@ #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" -#include "absl/functional/bind_front.h" #include "components/query/sets.h" namespace kv_server { -class ASTStackVisitor; -class ASTStringVisitor; +// All set operations using `KVStringSetView` operate on a reference to the data +// in the DB This means that the data in the DB must be locked throughout the +// lifetime of the result. +using KVStringSetView = absl::flat_hash_set; -// All set operations operate on a reference to the data in the DB -// This means that the data in the DB must be locked throughout the lifetime of -// the result. -using KVSetView = absl::flat_hash_set; +class ASTVisitor; +class ASTStringVisitor; class Node { public: @@ -43,26 +43,20 @@ class Node { virtual Node* Right() const { return nullptr; } // Return all Keys associated with ValueNodes in the tree. virtual absl::flat_hash_set Keys() const = 0; - // Uses the Visitor pattern for the concrete class - // to mutate the stack accordingly for `Eval` (ValueNode vs. OpNode) - virtual void Accept(ASTStackVisitor& visitor, - std::vector& stack) const = 0; + virtual void Accept(ASTVisitor& visitor) const = 0; virtual std::string Accept(ASTStringVisitor& visitor) const = 0; }; // The value associated with a `ValueNode` is the set with its associated `key`. class ValueNode : public Node { public: - ValueNode(absl::AnyInvocable lookup_fn, - std::string key); + explicit ValueNode(std::string key) : key_(std::move(key)) {} + std::string_view Key() const { return key_; } absl::flat_hash_set Keys() const override; - KVSetView Lookup() const; - void Accept(ASTStackVisitor& visitor, - std::vector& stack) const override; + void Accept(ASTVisitor& visitor) const override; std::string Accept(ASTStringVisitor& visitor) const override; private: - absl::AnyInvocable lookup_fn_; std::string key_; }; @@ -73,10 +67,6 @@ class OpNode : public Node { absl::flat_hash_set Keys() const override; inline Node* Left() const override { return left_.get(); } inline Node* Right() const override { return right_.get(); } - // Computes the operation over the `left` and `right` nodes. - virtual KVSetView Op(KVSetView left, KVSetView right) const = 0; - void Accept(ASTStackVisitor& visitor, - std::vector& stack) const override; private: std::unique_ptr left_; @@ -85,47 +75,29 @@ class OpNode : public Node { class UnionNode : public OpNode { public: - using OpNode::Accept; using OpNode::OpNode; - inline KVSetView Op(KVSetView left, KVSetView right) const override { - return Union(std::move(left), std::move(right)); - } + void Accept(ASTVisitor& visitor) const override; std::string Accept(ASTStringVisitor& visitor) const override; }; class IntersectionNode : public OpNode { public: - using OpNode::Accept; using OpNode::OpNode; - inline KVSetView Op(KVSetView left, KVSetView right) const override { - return Intersection(std::move(left), std::move(right)); - } + void Accept(ASTVisitor& visitor) const override; std::string Accept(ASTStringVisitor& visitor) const override; }; class DifferenceNode : public OpNode { public: - using OpNode::Accept; using OpNode::OpNode; - inline KVSetView Op(KVSetView left, KVSetView right) const override { - return Difference(std::move(left), std::move(right)); - } + void Accept(ASTVisitor& visitor) const override; std::string Accept(ASTStringVisitor& visitor) const override; }; -// Creates execution plan and runs it. -KVSetView Eval(const Node& node); - -// Responsible for mutating the stack with the given `Node`. -// Avoids downcasting for subclass specific behaviors. -class ASTStackVisitor { - public: - // Applies the operation to the top two values on the stack. - // Replaces the top two values with the result. - void Visit(const OpNode& node, std::vector& stack); - // Pushes the result of `Lookup` to the stack. - void Visit(const ValueNode& node, std::vector& stack); -}; +// Traverses the binary tree starting at root and returns a vector of `Node`s in +// post order. Represents the infix input as postfix which can then be more +// easily evaluated. +std::vector ComputePostfixOrder(const Node* root); // General purpose Vistor capable of returning a string representation of a Node // upon inspection. @@ -137,5 +109,75 @@ class ASTStringVisitor { virtual std::string Visit(const ValueNode&) = 0; }; +// Defines a general AST visitor interface which can be extended to implement +// concrete ast algorithms, e.g., ast evaluation. +class ASTVisitor { + public: + // Entrypoint for running the visitor algorithm on a given AST tree, `root`. + virtual void ConductVisit(const Node& root) = 0; + virtual void Visit(const ValueNode& node) = 0; + virtual void Visit(const UnionNode& node) = 0; + virtual void Visit(const DifferenceNode& node) = 0; + virtual void Visit(const IntersectionNode& node) = 0; +}; + +// Implements AST tree evaluation using iterative post order processing. +template +class ASTPostOrderEvalVisitor final : public ASTVisitor { + public: + explicit ASTPostOrderEvalVisitor( + absl::AnyInvocable lookup_fn) + : lookup_fn_(std::move(lookup_fn)) {} + + void ConductVisit(const Node& root) override { + stack_.clear(); + for (const auto* node : ComputePostfixOrder(&root)) { + node->Accept(*this); + } + } + + void Visit(const ValueNode& node) override { + stack_.push_back(std::move(lookup_fn_(node.Key()))); + } + void Visit(const UnionNode& node) override { Visit(node, Union); } + + void Visit(const DifferenceNode& node) override { + Visit(node, Difference); + } + + void Visit(const IntersectionNode& node) override { + Visit(node, Intersection); + } + + ValueT GetResult() { + if (stack_.empty()) { + return ValueT(); + } + return stack_.back(); + } + + private: + void Visit(const OpNode& node, + absl::AnyInvocable op_fn) { + auto right = std::move(stack_.back()); + stack_.pop_back(); + auto left = std::move(stack_.back()); + stack_.pop_back(); + stack_.push_back(op_fn(std::move(left), std::move(right))); + } + + absl::AnyInvocable lookup_fn_; + std::vector stack_; +}; + +// Accepts an AST representing a set query, creates execution plan and runs it. +template +ValueT Eval(const Node& node, + absl::AnyInvocable lookup_fn) { + auto visitor = ASTPostOrderEvalVisitor(std::move(lookup_fn)); + visitor.ConductVisit(node); + return visitor.GetResult(); +} + } // namespace kv_server #endif // COMPONENTS_QUERY_AST_H_ diff --git a/components/query/ast_test.cc b/components/query/ast_test.cc index 8be4bbd2..1b18b9d4 100644 --- a/components/query/ast_test.cc +++ b/components/query/ast_test.cc @@ -19,6 +19,8 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "roaring.hh" + namespace kv_server { namespace { @@ -29,116 +31,128 @@ const absl::flat_hash_map> {"C", {"c", "d", "e"}}, {"D", {"d", "e", "f"}}, }; +const absl::flat_hash_map kBitsetDb = { + {"A", {1, 2, 3}}, + {"B", {2, 3, 4}}, + {"C", {3, 4, 5}}, + {"D", {4, 5, 6}}, +}; absl::flat_hash_set Lookup(std::string_view key) { - const auto& it = kDb.find(key); - if (it != kDb.end()) { + if (const auto& it = kDb.find(key); it != kDb.end()) { + return it->second; + } + return {}; +} + +roaring::Roaring BitsetLookup(std::string_view key) { + if (const auto& it = kBitsetDb.find(key); it != kBitsetDb.end()) { return it->second; } return {}; } TEST(AstTest, Value) { - ValueNode value(Lookup, "A"); - EXPECT_EQ(Eval(value), Lookup("A")); - ValueNode value2(Lookup, "B"); - EXPECT_EQ(Eval(value2), Lookup("B")); - ValueNode value3(Lookup, "C"); - EXPECT_EQ(Eval(value3), Lookup("C")); - ValueNode value4(Lookup, "D"); - EXPECT_EQ(Eval(value4), Lookup("D")); - ValueNode value5(Lookup, "E"); - EXPECT_EQ(Eval(value5), Lookup("E")); + ValueNode value("A"); + EXPECT_EQ(Eval(value, Lookup), Lookup("A")); + ValueNode value2("B"); + EXPECT_EQ(Eval(value2, Lookup), Lookup("B")); + ValueNode value3("C"); + EXPECT_EQ(Eval(value3, Lookup), Lookup("C")); + ValueNode value4("D"); + EXPECT_EQ(Eval(value4, Lookup), Lookup("D")); + ValueNode value5("E"); + EXPECT_EQ(Eval(value5, Lookup), Lookup("E")); } TEST(AstTest, Union) { - std::unique_ptr a = std::make_unique(Lookup, "A"); - std::unique_ptr b = std::make_unique(Lookup, "B"); + std::unique_ptr a = std::make_unique("A"); + std::unique_ptr b = std::make_unique("B"); UnionNode op(std::move(a), std::move(b)); absl::flat_hash_set expected = {"a", "b", "c", "d"}; - EXPECT_EQ(Eval(op), expected); + EXPECT_EQ(Eval(op, Lookup), expected); } TEST(AstTest, UnionSelf) { - std::unique_ptr a = std::make_unique(Lookup, "A"); - std::unique_ptr a2 = std::make_unique(Lookup, "A"); + std::unique_ptr a = std::make_unique("A"); + std::unique_ptr a2 = std::make_unique("A"); UnionNode op(std::move(a), std::move(a2)); absl::flat_hash_set expected = {"a", "b", "c"}; - EXPECT_EQ(Eval(op), expected); + EXPECT_EQ(Eval(op, Lookup), expected); } TEST(AstTest, Intersection) { - std::unique_ptr a = std::make_unique(Lookup, "A"); - std::unique_ptr b = std::make_unique(Lookup, "B"); + std::unique_ptr a = std::make_unique("A"); + std::unique_ptr b = std::make_unique("B"); IntersectionNode op(std::move(a), std::move(b)); absl::flat_hash_set expected = {"b", "c"}; - EXPECT_EQ(Eval(op), expected); + EXPECT_EQ(Eval(op, Lookup), expected); } TEST(AstTest, IntersectionSelf) { - std::unique_ptr a = std::make_unique(Lookup, "A"); - std::unique_ptr a2 = std::make_unique(Lookup, "A"); + std::unique_ptr a = std::make_unique("A"); + std::unique_ptr a2 = std::make_unique("A"); IntersectionNode op(std::move(a), std::move(a2)); absl::flat_hash_set expected = {"a", "b", "c"}; - EXPECT_EQ(Eval(op), expected); + EXPECT_EQ(Eval(op, Lookup), expected); } TEST(AstTest, Difference) { - std::unique_ptr a = std::make_unique(Lookup, "A"); - std::unique_ptr b = std::make_unique(Lookup, "B"); + std::unique_ptr a = std::make_unique("A"); + std::unique_ptr b = std::make_unique("B"); DifferenceNode op(std::move(a), std::move(b)); absl::flat_hash_set expected = {"a"}; - EXPECT_EQ(Eval(op), expected); + EXPECT_EQ(Eval(op, Lookup), expected); - std::unique_ptr a2 = std::make_unique(Lookup, "A"); - std::unique_ptr b2 = std::make_unique(Lookup, "B"); + std::unique_ptr a2 = std::make_unique("A"); + std::unique_ptr b2 = std::make_unique("B"); DifferenceNode op2(std::move(b2), std::move(a2)); absl::flat_hash_set expected2 = {"d"}; - EXPECT_EQ(Eval(op2), expected2); + EXPECT_EQ(Eval(op2, Lookup), expected2); } TEST(AstTest, DifferenceSelf) { - std::unique_ptr a = std::make_unique(Lookup, "A"); - std::unique_ptr a2 = std::make_unique(Lookup, "A"); + std::unique_ptr a = std::make_unique("A"); + std::unique_ptr a2 = std::make_unique("A"); DifferenceNode op(std::move(a), std::move(a2)); absl::flat_hash_set expected = {}; - EXPECT_EQ(Eval(op), expected); + EXPECT_EQ(Eval(op, Lookup), expected); } TEST(AstTest, All) { // (A-B) | (C&D) = // {a} | {d,e} = // {a, d, e} - std::unique_ptr a = std::make_unique(Lookup, "A"); - std::unique_ptr b = std::make_unique(Lookup, "B"); - std::unique_ptr c = std::make_unique(Lookup, "C"); - std::unique_ptr d = std::make_unique(Lookup, "D"); + std::unique_ptr a = std::make_unique("A"); + std::unique_ptr b = std::make_unique("B"); + std::unique_ptr c = std::make_unique("C"); + std::unique_ptr d = std::make_unique("D"); std::unique_ptr left = std::make_unique(std::move(a), std::move(b)); std::unique_ptr right = std::make_unique(std::move(c), std::move(d)); UnionNode center(std::move(left), std::move(right)); absl::flat_hash_set expected = {"a", "d", "e"}; - EXPECT_EQ(Eval(center), expected); + EXPECT_EQ(Eval(center, Lookup), expected); } TEST(AstTest, ValueNodeKeys) { - ValueNode v(Lookup, "A"); + ValueNode v("A"); EXPECT_THAT(v.Keys(), testing::UnorderedElementsAre("A")); } TEST(AstTest, OpNodeKeys) { - std::unique_ptr a = std::make_unique(Lookup, "A"); - std::unique_ptr b = std::make_unique(Lookup, "B"); + std::unique_ptr a = std::make_unique("A"); + std::unique_ptr b = std::make_unique("B"); DifferenceNode op(std::move(b), std::move(a)); EXPECT_THAT(op.Keys(), testing::UnorderedElementsAre("A", "B")); } TEST(AstTest, DupeNodeKeys) { - std::unique_ptr a = std::make_unique(Lookup, "A"); - std::unique_ptr b = std::make_unique(Lookup, "B"); - std::unique_ptr c = std::make_unique(Lookup, "C"); - std::unique_ptr a2 = std::make_unique(Lookup, "A"); + std::unique_ptr a = std::make_unique("A"); + std::unique_ptr b = std::make_unique("B"); + std::unique_ptr c = std::make_unique("C"); + std::unique_ptr a2 = std::make_unique("A"); std::unique_ptr left = std::make_unique(std::move(a), std::move(b)); std::unique_ptr right = @@ -147,5 +161,63 @@ TEST(AstTest, DupeNodeKeys) { EXPECT_THAT(center.Keys(), testing::UnorderedElementsAre("A", "B", "C")); } +TEST(ASTEvalTest, VerifyValueNodeEvaluation) { + { + ValueNode root("DOES_NOT_EXIST"); + EXPECT_TRUE(Eval(root, Lookup).empty()); + } + ValueNode root("A"); + EXPECT_THAT(Eval(root, Lookup), + testing::UnorderedElementsAre("a", "b", "c")); + EXPECT_EQ(Eval(root, BitsetLookup), + roaring::Roaring({1, 2, 3})); +} + +TEST(ASTEvalTest, VerifyUnionNodeEvaluation) { + auto a = std::make_unique("A"); + auto b = std::make_unique("B"); + UnionNode root(std::move(a), std::move(b)); + EXPECT_THAT(Eval(root, Lookup), + testing::UnorderedElementsAre("a", "b", "c", "d")); + EXPECT_EQ(Eval(root, BitsetLookup), + roaring::Roaring({1, 2, 3, 4})); +} + +TEST(ASTEvalTest, VerifyDifferenceNodeEvaluation) { + auto a = std::make_unique("A"); + auto b = std::make_unique("B"); + DifferenceNode root(std::move(a), std::move(b)); + EXPECT_THAT(Eval(root, Lookup), + testing::UnorderedElementsAre("a")); + EXPECT_EQ(Eval(root, BitsetLookup), roaring::Roaring({1})); +} + +TEST(ASTEvalTest, VerifyIntersectionNodeEvaluation) { + auto a = std::make_unique("A"); + auto b = std::make_unique("B"); + IntersectionNode root(std::move(a), std::move(b)); + EXPECT_THAT(Eval(root, Lookup), + testing::UnorderedElementsAre("b", "c")); + EXPECT_EQ(Eval(root, BitsetLookup), + roaring::Roaring({2, 3})); +} + +TEST(ASTEvalTest, VerifyComplexNodeEvaluation) { + // (A-B) | (C&D) = + // {a} | {d,e} = + // {a, d, e} + auto a = std::make_unique("A"); + auto b = std::make_unique("B"); + auto c = std::make_unique("C"); + auto d = std::make_unique("D"); + auto left = std::make_unique(std::move(a), std::move(b)); + auto right = std::make_unique(std::move(c), std::move(d)); + UnionNode root(std::move(left), std::move(right)); + EXPECT_THAT(Eval(root, Lookup), + testing::UnorderedElementsAre("a", "d", "e")); + EXPECT_THAT(Eval(root, BitsetLookup), + roaring::Roaring({1, 4, 5})); +} + } // namespace } // namespace kv_server diff --git a/components/query/driver.cc b/components/query/driver.cc index 20c53217..3d388326 100644 --- a/components/query/driver.cc +++ b/components/query/driver.cc @@ -14,38 +14,14 @@ #include "components/query/driver.h" -#include #include -#include "absl/container/flat_hash_set.h" -#include "absl/functional/bind_front.h" #include "components/query/ast.h" namespace kv_server { -Driver::Driver(absl::AnyInvocable( - std::string_view key) const> - lookup_fn) - : lookup_fn_(std::move(lookup_fn)) {} - -absl::flat_hash_set Driver::Lookup( - std::string_view key) const { - return lookup_fn_(key); -} - void Driver::SetAst(std::unique_ptr ast) { ast_ = std::move(ast); } -absl::StatusOr> Driver::GetResult() - const { - if (!status_.ok()) { - return status_; - } - if (ast_ == nullptr) { - return absl::flat_hash_set(); - } - return Eval(*ast_); -} - void Driver::SetError(std::string error) { status_ = absl::InvalidArgumentError(std::move(error)); } diff --git a/components/query/driver.h b/components/query/driver.h index 728b3593..794daa86 100644 --- a/components/query/driver.h +++ b/components/query/driver.h @@ -20,8 +20,8 @@ #include #include #include +#include -#include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -35,23 +35,19 @@ namespace kv_server { // * Executing the query // * Storing the result // Typical usage: -// Driver driver(LookupFn); +// Driver driver; // std::istringstream stream(query); // Scanner scanner(stream); // Parser parse(driver, scanner); // int parse_result = parse(); -// auto result = driver.GetResult(); +// auto result = driver.GetResult(LookupFn); // parse_result is only expected to be non-zero when result is a failure. class Driver { public: - // `lookup_fn` returns the set associated with the provided key. - // If no key is present, an empty set should be returned. - explicit Driver(absl::AnyInvocable( - std::string_view key) const> - lookup_fn); - // The result contains views of the data within the DB. - absl::StatusOr> GetResult() const; + template + absl::StatusOr EvaluateQuery( + absl::AnyInvocable lookup_fn) const; // Returns the the `Node` associated with `SetAst` // or nullptr if unset. @@ -62,16 +58,22 @@ class Driver { void SetError(std::string error); void ClearError() { status_ = absl::OkStatus(); } - // Looks up the set which contains a view of the DB data. - absl::flat_hash_set Lookup(std::string_view key) const; - private: - absl::AnyInvocable(std::string_view key) - const> - lookup_fn_; std::unique_ptr ast_; absl::Status status_ = absl::OkStatus(); }; +template +absl::StatusOr Driver::EvaluateQuery( + absl::AnyInvocable lookup_fn) const { + if (!status_.ok()) { + return status_; + } + if (ast_ == nullptr) { + return SetType(); + } + return Eval(*ast_, std::move(lookup_fn)); +} + } // namespace kv_server #endif // COMPONENTS_QUERY_DRIVER_H_ diff --git a/components/query/driver_test.cc b/components/query/driver_test.cc index 840478db..4404402a 100644 --- a/components/query/driver_test.cc +++ b/components/query/driver_test.cc @@ -27,22 +27,28 @@ namespace kv_server { namespace { +const absl::flat_hash_map> + kStringSetDB = { + {"A", {"a", "b", "c"}}, + {"B", {"b", "c", "d"}}, + {"C", {"c", "d", "e"}}, + {"D", {"d", "e", "f"}}, +}; + +absl::flat_hash_set Lookup(std::string_view key) { + if (const auto& it = kStringSetDB.find(key); it != kStringSetDB.end()) { + return it->second; + } + return {}; +} + class DriverTest : public ::testing::Test { protected: void SetUp() override { - driver_ = - std::make_unique(absl::bind_front(&DriverTest::Lookup, this)); + driver_ = std::make_unique(); for (int i = 1000; i < 1; i++) { - drivers_.emplace_back(absl::bind_front(&DriverTest::Lookup, this)); - } - } - - absl::flat_hash_set Lookup(std::string_view key) { - const auto& it = db_.find(key); - if (it != db_.end()) { - return it->second; + drivers_.emplace_back(); } - return {}; } void Parse(const std::string& query) { @@ -54,19 +60,13 @@ class DriverTest : public ::testing::Test { std::unique_ptr driver_; std::vector drivers_; - const absl::flat_hash_map> - db_ = { - {"A", {"a", "b", "c"}}, - {"B", {"b", "c", "d"}}, - {"C", {"c", "d", "e"}}, - {"D", {"d", "e", "f"}}, - }; }; TEST_F(DriverTest, EmptyQuery) { Parse(""); EXPECT_EQ(driver_->GetRootNode(), nullptr); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); absl::flat_hash_set expected; EXPECT_EQ(*result, expected); @@ -75,104 +75,121 @@ TEST_F(DriverTest, EmptyQuery) { TEST_F(DriverTest, InvalidTokensQuery) { Parse("!! hi"); EXPECT_EQ(driver_->GetRootNode(), nullptr); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } TEST_F(DriverTest, MissingOperatorVar) { Parse("A A"); EXPECT_EQ(driver_->GetRootNode(), nullptr); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } TEST_F(DriverTest, MissingOperatorExp) { Parse("(A) (A)"); EXPECT_EQ(driver_->GetRootNode(), nullptr); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } TEST_F(DriverTest, InvalidOp) { Parse("A UNION "); EXPECT_EQ(driver_->GetRootNode(), nullptr); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); } TEST_F(DriverTest, KeyOnly) { Parse("A"); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c")); Parse("B"); - result = driver_->GetResult(); + result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c", "d")); } TEST_F(DriverTest, Union) { Parse("A UNION B"); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c", "d")); Parse("A | B"); - result = driver_->GetResult(); + result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c", "d")); } TEST_F(DriverTest, Difference) { Parse("A - B"); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); Parse("A DIFFERENCE B"); - result = driver_->GetResult(); + result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); Parse("B - A"); - result = driver_->GetResult(); + result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("d")); Parse("B DIFFERENCE A"); - result = driver_->GetResult(); + result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("d")); } TEST_F(DriverTest, Intersection) { Parse("A INTERSECTION B"); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c")); Parse("A & B"); - result = driver_->GetResult(); + result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c")); } TEST_F(DriverTest, OrderOfOperations) { Parse("A - B - C"); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); Parse("A - (B - C)"); - result = driver_->GetResult(); + result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "c")); } TEST_F(DriverTest, MultipleOperations) { Parse("(A-B) | (C&D)"); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "d", "e")); } @@ -186,7 +203,8 @@ TEST_F(DriverTest, MultipleThreads) { Scanner scanner(stream); Parser parse(*driver, scanner); parse(); - auto result = driver->GetResult(); + auto result = + driver->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "d", "e")); }; @@ -204,23 +222,27 @@ TEST_F(DriverTest, MultipleThreads) { TEST_F(DriverTest, EmptyResults) { // no overlap Parse("A & D"); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_EQ(result->size(), 0); // missing key Parse("A & E"); - result = driver_->GetResult(); + result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); EXPECT_EQ(result->size(), 0); } TEST_F(DriverTest, DriverErrorsClearedOnParse) { Parse("A &"); - auto result = driver_->GetResult(); + auto result = + driver_->EvaluateQuery>(Lookup); ASSERT_FALSE(result.ok()); Parse("A"); - result = driver_->GetResult(); + result = + driver_->EvaluateQuery>(Lookup); ASSERT_TRUE(result.ok()); } diff --git a/components/query/parser.yy b/components/query/parser.yy index 18838a26..32f73221 100644 --- a/components/query/parser.yy +++ b/components/query/parser.yy @@ -41,7 +41,6 @@ #include "components/query/parser.h" #include "components/query/driver.h" #include "components/query/scanner.h" - #include "absl/functional/bind_front.h" #undef yylex #define yylex(x) scanner.yylex(x) @@ -83,7 +82,7 @@ exp: term {$$ = std::move($1);} | ERROR { driver.SetError("Invalid token: " + $1); YYERROR;} ; -term: VAR { $$ = std::make_unique(absl::bind_front(&Driver::Lookup, &driver), std::move($1)); } +term: VAR { $$ = std::make_unique(std::move($1)); } ; %% diff --git a/components/query/scanner_test.cc b/components/query/scanner_test.cc index 7cd75238..ce0c3739 100644 --- a/components/query/scanner_test.cc +++ b/components/query/scanner_test.cc @@ -16,7 +16,6 @@ #include #include -#include #include #include "absl/strings/str_join.h" @@ -26,16 +25,10 @@ namespace kv_server { namespace { -absl::flat_hash_set NeverUsedLookup(std::string_view key) { - // Should never be called - assert(0); - return {}; -} - TEST(ScannerTest, Empty) { std::istringstream stream(""); Scanner scanner(stream); - Driver driver(NeverUsedLookup); + Driver driver; auto t = scanner.yylex(driver); ASSERT_EQ(t.token(), Parser::token::YYEOF); } @@ -43,7 +36,7 @@ TEST(ScannerTest, Empty) { TEST(ScannerTest, Var) { std::istringstream stream("FOO foo Foo123"); Scanner scanner(stream); - Driver driver(NeverUsedLookup); + Driver driver; // first token auto t1 = scanner.yylex(driver); ASSERT_EQ(t1.token(), Parser::token::VAR); @@ -67,7 +60,7 @@ TEST(ScannerTest, Var) { TEST(ScannerTest, Parens) { std::istringstream stream("()"); Scanner scanner(stream); - Driver driver(NeverUsedLookup); + Driver driver; auto t1 = scanner.yylex(driver); ASSERT_EQ(t1.token(), Parser::token::LPAREN); @@ -80,7 +73,7 @@ TEST(ScannerTest, Parens) { TEST(ScannerTest, WhitespaceVar) { std::istringstream stream(" FOO "); Scanner scanner(stream); - Driver driver(NeverUsedLookup); + Driver driver; // first token auto t1 = scanner.yylex(driver); ASSERT_EQ(t1.token(), Parser::token::VAR); @@ -94,7 +87,7 @@ TEST(ScannerTest, NotAlphaNumVar) { std::string token_list = absl::StrJoin(expected_vars, " "); std::istringstream stream(token_list); Scanner scanner(stream); - Driver driver(NeverUsedLookup); + Driver driver; for (const auto& expected_var : expected_vars) { auto token = scanner.yylex(driver); @@ -109,7 +102,7 @@ TEST(ScannerTest, QuotedVar) { std::istringstream stream( " \"A1:Stuff\" \"A-B:C&D=E|F\" \"A+B\" \"A/B\" \"A\" "); Scanner scanner(stream); - Driver driver(NeverUsedLookup); + Driver driver; auto t1 = scanner.yylex(driver); ASSERT_EQ(t1.token(), Parser::token::VAR); @@ -133,7 +126,7 @@ TEST(ScannerTest, QuotedVar) { TEST(ScannerTest, EmptyQuotedInvalid) { std::istringstream stream(" \"\" "); Scanner scanner(stream); - Driver driver(NeverUsedLookup); + Driver driver; // Since it there is no valid match, we have 2 errors // for each of the double quotes. @@ -149,7 +142,7 @@ TEST(ScannerTest, EmptyQuotedInvalid) { TEST(ScannerTest, Operators) { std::istringstream stream("| UNION & INTERSECTION - DIFFERENCE"); Scanner scanner(stream); - Driver driver(NeverUsedLookup); + Driver driver; auto t1 = scanner.yylex(driver); ASSERT_EQ(t1.token(), Parser::token::UNION); @@ -173,7 +166,7 @@ TEST(ScannerTest, Operators) { TEST(ScannerTest, Error) { std::istringstream stream("!"); Scanner scanner(stream); - Driver driver(NeverUsedLookup); + Driver driver; auto t1 = scanner.yylex(driver); ASSERT_EQ(t1.token(), Parser::token::ERROR); diff --git a/components/query/sets.cc b/components/query/sets.cc new file mode 100644 index 00000000..2223e8e2 --- /dev/null +++ b/components/query/sets.cc @@ -0,0 +1,73 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "components/query/sets.h" + +#include + +namespace kv_server { + +template <> +absl::flat_hash_set Union( + absl::flat_hash_set&& left, + absl::flat_hash_set&& right) { + auto& small = left.size() <= right.size() ? left : right; + auto& big = left.size() <= right.size() ? right : left; + big.insert(small.begin(), small.end()); + return std::move(big); +} + +template <> +absl::flat_hash_set Intersection( + absl::flat_hash_set&& left, + absl::flat_hash_set&& right) { + auto& small = left.size() <= right.size() ? left : right; + const auto& big = left.size() <= right.size() ? right : left; + // Traverse the smaller set removing what is not in both. + absl::erase_if(small, [&big](const std::string_view& elem) { + return !big.contains(elem); + }); + return std::move(small); +} + +template <> +absl::flat_hash_set Difference( + absl::flat_hash_set&& left, + absl::flat_hash_set&& right) { + // Remove all elements in right from left. + for (const auto& element : right) { + left.erase(element); + } + return std::move(left); +} + +template <> +roaring::Roaring Union(roaring::Roaring&& left, roaring::Roaring&& right) { + return left | right; +} + +template <> +roaring::Roaring Intersection(roaring::Roaring&& left, + roaring::Roaring&& right) { + return left & right; +} + +template <> +roaring::Roaring Difference(roaring::Roaring&& left, roaring::Roaring&& right) { + return left - right; +} + +} // namespace kv_server diff --git a/components/query/sets.h b/components/query/sets.h index a29f10b1..a6fcecc8 100644 --- a/components/query/sets.h +++ b/components/query/sets.h @@ -17,39 +17,46 @@ #ifndef COMPONENTS_QUERY_SETS_H_ #define COMPONENTS_QUERY_SETS_H_ -#include - #include "absl/container/flat_hash_set.h" +#include "roaring.hh" + namespace kv_server { -template -absl::flat_hash_set Union(absl::flat_hash_set&& left, - absl::flat_hash_set&& right) { - auto& small = left.size() <= right.size() ? left : right; - auto& big = left.size() <= right.size() ? right : left; - big.insert(small.begin(), small.end()); - return std::move(big); -} - -template -absl::flat_hash_set Intersection(absl::flat_hash_set&& left, - absl::flat_hash_set&& right) { - auto& small = left.size() <= right.size() ? left : right; - const auto& big = left.size() <= right.size() ? right : left; - // Traverse the smaller set removing what is not in both. - absl::erase_if(small, [&big](const T& elem) { return !big.contains(elem); }); - return std::move(small); -} - -template -absl::flat_hash_set Difference(absl::flat_hash_set&& left, - absl::flat_hash_set&& right) { - // Remove all elements in right from left. - for (const auto& element : right) { - left.erase(element); - } - return std::move(left); -} + +template +SetT Union(SetT&&, SetT&&); + +template +SetT Intersection(SetT&&, SetT&&); + +template +SetT Difference(SetT&&, SetT&&); + +template <> +absl::flat_hash_set Union( + absl::flat_hash_set&& left, + absl::flat_hash_set&& right); + +template <> +absl::flat_hash_set Intersection( + absl::flat_hash_set&& left, + absl::flat_hash_set&& right); + +template <> +absl::flat_hash_set Difference( + absl::flat_hash_set&& left, + absl::flat_hash_set&& right); + +template <> +roaring::Roaring Union(roaring::Roaring&& left, roaring::Roaring&& right); + +template <> +roaring::Roaring Intersection(roaring::Roaring&& left, + roaring::Roaring&& right); + +// Subtracts `right` from `left`. +template <> +roaring::Roaring Difference(roaring::Roaring&& left, roaring::Roaring&& right); } // namespace kv_server #endif // COMPONENTS_QUERY_SETS_H_ diff --git a/components/query/sets_test.cc b/components/query/sets_test.cc new file mode 100644 index 00000000..6f4ee250 --- /dev/null +++ b/components/query/sets_test.cc @@ -0,0 +1,64 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "components/query/sets.h" + +#include + +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +TEST(SetsTest, VerifyBitwiseUnion) { + roaring::Roaring left({1, 2, 3, 4, 5}); + roaring::Roaring right({6, 7, 8, 9, 10}); + EXPECT_EQ(Union(std::move(left), std::move(right)), + roaring::Roaring({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); +} + +TEST(SetsTest, VerifyBitwiseIntersection) { + { + roaring::Roaring left({1, 2, 3, 4, 5}); + roaring::Roaring right({6, 7, 8, 9, 10}); + EXPECT_EQ(Intersection(std::move(left), std::move(right)), + roaring::Roaring()); + } + { + roaring::Roaring left({1, 2, 3, 4, 5}); + roaring::Roaring right({1, 2, 3, 9, 10}); + EXPECT_EQ(Intersection(std::move(left), std::move(right)), + roaring::Roaring({1, 2, 3})); + } +} + +TEST(SetsTest, VerifyBitwiseDifference) { + { + roaring::Roaring left({1, 2, 3, 4, 5}); + roaring::Roaring right({6, 7, 8, 9, 10}); + EXPECT_EQ(Difference(std::move(left), std::move(right)), + roaring::Roaring({1, 2, 3, 4, 5})); + } + { + roaring::Roaring left({1, 2, 3, 4, 5}); + roaring::Roaring right({1, 2, 3, 9, 10}); + EXPECT_EQ(Difference(std::move(left), std::move(right)), + roaring::Roaring({4, 5})); + } +} + +} // namespace +} // namespace kv_server diff --git a/components/sharding/BUILD.bazel b/components/sharding/BUILD.bazel index 940030e0..6a4c57b3 100644 --- a/components/sharding/BUILD.bazel +++ b/components/sharding/BUILD.bazel @@ -33,6 +33,7 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", ], ) @@ -80,6 +81,7 @@ cc_library( "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", ], ) diff --git a/components/sharding/cluster_mappings_manager.cc b/components/sharding/cluster_mappings_manager.cc index 826422fc..ecf10b28 100644 --- a/components/sharding/cluster_mappings_manager.cc +++ b/components/sharding/cluster_mappings_manager.cc @@ -21,12 +21,14 @@ namespace kv_server { ClusterMappingsManager::ClusterMappingsManager( std::string environment, int32_t num_shards, - InstanceClient& instance_client, std::unique_ptr sleep_for, - int32_t update_interval_millis) + InstanceClient& instance_client, + privacy_sandbox::server_common::log::PSLogContext& log_context, + std::unique_ptr sleep_for, int32_t update_interval_millis) : environment_{std::move(environment)}, num_shards_{num_shards}, instance_client_{instance_client}, thread_manager_(ThreadManager::Create("Cluster mappings updater")), + log_context_(log_context), sleep_for_(std::move(sleep_for)), update_interval_millis_(update_interval_millis) { CHECK_GT(num_shards, 1) << "num_shards for ShardedLookup must be > 1"; @@ -53,4 +55,9 @@ void ClusterMappingsManager::Watch(ShardManager& shard_manager) { shard_manager.InsertBatch(GetClusterMappings()); } } + +privacy_sandbox::server_common::log::PSLogContext& +ClusterMappingsManager::GetLogContext() const { + return log_context_; +} } // namespace kv_server diff --git a/components/sharding/cluster_mappings_manager.h b/components/sharding/cluster_mappings_manager.h index 0c6b8ebc..597e6032 100644 --- a/components/sharding/cluster_mappings_manager.h +++ b/components/sharding/cluster_mappings_manager.h @@ -46,6 +46,7 @@ class ClusterMappingsManager { ClusterMappingsManager( std::string environment, int32_t num_shards, InstanceClient& instance_client, + privacy_sandbox::server_common::log::PSLogContext& log_context, std::unique_ptr sleep_for = std::make_unique(), int32_t update_interval_millis = 1000); // Retreives cluster mappings for the given `environment`, which are @@ -60,15 +61,20 @@ class ClusterMappingsManager { bool IsRunning() const; static std::unique_ptr Create( std::string environment, int32_t num_shards, - InstanceClient& instance_client, ParameterFetcher& parameter_fetcher); + InstanceClient& instance_client, ParameterFetcher& parameter_fetcher, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); protected: void Watch(ShardManager& shard_manager); + privacy_sandbox::server_common::log::PSLogContext& GetLogContext() const; std::string environment_; int32_t num_shards_; InstanceClient& instance_client_; std::unique_ptr thread_manager_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; std::unique_ptr sleep_for_; int32_t update_interval_millis_; }; diff --git a/components/sharding/cluster_mappings_manager_aws.cc b/components/sharding/cluster_mappings_manager_aws.cc index f0ae14cc..e2f45b9f 100644 --- a/components/sharding/cluster_mappings_manager_aws.cc +++ b/components/sharding/cluster_mappings_manager_aws.cc @@ -23,10 +23,11 @@ class AwsClusterMappingsManager : public ClusterMappingsManager { AwsClusterMappingsManager( std::string environment, int32_t num_shards, InstanceClient& instance_client, + privacy_sandbox::server_common::log::PSLogContext& log_context, std::unique_ptr sleep_for = std::make_unique(), int32_t update_interval_millis = 1000) : ClusterMappingsManager(std::move(environment), num_shards, - instance_client), + instance_client, log_context), asg_regex_{std::regex(absl::StrCat("kv-server-", environment_, R"(-(\d+)-instance-asg)"))} {} @@ -46,8 +47,8 @@ class AwsClusterMappingsManager : public ClusterMappingsManager { describe_instance_group_input); }, "DescribeInstanceGroupInstances", - LogStatusSafeMetricsFn()); - + LogStatusSafeMetricsFn(), + GetLogContext()); return GroupInstancesToClusterMappings(instance_group_instances); } @@ -81,8 +82,8 @@ class AwsClusterMappingsManager : public ClusterMappingsManager { [&instance_client, &instance_ids] { return instance_client.DescribeInstances(instance_ids); }, - "DescribeInstances", - LogStatusSafeMetricsFn()); + "DescribeInstances", LogStatusSafeMetricsFn(), + GetLogContext()); absl::flat_hash_map mapping; for (const auto& instance : instances_detailed_info) { @@ -123,9 +124,10 @@ class AwsClusterMappingsManager : public ClusterMappingsManager { std::unique_ptr ClusterMappingsManager::Create( std::string environment, int32_t num_shards, - InstanceClient& instance_client, ParameterFetcher& parameter_fetcher) { - return std::make_unique(environment, num_shards, - instance_client); + InstanceClient& instance_client, ParameterFetcher& parameter_fetcher, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + return std::make_unique( + environment, num_shards, instance_client, log_context); } } // namespace kv_server diff --git a/components/sharding/cluster_mappings_manager_gcp.cc b/components/sharding/cluster_mappings_manager_gcp.cc index 67da5f33..4042cb83 100644 --- a/components/sharding/cluster_mappings_manager_gcp.cc +++ b/components/sharding/cluster_mappings_manager_gcp.cc @@ -26,11 +26,12 @@ constexpr std::string_view kInitializedTag = "initialized"; class GcpClusterMappingsManager : public ClusterMappingsManager { public: - GcpClusterMappingsManager(std::string environment, int32_t num_shards, - InstanceClient& instance_client, - std::string project_id) + GcpClusterMappingsManager( + std::string environment, int32_t num_shards, + InstanceClient& instance_client, std::string project_id, + privacy_sandbox::server_common::log::PSLogContext& log_context) : ClusterMappingsManager(std::move(environment), num_shards, - instance_client), + instance_client, log_context), project_id_{project_id} {} std::vector> GetClusterMappings() override { @@ -42,7 +43,8 @@ class GcpClusterMappingsManager : public ClusterMappingsManager { describe_instance_group_input); }, "DescribeInstanceGroupInstances", - LogStatusSafeMetricsFn()); + LogStatusSafeMetricsFn(), + GetLogContext()); return GroupInstancesToClusterMappings(instance_group_instances); } @@ -92,11 +94,13 @@ class GcpClusterMappingsManager : public ClusterMappingsManager { std::unique_ptr ClusterMappingsManager::Create( std::string environment, int32_t num_shards, - InstanceClient& instance_client, ParameterFetcher& parameter_fetcher) { + InstanceClient& instance_client, ParameterFetcher& parameter_fetcher, + privacy_sandbox::server_common::log::PSLogContext& log_context) { std::string project_id = parameter_fetcher.GetParameter(kProjectIdParameterName); return std::make_unique( - environment, num_shards, instance_client, std::move(project_id)); + environment, num_shards, instance_client, std::move(project_id), + log_context); } } // namespace kv_server diff --git a/components/sharding/shard_manager.cc b/components/sharding/shard_manager.cc index 3a7e567f..783a977c 100644 --- a/components/sharding/shard_manager.cc +++ b/components/sharding/shard_manager.cc @@ -50,10 +50,12 @@ class ShardManagerImpl : public ShardManager { int32_t num_shards, std::function(const std::string& ip)> client_factory, - std::unique_ptr random_generator) + std::unique_ptr random_generator, + privacy_sandbox::server_common::log::PSLogContext& log_context) : num_shards_{num_shards}, client_factory_{client_factory}, - random_generator_{std::move(random_generator)} {} + random_generator_{std::move(random_generator)}, + log_context_(log_context) {} // taking in a set to exclude duplicates. // set doesn't have an O(1) lookup --> converting to vector. @@ -109,6 +111,7 @@ class ShardManagerImpl : public ShardManager { std::function(const std::string& ip)> client_factory_; std::unique_ptr random_generator_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; absl::Status ValidateMapping( @@ -140,7 +143,8 @@ absl::StatusOr> ShardManager::Create( int32_t num_shards, privacy_sandbox::server_common::KeyFetcherManagerInterface& key_fetcher_manager, - const std::vector>& cluster_mappings) { + const std::vector>& cluster_mappings, + privacy_sandbox::server_common::log::PSLogContext& log_context) { auto validationStatus = ValidateMapping(num_shards, cluster_mappings); if (!validationStatus.ok()) { return validationStatus; @@ -150,7 +154,7 @@ absl::StatusOr> ShardManager::Create( [&key_fetcher_manager](const std::string& ip) { return RemoteLookupClient::Create(ip, key_fetcher_manager); }, - std::make_unique()); + std::make_unique(), log_context); shard_manager->InsertBatch(std::move(cluster_mappings)); return shard_manager; } @@ -160,13 +164,15 @@ absl::StatusOr> ShardManager::Create( const std::vector>& cluster_mappings, std::unique_ptr random_generator, std::function(const std::string& ip)> - client_factory) { + client_factory, + privacy_sandbox::server_common::log::PSLogContext& log_context) { auto validationStatus = ValidateMapping(num_shards, cluster_mappings); if (!validationStatus.ok()) { return validationStatus; } auto shard_manager = std::make_unique( - cluster_mappings.size(), client_factory, std::move(random_generator)); + cluster_mappings.size(), client_factory, std::move(random_generator), + log_context); shard_manager->InsertBatch(std::move(cluster_mappings)); return shard_manager; } diff --git a/components/sharding/shard_manager.h b/components/sharding/shard_manager.h index ea4ff8ca..b1017148 100644 --- a/components/sharding/shard_manager.h +++ b/components/sharding/shard_manager.h @@ -25,6 +25,7 @@ #include "absl/container/flat_hash_set.h" #include "components/internal_server/remote_lookup_client.h" +#include "src/logger/request_context_logger.h" namespace kv_server { // This class is useful for testing ShardManager @@ -54,13 +55,19 @@ class ShardManager { int32_t num_shards, privacy_sandbox::server_common::KeyFetcherManagerInterface& key_fetcher_manager, - const std::vector>& cluster_mappings); + const std::vector>& cluster_mappings, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); static absl::StatusOr> Create( int32_t num_shards, const std::vector>& cluster_mappings, std::unique_ptr random_generator, std::function(const std::string& ip)> - client_factory); + client_factory, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)); }; } // namespace kv_server #endif // COMPONENTS_SHARDING_SHARD_MANAGER_H_ diff --git a/components/telemetry/BUILD.bazel b/components/telemetry/BUILD.bazel index 67289032..77120400 100644 --- a/components/telemetry/BUILD.bazel +++ b/components/telemetry/BUILD.bazel @@ -55,6 +55,7 @@ cc_library( "@google_privacysandbox_servers_common//src/metric:context_map", "@google_privacysandbox_servers_common//src/util:duration", "@google_privacysandbox_servers_common//src/util:read_system", + "@io_opentelemetry_cpp//sdk/src/metrics", ], ) diff --git a/components/telemetry/init_local_otlp.cc b/components/telemetry/init_local_otlp.cc index d0754012..bfd0ab03 100644 --- a/components/telemetry/init_local_otlp.cc +++ b/components/telemetry/init_local_otlp.cc @@ -20,8 +20,8 @@ // To use Jaeger first run a local instance of the collector // https://www.jaegertracing.io/docs/1.42/getting-started/ // Then build run server with flags for local and otlp. Ex: -// `bazel run //components/data_server/server:server --//:instance=local -// --//:platform=aws +// `bazel run //components/data_server/server:server --config=local_instance +// --config=aws_platform // --@google_privacysandbox_servers_common//src/telemetry:local_otel_export=otlp // -- // --environment="test"` diff --git a/components/telemetry/server_definition.h b/components/telemetry/server_definition.h index 339dd851..a8e39782 100644 --- a/components/telemetry/server_definition.h +++ b/components/telemetry/server_definition.h @@ -23,6 +23,7 @@ #include "absl/time/time.h" #include "components/telemetry/error_code.h" +#include "opentelemetry/metrics/provider.h" #include "src/core/common/uuid/uuid.h" #include "src/metric/context_map.h" #include "src/util/duration.h" @@ -40,21 +41,24 @@ constexpr std::string_view kInternalLookupServiceName = "InternalLookupServer"; // metric monitoring set up. // TODO(b/307362951): Tune the upper bound and lower bound for different // unsafe metrics. -constexpr int kCounterDPLowerBound = 1; -constexpr int kCounterDPUpperBound = 10; +inline constexpr int kCounterDPLowerBound = 1; +inline constexpr int kCounterDPUpperBound = 10; -constexpr int kErrorCounterDPLowerBound = 0; -constexpr int kErrorCounterDPUpperBound = 1; -constexpr int kErrorMaxPartitionsContributed = 1; -constexpr double kErrorMinNoiseToOutput = 0.99; +inline constexpr int kErrorCounterDPLowerBound = 0; +inline constexpr int kErrorCounterDPUpperBound = 1; +inline constexpr int kErrorMaxPartitionsContributed = 1; +inline constexpr double kErrorMinNoiseToOutput = 0.99; -constexpr int kMicroSecondsLowerBound = 1; -constexpr int kMicroSecondsUpperBound = 2'000'000'000; +inline constexpr int kMicroSecondsLowerBound = 1; +// 2 seconds +inline constexpr int kMicroSecondsUpperBound = 2'000'000'000; inline constexpr double kLatencyInMicroSecondsBoundaries[] = { - 160, 220, 280, 320, 640, 1'200, 2'500, - 5'000, 10'000, 20'000, 40'000, 80'000, 160'000, 320'000, - 640'000, 1'000'000, 1'300'000, 2'600'000, 5'000'000, 10'000'000'000}; + 160, 220, 280, 320, 640, + 1'200, 2'500, 5'000, 10'000, 20'000, + 40'000, 80'000, 160'000, 320'000, 640'000, + 1'000'000, 1'300'000, 2'600'000, 5'000'000, 10'000'000'000, +}; // String literals for absl status partition, the string list and literals match // those implemented in the absl::StatusCodeToString method @@ -90,7 +94,9 @@ inline constexpr std::string_view kCacheAccessEvents[] = { kKeyValueSetCacheMiss}; inline constexpr privacy_sandbox::server_common::metrics::PrivacyBudget - privacy_total_budget{/*epsilon*/ 5}; + privacy_total_budget = { + .epsilon = 5, +}; // Metric definitions for request level metrics that are privacy impacting // and should be logged unsafe with DP(differential privacy) noises. @@ -141,6 +147,15 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, kMicroSecondsLowerBound); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kShardedLookupRunSetQueryIntLatencyInMicros( + "ShardedLookupRunSetQueryIntLatencyInMicros", + "Latency in executing RunQuery in the sharded lookup", + kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, privacy_sandbox::server_common::metrics::Instrument::kHistogram> @@ -215,6 +230,15 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, kMicroSecondsLowerBound); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kGetUInt32ValueSetLatencyInMicros( + "GetUInt32ValueSetLatencyInMicros", + "Latency in executing GetUInt32ValueSet in cache", + kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + inline constexpr privacy_sandbox::server_common::metrics::Definition< int, privacy_sandbox::server_common::metrics::Privacy::kImpacting, privacy_sandbox::server_common::metrics::Instrument::kPartitionedCounter> @@ -314,6 +338,21 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< "notification messages", kLatencyInMicroSecondsBoundaries); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kReceivedLowLatencyNotificationsBytes( + "ReceivedLowLatencyNotificationsBytes", + "Size of messages received through pub/sub", + privacy_sandbox::server_common::metrics::kSizeHistogram); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kUpDownCounter> + kReceivedLowLatencyNotificationsCount( + "kReceivedLowLatencyNotificationsCount", + "Count of messages received through pub/sub"); + inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, privacy_sandbox::server_common::metrics::Instrument::kHistogram> @@ -350,6 +389,13 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< "Latency in seeking input streambuf seekoff", kLatencyInMicroSecondsBoundaries); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kBlobStorageReadBytes( + "BlobStorageReadBytes", "Size of data read from blob storage in bytes", + privacy_sandbox::server_common::metrics::kSizeHistogram); + inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, privacy_sandbox::server_common::metrics::Instrument::kPartitionedCounter> @@ -418,6 +464,13 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< "Latency in key value set update", kLatencyInMicroSecondsBoundaries); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kUpdateUInt32ValueSetLatency("UpdateUInt32ValueSetLatency", + "Latency in uint32 key value set update", + kLatencyInMicroSecondsBoundaries); + inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, privacy_sandbox::server_common::metrics::Instrument::kHistogram> @@ -431,6 +484,13 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< "Latency in deleting values in set", kLatencyInMicroSecondsBoundaries); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kDeleteUInt32ValueSetLatency("DeleteUInt32ValueSetLatency", + "Latency in deleting values in an uint32 set", + kLatencyInMicroSecondsBoundaries); + inline constexpr privacy_sandbox::server_common::metrics::Definition< double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, privacy_sandbox::server_common::metrics::Instrument::kHistogram> @@ -453,6 +513,14 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< "Latency in cleaning up key value set map", kLatencyInMicroSecondsBoundaries); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kCleanUpUInt32SetMapLatency( + "CleanUpUInt32SetMapMapLatency", + "Latency in cleaning up key value uint32 set map", + kLatencyInMicroSecondsBoundaries); + inline constexpr privacy_sandbox::server_common::metrics::Definition< int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, privacy_sandbox::server_common::metrics::Instrument::kUpDownCounter> @@ -460,7 +528,68 @@ inline constexpr privacy_sandbox::server_common::metrics::Definition< "SecureLookupRequestCount", "Number of secure lookup requests received from remote server"); -// KV server metrics list contains contains non request related safe metrics +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kTotalV2LatencyWithoutCustomCode( + "TotalV2LatencyWithoutCustomCode", + "Latency for running V2 request without UDF execution time", + kLatencyInMicroSecondsBoundaries, kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + double, privacy_sandbox::server_common::metrics::Privacy::kImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kUDFExecutionLatencyInMicros("UDFExecutionLatencyInMicros", + "UDF execution time", + kLatencyInMicroSecondsBoundaries, + kMicroSecondsUpperBound, + kMicroSecondsLowerBound); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kUpDownCounter> + kTotalRequestCountV1("request.v1.count", + "Total number of V1 requests received by the server"); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kPartitionedCounter> + kRequestFailedCountByStatusV1( + "request.v1.failed_count_by_status", + "Total number of V1 requests that resulted in failure partitioned by " + "Error Code", + "error_code", kAbslStatusStrings); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kServerTotalTimeMsV1( + "request.v1.duration_ms", + "Total time taken by the server to execute the request", + privacy_sandbox::server_common::metrics::kTimeHistogram); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kResponseByteV1("response.v1.size_bytes", "V1 response size in bytes", + privacy_sandbox::server_common::metrics::kSizeHistogram); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kRequestByteV1("request.v1.size_bytes", "V1 request size in bytes", + privacy_sandbox::server_common::metrics::kSizeHistogram); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kGetValuesAdapterLatency( + "GetValuesAdapterLatencyMs", + "GetValues adapter latency in milliseconds", + privacy_sandbox::server_common::metrics::kTimeHistogram); + +// KV server metrics list contains non request related safe metrics // and request metrics collected before stage of internal lookups inline constexpr const privacy_sandbox::server_common::metrics::DefinitionName* kKVServerMetricList[] = { @@ -469,21 +598,27 @@ inline constexpr const privacy_sandbox::server_common::metrics::DefinitionName* &kShardedLookupGetKeyValuesLatencyInMicros, &kShardedLookupGetKeyValueSetLatencyInMicros, &kShardedLookupRunQueryLatencyInMicros, + &kShardedLookupRunSetQueryIntLatencyInMicros, &kRemoteLookupGetValuesLatencyInMicros, + &kTotalV2LatencyWithoutCustomCode, &kUDFExecutionLatencyInMicros, // Safe metrics &kKVServerError, &privacy_sandbox::server_common::metrics::kTotalRequestCount, &privacy_sandbox::server_common::metrics::kServerTotalTimeMs, &privacy_sandbox::server_common::metrics::kRequestByte, &privacy_sandbox::server_common::metrics::kResponseByte, - &kRequestFailedCountByStatus, &kGetParameterStatus, - &kCompleteLifecycleStatus, &kCreateDataOrchestratorStatus, - &kStartDataOrchestratorStatus, &kLoadNewFilesStatus, - &kGetShardManagerStatus, &kDescribeInstanceGroupInstancesStatus, - &kDescribeInstancesStatus, + &kTotalRequestCountV1, &kServerTotalTimeMsV1, &kRequestByteV1, + &kResponseByteV1, &kRequestFailedCountByStatusV1, + &kRequestFailedCountByStatus, &kGetValuesAdapterLatency, + &kGetParameterStatus, &kCompleteLifecycleStatus, + &kCreateDataOrchestratorStatus, &kStartDataOrchestratorStatus, + &kLoadNewFilesStatus, &kGetShardManagerStatus, + &kDescribeInstanceGroupInstancesStatus, &kDescribeInstancesStatus, &kReceivedLowLatencyNotificationsE2ECloudProvided, &kReceivedLowLatencyNotificationsE2E, &kReceivedLowLatencyNotifications, - &kAwsSqsReceiveMessageLatency, &kSeekingInputStreambufSeekoffLatency, + &kReceivedLowLatencyNotificationsBytes, + &kReceivedLowLatencyNotificationsCount, &kAwsSqsReceiveMessageLatency, + &kSeekingInputStreambufSeekoffLatency, &kSeekingInputStreambufSizeLatency, &kSeekingInputStreambufUnderflowLatency, &kTotalRowsDroppedInDataLoading, &kTotalRowsUpdatedInDataLoading, @@ -491,9 +626,12 @@ inline constexpr const privacy_sandbox::server_common::metrics::DefinitionName* &kConcurrentStreamRecordReaderReadShardRecordsLatency, &kConcurrentStreamRecordReaderReadStreamRecordsLatency, &kConcurrentStreamRecordReaderReadByteRangeLatency, - &kUpdateKeyValueLatency, &kUpdateKeyValueSetLatency, &kDeleteKeyLatency, - &kDeleteValuesInSetLatency, &kRemoveDeletedKeyLatency, - &kCleanUpKeyValueMapLatency, &kCleanUpKeyValueSetMapLatency}; + &kUpdateKeyValueLatency, &kUpdateKeyValueSetLatency, + &kUpdateUInt32ValueSetLatency, &kDeleteKeyLatency, + &kDeleteValuesInSetLatency, &kDeleteUInt32ValueSetLatency, + &kRemoveDeletedKeyLatency, &kCleanUpKeyValueMapLatency, + &kCleanUpKeyValueSetMapLatency, &kCleanUpUInt32SetMapLatency, + &kBlobStorageReadBytes}; // Internal lookup service metrics list contains metrics collected in the // internal lookup server. This separation from KV metrics list allows all @@ -510,7 +648,8 @@ inline constexpr const privacy_sandbox::server_common::metrics::DefinitionName* &kInternalGetKeyValuesLatencyInMicros, &kInternalGetKeyValueSetLatencyInMicros, &kInternalSecureLookupLatencyInMicros, &kGetValuePairsLatencyInMicros, - &kGetKeyValueSetLatencyInMicros, &kCacheAccessEventCount}; + &kGetKeyValueSetLatencyInMicros, &kGetUInt32ValueSetLatencyInMicros, + &kCacheAccessEventCount}; inline constexpr absl::Span< const privacy_sandbox::server_common::metrics::DefinitionName* const> @@ -566,6 +705,16 @@ inline void LogIfError(const absl::Status& s, << message << ": " << s; } +template +inline void LogIfError(const absl::StatusOr& s, std::string_view message, + privacy_sandbox::server_common::SourceLocation location + PS_LOC_CURRENT_DEFAULT_ARG) { + if (s.ok()) return; + ABSL_LOG_EVERY_N_SEC(WARNING, 60) + .AtLocation(location.file_name(), location.line()) + << message << ": " << s.status(); +} + template inline absl::AnyInvocable LogStatusSafeMetricsFn() { @@ -623,12 +772,12 @@ inline void LogServerErrorMetric(std::string_view error_code) { {{std::string(error_code), 1}})); } -// Logs common safe request metrics +// Logs common safe request metrics for V2 request path template inline void LogRequestCommonSafeMetrics( const RequestT* request, const ResponseT* response, const grpc::Status& grpc_request_status, - const absl::Time& request_received_time) { + const privacy_sandbox::server_common::Stopwatch& stopwatch) { LogIfError( KVServerContextMap() ->SafeMetric() @@ -642,18 +791,18 @@ inline void LogRequestCommonSafeMetrics( .LogUpDownCounter( {{absl::StatusCodeToString(request_status.code()), 1}})); } - LogIfError(KVServerContextMap() - ->SafeMetric() - .template LogHistogram< - privacy_sandbox::server_common::metrics::kRequestByte>( - (int)request->ByteSizeLong())); - LogIfError(KVServerContextMap() - ->SafeMetric() - .template LogHistogram< - privacy_sandbox::server_common::metrics::kResponseByte>( - (int)response->ByteSizeLong())); + LogIfError( + KVServerContextMap() + ->SafeMetric() + .LogHistogram( + static_cast(request->ByteSizeLong()))); + LogIfError( + KVServerContextMap() + ->SafeMetric() + .LogHistogram( + static_cast(response->ByteSizeLong()))); int duration_ms = - (absl::Now() - request_received_time) / absl::Milliseconds(1); + static_cast(absl::ToInt64Milliseconds(stopwatch.GetElapsedTime())); LogIfError( KVServerContextMap() ->SafeMetric() @@ -662,55 +811,33 @@ inline void LogRequestCommonSafeMetrics( duration_ms)); } -// ScopeMetricsContext provides metrics context ties to the request and -// should have the same lifetime of the request. -// The purpose of this class is to avoid explicit creating and deleting metrics -// context from context map. The metrics context associated with the request -// will be destroyed after ScopeMetricsContext goes out of scope. -class ScopeMetricsContext { - public: - explicit ScopeMetricsContext( - std::string request_id = google::scp::core::common::ToString( - google::scp::core::common::Uuid::GenerateUuid())) - : request_id_(std::move(request_id)) { - // Create a metrics context in the context map and - // associated it with request id - KVServerContextMap()->Get(&request_id_); - CHECK_OK([this]() { - // Remove the metrics context for request_id to transfer the ownership - // of metrics context to the ScopeMetricsContext. This is to ensure that - // metrics context has the same lifetime with RequestContext and be - // destroyed when ScopeMetricsContext goes out of scope. - PS_ASSIGN_OR_RETURN(udf_request_metrics_context_, - KVServerContextMap()->Remove(&request_id_)); - return absl::OkStatus(); - }()) << "Udf request metrics context is not initialized"; - InternalLookupServerContextMap()->Get(&request_id_); - CHECK_OK([this]() { - // Remove the metrics context for request_id to transfer the ownership - // of metrics context to the ScopeMetricsContext. This is to ensure that - // metrics context has the same lifetime with RequestContext and be - // destroyed when ScopeMetricsContext goes out of scope. - PS_ASSIGN_OR_RETURN( - internal_lookup_metrics_context_, - InternalLookupServerContextMap()->Remove(&request_id_)); - return absl::OkStatus(); - }()) << "Internal lookup metrics context is not initialized"; - } - UdfRequestMetricsContext& GetUdfRequestMetricsContext() const { - return *udf_request_metrics_context_; - } - InternalLookupMetricsContext& GetInternalLookupMetricsContext() const { - return *internal_lookup_metrics_context_; +// Logs safe V1 request metrics +template +inline void LogV1RequestCommonSafeMetrics( + const RequestT* request, const ResponseT* response, + const grpc::Status& grpc_request_status, + const privacy_sandbox::server_common::Stopwatch& stopwatch) { + LogIfError( + KVServerContextMap()->SafeMetric().LogUpDownCounter( + 1)); + if (auto request_status = + privacy_sandbox::server_common::ToAbslStatus(grpc_request_status); + !request_status.ok()) { + LogIfError(KVServerContextMap() + ->SafeMetric() + .LogUpDownCounter( + {{absl::StatusCodeToString(request_status.code()), 1}})); } - - private: - const std::string request_id_; - // Metrics context has the same lifetime of server request context - std::unique_ptr udf_request_metrics_context_; - std::unique_ptr - internal_lookup_metrics_context_; -}; + LogIfError(KVServerContextMap()->SafeMetric().LogHistogram( + static_cast(request->ByteSizeLong()))); + LogIfError(KVServerContextMap()->SafeMetric().LogHistogram( + static_cast(response->ByteSizeLong()))); + int duration_ms = + static_cast(absl::ToInt64Milliseconds(stopwatch.GetElapsedTime())); + LogIfError( + KVServerContextMap()->SafeMetric().LogHistogram( + duration_ms)); +} // Measures the latency of a block of code. The latency is recorded in // microseconds as histogram metrics when the object of this class goes diff --git a/components/tools/BUILD.bazel b/components/tools/BUILD.bazel index 91a7483f..a645d8f3 100644 --- a/components/tools/BUILD.bazel +++ b/components/tools/BUILD.bazel @@ -66,6 +66,7 @@ cc_binary( "//components/data_server/cache", "//components/data_server/cache:key_value_cache", "//components/data_server/data_loading:data_orchestrator", + "//components/tools/util:configure_telemetry_tools", "//components/udf:noop_udf_client", "//components/util:platform_initializer", "//public:base_types_cc_proto", @@ -88,6 +89,7 @@ cc_binary( ], deps = [ "//components/data/blob_storage:blob_storage_change_notifier", + "//components/tools/util:configure_telemetry_tools", "//components/util:platform_initializer", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", @@ -114,6 +116,7 @@ cc_binary( deps = [ ":blob_storage_commands", "//components/data/blob_storage:blob_storage_client", + "//components/tools/util:configure_telemetry_tools", "//components/util:platform_initializer", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", @@ -130,6 +133,7 @@ cc_binary( "//components/data/blob_storage:blob_storage_client", "//components/data/blob_storage:delta_file_notifier", "//components/data/common:thread_manager", + "//components/tools/util:configure_telemetry_tools", "//components/util:platform_initializer", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", @@ -144,6 +148,7 @@ cc_binary( ], deps = [ "//components/data/realtime:delta_file_record_change_notifier", + "//components/tools/util:configure_telemetry_tools", "//components/util:platform_initializer", "//public/data_loading:data_loading_fbs", "//public/data_loading:filename_utils", @@ -231,6 +236,7 @@ cc_library( ], }) + [ "//components/data/common:message_service", + "//components/tools/util:configure_telemetry_tools", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/log", @@ -271,6 +277,7 @@ cc_binary( deps = [ ":publisher_service", "//components/data/realtime:realtime_notifier", + "//components/tools/util:configure_telemetry_tools", "//components/util:platform_initializer", "//public/data_loading:data_loading_fbs", "//public/data_loading:filename_utils", diff --git a/components/tools/benchmarks/BUILD.bazel b/components/tools/benchmarks/BUILD.bazel index ccc611e4..8f1e5705 100644 --- a/components/tools/benchmarks/BUILD.bazel +++ b/components/tools/benchmarks/BUILD.bazel @@ -29,6 +29,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/logger:request_context_impl", ], ) @@ -54,6 +55,7 @@ cc_binary( "//components/data_server/cache", "//components/data_server/cache:key_value_cache", "//components/data_server/cache:noop_key_value_cache", + "//components/tools/util:configure_telemetry_tools", "//components/util:platform_initializer", "//public/data_loading:data_loading_fbs", "//public/data_loading:records_utils", @@ -80,6 +82,7 @@ cc_binary( "//components/data_server/cache", "//components/data_server/cache:key_value_cache", "//components/data_server/cache:noop_key_value_cache", + "//components/tools/util:configure_telemetry_tools", "//components/util:request_context", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/flags:flag", @@ -93,3 +96,31 @@ cc_binary( "@com_google_benchmark//:benchmark", ], ) + +cc_binary( + name = "query_evaluation_benchmark", + srcs = ["query_evaluation_benchmark.cc"], + malloc = "@com_google_tcmalloc//tcmalloc", + deps = [ + ":benchmark_util", + "//components/data_server/cache:get_key_value_set_result_impl", + "//components/data_server/cache:key_value_cache", + "//components/query:ast", + "//components/query:driver", + "//components/query:scanner", + "//components/query:sets", + "//components/tools/util:configure_telemetry_tools", + "//components/util:request_context", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:flags", + "@com_google_absl//absl/log:initialize", + "@com_google_absl//absl/strings", + "@com_google_benchmark//:benchmark", + "@roaring_bitmap//:c_roaring", + ], +) diff --git a/components/tools/benchmarks/benchmark_util.h b/components/tools/benchmarks/benchmark_util.h index 268a0a8d..2b003506 100644 --- a/components/tools/benchmarks/benchmark_util.h +++ b/components/tools/benchmarks/benchmark_util.h @@ -24,6 +24,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "src/logger/request_context_impl.h" namespace kv_server::benchmark { @@ -45,6 +46,12 @@ class AsyncTask { std::thread runner_thread_; }; +class BenchmarkLogContext + : public privacy_sandbox::server_common::log::SafePathContext { + public: + BenchmarkLogContext() = default; +}; + // Generates a random string with `char_count` characters. std::string GenerateRandomString(const int64_t char_count); diff --git a/components/tools/benchmarks/cache_benchmark.cc b/components/tools/benchmarks/cache_benchmark.cc index 3aa3133f..c983e506 100644 --- a/components/tools/benchmarks/cache_benchmark.cc +++ b/components/tools/benchmarks/cache_benchmark.cc @@ -35,6 +35,7 @@ #include "components/data_server/cache/key_value_cache.h" #include "components/data_server/cache/noop_key_value_cache.h" #include "components/tools/benchmarks/benchmark_util.h" +#include "components/tools/util/configure_telemetry_tools.h" ABSL_FLAG(std::vector, record_size, std::vector({"1"}), @@ -158,21 +159,23 @@ struct BenchmarkArgs { void BM_GetKeyValuePairs(::benchmark::State& state, BenchmarkArgs args) { uint seed = args.concurrent_tasks; std::vector writer_tasks; + benchmark::BenchmarkLogContext log_context; if (state.thread_index() == 0 && args.concurrent_tasks > 0) { auto num_writers = args.concurrent_tasks; writer_tasks.reserve(num_writers); while (num_writers-- > 0) { - writer_tasks.emplace_back( - [args, &seed, value = GenerateRandomString(args.record_size)]() { - auto key = std::to_string(rand_r(&seed) % args.query_size); - args.cache->UpdateKeyValue(key, value, ++GetLogicalTimestamp()); - }); + writer_tasks.emplace_back([args, &seed, + value = GenerateRandomString(args.record_size), + &log_context]() { + auto key = std::to_string(rand_r(&seed) % args.query_size); + args.cache->UpdateKeyValue(log_context, key, value, + ++GetLogicalTimestamp()); + }); } } auto keys = GetKeys(args.query_size); auto keys_view = ToContainerView>(keys); - auto scope_metrics_context = std::make_unique(); - RequestContext request_context(*scope_metrics_context); + RequestContext request_context; for (auto _ : state) { ::benchmark::DoNotOptimize( args.cache->GetKeyValuePairs(request_context, keys_view)); @@ -184,24 +187,25 @@ void BM_GetKeyValuePairs(::benchmark::State& state, BenchmarkArgs args) { void BM_GetKeyValueSet(::benchmark::State& state, BenchmarkArgs args) { uint seed = args.concurrent_tasks; std::vector writer_tasks; + benchmark::BenchmarkLogContext log_context; if (state.thread_index() == 0 && args.concurrent_tasks > 0) { auto num_writers = args.concurrent_tasks; writer_tasks.reserve(num_writers); while (num_writers-- > 0) { writer_tasks.emplace_back([args, &seed, set_query = GetSetQuery(args.set_query_size, - args.record_size)]() { + args.record_size), + &log_context]() { auto key = std::to_string(rand_r(&seed) % args.query_size); auto view = ToContainerView>(set_query); - args.cache->UpdateKeyValueSet(key, absl::MakeSpan(view), + args.cache->UpdateKeyValueSet(log_context, key, absl::MakeSpan(view), ++GetLogicalTimestamp()); }); } } auto keys = GetKeys(args.query_size); auto keys_view = ToContainerView>(keys); - auto scope_metrics_context = std::make_unique(); - RequestContext request_context(*scope_metrics_context); + RequestContext request_context; for (auto _ : state) { ::benchmark::DoNotOptimize( args.cache->GetKeyValueSet(request_context, keys_view)); @@ -213,8 +217,8 @@ void BM_GetKeyValueSet(::benchmark::State& state, BenchmarkArgs args) { void BM_UpdateKeyValue(::benchmark::State& state, BenchmarkArgs args) { uint seed = args.concurrent_tasks; std::vector reader_tasks; - auto scope_metrics_context = std::make_unique(); - RequestContext request_context(*scope_metrics_context); + RequestContext request_context; + benchmark::BenchmarkLogContext log_context; if (state.thread_index() == 0 && args.concurrent_tasks) { auto num_readers = args.concurrent_tasks; reader_tasks.reserve(num_readers); @@ -229,7 +233,8 @@ void BM_UpdateKeyValue(::benchmark::State& state, BenchmarkArgs args) { auto value = GenerateRandomString(args.record_size); for (auto _ : state) { auto key = std::to_string(rand_r(&seed) % args.keyspace_size); - args.cache->UpdateKeyValue(key, value, ++GetLogicalTimestamp()); + args.cache->UpdateKeyValue(log_context, key, value, + ++GetLogicalTimestamp()); } state.counters[std::string(kWritesPerSec)] = ::benchmark::Counter(state.iterations(), ::benchmark::Counter::kIsRate); @@ -238,8 +243,8 @@ void BM_UpdateKeyValue(::benchmark::State& state, BenchmarkArgs args) { void BM_UpdateKeyValueSet(::benchmark::State& state, BenchmarkArgs args) { uint seed = args.concurrent_tasks; std::vector reader_tasks; - auto scope_metrics_context = std::make_unique(); - RequestContext request_context(*scope_metrics_context); + RequestContext request_context; + benchmark::BenchmarkLogContext log_context; if (state.thread_index() == 0 && args.concurrent_tasks) { auto num_readers = args.concurrent_tasks; reader_tasks.reserve(num_readers); @@ -254,7 +259,7 @@ void BM_UpdateKeyValueSet(::benchmark::State& state, BenchmarkArgs args) { auto set_view = ToContainerView>(set_value); for (auto _ : state) { auto key = std::to_string(rand_r(&seed) % args.keyspace_size); - args.cache->UpdateKeyValueSet(key, absl::MakeSpan(set_view), + args.cache->UpdateKeyValueSet(log_context, key, absl::MakeSpan(set_view), ++GetLogicalTimestamp()); } state.counters[std::string(kWritesPerSec)] = @@ -367,14 +372,14 @@ void RegisterWriteBenchmarks() { // // bazel run -c opt \ // //components/tools/benchmarks:cache_benchmark \ -// --//:instance=local \ -// --//:platform=local -- \ +// --config=local_instance \ +// --config=local_platform -- \ // --benchmark_counters_tabular=true --stderrthreshold=0 int main(int argc, char** argv) { absl::InitializeLog(); ::benchmark::Initialize(&argc, argv); absl::ParseCommandLine(argc, argv); - kv_server::InitMetricsContextMap(); + kv_server::ConfigureTelemetryForTools(); ::kv_server::RegisterReadBenchmarks(); ::kv_server::RegisterWriteBenchmarks(); ::benchmark::RunSpecifiedBenchmarks(); diff --git a/components/tools/benchmarks/data_loading_benchmark.cc b/components/tools/benchmarks/data_loading_benchmark.cc index cf95b0ea..9c680b17 100644 --- a/components/tools/benchmarks/data_loading_benchmark.cc +++ b/components/tools/benchmarks/data_loading_benchmark.cc @@ -35,6 +35,7 @@ #include "components/data_server/cache/key_value_cache.h" #include "components/data_server/cache/noop_key_value_cache.h" #include "components/tools/benchmarks/benchmark_util.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "components/util/platform_initializer.h" #include "public/data_loading/data_loading_generated.h" #include "public/data_loading/readers/riegeli_stream_io.h" @@ -177,17 +178,19 @@ void RegisterBenchmarks() { } } -absl::Status ApplyUpdateMutation(const KeyValueMutationRecord& record, - Cache& cache) { +absl::Status ApplyUpdateMutation( + kv_server::benchmark::BenchmarkLogContext& log_context, + const KeyValueMutationRecord& record, Cache& cache) { if (record.value_type() == Value::StringValue) { - cache.UpdateKeyValue(record.key()->string_view(), + cache.UpdateKeyValue(log_context, record.key()->string_view(), GetRecordValue(record), record.logical_commit_time()); return absl::OkStatus(); } if (record.value_type() == Value::StringSet) { auto values = GetRecordValue>(record); - cache.UpdateKeyValueSet(record.key()->string_view(), absl::MakeSpan(values), + cache.UpdateKeyValueSet(log_context, record.key()->string_view(), + absl::MakeSpan(values), record.logical_commit_time()); return absl::OkStatus(); } @@ -196,15 +199,18 @@ absl::Status ApplyUpdateMutation(const KeyValueMutationRecord& record, " has unsupported value type: ", record.value_type())); } -absl::Status ApplyDeleteMutation(const KeyValueMutationRecord& record, - Cache& cache) { +absl::Status ApplyDeleteMutation( + kv_server::benchmark::BenchmarkLogContext& log_context, + const KeyValueMutationRecord& record, Cache& cache) { if (record.value_type() == Value::StringValue) { - cache.DeleteKey(record.key()->string_view(), record.logical_commit_time()); + cache.DeleteKey(log_context, record.key()->string_view(), + record.logical_commit_time()); return absl::OkStatus(); } if (record.value_type() == Value::StringSet) { auto values = GetRecordValue>(record); - cache.DeleteValuesInSet(record.key()->string_view(), absl::MakeSpan(values), + cache.DeleteValuesInSet(log_context, record.key()->string_view(), + absl::MakeSpan(values), record.logical_commit_time()); return absl::OkStatus(); } @@ -234,41 +240,46 @@ void BM_LoadDataIntoCache(benchmark::State& state, BenchmarkArgs args) { }); auto stream_size = GetBlobSize(*blob_client, GetBlobLocation()); std::atomic num_records_read{0}; + kv_server::benchmark::BenchmarkLogContext log_context; for (auto _ : state) { state.PauseTiming(); auto cache = args.create_cache_fn(); state.ResumeTiming(); - auto status = record_reader.ReadStreamRecords([&num_records_read, - cache = cache.get()]( - std::string_view raw) { - num_records_read++; - return DeserializeDataRecord(raw, [cache](const DataRecord& data_record) { - if (data_record.record_type() == Record::KeyValueMutationRecord) { - const auto* record = data_record.record_as_KeyValueMutationRecord(); - switch (record->mutation_type()) { - case KeyValueMutationType::Update: { - if (auto status = ApplyUpdateMutation(*record, *cache); - status.ok()) { - return status; + auto status = record_reader.ReadStreamRecords( + [&num_records_read, &log_context, + cache = cache.get()](std::string_view raw) { + num_records_read++; + return DeserializeDataRecord(raw, [cache, &log_context]( + const DataRecord& data_record) { + if (data_record.record_type() == Record::KeyValueMutationRecord) { + const auto* record = + data_record.record_as_KeyValueMutationRecord(); + switch (record->mutation_type()) { + case KeyValueMutationType::Update: { + if (auto status = + ApplyUpdateMutation(log_context, *record, *cache); + status.ok()) { + return status; + } + break; + } + case KeyValueMutationType::Delete: { + if (auto status = + ApplyDeleteMutation(log_context, *record, *cache); + status.ok()) { + return status; + } + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Invalid mutation type: ", + kv_server::EnumNameKeyValueMutationType( + record->mutation_type()))); } - break; } - case KeyValueMutationType::Delete: { - if (auto status = ApplyDeleteMutation(*record, *cache); - status.ok()) { - return status; - } - } - default: - return absl::InvalidArgumentError( - absl::StrCat("Invalid mutation type: ", - kv_server::EnumNameKeyValueMutationType( - record->mutation_type()))); - } - } - return absl::OkStatus(); - }); - }); + return absl::OkStatus(); + }); + }); benchmark::DoNotOptimize(status); } state.SetItemsProcessed(num_records_read); @@ -280,7 +291,7 @@ void BM_LoadDataIntoCache(benchmark::State& state, BenchmarkArgs args) { // // bazel run \ // components/tools/benchmarks:data_loading_benchmark \ -// --//:instance=local --//:platform=local -- \ +// --config=local_instance --config=local_platform -- \ // --benchmark_time_unit=ms \ // --benchmark_counters_tabular=true \ // --data_directory=/tmp/data \ @@ -304,7 +315,7 @@ int main(int argc, char** argv) { LOG(ERROR) << "Flag '--filename' must be not empty."; return -1; } - kv_server::InitMetricsContextMap(); + kv_server::ConfigureTelemetryForTools(); std::unique_ptr blob_storage_client_factory = BlobStorageClientFactory::Create(); std::unique_ptr blob_client = diff --git a/components/tools/benchmarks/query_evaluation_benchmark.cc b/components/tools/benchmarks/query_evaluation_benchmark.cc new file mode 100644 index 00000000..352ee144 --- /dev/null +++ b/components/tools/benchmarks/query_evaluation_benchmark.cc @@ -0,0 +1,210 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/log/initialize.h" +#include "absl/strings/str_cat.h" +#include "benchmark/benchmark.h" +#include "components/data_server/cache/cache.h" +#include "components/data_server/cache/key_value_cache.h" +#include "components/query/ast.h" +#include "components/query/driver.h" +#include "components/query/scanner.h" +#include "components/query/sets.h" +#include "components/tools/benchmarks/benchmark_util.h" +#include "components/tools/util/configure_telemetry_tools.h" + +#include "roaring.hh" + +ABSL_FLAG(int64_t, set_size, 1000, "Number of elements in a set."); +ABSL_FLAG(std::string, query, "(A - B) | (C & D)", "Query to evaluate"); +ABSL_FLAG(uint32_t, range_min, 0, "Minimum element in a set"); +ABSL_FLAG(uint32_t, range_max, 65536, "Maximum element in a set"); +ABSL_FLAG(std::vector, set_names, + std::vector({"A", "B", "C", "D"}), + "Set names used in benchmarking query."); + +namespace kv_server { +namespace { + +using RoaringBitSet = roaring::Roaring; +using StringSet = absl::flat_hash_set; + +std::unique_ptr STRING_SET_RESULT = nullptr; +std::unique_ptr UINT32_SET_RESULT = nullptr; + +template +ValueT Lookup(std::string_view); + +template <> +StringSet Lookup(std::string_view key) { + return STRING_SET_RESULT->GetValueSet(key); +} + +template <> +RoaringBitSet Lookup(std::string_view key) { + return UINT32_SET_RESULT->GetUInt32ValueSet(key)->GetValuesBitSet(); +} + +Driver* GetDriver() { + static auto* const driver = std::make_unique().release(); + return driver; +} + +Cache* GetKeyValueCache() { + static auto* const cache = KeyValueCache::Create().release(); + return cache; +} + +void SetUpKeyValueCache(int64_t set_size, uint32_t range_min, + uint32_t range_max, + const std::vector& set_names) { + kv_server::benchmark::BenchmarkLogContext log_context; + std::srand(absl::GetCurrentTimeNanos()); + for (const auto& set_name : set_names) { + auto nums = std::vector(); + nums.reserve(set_size); + for (int i = 0; i < set_size; i++) { + nums.push_back(range_min + (std::rand() % (range_max - range_min))); + } + GetKeyValueCache()->UpdateKeyValueSet(log_context, set_name, + absl::MakeSpan(nums), 1); + + auto strings = std::vector(); + std::transform(nums.begin(), nums.end(), std::back_inserter(strings), + [](uint32_t elem) { return absl::StrCat(elem); }); + auto string_views = std::vector(); + std::transform(strings.begin(), strings.end(), + std::back_inserter(string_views), + [](std::string_view elem) { return elem; }); + GetKeyValueCache()->UpdateKeyValueSet(log_context, set_name, + absl::MakeSpan(string_views), 1); + } +} + +template +void BM_SetUnion(::benchmark::State& state) { + for (auto _ : state) { + auto left = Lookup("A"); + auto right = Lookup("B"); + auto result = Union(std::move(left), std::move(right)); + ::benchmark::DoNotOptimize(result); + } + state.counters["Ops/s"] = + ::benchmark::Counter(state.iterations(), ::benchmark::Counter::kIsRate); +} + +template +void BM_SetDifference(::benchmark::State& state) { + for (auto _ : state) { + auto left = Lookup("A"); + auto right = Lookup("B"); + auto result = Difference(std::move(left), std::move(right)); + ::benchmark::DoNotOptimize(result); + } + state.counters["Ops/s"] = + ::benchmark::Counter(state.iterations(), ::benchmark::Counter::kIsRate); +} + +template +void BM_SetIntersection(::benchmark::State& state) { + for (auto _ : state) { + auto left = Lookup("A"); + auto right = Lookup("B"); + auto result = Intersection(std::move(left), std::move(right)); + ::benchmark::DoNotOptimize(result); + } + state.counters["Ops/s"] = + ::benchmark::Counter(state.iterations(), ::benchmark::Counter::kIsRate); +} + +template +void BM_AstTreeEvaluation(::benchmark::State& state) { + const auto* ast_tree = GetDriver()->GetRootNode(); + for (auto _ : state) { + auto result = Eval(*ast_tree, Lookup); + ::benchmark::DoNotOptimize(result); + } + state.counters["QueryEvals/s"] = + ::benchmark::Counter(state.iterations(), ::benchmark::Counter::kIsRate); +} + +} // namespace +} // namespace kv_server + +BENCHMARK(kv_server::BM_SetUnion); +BENCHMARK(kv_server::BM_SetUnion); +BENCHMARK(kv_server::BM_SetDifference); +BENCHMARK(kv_server::BM_SetDifference); +BENCHMARK(kv_server::BM_SetIntersection); +BENCHMARK(kv_server::BM_SetIntersection); +BENCHMARK(kv_server::BM_AstTreeEvaluation); +BENCHMARK(kv_server::BM_AstTreeEvaluation); + +using kv_server::ConfigureTelemetryForTools; +using kv_server::GetKeyValueCache; +using kv_server::RequestContext; +using kv_server::RoaringBitSet; +using kv_server::SetUpKeyValueCache; +using kv_server::StringSet; + +// Sample run: +// +// bazel run -c opt //components/tools/benchmarks:query_evaluation_benchmark \ +// -- --benchmark_counters_tabular=true \ +// --benchmark_time_unit=us \ +// --benchmark_filter="*" \ +// --range_min=1000000 --range_max=2000000 \ +// --set_size=10000 \ +// --query="A & B - C | D" \ +// --set_names="A,B,C,D" +int main(int argc, char** argv) { + // Initialize the environment and flags + absl::InitializeLog(); + ::benchmark::Initialize(&argc, argv); + absl::ParseCommandLine(argc, argv); + ConfigureTelemetryForTools(); + auto range_min = absl::GetFlag(FLAGS_range_min); + auto range_max = absl::GetFlag(FLAGS_range_max); + if (range_max <= range_min) { + ABSL_LOG(ERROR) << "range_max: " << range_max + << " must be greater than range_min: " << range_min; + return -1; + } + // Set up the cache and the ast tree. + auto set_names = absl::GetFlag(FLAGS_set_names); + SetUpKeyValueCache(absl::GetFlag(FLAGS_set_size), range_min, range_max, + set_names); + RequestContext request_context; + kv_server::STRING_SET_RESULT = GetKeyValueCache()->GetKeyValueSet( + request_context, absl::flat_hash_set(set_names.begin(), + set_names.end())); + kv_server::UINT32_SET_RESULT = GetKeyValueCache()->GetUInt32ValueSet( + request_context, absl::flat_hash_set(set_names.begin(), + set_names.end())); + std::istringstream stream(absl::GetFlag(FLAGS_query)); + kv_server::Scanner scanner(stream); + kv_server::Parser parser(*kv_server::GetDriver(), scanner); + parser(); + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); + return 0; +} diff --git a/components/tools/blob_storage_change_watcher_aws.cc b/components/tools/blob_storage_change_watcher_aws.cc index 48477cce..21c96175 100644 --- a/components/tools/blob_storage_change_watcher_aws.cc +++ b/components/tools/blob_storage_change_watcher_aws.cc @@ -19,6 +19,7 @@ #include "absl/flags/usage.h" #include "components/data/blob_storage/blob_storage_change_notifier.h" #include "components/telemetry/server_definition.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "components/util/platform_initializer.h" #include "src/telemetry/telemetry_provider.h" @@ -37,12 +38,7 @@ int main(int argc, char** argv) { return -1; } // Initialize no-op telemetry - privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; - config_proto.set_mode( - privacy_sandbox::server_common::telemetry::TelemetryConfig::PROD); - kv_server::KVServerContextMap( - privacy_sandbox::server_common::telemetry::BuildDependentConfig( - config_proto)); + kv_server::ConfigureTelemetryForTools(); auto message_service_status = kv_server::MessageService::Create( kv_server::AwsNotifierMetadata{"BlobNotifier_", sns_arn}); diff --git a/components/tools/blob_storage_util.cc b/components/tools/blob_storage_util.cc index 4ee70581..354b7f90 100644 --- a/components/tools/blob_storage_util.cc +++ b/components/tools/blob_storage_util.cc @@ -20,6 +20,7 @@ #include "absl/strings/match.h" #include "components/data/blob_storage/blob_storage_client.h" #include "components/tools/blob_storage_commands.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "components/util/platform_initializer.h" #include "src/telemetry/telemetry_provider.h" @@ -90,7 +91,7 @@ bool CpObjects(std::string bucket, std::string source, std::string dest) { int main(int argc, char** argv) { kv_server::PlatformInitializer initializer; - + kv_server::ConfigureTelemetryForTools(); absl::SetProgramUsageMessage("[cat|cp|ls|rm]"); std::vector commands = absl::ParseCommandLine(argc, argv); if (commands.size() < 2) { diff --git a/components/tools/data_loading_analyzer.cc b/components/tools/data_loading_analyzer.cc index 53ff75a0..f3b95f7d 100644 --- a/components/tools/data_loading_analyzer.cc +++ b/components/tools/data_loading_analyzer.cc @@ -26,6 +26,7 @@ #include "components/data_server/cache/cache.h" #include "components/data_server/cache/key_value_cache.h" #include "components/data_server/data_loading/data_orchestrator.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "components/udf/noop_udf_client.h" #include "components/util/platform_initializer.h" #include "public/base_types.pb.h" @@ -43,6 +44,12 @@ ABSL_FLAG(std::string, bucket, "performance-test-data-bucket", namespace kv_server { namespace { +class DataLoadingAnalyzerLogContext + : public privacy_sandbox::server_common::log::SafePathContext { + public: + DataLoadingAnalyzerLogContext() = default; +}; + class NoopBlobStorageChangeNotifier : public BlobStorageChangeNotifier { public: absl::StatusOr> GetNotifications( @@ -139,8 +146,9 @@ std::vector OperationsFromFlag() { } absl::Status InitOnce(Operation operation) { + DataLoadingAnalyzerLogContext log_context; std::unique_ptr noop_udf_client = NewNoopUdfClient(); - InitMetricsContextMap(); + ConfigureTelemetryForTools(); std::unique_ptr cache = KeyValueCache::Create(); std::unique_ptr blob_storage_client_factory = @@ -184,6 +192,7 @@ absl::Status InitOnce(Operation operation) { .realtime_thread_pool_manager = realtime_thread_pool_manager, .udf_client = *noop_udf_client, .key_sharder = KeySharder(ShardingFunction{/*seed=*/""}), + .log_context = log_context, }); absl::Time end_time = absl::Now(); LOG(INFO) << "Init used " << (end_time - start_time); diff --git a/components/tools/delta_file_record_change_watcher.cc b/components/tools/delta_file_record_change_watcher.cc index de01a605..5487384e 100644 --- a/components/tools/delta_file_record_change_watcher.cc +++ b/components/tools/delta_file_record_change_watcher.cc @@ -20,6 +20,7 @@ #include "absl/strings/str_join.h" #include "components/data/realtime/delta_file_record_change_notifier.h" #include "components/telemetry/server_definition.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "components/util/platform_initializer.h" #include "public/constants.h" #include "public/data_loading/data_loading_generated.h" @@ -104,7 +105,7 @@ int main(int argc, char** argv) { return -1; } - kv_server::InitMetricsContextMap(); + kv_server::ConfigureTelemetryForTools(); auto status_or_notifier = kv_server::ChangeNotifier::Create(kv_server::AwsNotifierMetadata{ .queue_prefix = "QueueNotifier_", diff --git a/components/tools/delta_file_watcher_aws.cc b/components/tools/delta_file_watcher_aws.cc index 9a951a00..fefb9c0d 100644 --- a/components/tools/delta_file_watcher_aws.cc +++ b/components/tools/delta_file_watcher_aws.cc @@ -20,6 +20,7 @@ #include "components/data/blob_storage/blob_storage_client.h" #include "components/data/blob_storage/delta_file_notifier.h" #include "components/data/common/thread_manager.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "components/util/platform_initializer.h" #include "src/telemetry/telemetry_provider.h" @@ -46,7 +47,7 @@ int main(int argc, char** argv) { std::cerr << "Must specify sns_arn" << std::endl; return -1; } - kv_server::InitMetricsContextMap(); + kv_server::ConfigureTelemetryForTools(); std::unique_ptr blob_storage_client_factory = BlobStorageClientFactory::Create(); std::unique_ptr client = diff --git a/components/tools/query_dot.cc b/components/tools/query_dot.cc index a965a54f..c2e5a404 100644 --- a/components/tools/query_dot.cc +++ b/components/tools/query_dot.cc @@ -14,8 +14,8 @@ #include "components/tools/query_dot.h" -#include #include +#include #include #include "absl/strings/str_join.h" @@ -36,6 +36,12 @@ class ASTNameVisitor : public ASTStringVisitor { class ASTDotGraphLabelVisitor : public ASTStringVisitor { public: + explicit ASTDotGraphLabelVisitor( + absl::AnyInvocable< + absl::flat_hash_set(std::string_view key) const> + lookup_fn) + : lookup_fn_(std::move(lookup_fn)) {} + virtual std::string Visit(const UnionNode& node) { return name_visitor_.Visit(node); } @@ -49,11 +55,17 @@ class ASTDotGraphLabelVisitor : public ASTStringVisitor { } virtual std::string Visit(const ValueNode& node) { - return absl::StrCat(ToString(node.Keys()), "->", ToString(Eval(node))); + return absl::StrCat( + ToString(node.Keys()), "->", + ToString(Eval>( + node, [this](std::string_view key) { return lookup_fn_(key); }))); } private: ASTNameVisitor name_visitor_; + absl::AnyInvocable(std::string_view key) + const> + lookup_fn_; }; std::string DotNodeName(const Node& node, uint32_t namecnt) { @@ -61,8 +73,11 @@ std::string DotNodeName(const Node& node, uint32_t namecnt) { return absl::StrCat(node.Accept(name_visitor), namecnt); } -std::string ToDotGraphBody(const Node& node, uint32_t* namecnt) { - ASTDotGraphLabelVisitor label_visitor; +std::string ToDotGraphBody( + const Node& node, uint32_t* namecnt, + std::function(std::string_view)> + lookup_fn) { + ASTDotGraphLabelVisitor label_visitor(lookup_fn); const std::string label = node.Accept(label_visitor); const std::string node_name = DotNodeName(node, *namecnt); std::string dot_str = absl::StrCat(node_name, " [label=\"", label, "\"]\n"); @@ -71,25 +86,29 @@ std::string ToDotGraphBody(const Node& node, uint32_t* namecnt) { const std::string arrow = absl::StrCat(node_name, " -- ", DotNodeName(*node.Left(), *namecnt)); absl::StrAppend(&dot_str, arrow, "\n", - ToDotGraphBody(*node.Left(), namecnt)); + ToDotGraphBody(*node.Left(), namecnt, lookup_fn)); } if (node.Right() != nullptr) { *namecnt = *namecnt + 1; const std::string arrow = absl::StrCat(node_name, " -- ", DotNodeName(*node.Right(), *namecnt)); absl::StrAppend(&dot_str, arrow, "\n", - ToDotGraphBody(*node.Right(), namecnt)); + ToDotGraphBody(*node.Right(), namecnt, lookup_fn)); } return dot_str; } } // namespace -void QueryDotWriter::WriteAst(std::string_view query, const Node& node) { +void QueryDotWriter::WriteAst( + std::string_view query, const Node& node, + std::function(std::string_view)> + lookup_fn) { uint32_t namecnt = 0; const std::string title = absl::StrCat("labelloc=\"t\"\nlabel=\"AST for Query: ", query, "\"\n"); - file_ << absl::StrCat("graph {\n", title, ToDotGraphBody(node, &namecnt), + file_ << absl::StrCat("graph {\n", title, + ToDotGraphBody(node, &namecnt, std::move(lookup_fn)), "\n}\n"); } diff --git a/components/tools/query_dot.h b/components/tools/query_dot.h index 095dc0d0..9894d631 100644 --- a/components/tools/query_dot.h +++ b/components/tools/query_dot.h @@ -32,7 +32,10 @@ class QueryDotWriter { explicit QueryDotWriter(std::string_view path) : file_(path.data()) {} ~QueryDotWriter() { file_.close(); } // Outputs the dot representation of the AST node to the output path. - void WriteAst(const std::string_view query, const Node& node); + void WriteAst( + const std::string_view query, const Node& node, + std::function(std::string_view key)> + lookup_fn); void Flush(); private: diff --git a/components/tools/query_toy.cc b/components/tools/query_toy.cc index eab51355..76d66b04 100644 --- a/components/tools/query_toy.cc +++ b/components/tools/query_toy.cc @@ -29,8 +29,6 @@ #include "absl/container/flat_hash_set.h" #include "absl/flags/flag.h" #include "absl/flags/parse.h" -#include "absl/flags/usage.h" -#include "absl/strings/str_join.h" #include "components/query/driver.h" #include "components/query/scanner.h" #include "components/tools/query_dot.h" @@ -47,6 +45,8 @@ ABSL_FLAG( "Output is written to the provided, which can then be visualized. See " "https://graphviz.org/ for details."); +// TODO: Add a flag to allow using uin32_t sets. + absl::flat_hash_map> kDb = { {"A", {"a", "b", "c"}}, {"B", {"b", "c", "d"}}, @@ -86,27 +86,28 @@ absl::flat_hash_set ToView( return result; } +absl::flat_hash_set Lookup(std::string_view key) { + const auto& it = kDb.find(key); + if (it != kDb.end()) { + return ToView(it->second); + } + return kEmptySet; +} + absl::StatusOr> Parse( kv_server::Driver& driver, std::string query) { std::istringstream stream(query); kv_server::Scanner scanner(stream); kv_server::Parser parse(driver, scanner); int parse_result = parse(); - auto result = driver.GetResult(); + auto result = + driver.EvaluateQuery>(Lookup); if (parse_result && result.ok()) { std::cerr << "Unexpected failed parse result with an OK query result."; } return result; } -absl::flat_hash_set Lookup(std::string_view key) { - const auto& it = kDb.find(key); - if (it != kDb.end()) { - return ToView(it->second); - } - return kEmptySet; -} - void ProcessQuery(kv_server::Driver& driver, std::string query) { const auto result = Parse(driver, query); if (!result.ok()) { @@ -125,7 +126,7 @@ void PromptForQuery( std::getline(std::cin, query); ProcessQuery(driver, query); if (dot_writer && driver.GetRootNode()) { - dot_writer->WriteAst(query, *driver.GetRootNode()); + dot_writer->WriteAst(query, *driver.GetRootNode(), Lookup); dot_writer->Flush(); } } @@ -138,7 +139,7 @@ void SignalHandler(int signal) { int main(int argc, char* argv[]) { absl::ParseCommandLine(argc, argv); - kv_server::Driver driver(Lookup); + kv_server::Driver driver; const std::string query = absl::GetFlag(FLAGS_query); const std::optional dot_path = absl::GetFlag(FLAGS_dot_path); std::optional dot_writer = @@ -148,7 +149,7 @@ int main(int argc, char* argv[]) { if (!query.empty()) { ProcessQuery(driver, query); if (dot_writer && driver.GetRootNode()) { - dot_writer->WriteAst(query, *driver.GetRootNode()); + dot_writer->WriteAst(query, *driver.GetRootNode(), Lookup); } return 0; } diff --git a/components/tools/realtime_notifier.cc b/components/tools/realtime_notifier.cc index 65667625..c1fb2800 100644 --- a/components/tools/realtime_notifier.cc +++ b/components/tools/realtime_notifier.cc @@ -24,6 +24,7 @@ #include "absl/strings/str_join.h" #include "components/data/common/msg_svc.h" #include "components/tools/publisher_service.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "components/util/platform_initializer.h" #include "public/data_loading/data_loading_generated.h" #include "public/data_loading/filename_utils.h" @@ -86,7 +87,7 @@ void Print(std::string string_decoded) { absl::Status Run() { PlatformInitializer initializer; NotifierMetadata metadata; - kv_server::InitMetricsContextMap(); + kv_server::ConfigureTelemetryForTools(); // TODO(b/299623229): Remove CLOUD_PLATFORM_LOCAL macro and extract to // publisher_service. #if defined(CLOUD_PLATFORM_LOCAL) diff --git a/components/tools/realtime_updates_publisher.cc b/components/tools/realtime_updates_publisher.cc index 65854d17..5629a0fb 100644 --- a/components/tools/realtime_updates_publisher.cc +++ b/components/tools/realtime_updates_publisher.cc @@ -26,6 +26,7 @@ #include "components/data/common/msg_svc.h" #include "components/tools/concurrent_publishing_engine.h" #include "components/tools/publisher_service.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "components/util/platform_initializer.h" ABSL_FLAG(std::string, deltas_folder_path, "", @@ -57,6 +58,7 @@ void PopulateQueue(const std::string& deltas_folder_path) { absl::Status Run() { PlatformInitializer initializer; + kv_server::ConfigureTelemetryForTools(); auto maybe_notifier_metadata = PublisherService::GetNotifierMetadata(); if (!maybe_notifier_metadata.ok()) { return maybe_notifier_metadata.status(); diff --git a/components/tools/sharding_correctness_validator/BUILD.bazel b/components/tools/sharding_correctness_validator/BUILD.bazel index 160abb3d..643d25c7 100644 --- a/components/tools/sharding_correctness_validator/BUILD.bazel +++ b/components/tools/sharding_correctness_validator/BUILD.bazel @@ -17,9 +17,21 @@ load("@rules_cc//cc:defs.bzl", "cc_binary") cc_binary( name = "validator", srcs = ["validator.cc"], + copts = select({ + "//:aws_platform": ["-DCLOUD_PLATFORM_AWS=1"], + "//:gcp_platform": ["-DCLOUD_PLATFORM_GCP=1"], + "//conditions:default": [], + }), deps = [ + "//components/cloud_config:parameter_client", + "//components/data_server/request_handler:ohttp_client_encryptor", + "//components/data_server/server:key_fetcher_factory", + "//components/data_server/server:parameter_fetcher", + "//components/tools/util:configure_telemetry_tools", + "//components/util:platform_initializer", "//public/applications/pa:response_utils", "//public/query/cpp:grpc_client", + "@com_github_google_quiche//quiche:binary_http_unstable_api", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/log", @@ -27,5 +39,6 @@ cc_binary( "@com_google_absl//absl/log:initialize", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", + "@google_privacysandbox_servers_common//src/communication:encoding_utils", ], ) diff --git a/components/tools/sharding_correctness_validator/validator.cc b/components/tools/sharding_correctness_validator/validator.cc index 5299bf79..a33af827 100644 --- a/components/tools/sharding_correctness_validator/validator.cc +++ b/components/tools/sharding_correctness_validator/validator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "absl/flags/flag.h" @@ -22,8 +23,19 @@ #include "absl/random/random.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "components/cloud_config/parameter_client.h" +#include "components/data_server/request_handler/ohttp_client_encryptor.h" +#include "components/data_server/server/key_fetcher_factory.h" +#include "components/data_server/server/parameter_fetcher.h" +#include "components/tools/util/configure_telemetry_tools.h" +#include "components/util/platform_initializer.h" #include "public/applications/pa/response_utils.h" #include "public/query/cpp/grpc_client.h" +#include "public/query/v2/get_values_v2.grpc.pb.h" +#include "quiche/binary_http/binary_http_message.h" +#include "src/communication/encoding_utils.h" + +ABSL_DECLARE_FLAG(std::string, gcp_project_id); ABSL_FLAG(std::string, kv_endpoint, ":50051", "KV grpc endpoint"); ABSL_FLAG(int, inclusive_upper_bound, 999999999, "Inclusive upper bound"); @@ -33,17 +45,151 @@ ABSL_FLAG(int, value_size, 10000, "Specify the size of value for the key"); ABSL_FLAG(int, batch_size, 10, "Batch size"); ABSL_FLAG(std::string, key_prefix, "foo", "Key prefix"); ABSL_FLAG(bool, use_tls, false, "Whether to use TLS for grpc calls."); +ABSL_FLAG(std::string, environment, "NOT_SPECIFIED", "Environment name."); +ABSL_FLAG(bool, use_coordinator, false, + "Whether to use coordinator for query encryption."); namespace kv_server { namespace { +inline constexpr std::string_view kPublicKeyEndpointParameterSuffix = + "public-key-endpoint"; +inline constexpr std::string_view kUseRealCoordinatorsParameterSuffix = + "use-real-coordinators"; +inline constexpr std::string_view kContentTypeHeader = "content-type"; +inline constexpr std::string_view kContentEncodingProtoHeaderValue = + "application/protobuf"; + absl::BitGen bitgen; int total_failures = 0; int total_mismatches = 0; +privacy_sandbox::server_common::CloudPlatform GetCloudPlatform() { +#if defined(CLOUD_PLATFORM_AWS) + return privacy_sandbox::server_common::CloudPlatform::kAws; +#elif defined(CLOUD_PLATFORM_GCP) + return privacy_sandbox::server_common::CloudPlatform::kGcp; +#endif + return privacy_sandbox::server_common::CloudPlatform::kLocal; +} + int64_t Get(int64_t upper_bound) { return absl::Uniform(bitgen, 0, upper_bound); } +absl::StatusOr +GetPublicKey(std::unique_ptr& parameter_fetcher) { + if (!parameter_fetcher->GetBoolParameter( + kUseRealCoordinatorsParameterSuffix)) { + // The key_fetcher_manager would just return hard coded public key without + // involving private key fetching + auto factory = kv_server::KeyFetcherFactory::Create(); + auto key_fetcher_manager = + factory->CreateKeyFetcherManager(*parameter_fetcher); + auto maybe_public_key = + key_fetcher_manager->GetPublicKey(GetCloudPlatform()); + if (!maybe_public_key.ok()) { + const std::string error = + absl::StrCat("Could not get public key to use for HPKE encryption:", + maybe_public_key.status().message()); + LOG(ERROR) << error; + return absl::InternalError(error); + } + return maybe_public_key.value(); + } + + auto publicKeyEndpointParameter = + parameter_fetcher->GetParameter(kPublicKeyEndpointParameterSuffix); + LOG(INFO) << "Retrieved public_key_endpoint parameter: " + << publicKeyEndpointParameter; + std::vector endpoints = {publicKeyEndpointParameter}; + auto public_key_fetcher = + privacy_sandbox::server_common::PublicKeyFetcherFactory::Create( + {{GetCloudPlatform(), endpoints}}); + if (public_key_fetcher) { + absl::Status public_key_refresh_status = public_key_fetcher->Refresh(); + if (!public_key_refresh_status.ok()) { + const std::string error = absl::StrCat( + "Public key refresh failed: ", public_key_refresh_status.message()); + LOG(ERROR) << error; + return absl::InternalError(error); + } + } + return public_key_fetcher->GetKey(GetCloudPlatform()); +} + +absl::StatusOr GetValuesWithCoordinators( + const v2::GetValuesRequest& proto_req, + std::unique_ptr& stub, + std::unique_ptr& + public_key) { + std::string serialized_req; + if (!proto_req.SerializeToString(&serialized_req)) { + return absl::Status(absl::StatusCode::kUnknown, + absl::StrCat("Protobuf SerializeToString failed!")); + } + quiche::BinaryHttpRequest req_bhttp_layer({}); + req_bhttp_layer.AddHeaderField({ + .name = std::string(kContentTypeHeader), + .value = std::string(kContentEncodingProtoHeaderValue), + }); + req_bhttp_layer.set_body(serialized_req); + auto maybe_serialized_bhttp = req_bhttp_layer.Serialize(); + if (!maybe_serialized_bhttp.ok()) { + return absl::Status( + absl::StatusCode::kInternal, + absl::StrCat(maybe_serialized_bhttp.status().message())); + } + + if (!public_key) { + const std::string error = "public_key==nullptr, cannot proceed."; + LOG(ERROR) << error; + return absl::InternalError(error); + } + OhttpClientEncryptor encryptor(*public_key); + + auto encrypted_serialized_request_maybe = + encryptor.EncryptRequest(*maybe_serialized_bhttp); + if (!encrypted_serialized_request_maybe.ok()) { + return encrypted_serialized_request_maybe.status(); + } + v2::ObliviousGetValuesRequest ohttp_req; + ohttp_req.mutable_raw_body()->set_data(*encrypted_serialized_request_maybe); + google::api::HttpBody ohttp_res; + grpc::ClientContext context; + grpc::Status status = + stub->ObliviousGetValues(&context, ohttp_req, &ohttp_res); + if (!status.ok()) { + LOG(ERROR) << status.error_code() << ": " << status.error_message(); + return absl::Status((absl::StatusCode)status.error_code(), + status.error_message()); + } + auto decrypted_ohttp_response_maybe = + encryptor.DecryptResponse(std::move(ohttp_res.data())); + if (!decrypted_ohttp_response_maybe.ok()) { + LOG(ERROR) << "ohttp response decryption failed!"; + return decrypted_ohttp_response_maybe.status(); + } + auto deframed_req = privacy_sandbox::server_common::DecodeRequestPayload( + *decrypted_ohttp_response_maybe); + if (!deframed_req.ok()) { + LOG(ERROR) << "unpadding response failed!"; + return deframed_req.status(); + } + const absl::StatusOr maybe_res_bhttp_layer = + quiche::BinaryHttpResponse::Create(deframed_req->compressed_data); + if (!maybe_res_bhttp_layer.ok()) { + LOG(ERROR) << "Failed to create bhttp resonse layer!"; + return maybe_res_bhttp_layer.status(); + } + v2::GetValuesResponse get_value_response; + if (!get_value_response.ParseFromString( + std::string(maybe_res_bhttp_layer->body()))) { + return absl::Status(absl::StatusCode::kUnknown, + absl::StrCat("Protobuf ParseFromString failed!")); + } + return get_value_response; +} + v2::GetValuesRequest GetRequest(const std::vector& input_values) { v2::GetValuesRequest req; v2::RequestPartition* partition = req.add_partitions(); @@ -117,7 +263,9 @@ void ValidateResponse(absl::StatusOr maybe_response, } } -void Validate() { +void Validate( + std::unique_ptr& + public_key) { const std::string kv_endpoint = absl::GetFlag(FLAGS_kv_endpoint); const int inclusive_upper_bound = absl::GetFlag(FLAGS_inclusive_upper_bound); const int batch_size = absl::GetFlag(FLAGS_batch_size); @@ -142,7 +290,13 @@ void Validate() { random_index *= batch_size; std::vector keys = GetKeys(random_index, batch_size); auto req = GetRequest(keys); - ValidateResponse(client.GetValues(req), keys); + absl::StatusOr get_value_response; + if (absl::GetFlag(FLAGS_use_coordinator)) { + get_value_response = GetValuesWithCoordinators(req, stub, public_key); + } else { + get_value_response = client.GetValues(req); + } + ValidateResponse(get_value_response, keys); requests_made_this_second++; // rate limit to N files per second if (requests_made_this_second % qps == 0) { @@ -179,7 +333,41 @@ void Validate() { int main(int argc, char** argv) { const std::vector commands = absl::ParseCommandLine(argc, argv); absl::InitializeLog(); - kv_server::Validate(); + kv_server::ConfigureTelemetryForTools(); + + // ptrs for validation with coordinators + std::unique_ptr platform_initializer; + std::unique_ptr parameter_client; + std::unique_ptr parameter_fetcher; + std::unique_ptr + public_key; + + if (absl::GetFlag(FLAGS_use_coordinator)) { + // Initializes GCP platform and its parameter client. + platform_initializer = std::make_unique(); + parameter_client = kv_server::ParameterClient::Create(); + + // Gets environment name + std::string environment = absl::GetFlag(FLAGS_environment); + if (environment == "NOT_SPECIFIED") { + LOG(ERROR) << "Flag environment is required to get key fetch parameters"; + return 1; + } + + // Create parameter fetcher and key fetcher manager + auto parameter_fetcher = std::make_unique( + environment, *parameter_client); + auto maybe_public_key = kv_server::GetPublicKey(parameter_fetcher); + if (!maybe_public_key.ok()) { + LOG(ERROR) << "GetPublicKey failed with error: " + << maybe_public_key.status().message(); + return 1; + } + public_key = + std::make_unique( + maybe_public_key.value()); + } + kv_server::Validate(public_key); if (kv_server::total_failures > 0 || kv_server::total_mismatches > 0) { LOG(ERROR) << "Validation failed with total_failures: " @@ -187,6 +375,6 @@ int main(int argc, char** argv) { << ", total_mismatches: " << kv_server::total_mismatches; return 1; } - + LOG(INFO) << "Query Validation succeed!"; return 0; } diff --git a/components/tools/util/BUILD.bazel b/components/tools/util/BUILD.bazel new file mode 100644 index 00000000..0470f400 --- /dev/null +++ b/components/tools/util/BUILD.bazel @@ -0,0 +1,17 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") + +package(default_visibility = [ + "//components/tools:__subpackages__", + "//production/packaging/tools:__subpackages__", + "//tools:__subpackages__", +]) + +cc_library( + name = "configure_telemetry_tools", + hdrs = ["configure_telemetry_tools.h"], + deps = [ + "//components/telemetry:server_definition", + "@google_privacysandbox_servers_common//src/logger:request_context_impl", + "@io_opentelemetry_cpp//exporters/ostream:ostream_log_record_exporter", + ], +) diff --git a/components/tools/util/configure_telemetry_tools.h b/components/tools/util/configure_telemetry_tools.h new file mode 100644 index 00000000..3eca5a4e --- /dev/null +++ b/components/tools/util/configure_telemetry_tools.h @@ -0,0 +1,49 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_TOOLS_UTIL_CONFIGURE_TELEMETRY_TOOLS_H_ +#define COMPONENTS_TOOLS_UTIL_CONFIGURE_TELEMETRY_TOOLS_H_ +#include + +#include "components/telemetry/server_definition.h" +#include "opentelemetry/exporters/ostream/log_record_exporter.h" +#include "opentelemetry/sdk/logs/logger_provider_factory.h" +#include "opentelemetry/sdk/logs/simple_log_record_processor_factory.h" +#include "src/logger/request_context_impl.h" + +namespace kv_server { + +// Configure telemetry and logger for tools +inline void ConfigureTelemetryForTools() { + // Init noop telemetry for metrics + InitMetricsContextMap(); + // Init logger to write to console + static opentelemetry::logs::LoggerProvider* logger_provider = + opentelemetry::sdk::logs::LoggerProviderFactory::Create( + opentelemetry::sdk::logs::SimpleLogRecordProcessorFactory::Create( + std::make_unique< + opentelemetry::exporter::logs::OStreamLogRecordExporter>( + std::cerr))) + .release(); + privacy_sandbox::server_common::log::logger_private = + logger_provider->GetLogger("default").get(); + // Turn on all logging + privacy_sandbox::server_common::log::AlwaysLogOtel(true); +} + +} // namespace kv_server + +#endif // COMPONENTS_TOOLS_UTIL_CONFIGURE_TELEMETRY_TOOLS_H_ diff --git a/components/udf/BUILD.bazel b/components/udf/BUILD.bazel index 5f7e7525..21fdd0d9 100644 --- a/components/udf/BUILD.bazel +++ b/components/udf/BUILD.bazel @@ -39,6 +39,7 @@ cc_library( ], deps = [ ":code_config", + "//components/errors:error_tag", "//components/errors:retry", "//components/udf/hooks:get_values_hook", "//components/udf/hooks:run_query_hook", @@ -48,8 +49,10 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", "@google_privacysandbox_servers_common//src/roma/interface", "@google_privacysandbox_servers_common//src/roma/roma_service", + "@google_privacysandbox_servers_common//src/util:duration", ], ) @@ -107,6 +110,7 @@ cc_test( "//components/udf/hooks:run_query_hook", "//public/query/v2:get_values_v2_cc_proto", "//public/test_util:proto_matcher", + "//public/test_util:request_example", "//public/udf:constants", "@com_google_absl//absl/log:scoped_mock_log", "@com_google_absl//absl/status", @@ -114,6 +118,7 @@ cc_test( "@com_google_googletest//:gtest_main", "@google_privacysandbox_servers_common//src/roma/interface", "@google_privacysandbox_servers_common//src/roma/roma_service", + "@io_opentelemetry_cpp//exporters/ostream:ostream_log_record_exporter", ], ) diff --git a/components/udf/hooks/BUILD.bazel b/components/udf/hooks/BUILD.bazel index c3a37094..fa917de1 100644 --- a/components/udf/hooks/BUILD.bazel +++ b/components/udf/hooks/BUILD.bazel @@ -45,9 +45,6 @@ cc_library( cc_library( name = "run_query_hook", - srcs = [ - "run_query_hook.cc", - ], hdrs = [ "run_query_hook.h", ], @@ -59,6 +56,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@google_privacysandbox_servers_common//src/roma/interface", "@google_privacysandbox_servers_common//src/roma/interface:function_binding_io_cc_proto", "@nlohmann_json//:lib", diff --git a/components/udf/hooks/get_values_hook.cc b/components/udf/hooks/get_values_hook.cc index e35fa682..b40c05af 100644 --- a/components/udf/hooks/get_values_hook.cc +++ b/components/udf/hooks/get_values_hook.cc @@ -65,7 +65,8 @@ void SetStatusAsBytes(absl::StatusCode code, std::string_view message, } void SetOutputAsBytes(const InternalLookupResponse& response, - FunctionBindingIoProto& io) { + FunctionBindingIoProto& io, + const RequestContext& request_context) { BinaryGetValuesResponse binary_response; for (auto&& [k, v] : response.kv_pairs()) { Value value; @@ -91,14 +92,18 @@ void SetStatusAsString(absl::StatusCode code, std::string_view message, } void SetOutputAsString(const InternalLookupResponse& response, - FunctionBindingIoProto& io) { - VLOG(9) << "Processing internal lookup response"; + FunctionBindingIoProto& io, + const RequestContext& request_context) { + PS_VLOG(9, request_context.GetPSLogContext()) + << "Processing internal lookup response"; std::string kv_pairs_json; if (const auto json_status = MessageToJsonString(response, &kv_pairs_json); !json_status.ok()) { SetStatusAsString(json_status.code(), json_status.message(), io); - LOG(ERROR) << "MessageToJsonString failed with " << json_status; - VLOG(1) << "getValues result: " << io.DebugString(); + PS_LOG(ERROR, request_context.GetPSLogContext()) + << "MessageToJsonString failed with " << json_status; + PS_VLOG(1, request_context.GetPSLogContext()) + << "getValues result: " << io.DebugString(); return; } @@ -108,7 +113,8 @@ void SetOutputAsString(const InternalLookupResponse& response, if (kv_pairs_json_object.is_discarded()) { SetStatusAsString(absl::StatusCode::kInvalidArgument, "Error while parsing JSON string.", io); - LOG(ERROR) << "json parse failed for " << kv_pairs_json; + PS_LOG(ERROR, request_context.GetPSLogContext()) + << "json parse failed for " << kv_pairs_json; return; } @@ -128,21 +134,30 @@ class GetValuesHookImpl : public GetValuesHook { } } - void operator()(FunctionBindingPayload& payload) { - VLOG(9) << "Called getValues hook"; + void operator()( + FunctionBindingPayload>& payload) { + std::shared_ptr request_context = payload.metadata.lock(); + if (request_context == nullptr) { + PS_VLOG(1) << "Request context is not available, the request might " + "have been marked as complete"; + return; + } + PS_VLOG(9, request_context->GetPSLogContext()) << "Called getValues hook"; if (lookup_ == nullptr) { SetStatus(absl::StatusCode::kInternal, "getValues has not been initialized yet", payload.io_proto); - LOG(ERROR) + PS_LOG(ERROR, request_context->GetPSLogContext()) << "getValues hook is not initialized properly: lookup is nullptr"; return; } - VLOG(9) << "getValues request: " << payload.io_proto.DebugString(); + PS_VLOG(9, request_context->GetPSLogContext()) + << "getValues request: " << payload.io_proto.DebugString(); if (!payload.io_proto.has_input_list_of_string()) { SetStatus(absl::StatusCode::kInvalidArgument, "getValues input must be list of strings", payload.io_proto); - VLOG(1) << "getValues result: " << payload.io_proto.DebugString(); + PS_VLOG(1, request_context->GetPSLogContext()) + << "getValues result: " << payload.io_proto.DebugString(); return; } @@ -151,18 +166,21 @@ class GetValuesHookImpl : public GetValuesHook { keys.insert(key); } - VLOG(9) << "Calling internal lookup client"; + PS_VLOG(9, request_context->GetPSLogContext()) + << "Calling internal lookup client"; absl::StatusOr response_or_status = - lookup_->GetKeyValues(payload.metadata, keys); + lookup_->GetKeyValues(*request_context, keys); if (!response_or_status.ok()) { SetStatus(response_or_status.status().code(), response_or_status.status().message(), payload.io_proto); - VLOG(1) << "getValues result: " << payload.io_proto.DebugString(); + PS_VLOG(1, request_context->GetPSLogContext()) + << "getValues result: " << payload.io_proto.DebugString(); return; } - SetOutput(response_or_status.value(), payload.io_proto); - VLOG(9) << "getValues result: " << payload.io_proto.DebugString(); + SetOutput(response_or_status.value(), payload.io_proto, *request_context); + PS_VLOG(9, request_context->GetPSLogContext()) + << "getValues result: " << payload.io_proto.DebugString(); } private: @@ -176,11 +194,12 @@ class GetValuesHookImpl : public GetValuesHook { } void SetOutput(const InternalLookupResponse& response, - FunctionBindingIoProto& io) { + FunctionBindingIoProto& io, + const RequestContext& request_context) { if (output_type_ == OutputType::kString) { - SetOutputAsString(response, io); + SetOutputAsString(response, io, request_context); } else { - SetOutputAsBytes(response, io); + SetOutputAsBytes(response, io, request_context); } } diff --git a/components/udf/hooks/get_values_hook.h b/components/udf/hooks/get_values_hook.h index 24d576ac..733756c0 100644 --- a/components/udf/hooks/get_values_hook.h +++ b/components/udf/hooks/get_values_hook.h @@ -46,7 +46,8 @@ class GetValuesHook { // This is registered with v8 and is exposed to the UDF. Internally, it calls // the internal lookup client. virtual void operator()( - google::scp::roma::FunctionBindingPayload& payload) = 0; + google::scp::roma::FunctionBindingPayload>& + payload) = 0; static std::unique_ptr Create(OutputType output_type); }; diff --git a/components/udf/hooks/get_values_hook_test.cc b/components/udf/hooks/get_values_hook_test.cc index 3d6f2b03..0edf28a3 100644 --- a/components/udf/hooks/get_values_hook_test.cc +++ b/components/udf/hooks/get_values_hook_test.cc @@ -40,7 +40,14 @@ using testing::Return; class GetValuesHookTest : public ::testing::Test { protected: - void SetUp() override { InitMetricsContextMap(); } + GetValuesHookTest() { + InitMetricsContextMap(); + request_context_ = std::make_shared(); + } + std::shared_ptr GetRequestContext() { + return request_context_; + } + std::shared_ptr request_context_; }; TEST_F(GetValuesHookTest, StringOutput_SuccessfullyProcessesValue) { @@ -65,9 +72,8 @@ TEST_F(GetValuesHookTest, StringOutput_SuccessfullyProcessesValue) { auto get_values_hook = GetValuesHook::Create(GetValuesHook::OutputType::kString); get_values_hook->FinishInit(std::move(mock_lookup)); - ScopeMetricsContext metrics_context; - FunctionBindingPayload payload{ - io, RequestContext(metrics_context)}; + FunctionBindingPayload> payload{ + io, GetRequestContext()}; (*get_values_hook)(payload); nlohmann::json result_json = @@ -104,9 +110,8 @@ TEST_F(GetValuesHookTest, StringOutput_SuccessfullyProcessesResultsWithStatus) { auto get_values_hook = GetValuesHook::Create(GetValuesHook::OutputType::kString); get_values_hook->FinishInit(std::move(mock_lookup)); - ScopeMetricsContext metrics_context; - FunctionBindingPayload payload{ - io, RequestContext(metrics_context)}; + FunctionBindingPayload> payload{ + io, GetRequestContext()}; (*get_values_hook)(payload); nlohmann::json expected = @@ -126,9 +131,8 @@ TEST_F(GetValuesHookTest, StringOutput_LookupReturnsError) { auto get_values_hook = GetValuesHook::Create(GetValuesHook::OutputType::kString); get_values_hook->FinishInit(std::move(mock_lookup)); - ScopeMetricsContext metrics_context; - FunctionBindingPayload payload{ - io, RequestContext(metrics_context)}; + FunctionBindingPayload> payload{ + io, GetRequestContext()}; (*get_values_hook)(payload); nlohmann::json expected = R"({"code":2,"message":"Some error"})"_json; @@ -144,9 +148,8 @@ TEST_F(GetValuesHookTest, StringOutput_InputIsNotListOfStrings) { auto get_values_hook = GetValuesHook::Create(GetValuesHook::OutputType::kString); get_values_hook->FinishInit(std::move(mock_lookup)); - ScopeMetricsContext metrics_context; - FunctionBindingPayload payload{ - io, RequestContext(metrics_context)}; + FunctionBindingPayload> payload{ + io, GetRequestContext()}; (*get_values_hook)(payload); nlohmann::json expected = @@ -177,9 +180,8 @@ TEST_F(GetValuesHookTest, BinaryOutput_SuccessfullyProcessesValue) { auto get_values_hook = GetValuesHook::Create(GetValuesHook::OutputType::kBinary); get_values_hook->FinishInit(std::move(mock_lookup)); - ScopeMetricsContext metrics_context; - FunctionBindingPayload payload{ - io, RequestContext(metrics_context)}; + FunctionBindingPayload> payload{ + io, GetRequestContext()}; (*get_values_hook)(payload); EXPECT_TRUE(io.has_output_bytes()); @@ -214,9 +216,8 @@ TEST_F(GetValuesHookTest, BinaryOutput_LookupReturnsError) { auto get_values_hook = GetValuesHook::Create(GetValuesHook::OutputType::kBinary); get_values_hook->FinishInit(std::move(mock_lookup)); - ScopeMetricsContext metrics_context; - FunctionBindingPayload payload{ - io, RequestContext(metrics_context)}; + FunctionBindingPayload> payload{ + io, GetRequestContext()}; (*get_values_hook)(payload); EXPECT_TRUE(io.has_output_bytes()); diff --git a/components/udf/hooks/logging_hook.h b/components/udf/hooks/logging_hook.h index 3e541600..020e557b 100644 --- a/components/udf/hooks/logging_hook.h +++ b/components/udf/hooks/logging_hook.h @@ -17,6 +17,7 @@ #ifndef COMPONENTS_UDF_LOGGING_HOOK_H_ #define COMPONENTS_UDF_LOGGING_HOOK_H_ +#include #include #include @@ -27,9 +28,17 @@ namespace kv_server { // Logging function to register with Roma. inline void LoggingFunction(absl::LogSeverity severity, - const RequestContext& context, + const std::weak_ptr& context, std::string_view msg) { - LOG(LEVEL(severity)) << msg; + std::shared_ptr request_context = context.lock(); + if (request_context == nullptr) { + PS_VLOG(1) << "Request context is not available, the request might " + "have been marked as complete"; + return; + } + PS_VLOG(9, request_context->GetPSLogContext()) << "Called logging hook"; + privacy_sandbox::server_common::log::LogWithPSLog( + severity, request_context->GetPSLogContext(), msg); } } // namespace kv_server diff --git a/components/udf/hooks/run_query_hook.cc b/components/udf/hooks/run_query_hook.cc deleted file mode 100644 index c2aa113f..00000000 --- a/components/udf/hooks/run_query_hook.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "components/udf/hooks/run_query_hook.h" - -#include -#include -#include -#include -#include - -#include "absl/functional/any_invocable.h" -#include "absl/log/log.h" -#include "absl/status/statusor.h" -#include "components/internal_server/lookup.h" -#include "nlohmann/json.hpp" - -namespace kv_server { -namespace { - -using google::scp::roma::FunctionBindingPayload; - -class RunQueryHookImpl : public RunQueryHook { - public: - void FinishInit(std::unique_ptr lookup) { - if (lookup_ == nullptr) { - lookup_ = std::move(lookup); - } - } - - void operator()(FunctionBindingPayload& payload) { - if (lookup_ == nullptr) { - nlohmann::json status; - status["code"] = absl::StatusCode::kInternal; - status["message"] = "runQuery has not been initialized yet"; - payload.io_proto.mutable_output_list_of_string()->add_data(status.dump()); - LOG(ERROR) - << "runQuery hook is not initialized properly: lookup is nullptr"; - return; - } - - VLOG(9) << "runQuery request: " << payload.io_proto.DebugString(); - if (!payload.io_proto.has_input_string()) { - nlohmann::json status; - status["code"] = absl::StatusCode::kInvalidArgument; - status["message"] = "runQuery input must be a string"; - payload.io_proto.mutable_output_list_of_string()->add_data(status.dump()); - VLOG(1) << "runQuery result: " << payload.io_proto.DebugString(); - return; - } - - VLOG(9) << "Calling internal run query client"; - absl::StatusOr response_or_status = - lookup_->RunQuery(payload.metadata, payload.io_proto.input_string()); - - if (!response_or_status.ok()) { - LOG(ERROR) << "Internal run query returned error: " - << response_or_status.status(); - payload.io_proto.mutable_output_list_of_string()->mutable_data(); - VLOG(1) << "runQuery result: " << payload.io_proto.DebugString(); - return; - } - - VLOG(9) << "Processing internal run query response"; - *payload.io_proto.mutable_output_list_of_string()->mutable_data() = - *std::move(response_or_status.value().mutable_elements()); - VLOG(9) << "runQuery result: " << payload.io_proto.DebugString(); - } - - private: - // `lookup_` is initialized separately, since its dependencies create threads. - // Lazy load is used to ensure that it only happens after Roma forks. - std::unique_ptr lookup_; -}; -} // namespace - -std::unique_ptr RunQueryHook::Create() { - return std::make_unique(); -} - -} // namespace kv_server diff --git a/components/udf/hooks/run_query_hook.h b/components/udf/hooks/run_query_hook.h index 35560c3e..db260fe8 100644 --- a/components/udf/hooks/run_query_hook.h +++ b/components/udf/hooks/run_query_hook.h @@ -19,35 +19,149 @@ #include #include -#include -#include +#include -#include "absl/functional/any_invocable.h" +#include "absl/strings/str_cat.h" #include "components/internal_server/lookup.h" #include "components/util/request_context.h" +#include "nlohmann/json.hpp" #include "src/roma/config/function_binding_object_v2.h" namespace kv_server { // Functor that acts as a wrapper for the internal query client call. -class RunQueryHook { +template +class RunSetQueryHook { public: - virtual ~RunQueryHook() = default; - // We need to split the hook init, since lookup depends on the cache. // However, UdfClient init requires the hook and it also forks. // The cache is only initialized after UdfClient init, so the hook // init can only be completed after UdfClient and cache init. - virtual void FinishInit(std::unique_ptr lookup) = 0; - + void FinishInit(std::unique_ptr lookup); // This is registered with v8 and is exposed to the UDF. Internally, it calls // the internal query client. - virtual void operator()( - google::scp::roma::FunctionBindingPayload& payload) = 0; + void operator()( + google::scp::roma::FunctionBindingPayload>& + payload); + static std::unique_ptr> Create(); + + private: + constexpr std::string_view HookName(); + void ReportErrorStatus( + absl::StatusCode error_code, std::string_view error_message, + google::scp::roma::FunctionBindingPayload>& + payload); - static std::unique_ptr Create(); + // `lookup_` is initialized separately, since its dependencies create + // threads. Lazy load is used to ensure that it only happens after Roma + // forks. + std::unique_ptr lookup_; }; +template +constexpr std::string_view RunSetQueryHook::HookName() { + if constexpr (std::is_same_v) { + return "runQuery"; + } + if constexpr (std::is_same_v) { + return "runSetQueryInt"; + } +} + +template +void RunSetQueryHook::ReportErrorStatus( + absl::StatusCode error_code, std::string_view error_message, + google::scp::roma::FunctionBindingPayload>& + payload) { + nlohmann::json status; + status["code"] = error_code; + status["message"] = error_message; + payload.io_proto.mutable_output_list_of_string()->add_data(status.dump()); +} + +template +void RunSetQueryHook::FinishInit(std::unique_ptr lookup) { + if (lookup_ == nullptr) { + lookup_ = std::move(lookup); + } +} + +template +void RunSetQueryHook::operator()( + google::scp::roma::FunctionBindingPayload>& + payload) { + std::shared_ptr request_context = payload.metadata.lock(); + if (request_context == nullptr) { + PS_VLOG(1) << "Request context is not available, the request might " + "have been marked as complete"; + return; + } + if (lookup_ == nullptr) { + ReportErrorStatus(absl::StatusCode::kInternal, + absl::StrCat(HookName(), " has not been initialized yet"), + payload); + PS_LOG(ERROR, request_context->GetPSLogContext()) << absl::StrCat( + HookName(), " hook is not initialized properly: lookup is nullptr"); + return; + } + PS_VLOG(9, request_context->GetPSLogContext()) + << HookName() << " request: " << payload.io_proto.DebugString(); + if (!payload.io_proto.has_input_string()) { + ReportErrorStatus(absl::StatusCode::kInvalidArgument, + absl::StrCat(HookName(), " input must be a string"), + payload); + PS_VLOG(1, request_context->GetPSLogContext()) + << HookName() << " result: " << payload.io_proto.DebugString(); + return; + } + PS_VLOG(9, request_context->GetPSLogContext()) + << "Calling internal " << HookName() << " client"; + absl::StatusOr response_or_status; + if constexpr (std::is_same_v) { + response_or_status = + lookup_->RunQuery(*request_context, payload.io_proto.input_string()); + } + if constexpr (std::is_same_v) { + response_or_status = lookup_->RunSetQueryInt( + *request_context, payload.io_proto.input_string()); + } + if (!response_or_status.ok()) { + PS_LOG(ERROR, request_context->GetPSLogContext()) + << "Internal " << HookName() + << " returned error: " << response_or_status.status(); + const auto& status = response_or_status.status(); + ReportErrorStatus( + status.code(), + absl::StrCat(HookName(), " failed with error: ", status.message()), + payload); + PS_VLOG(1, request_context->GetPSLogContext()) + << HookName() << " result: " << payload.io_proto.DebugString(); + return; + } + PS_VLOG(9, request_context->GetPSLogContext()) + << "Processing internal " << HookName() << " response"; + if constexpr (std::is_same_v) { + *payload.io_proto.mutable_output_list_of_string()->mutable_data() = + std::move(*response_or_status.value().mutable_elements()); + } + if constexpr (std::is_same_v) { + const auto& elements = response_or_status->elements(); + payload.io_proto.set_output_bytes(elements.data(), + elements.size() * sizeof(uint32_t)); + } + PS_VLOG(9, request_context->GetPSLogContext()) + << HookName() << " result: " << payload.io_proto.DebugString(); +} + +template +std::unique_ptr> +RunSetQueryHook::Create() { + return std::make_unique>(); +} + +using RunSetQueryStringHook = RunSetQueryHook; +using RunSetQueryIntHook = RunSetQueryHook; + } // namespace kv_server #endif // COMPONENTS_UDF_RUN_QUERY_HOOK_H_ diff --git a/components/udf/hooks/run_query_hook_test.cc b/components/udf/hooks/run_query_hook_test.cc index 985513de..749d5edd 100644 --- a/components/udf/hooks/run_query_hook_test.cc +++ b/components/udf/hooks/run_query_hook_test.cc @@ -15,16 +15,14 @@ #include "components/udf/hooks/run_query_hook.h" #include -#include #include -#include #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "components/internal_server/mocks.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" +#include "public/test_util/proto_matcher.h" namespace kv_server { namespace { @@ -38,7 +36,17 @@ using testing::UnorderedElementsAreArray; class RunQueryHookTest : public ::testing::Test { protected: - void SetUp() override { InitMetricsContextMap(); } + RunQueryHookTest() { + InitMetricsContextMap(); + request_context_ = std::make_unique(); + request_context_->UpdateLogContext( + privacy_sandbox::server_common::LogContext(), + privacy_sandbox::server_common::ConsentedDebugConfiguration()); + } + std::shared_ptr GetRequestContext() { + return request_context_; + } + std::shared_ptr request_context_; }; TEST_F(RunQueryHookTest, SuccessfullyProcessesValue) { @@ -52,16 +60,38 @@ TEST_F(RunQueryHookTest, SuccessfullyProcessesValue) { FunctionBindingIoProto io; TextFormat::ParseFromString(R"pb(input_string: "Q")pb", &io); - auto run_query_hook = RunQueryHook::Create(); + auto run_query_hook = RunSetQueryStringHook::Create(); run_query_hook->FinishInit(std::move(mock_lookup)); - ScopeMetricsContext metrics_context; - RequestContext request_context(metrics_context); - FunctionBindingPayload payload{io, request_context}; + FunctionBindingPayload> payload{ + io, GetRequestContext()}; (*run_query_hook)(payload); EXPECT_THAT(io.output_list_of_string().data(), UnorderedElementsAreArray({"a", "b"})); } +TEST_F(RunQueryHookTest, VerifyProcessingIntSetsSuccessfully) { + InternalRunSetQueryIntResponse run_query_response; + TextFormat::ParseFromString(R"pb(elements: 1000 elements: 1001)pb", + &run_query_response); + auto mock_lookup = std::make_unique(); + EXPECT_CALL(*mock_lookup, RunSetQueryInt(_, "Q")) + .WillOnce(Return(run_query_response)); + FunctionBindingIoProto io; + TextFormat::ParseFromString(R"pb(input_string: "Q")pb", &io); + auto run_query_hook = RunSetQueryIntHook::Create(); + run_query_hook->FinishInit(std::move(mock_lookup)); + FunctionBindingPayload> payload{ + io, GetRequestContext()}; + (*run_query_hook)(payload); + ASSERT_TRUE(io.has_output_bytes()); + InternalRunSetQueryIntResponse actual_response; + actual_response.mutable_elements()->Resize( + io.output_bytes().size() / sizeof(uint32_t), 0); + std::memcpy(actual_response.mutable_elements()->mutable_data(), + io.output_bytes().data(), io.output_bytes().size()); + EXPECT_THAT(actual_response, EqualsProto(run_query_response)); +} + TEST_F(RunQueryHookTest, RunQueryClientReturnsError) { std::string query = "Q"; auto mock_lookup = std::make_unique(); @@ -70,13 +100,32 @@ TEST_F(RunQueryHookTest, RunQueryClientReturnsError) { FunctionBindingIoProto io; TextFormat::ParseFromString(R"pb(input_string: "Q")pb", &io); - auto run_query_hook = RunQueryHook::Create(); + auto run_query_hook = RunSetQueryStringHook::Create(); + run_query_hook->FinishInit(std::move(mock_lookup)); + FunctionBindingPayload> payload{ + io, GetRequestContext()}; + (*run_query_hook)(payload); + EXPECT_THAT( + io.output_list_of_string().data(), + UnorderedElementsAreArray( + {R"({"code":2,"message":"runQuery failed with error: Some error"})"})); +} + +TEST_F(RunQueryHookTest, RunSetQueryIntClientReturnsError) { + auto mock_lookup = std::make_unique(); + EXPECT_CALL(*mock_lookup, RunSetQueryInt(_, "Q")) + .WillOnce(Return(absl::UnknownError("Some error"))); + FunctionBindingIoProto io; + TextFormat::ParseFromString(R"pb(input_string: "Q")pb", &io); + auto run_query_hook = RunSetQueryIntHook::Create(); run_query_hook->FinishInit(std::move(mock_lookup)); - ScopeMetricsContext metrics_context; - RequestContext request_context(metrics_context); - FunctionBindingPayload payload{io, request_context}; + FunctionBindingPayload> payload{ + io, GetRequestContext()}; (*run_query_hook)(payload); - EXPECT_TRUE(io.output_list_of_string().data().empty()); + EXPECT_THAT( + io.output_list_of_string().data(), + UnorderedElementsAreArray( + {R"({"code":2,"message":"runSetQueryInt failed with error: Some error"})"})); } TEST_F(RunQueryHookTest, InputIsNotString) { @@ -85,13 +134,11 @@ TEST_F(RunQueryHookTest, InputIsNotString) { FunctionBindingIoProto io; TextFormat::ParseFromString(R"pb(input_list_of_string { data: "key1" })pb", &io); - auto run_query_hook = RunQueryHook::Create(); + auto run_query_hook = RunSetQueryStringHook::Create(); run_query_hook->FinishInit(std::move(mock_lookup)); - ScopeMetricsContext metrics_context; - RequestContext request_context(metrics_context); - FunctionBindingPayload payload{io, request_context}; + FunctionBindingPayload> payload{ + io, GetRequestContext()}; (*run_query_hook)(payload); - EXPECT_THAT( io.output_list_of_string().data(), UnorderedElementsAreArray( diff --git a/components/udf/mocks.h b/components/udf/mocks.h index 8e8bbbc5..52a5d916 100644 --- a/components/udf/mocks.h +++ b/components/udf/mocks.h @@ -17,6 +17,7 @@ #ifndef COMPONENTS_UDF_MOCKS_H_ #define COMPONENTS_UDF_MOCKS_H_ +#include #include #include #include @@ -32,14 +33,21 @@ namespace kv_server { class MockUdfClient : public UdfClient { public: MOCK_METHOD((absl::StatusOr), ExecuteCode, - (RequestContext, std::vector), (const, override)); + (const RequestContextFactory&, std::vector, + ExecutionMetadata& execution_metadata), + (const, override)); MOCK_METHOD((absl::StatusOr), ExecuteCode, - (RequestContext, UDFExecutionMetadata&&, - const google::protobuf::RepeatedPtrField&), + (const RequestContextFactory&, UDFExecutionMetadata&&, + const google::protobuf::RepeatedPtrField&, + ExecutionMetadata& execution_metadata), (const, override)); MOCK_METHOD((absl::Status), Stop, (), (override)); - MOCK_METHOD((absl::Status), SetCodeObject, (CodeConfig), (override)); - MOCK_METHOD((absl::Status), SetWasmCodeObject, (CodeConfig), (override)); + MOCK_METHOD((absl::Status), SetCodeObject, + (CodeConfig, privacy_sandbox::server_common::log::PSLogContext&), + (override)); + MOCK_METHOD((absl::Status), SetWasmCodeObject, + (CodeConfig, privacy_sandbox::server_common::log::PSLogContext&), + (override)); }; } // namespace kv_server diff --git a/components/udf/noop_udf_client.cc b/components/udf/noop_udf_client.cc index e257a7e7..af45a4f1 100644 --- a/components/udf/noop_udf_client.cc +++ b/components/udf/noop_udf_client.cc @@ -31,23 +31,31 @@ namespace kv_server { namespace { class NoopUdfClientImpl : public UdfClient { public: - absl::StatusOr ExecuteCode(RequestContext request_context, - std::vector keys) const { + absl::StatusOr ExecuteCode( + const RequestContextFactory& request_context_factory, + std::vector keys, + ExecutionMetadata& execution_metadata) const { return ""; } absl::StatusOr ExecuteCode( - RequestContext request_context, UDFExecutionMetadata&&, - const google::protobuf::RepeatedPtrField& arguments) const { + const RequestContextFactory& request_context_factory, + UDFExecutionMetadata&&, + const google::protobuf::RepeatedPtrField& arguments, + ExecutionMetadata& execution_metadata) const { return ""; } absl::Status Stop() { return absl::OkStatus(); } - absl::Status SetCodeObject(CodeConfig code_config) { + absl::Status SetCodeObject( + CodeConfig code_config, + privacy_sandbox::server_common::log::PSLogContext& log_context) { return absl::OkStatus(); } - absl::Status SetWasmCodeObject(CodeConfig code_config) { + absl::Status SetWasmCodeObject( + CodeConfig code_config, + privacy_sandbox::server_common::log::PSLogContext& log_context) { return absl::OkStatus(); } }; diff --git a/components/udf/udf_client.cc b/components/udf/udf_client.cc index 8569c7f5..9ed410e9 100644 --- a/components/udf/udf_client.cc +++ b/components/udf/udf_client.cc @@ -27,14 +27,22 @@ #include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" #include "absl/time/time.h" +#include "components/errors/error_tag.h" #include "google/protobuf/util/json_util.h" #include "src/roma/config/config.h" #include "src/roma/interface/roma.h" #include "src/roma/roma_service/roma_service.h" +#include "src/util/duration.h" namespace kv_server { namespace { + +enum class ErrorTag : int { + kCodeUpdateTimeoutError = 1, + kUdfExecutionTimeoutError = 2 +}; + using google::protobuf::json::MessageToJsonString; using google::scp::roma::CodeObject; using google::scp::roma::Config; @@ -43,8 +51,6 @@ using google::scp::roma::kTimeoutDurationTag; using google::scp::roma::ResponseObject; using google::scp::roma::sandbox::roma_service::RomaService; -constexpr absl::Duration kCodeUpdateTimeout = absl::Seconds(1); - // Roma IDs and version numbers are required for execution. // We do not currently make use of IDs or the code version number, set them to // constants. @@ -54,17 +60,22 @@ constexpr int kUdfInterfaceVersion = 1; class UdfClientImpl : public UdfClient { public: - explicit UdfClientImpl( - Config&& config = Config(), - absl::Duration udf_timeout = absl::Seconds(5), int udf_min_log_level = 0) + explicit UdfClientImpl(Config>&& config = + Config>(), + absl::Duration udf_timeout = absl::Seconds(5), + absl::Duration udf_update_timeout = absl::Seconds(30), + int udf_min_log_level = 0) : udf_timeout_(udf_timeout), + udf_update_timeout_(udf_update_timeout), roma_service_(std::move(config)), udf_min_log_level_(udf_min_log_level) {} // Converts the arguments into plain JSON strings to pass to Roma. absl::StatusOr ExecuteCode( - RequestContext request_context, UDFExecutionMetadata&& execution_metadata, - const google::protobuf::RepeatedPtrField& arguments) const { + const RequestContextFactory& request_context_factory, + UDFExecutionMetadata&& execution_metadata, + const google::protobuf::RepeatedPtrField& arguments, + ExecutionMetadata& metadata) const { execution_metadata.set_udf_interface_version(kUdfInterfaceVersion); std::vector string_args; string_args.reserve(arguments.size() + 1); @@ -91,22 +102,26 @@ class UdfClientImpl : public UdfClient { } string_args.push_back(json_arg); } - return ExecuteCode(std::move(request_context), std::move(string_args)); + return ExecuteCode(request_context_factory, std::move(string_args), + metadata); } absl::StatusOr ExecuteCode( - RequestContext request_context, std::vector input) const { + const RequestContextFactory& request_context_factory, + std::vector input, ExecutionMetadata& metadata) const { std::shared_ptr response_status = std::make_shared(); std::shared_ptr result = std::make_shared(); std::shared_ptr notification = std::make_shared(); auto invocation_request = - BuildInvocationRequest(std::move(request_context), std::move(input)); - VLOG(9) << "Executing UDF with input arg(s): " - << absl::StrJoin(invocation_request.input, ","); + BuildInvocationRequest(request_context_factory, std::move(input)); + PS_VLOG(9, request_context_factory.Get().GetPSLogContext()) + << "Executing UDF with input arg(s): " + << absl::StrJoin(invocation_request.input, ","); + privacy_sandbox::server_common::Stopwatch stopwatch; const auto status = roma_service_.Execute( - std::make_unique>( + std::make_unique>>( std::move(invocation_request)), [notification, response_status, result](absl::StatusOr response) { @@ -118,37 +133,54 @@ class UdfClientImpl : public UdfClient { notification->Notify(); }); if (!status.ok()) { - LOG(ERROR) << "Error sending UDF for execution: " << status; + PS_LOG(ERROR, request_context_factory.Get().GetPSLogContext()) + << "Error sending UDF for execution: " << status; return status; } notification->WaitForNotificationWithTimeout(udf_timeout_); if (!notification->HasBeenNotified()) { - return absl::InternalError("Timed out waiting for UDF result."); + return StatusWithErrorTag( + absl::Status(absl::StatusCode::kDeadlineExceeded, + "Timed out waiting for UDF execution result."), + __FILE__, ErrorTag::kUdfExecutionTimeoutError); } if (!response_status->ok()) { - LOG(ERROR) << "Error executing UDF: " << *response_status; + PS_LOG(ERROR, request_context_factory.Get().GetPSLogContext()) + << "Error executing UDF: " << *response_status; return *response_status; } + // TODO(b/338813575): waiting on the K&B team. Once that's + // implemented we should just use that number. + const auto udf_execution_time = stopwatch.GetElapsedTime(); + metadata.custom_code_total_execution_time_micros = + absl::ToInt64Milliseconds(udf_execution_time); + LogIfError(request_context_factory.Get() + .GetUdfRequestMetricsContext() + .LogHistogram( + absl::ToDoubleMicroseconds(udf_execution_time))); return *result; } absl::Status Init() { return roma_service_.Init(); } absl::Status Stop() { return roma_service_.Stop(); } - absl::Status SetCodeObject(CodeConfig code_config) { + absl::Status SetCodeObject( + CodeConfig code_config, + privacy_sandbox::server_common::log::PSLogContext& log_context) { // Only update code if logical commit time is larger. if (logical_commit_time_ >= code_config.logical_commit_time) { - VLOG(1) << "Not updating code object. logical_commit_time " - << code_config.logical_commit_time - << " too small, should be greater than " << logical_commit_time_; + PS_VLOG(1, log_context) + << "Not updating code object. logical_commit_time " + << code_config.logical_commit_time + << " too small, should be greater than " << logical_commit_time_; return absl::OkStatus(); } std::shared_ptr response_status = std::make_shared(); std::shared_ptr notification = std::make_shared(); - VLOG(9) << "Setting UDF: " << code_config.js; + PS_VLOG(9, log_context) << "Setting UDF: " << code_config.js; CodeObject code_object = BuildCodeObject(std::move(code_config.js), std::move(code_config.wasm), code_config.version); @@ -161,28 +193,37 @@ class UdfClientImpl : public UdfClient { notification->Notify(); }); if (!load_status.ok()) { - LOG(ERROR) << "Error setting UDF Code object: " << load_status; + PS_LOG(ERROR, log_context) + << "Error setting UDF Code object: " << load_status; return load_status; } - notification->WaitForNotificationWithTimeout(kCodeUpdateTimeout); + notification->WaitForNotificationWithTimeout(udf_update_timeout_); if (!notification->HasBeenNotified()) { - return absl::InternalError("Timed out setting UDF code object."); + return StatusWithErrorTag( + absl::Status(absl::StatusCode::kDeadlineExceeded, + "Timed out setting UDF code object."), + __FILE__, ErrorTag::kCodeUpdateTimeoutError); } if (!response_status->ok()) { - LOG(ERROR) << "Error compiling UDF code object. " << *response_status; + PS_LOG(ERROR, log_context) + << "Error compiling UDF code object. " << *response_status; return *response_status; } handler_name_ = std::move(code_config.udf_handler_name); logical_commit_time_ = code_config.logical_commit_time; version_ = code_config.version; - VLOG(5) << "Successfully set UDF code object with handler_name " - << handler_name_; + PS_VLOG(5, log_context) + << "Successfully set UDF code object with handler_name " + << handler_name_; return absl::OkStatus(); } - absl::Status SetWasmCodeObject(CodeConfig code_config) { - const auto code_object_status = SetCodeObject(std::move(code_config)); + absl::Status SetWasmCodeObject( + CodeConfig code_config, + privacy_sandbox::server_common::log::PSLogContext& log_context) { + const auto code_object_status = + SetCodeObject(std::move(code_config), log_context); if (!code_object_status.ok()) { return code_object_status; } @@ -190,15 +231,16 @@ class UdfClientImpl : public UdfClient { } private: - InvocationStrRequest BuildInvocationRequest( - RequestContext request_context, std::vector input) const { + InvocationStrRequest> BuildInvocationRequest( + const RequestContextFactory& request_context_factory, + std::vector input) const { return {.id = kInvocationRequestId, .version_string = absl::StrCat("v", version_), .handler_name = handler_name_, .tags = {{std::string(kTimeoutDurationTag), FormatDuration(udf_timeout_)}}, .input = std::move(input), - .metadata = std::move(request_context), + .metadata = request_context_factory.GetWeakCopy(), .min_log_level = absl::LogSeverity(udf_min_log_level_)}; } @@ -214,6 +256,7 @@ class UdfClientImpl : public UdfClient { int64_t logical_commit_time_ = -1; int64_t version_ = 1; const absl::Duration udf_timeout_; + const absl::Duration udf_update_timeout_; int udf_min_log_level_; // Per b/299667930, RomaService has been extended to support metadata storage // as a side effect of RomaService::Execute(), making it no longer const. @@ -222,16 +265,16 @@ class UdfClientImpl : public UdfClient { // concerns about mutable or go/totw/174, RomaService is thread-safe, so // losing the thread-safety of usage within a const function is a lesser // concern. - mutable RomaService roma_service_; + mutable RomaService> roma_service_; }; } // namespace absl::StatusOr> UdfClient::Create( - Config&& config, absl::Duration udf_timeout, - int udf_min_log_level) { + Config>&& config, absl::Duration udf_timeout, + absl::Duration udf_update_timeout, int udf_min_log_level) { auto udf_client = std::make_unique( - std::move(config), udf_timeout, udf_min_log_level); + std::move(config), udf_timeout, udf_update_timeout, udf_min_log_level); const auto init_status = udf_client->Init(); if (!init_status.ok()) { return init_status; diff --git a/components/udf/udf_client.h b/components/udf/udf_client.h index abb9e8bd..cb4bbe85 100644 --- a/components/udf/udf_client.h +++ b/components/udf/udf_client.h @@ -28,11 +28,16 @@ #include "components/util/request_context.h" #include "google/protobuf/message.h" #include "public/api_schema.pb.h" +#include "src/logger/request_context_logger.h" #include "src/roma/config/config.h" #include "src/roma/interface/roma.h" namespace kv_server { +struct ExecutionMetadata { + std::optional custom_code_total_execution_time_micros; +}; + // Client to execute UDF class UdfClient { public: @@ -43,28 +48,41 @@ class UdfClient { // UDF signature. ABSL_DEPRECATED("Use ExecuteCode(metadata, arguments) instead") virtual absl::StatusOr ExecuteCode( - RequestContext request_context, std::vector keys) const = 0; + const RequestContextFactory& request_context_factory, + std::vector keys, + ExecutionMetadata& execution_metadata) const = 0; // Executes the UDF. Code object must be set before making // this call. virtual absl::StatusOr ExecuteCode( - RequestContext request_context, UDFExecutionMetadata&& execution_metadata, - const google::protobuf::RepeatedPtrField& arguments) - const = 0; + const RequestContextFactory& request_context_factory, + UDFExecutionMetadata&& execution_metadata, + const google::protobuf::RepeatedPtrField& arguments, + ExecutionMetadata& metadata) const = 0; virtual absl::Status Stop() = 0; // Sets the code object that will be used for UDF execution - virtual absl::Status SetCodeObject(CodeConfig code_config) = 0; + virtual absl::Status SetCodeObject( + CodeConfig code_config, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)) = 0; // Sets the WASM code object that will be used for UDF execution - virtual absl::Status SetWasmCodeObject(CodeConfig code_config) = 0; + virtual absl::Status SetWasmCodeObject( + CodeConfig code_config, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)) = 0; // Creates a UDF executor. This calls Roma::Init, which forks. static absl::StatusOr> Create( - google::scp::roma::Config&& config = - google::scp::roma::Config(), - absl::Duration udf_timeout = absl::Seconds(5), int udf_min_log_level = 0); + google::scp::roma::Config>&& config = + google::scp::roma::Config>(), + absl::Duration udf_timeout = absl::Seconds(5), + absl::Duration udf_update_timeout = absl::Seconds(30), + int udf_min_log_level = 0); }; } // namespace kv_server diff --git a/components/udf/udf_client_test.cc b/components/udf/udf_client_test.cc index 755d5a00..10cc5092 100644 --- a/components/udf/udf_client_test.cc +++ b/components/udf/udf_client_test.cc @@ -31,7 +31,12 @@ #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" +#include "opentelemetry/exporters/ostream/log_record_exporter.h" +#include "opentelemetry/sdk/logs/logger_provider_factory.h" +#include "opentelemetry/sdk/logs/simple_log_record_processor_factory.h" +#include "opentelemetry/sdk/resource/resource.h" #include "public/query/v2/get_values_v2.pb.h" +#include "public/test_util/request_example.h" #include "public/udf/constants.h" #include "src/roma/config/config.h" #include "src/roma/interface/roma.h" @@ -41,20 +46,30 @@ using google::scp::roma::Config; using google::scp::roma::FunctionBindingObjectV2; using google::scp::roma::FunctionBindingPayload; using testing::_; +using testing::ContainsRegex; +using testing::HasSubstr; using testing::Return; namespace kv_server { namespace { - absl::StatusOr> CreateUdfClient() { - Config config; + Config> config; config.number_of_workers = 1; return UdfClient::Create(std::move(config)); } class UdfClientTest : public ::testing::Test { protected: - void SetUp() override { InitMetricsContextMap(); } + UdfClientTest() { + privacy_sandbox::server_common::log::ServerToken( + kExampleConsentedDebugToken); + InitMetricsContextMap(); + request_context_factory_ = std::make_unique( + privacy_sandbox::server_common::LogContext(), + privacy_sandbox::server_common::ConsentedDebugConfiguration()); + } + std::unique_ptr request_context_factory_; + ExecutionMetadata execution_metadata_; }; TEST_F(UdfClientTest, UdfClient_Create_Success) { @@ -75,9 +90,8 @@ TEST_F(UdfClientTest, JsCallSucceeds) { .version = 1, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; - absl::StatusOr result = - udf_client.value()->ExecuteCode(RequestContext(metrics_context), {}); + absl::StatusOr result = udf_client.value()->ExecuteCode( + *request_context_factory_, {}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("Hello world!")"); @@ -85,6 +99,26 @@ TEST_F(UdfClientTest, JsCallSucceeds) { EXPECT_TRUE(stop.ok()); } +TEST_F(UdfClientTest, JsExceptionReturnsStatus) { + auto udf_client = CreateUdfClient(); + EXPECT_TRUE(udf_client.ok()); + + absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{ + .js = "function hello() { throw new Error('Oh no!'); }", + .udf_handler_name = "hello", + .logical_commit_time = 1, + .version = 1, + }); + EXPECT_TRUE(code_obj_status.ok()); + absl::StatusOr result = udf_client.value()->ExecuteCode( + *request_context_factory_, {}, execution_metadata_); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().message(), HasSubstr("Oh no!")); + + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} + TEST_F(UdfClientTest, RepeatedJsCallsSucceed) { auto udf_client = CreateUdfClient(); EXPECT_TRUE(udf_client.ok()); @@ -96,14 +130,13 @@ TEST_F(UdfClientTest, RepeatedJsCallsSucceed) { .version = 1, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; - absl::StatusOr result1 = - udf_client.value()->ExecuteCode(RequestContext(metrics_context), {}); + absl::StatusOr result1 = udf_client.value()->ExecuteCode( + *request_context_factory_, {}, execution_metadata_); EXPECT_TRUE(result1.ok()); EXPECT_EQ(*result1, R"("Hello world!")"); - absl::StatusOr result2 = - udf_client.value()->ExecuteCode(RequestContext(metrics_context), {}); + absl::StatusOr result2 = udf_client.value()->ExecuteCode( + *request_context_factory_, {}, execution_metadata_); EXPECT_TRUE(result2.ok()); EXPECT_EQ(*result2, R"("Hello world!")"); @@ -122,9 +155,8 @@ TEST_F(UdfClientTest, JsEchoCallSucceeds) { .version = 1, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {R"("ECHO")"}); + *request_context_factory_, {R"("ECHO")"}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("Hello world! \"ECHO\"")"); @@ -151,9 +183,8 @@ TEST_F(UdfClientTest, JsEchoCallSucceeds_SimpleUDFArg_string) { arg.mutable_data()->set_string_value("ECHO"); return arg; }()); - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {}, args); + *request_context_factory_, {}, args, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("Hello world! \"ECHO\"")"); @@ -181,9 +212,8 @@ TEST_F(UdfClientTest, JsEchoCallSucceeds_SimpleUDFArg_string_tagged) { arg.mutable_data()->set_string_value("ECHO"); return arg; }()); - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {}, args); + *request_context_factory_, {}, args, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("Hello world! {\"tags\":[\"tag1\"],\"data\":\"ECHO\"}")"); @@ -213,9 +243,8 @@ TEST_F(UdfClientTest, JsEchoCallSucceeds_SimpleUDFArg_string_tagged_list) { list_value->add_values()->set_string_value("key2"); return arg; }()); - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {}, args); + *request_context_factory_, {}, args, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ( *result, @@ -245,9 +274,8 @@ TEST_F(UdfClientTest, JsEchoCallSucceeds_SimpleUDFArg_struct) { .set_string_value("value"); return arg; }()); - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {}, args); + *request_context_factory_, {}, args, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("Hello world! {\"key\":\"value\"}")"); @@ -255,18 +283,19 @@ TEST_F(UdfClientTest, JsEchoCallSucceeds_SimpleUDFArg_struct) { EXPECT_TRUE(stop.ok()); } -static void udfCbEcho(FunctionBindingPayload& payload) { +static void udfCbEcho( + FunctionBindingPayload>& payload) { payload.io_proto.set_output_string("Echo: " + payload.io_proto.input_string()); } TEST_F(UdfClientTest, JsEchoHookCallSucceeds) { - auto function_object = - std::make_unique>(); + auto function_object = std::make_unique< + FunctionBindingObjectV2>>(); function_object->function_name = "echo"; function_object->function = udfCbEcho; - Config config; + Config> config; config.number_of_workers = 1; config.RegisterFunctionBinding(std::move(function_object)); absl::StatusOr> udf_client = @@ -280,9 +309,8 @@ TEST_F(UdfClientTest, JsEchoHookCallSucceeds) { .version = 1, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {R"("I'm a key")"}); + *request_context_factory_, {R"("I'm a key")"}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("Hello world! Echo: I'm a key")"); @@ -329,9 +357,8 @@ TEST_F(UdfClientTest, JsStringInWithGetValuesHookSucceeds) { .version = 1, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {R"("key1")"}); + *request_context_factory_, {R"("key1")"}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("Key: key1, Value: value1")"); @@ -380,9 +407,8 @@ TEST_F(UdfClientTest, JsJSONObjectInWithGetValuesHookSucceeds) { .version = 1, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {R"({"keys":["key1"]})"}); + *request_context_factory_, {R"({"keys":["key1"]})"}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("Key: key1, Value: value1")"); @@ -397,11 +423,11 @@ TEST_F(UdfClientTest, JsJSONObjectInWithRunQueryHookSucceeds) { TextFormat::ParseFromString(R"pb(elements: "a")pb", &response); ON_CALL(*mock_lookup, RunQuery(_, _)).WillByDefault(Return(response)); - auto run_query_hook = RunQueryHook::Create(); + auto run_query_hook = RunSetQueryStringHook::Create(); run_query_hook->FinishInit(std::move(mock_lookup)); UdfConfigBuilder config_builder; absl::StatusOr> udf_client = UdfClient::Create( - std::move(config_builder.RegisterRunQueryHook(*run_query_hook) + std::move(config_builder.RegisterRunSetQueryStringHook(*run_query_hook) .SetNumberOfWorkers(1) .Config())); EXPECT_TRUE(udf_client.ok()); @@ -419,9 +445,8 @@ TEST_F(UdfClientTest, JsJSONObjectInWithRunQueryHookSucceeds) { .version = 1, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {R"({"keys":["key1"]})"}); + *request_context_factory_, {R"({"keys":["key1"]})"}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"(["a"])"); @@ -429,7 +454,55 @@ TEST_F(UdfClientTest, JsJSONObjectInWithRunQueryHookSucceeds) { EXPECT_TRUE(stop.ok()); } -TEST_F(UdfClientTest, JsCallsLoggingFunctionSucceeds) { +TEST_F(UdfClientTest, VerifyJsRunSetQueryIntHookSucceeds) { + auto mock_lookup = std::make_unique(); + InternalRunSetQueryIntResponse response; + TextFormat::ParseFromString(R"pb(elements: 1000 elements: 1001)pb", + &response); + ON_CALL(*mock_lookup, RunSetQueryInt(_, _)).WillByDefault(Return(response)); + auto run_query_hook = RunSetQueryIntHook::Create(); + run_query_hook->FinishInit(std::move(mock_lookup)); + UdfConfigBuilder config_builder; + absl::StatusOr> udf_client = UdfClient::Create( + std::move(config_builder.RegisterRunSetQueryIntHook(*run_query_hook) + .SetNumberOfWorkers(1) + .Config())); + EXPECT_TRUE(udf_client.ok()); + absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{ + .js = R"( + function hello(input) { + let keys = input.keys; + let bytes = runSetQueryInt(keys[0]); + if (bytes instanceof Uint8Array) { + return Array.from(new Uint32Array(bytes.buffer)); + } + return "runSetQueryInt failed."; + } + )", + .udf_handler_name = "hello", + .logical_commit_time = 1, + .version = 1, + }); + EXPECT_TRUE(code_obj_status.ok()); + absl::StatusOr result = udf_client.value()->ExecuteCode( + *request_context_factory_, {R"({"keys":["key1"]})"}, execution_metadata_); + ASSERT_TRUE(result.ok()) << result.status(); + EXPECT_EQ(*result, R"([1000,1001])"); + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} + +TEST_F(UdfClientTest, JsCallsLoggingFunctionLogForConsentedRequests) { + std::stringstream log_ss; + auto* logger_provider = + opentelemetry::sdk::logs::LoggerProviderFactory::Create( + opentelemetry::sdk::logs::SimpleLogRecordProcessorFactory::Create( + std::make_unique< + opentelemetry::exporter::logs::OStreamLogRecordExporter>( + log_ss))) + .release(); + privacy_sandbox::server_common::log::logger_private = + logger_provider->GetLogger("test").get(); UdfConfigBuilder config_builder; absl::StatusOr> udf_client = UdfClient::Create(std::move(config_builder.RegisterLoggingFunction() @@ -450,23 +523,72 @@ TEST_F(UdfClientTest, JsCallsLoggingFunctionSucceeds) { .logical_commit_time = 1, .version = 1, }); + privacy_sandbox::server_common::ConsentedDebugConfiguration + consented_debug_configuration; + consented_debug_configuration.set_is_consented(true); + consented_debug_configuration.set_token(kExampleConsentedDebugToken); + privacy_sandbox::server_common::LogContext log_context; + request_context_factory_->UpdateLogContext(log_context, + consented_debug_configuration); EXPECT_TRUE(code_obj_status.ok()); - - absl::ScopedMockLog log; - EXPECT_CALL(log, Log(absl::LogSeverity::kError, testing::_, "Error message")); - EXPECT_CALL(log, - Log(absl::LogSeverity::kWarning, testing::_, "Warning message")); - EXPECT_CALL(log, Log(absl::LogSeverity::kInfo, testing::_, "Info message")); - log.StartCapturingLogs(); - - ScopeMetricsContext metrics_context; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), {R"({"keys":["key1"]})"}); + *request_context_factory_, {R"({"keys":["key1"]})"}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("")"); + auto output_log = log_ss.str(); + EXPECT_THAT(output_log, ContainsRegex("Error message")); + EXPECT_THAT(output_log, ContainsRegex("Warning message")); + EXPECT_THAT(output_log, ContainsRegex("Info message")); - log.StopCapturingLogs(); + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} +TEST_F(UdfClientTest, JsCallsLoggingFunctionNoLogForNonConsentedRequests) { + std::stringstream log_ss; + auto* logger_provider = + opentelemetry::sdk::logs::LoggerProviderFactory::Create( + opentelemetry::sdk::logs::SimpleLogRecordProcessorFactory::Create( + std::make_unique< + opentelemetry::exporter::logs::OStreamLogRecordExporter>( + log_ss))) + .release(); + privacy_sandbox::server_common::log::logger_private = + logger_provider->GetLogger("test").get(); + UdfConfigBuilder config_builder; + absl::StatusOr> udf_client = + UdfClient::Create(std::move(config_builder.RegisterLoggingFunction() + .SetNumberOfWorkers(1) + .Config())); + EXPECT_TRUE(udf_client.ok()); + + absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{ + .js = R"( + function hello(input) { + const a = console.error("Error message"); + const b = console.warn("Warning message"); + const c = console.log("Info message"); + return ""; + } + )", + .udf_handler_name = "hello", + .logical_commit_time = 1, + .version = 1, + }); + privacy_sandbox::server_common::ConsentedDebugConfiguration + consented_debug_configuration; + consented_debug_configuration.set_is_consented(false); + consented_debug_configuration.set_token("mismatch_token"); + privacy_sandbox::server_common::LogContext log_context; + request_context_factory_->UpdateLogContext(log_context, + consented_debug_configuration); + EXPECT_TRUE(code_obj_status.ok()); + absl::StatusOr result = udf_client.value()->ExecuteCode( + *request_context_factory_, {R"({"keys":["key1"]})"}, execution_metadata_); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(*result, R"("")"); + auto output_log = log_ss.str(); + EXPECT_TRUE(output_log.empty()); absl::Status stop = udf_client.value()->Stop(); EXPECT_TRUE(stop.ok()); } @@ -490,9 +612,8 @@ TEST_F(UdfClientTest, UpdatesCodeObjectTwice) { .version = 2, }); EXPECT_TRUE(status.ok()); - ScopeMetricsContext metrics_context; - absl::StatusOr result = - udf_client.value()->ExecuteCode(RequestContext(metrics_context), {}); + absl::StatusOr result = udf_client.value()->ExecuteCode( + *request_context_factory_, {}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("2")"); @@ -519,9 +640,8 @@ TEST_F(UdfClientTest, IgnoresCodeObjectWithSameCommitTime) { .version = 1, }); EXPECT_TRUE(status.ok()); - ScopeMetricsContext metrics_context; - absl::StatusOr result = - udf_client.value()->ExecuteCode(RequestContext(metrics_context), {}); + absl::StatusOr result = udf_client.value()->ExecuteCode( + *request_context_factory_, {}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("1")"); @@ -548,9 +668,8 @@ TEST_F(UdfClientTest, IgnoresCodeObjectWithSmallerCommitTime) { .version = 1, }); EXPECT_TRUE(status.ok()); - ScopeMetricsContext metrics_context; - absl::StatusOr result = - udf_client.value()->ExecuteCode(RequestContext(metrics_context), {}); + absl::StatusOr result = udf_client.value()->ExecuteCode( + *request_context_factory_, {}, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("1")"); @@ -561,9 +680,8 @@ TEST_F(UdfClientTest, IgnoresCodeObjectWithSmallerCommitTime) { TEST_F(UdfClientTest, CodeObjectNotSetError) { auto udf_client = CreateUdfClient(); EXPECT_TRUE(udf_client.ok()); - ScopeMetricsContext metrics_context; - absl::StatusOr result = - udf_client.value()->ExecuteCode(RequestContext(metrics_context), {}); + absl::StatusOr result = udf_client.value()->ExecuteCode( + *request_context_factory_, {}, execution_metadata_); EXPECT_FALSE(result.ok()); EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); @@ -597,16 +715,18 @@ TEST_F(UdfClientTest, MetadataPassedSuccesfully) { "true"); UDFExecutionMetadata udf_metadata; *udf_metadata.mutable_request_metadata() = *req.mutable_metadata(); - ScopeMetricsContext metrics_context; google::protobuf::RepeatedPtrField args; absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), std::move(udf_metadata), args); + *request_context_factory_, std::move(udf_metadata), args, + execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("true")"); UDFExecutionMetadata udf_metadata_non_pas; - result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), std::move(udf_metadata_non_pas), args); + + result = udf_client.value()->ExecuteCode(*request_context_factory_, + std::move(udf_metadata_non_pas), + args, execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"("false")"); absl::Status stop = udf_client.value()->Stop(); @@ -639,7 +759,6 @@ TEST_F(UdfClientTest, DefaultUdfPASucceeds) { .version = kDefaultVersion, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; UDFExecutionMetadata udf_metadata; google::protobuf::RepeatedPtrField args; args.Add([] { @@ -656,7 +775,8 @@ TEST_F(UdfClientTest, DefaultUdfPASucceeds) { return arg; }()); absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), std::move(udf_metadata), args); + *request_context_factory_, std::move(udf_metadata), args, + execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ( *result, @@ -685,7 +805,6 @@ TEST_F(UdfClientTest, DefaultUdfPasKeyLookupFails) { .version = kDefaultVersion, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; v2::GetValuesRequest req; (*(req.mutable_metadata()->mutable_fields()))["is_pas"].set_string_value( "true"); @@ -706,7 +825,8 @@ TEST_F(UdfClientTest, DefaultUdfPasKeyLookupFails) { return arg; }()); absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), std::move(udf_metadata), args); + *request_context_factory_, std::move(udf_metadata), args, + execution_metadata_); EXPECT_FALSE(result.ok()); absl::Status stop = udf_client.value()->Stop(); EXPECT_TRUE(stop.ok()); @@ -737,7 +857,6 @@ TEST_F(UdfClientTest, DefaultUdfPasSucceeds) { .version = kDefaultVersion, }); EXPECT_TRUE(code_obj_status.ok()); - ScopeMetricsContext metrics_context; v2::GetValuesRequest req; (*(req.mutable_metadata()->mutable_fields()))["is_pas"].set_string_value( "true"); @@ -758,7 +877,8 @@ TEST_F(UdfClientTest, DefaultUdfPasSucceeds) { return arg; }()); absl::StatusOr result = udf_client.value()->ExecuteCode( - RequestContext(metrics_context), std::move(udf_metadata), args); + *request_context_factory_, std::move(udf_metadata), args, + execution_metadata_); EXPECT_TRUE(result.ok()); EXPECT_EQ(*result, R"({"key1":{"value":"value1"}})"); absl::Status stop = udf_client.value()->Stop(); diff --git a/components/udf/udf_config_builder.cc b/components/udf/udf_config_builder.cc index e6bc3fbe..6ba2e3c8 100644 --- a/components/udf/udf_config_builder.cc +++ b/components/udf/udf_config_builder.cc @@ -38,15 +38,17 @@ using google::scp::roma::FunctionBindingPayload; constexpr char kStringGetValuesHookJsName[] = "getValues"; constexpr char kBinaryGetValuesHookJsName[] = "getValuesBinary"; constexpr char kRunQueryHookJsName[] = "runQuery"; +constexpr char kRunSetQueryIntHookJsName[] = "runSetQueryInt"; -std::unique_ptr> +std::unique_ptr>> GetValuesFunctionObject(GetValuesHook& get_values_hook, std::string handler_name) { - auto get_values_function_object = - std::make_unique>(); + auto get_values_function_object = std::make_unique< + FunctionBindingObjectV2>>(); get_values_function_object->function_name = std::move(handler_name); get_values_function_object->function = - [&get_values_hook](FunctionBindingPayload& in) { + [&get_values_hook]( + FunctionBindingPayload>& in) { get_values_hook(in); }; return get_values_function_object; @@ -68,19 +70,34 @@ UdfConfigBuilder& UdfConfigBuilder::RegisterBinaryGetValuesHook( return *this; } -UdfConfigBuilder& UdfConfigBuilder::RegisterRunQueryHook( - RunQueryHook& run_query_hook) { - auto run_query_function_object = - std::make_unique>(); +UdfConfigBuilder& UdfConfigBuilder::RegisterRunSetQueryStringHook( + RunSetQueryStringHook& run_query_hook) { + auto run_query_function_object = std::make_unique< + FunctionBindingObjectV2>>(); run_query_function_object->function_name = kRunQueryHookJsName; run_query_function_object->function = - [&run_query_hook](FunctionBindingPayload& in) { + [&run_query_hook]( + FunctionBindingPayload>& in) { run_query_hook(in); }; config_.RegisterFunctionBinding(std::move(run_query_function_object)); return *this; } +UdfConfigBuilder& UdfConfigBuilder::RegisterRunSetQueryIntHook( + RunSetQueryIntHook& run_set_query_int_hook) { + auto run_query_function_object = std::make_unique< + FunctionBindingObjectV2>>(); + run_query_function_object->function_name = kRunSetQueryIntHookJsName; + run_query_function_object->function = + [&run_set_query_int_hook]( + FunctionBindingPayload>& in) { + run_set_query_int_hook(in); + }; + config_.RegisterFunctionBinding(std::move(run_query_function_object)); + return *this; +} + UdfConfigBuilder& UdfConfigBuilder::RegisterLoggingFunction() { config_.SetLoggingFunction(LoggingFunction); return *this; @@ -92,7 +109,8 @@ UdfConfigBuilder& UdfConfigBuilder::SetNumberOfWorkers( return *this; } -google::scp::roma::Config& UdfConfigBuilder::Config() { +google::scp::roma::Config>& +UdfConfigBuilder::Config() { return config_; } diff --git a/components/udf/udf_config_builder.h b/components/udf/udf_config_builder.h index 9ad033fb..e6bcd965 100644 --- a/components/udf/udf_config_builder.h +++ b/components/udf/udf_config_builder.h @@ -27,15 +27,19 @@ class UdfConfigBuilder { UdfConfigBuilder& RegisterBinaryGetValuesHook(GetValuesHook& get_values_hook); - UdfConfigBuilder& RegisterRunQueryHook(RunQueryHook& run_query_hook); + UdfConfigBuilder& RegisterRunSetQueryStringHook( + RunSetQueryStringHook& run_query_hook); + + UdfConfigBuilder& RegisterRunSetQueryIntHook( + RunSetQueryIntHook& run_set_query_int_hook); UdfConfigBuilder& RegisterLoggingFunction(); UdfConfigBuilder& SetNumberOfWorkers(int number_of_workers); - google::scp::roma::Config& Config(); + google::scp::roma::Config>& Config(); private: - google::scp::roma::Config config_; + google::scp::roma::Config> config_; }; } // namespace kv_server diff --git a/components/util/BUILD.bazel b/components/util/BUILD.bazel index a017a0dd..a1423ce6 100644 --- a/components/util/BUILD.bazel +++ b/components/util/BUILD.bazel @@ -74,6 +74,10 @@ local_defines = select({ ":local_otel_ostream": ["OTEL_EXPORT=\\\"ostream\\\""], ":local_otel_otlp": ["OTEL_EXPORT=\\\"otlp\\\""], "//conditions:default": ["OTEL_EXPORT=\\\"unknown\\\""], +}) + select({ + "//:nonprod_mode": ["MODE=\\\"non_prod\\\""], + "//:prod_mode": ["MODE=\\\"prod\\\""], + "//conditions:default": ["MODE=\\\"unknown\\\""], }) genrule( @@ -126,6 +130,9 @@ cc_library( copts = select({ "//:gcp_platform": ["-DCLOUD_PLATFORM_GCP=1"], "//conditions:default": [], + }) + select({ + "//:local_instance": ["-DINSTANCE_LOCAL=1"], + "//conditions:default": [], }), visibility = [ "//components:__subpackages__", @@ -136,6 +143,7 @@ cc_library( "//:aws_platform": [ "//components/errors:aws_error_util", "@aws_sdk_cpp//:core", + "@google_privacysandbox_servers_common//src/public/core/interface:execution_result", "@google_privacysandbox_servers_common//src/public/cpio/interface:cpio", ], "//:gcp_platform": [ @@ -209,5 +217,16 @@ cc_library( ], deps = [ "//components/telemetry:server_definition", + "@google_privacysandbox_servers_common//src/logger:request_context_impl", + ], +) + +cc_library( + name = "safe_path_log_context", + hdrs = [ + "safe_path_log_context.h", + ], + deps = [ + "@google_privacysandbox_servers_common//src/logger:request_context_impl", ], ) diff --git a/components/util/build_flags.cc b/components/util/build_flags.cc index 2b35c370..46bad137 100644 --- a/components/util/build_flags.cc +++ b/components/util/build_flags.cc @@ -17,6 +17,7 @@ namespace kv_server { const std::string_view kVersionBuildFlavor = - "instance:" INSTANCE " platform:" PLATFORM " otel_export:" OTEL_EXPORT; + "instance:" INSTANCE " platform:" PLATFORM " otel_export:" OTEL_EXPORT + " mode:" MODE; } // namespace kv_server diff --git a/components/util/platform_initializer_aws.cc b/components/util/platform_initializer_aws.cc index 1f2fc025..babe87be 100644 --- a/components/util/platform_initializer_aws.cc +++ b/components/util/platform_initializer_aws.cc @@ -12,12 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + +#include "absl/flags/flag.h" #include "absl/log/log.h" #include "aws/core/Aws.h" #include "components/util/platform_initializer.h" +#include "src/public/core/interface/execution_result.h" #include "src/public/cpio/interface/cpio.h" namespace kv_server { +using google::scp::core::errors::GetErrorMessage; using google::scp::cpio::Cpio; using google::scp::cpio::CpioOptions; using google::scp::cpio::LogOption; @@ -36,16 +41,28 @@ PlatformInitializer::PlatformInitializer() { cpio_options_.log_option = LogOption::kConsoleLog; cpio_options_.cloud_init_option = google::scp::cpio::CloudInitOption::kNoInitInCpio; - auto result = Cpio::InitCpio(cpio_options_); - if (!result.Successful()) { - LOG(ERROR) << "Failed to initialize CPIO." << std::endl; + +// TODO(b/338206801): Remove this aws region logic for aws local instance once +// it's fixed on the CPIO side. +#if defined(INSTANCE_LOCAL) + if (std::string aws_region = std::getenv("AWS_DEFAULT_REGION"); + aws_region.empty()) { + LOG(WARNING) << "Failed to get environment variable 'AWS_DEFAULT_REGION' " + "for PlatformInitializer."; + } else { + cpio_options_.region = aws_region; + } +#endif + if (auto error = Cpio::InitCpio(cpio_options_); !error.Successful()) { + LOG(ERROR) << "Failed to initialize CPIO: " + << GetErrorMessage(error.status_code) << std::endl; } } PlatformInitializer::~PlatformInitializer() { - auto result = Cpio::ShutdownCpio(cpio_options_); - if (!result.Successful()) { - LOG(ERROR) << "Failed to shutdown CPIO." << std::endl; + if (auto error = Cpio::ShutdownCpio(cpio_options_); !error.Successful()) { + LOG(ERROR) << "Failed to shutdown CPIO: " + << GetErrorMessage(error.status_code) << std::endl; } Aws::ShutdownAPI(options_); } diff --git a/components/util/platform_initializer_gcp.cc b/components/util/platform_initializer_gcp.cc index e15be624..b695735d 100644 --- a/components/util/platform_initializer_gcp.cc +++ b/components/util/platform_initializer_gcp.cc @@ -43,17 +43,17 @@ google::scp::cpio::CpioOptions cpio_options_; PlatformInitializer::PlatformInitializer() { cpio_options_.log_option = LogOption::kConsoleLog; cpio_options_.project_id = absl::GetFlag(FLAGS_gcp_project_id); - auto execution_result = Cpio::InitCpio(cpio_options_); - CHECK(execution_result.Successful()) - << "Failed to initialize CPIO: " - << GetErrorMessage(execution_result.status_code); + if (auto error = Cpio::InitCpio(cpio_options_); !error.Successful()) { + LOG(ERROR) << "Failed to initialize CPIO: " + << GetErrorMessage(error.status_code) << std::endl; + } } PlatformInitializer::~PlatformInitializer() { - auto execution_result = Cpio::ShutdownCpio(cpio_options_); - if (!execution_result.Successful()) { + if (auto error = Cpio::ShutdownCpio(cpio_options_); !error.Successful()) { LOG(ERROR) << "Failed to shutdown CPIO: " - << GetErrorMessage(execution_result.status_code); + << GetErrorMessage(error.status_code) << std::endl; } } + } // namespace kv_server diff --git a/components/util/request_context.cc b/components/util/request_context.cc index aec75902..6531b4fc 100644 --- a/components/util/request_context.cc +++ b/components/util/request_context.cc @@ -19,6 +19,11 @@ #include "components/telemetry/server_definition.h" namespace kv_server { +namespace { +constexpr char kGenerationId[] = "generationId"; +constexpr char kAdtechDebugId[] = "adtechDebugId"; +constexpr char kDefaultConsentedGenerationId[] = "consented"; +} // namespace UdfRequestMetricsContext& RequestContext::GetUdfRequestMetricsContext() const { return udf_request_metrics_context_; @@ -27,5 +32,87 @@ InternalLookupMetricsContext& RequestContext::GetInternalLookupMetricsContext() const { return internal_lookup_metrics_context_; } +RequestLogContext& RequestContext::GetRequestLogContext() const { + return *request_log_context_; +} +void RequestContext::UpdateLogContext( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config) { + request_log_context_ = + std::make_unique(log_context, consented_debug_config); + if (request_log_context_->GetRequestLoggingContext().is_consented()) { + const std::string generation_id = + request_log_context_->GetLogContext().generation_id().empty() + ? kDefaultConsentedGenerationId + : request_log_context_->GetLogContext().generation_id(); + udf_request_metrics_context_.SetConsented(generation_id); + internal_lookup_metrics_context_.SetConsented(generation_id); + } +} +RequestContext::RequestContext( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config, + std::string request_id) + : request_id_(request_id), + udf_request_metrics_context_(KVServerContextMap()->Get(&request_id_)), + internal_lookup_metrics_context_( + InternalLookupServerContextMap()->Get(&request_id_)) { + request_log_context_ = + std::make_unique(log_context, consented_debug_config); +} + +privacy_sandbox::server_common::log::ContextImpl<>& +RequestContext::GetPSLogContext() const { + return request_log_context_->GetRequestLoggingContext(); +} + +RequestLogContext::RequestLogContext( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config) + : log_context_(log_context), + consented_debug_config_(consented_debug_config), + request_logging_context_(GetContextMap(log_context), + consented_debug_config) {} + +privacy_sandbox::server_common::log::ContextImpl<>& +RequestLogContext::GetRequestLoggingContext() { + return request_logging_context_; +} +const privacy_sandbox::server_common::LogContext& +RequestLogContext::GetLogContext() const { + return log_context_; +} +const privacy_sandbox::server_common::ConsentedDebugConfiguration& +RequestLogContext::GetConsentedDebugConfiguration() const { + return consented_debug_config_; +} +absl::btree_map RequestLogContext::GetContextMap( + const privacy_sandbox::server_common::LogContext& log_context) { + return {{kGenerationId, log_context.generation_id()}, + {kAdtechDebugId, log_context.adtech_debug_id()}}; +} +RequestContextFactory::RequestContextFactory( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config) { + request_context_ = + std::make_shared(log_context, consented_debug_config); +} +std::weak_ptr RequestContextFactory::GetWeakCopy() const { + return request_context_; +} +const RequestContext& RequestContextFactory::Get() const { + return *request_context_; +} + +void RequestContextFactory::UpdateLogContext( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config) { + request_context_->UpdateLogContext(log_context, consented_debug_config); +} } // namespace kv_server diff --git a/components/util/request_context.h b/components/util/request_context.h index b58036c2..07fc8c80 100644 --- a/components/util/request_context.h +++ b/components/util/request_context.h @@ -22,29 +22,123 @@ #include #include "components/telemetry/server_definition.h" +#include "src/logger/request_context_impl.h" namespace kv_server { -// RequestContext holds the reference of udf request metrics context and -// internal lookup request context that ties to a single -// request. The request_id can be either passed from upper stream or assigned -// from uuid generated when RequestContext is constructed. +// RequestLogContext holds value of LogContext and ConsentedDebugConfiguration +// passed from the upstream application. +class RequestLogContext { + public: + explicit RequestLogContext( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config); + + privacy_sandbox::server_common::log::ContextImpl<>& + GetRequestLoggingContext(); + + const privacy_sandbox::server_common::LogContext& GetLogContext() const; + + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + GetConsentedDebugConfiguration() const; + + private: + // Parses the LogContext to btree map which is used to construct request + // logging context and used as labels for logging messages + absl::btree_map GetContextMap( + const privacy_sandbox::server_common::LogContext& log_context); + const privacy_sandbox::server_common::LogContext log_context_; + const privacy_sandbox::server_common::ConsentedDebugConfiguration + consented_debug_config_; + privacy_sandbox::server_common::log::ContextImpl<> request_logging_context_; +}; + +// RequestContext holds the reference of udf request metrics context, +// internal lookup request context, and request log context +// that ties to a single request. The request_id can be either passed +// from upper stream or assigned from uuid generated when +// RequestContext is constructed. class RequestContext { public: - explicit RequestContext(const ScopeMetricsContext& metrics_context) - : udf_request_metrics_context_( - metrics_context.GetUdfRequestMetricsContext()), - internal_lookup_metrics_context_( - metrics_context.GetInternalLookupMetricsContext()) {} + explicit RequestContext( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config, + std::string request_id = google::scp::core::common::ToString( + google::scp::core::common::Uuid::GenerateUuid())); + RequestContext() + : RequestContext( + privacy_sandbox::server_common::LogContext(), + privacy_sandbox::server_common::ConsentedDebugConfiguration()) {} + // Updates request log context with the new log context and consented debug + // configuration. This function is typically called after RequestContext is + // created and the consented debugging information is available after request + // is decrypted. + void UpdateLogContext( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config); UdfRequestMetricsContext& GetUdfRequestMetricsContext() const; InternalLookupMetricsContext& GetInternalLookupMetricsContext() const; + RequestLogContext& GetRequestLogContext() const; + privacy_sandbox::server_common::log::ContextImpl<>& GetPSLogContext() const; - ~RequestContext() = default; + ~RequestContext() { + // Remove the metrics context for request_id, This is to ensure that + // metrics context has the same lifetime with RequestContext and be + // destroyed when RequestContext goes out of scope. + LogIfError(KVServerContextMap()->Remove(&request_id_), + "When removing Udf request metrics context"); + LogIfError(InternalLookupServerContextMap()->Remove(&request_id_), + "When removing internal lookup metrics context"); + } private: + const std::string request_id_; UdfRequestMetricsContext& udf_request_metrics_context_; InternalLookupMetricsContext& internal_lookup_metrics_context_; + std::unique_ptr request_log_context_; +}; + +// Class that facilitates the passing around of request context to +// public interfaces while hiding the implementation details like wrapping +// the request context with a smart pointer +class RequestContextFactory { + public: + RequestContextFactory() + : RequestContextFactory( + privacy_sandbox::server_common::LogContext(), + privacy_sandbox::server_common::ConsentedDebugConfiguration()) {} + explicit RequestContextFactory( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config); + // Returns a weak pointer of the RequestContext. This function should only be + // used to pass a weakly shared ownership of the RequestContext, it should not + // be used to get access to the RequestContext. + std::weak_ptr GetWeakCopy() const; + // Provide access to RequestContext via const reference + const RequestContext& Get() const; + // Updates request log context with the new log context and consented debug + // configuration. This function is typically called after RequestContext is + // created and the consented debugging information is available after request + // is decrypted. + void UpdateLogContext( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config); + // Not movable and copyable to prevent making unnecessary + // copies of underlying shared_ptr of request context, and moving of + // shared ownership of request context + RequestContextFactory(RequestContextFactory&& other) = delete; + RequestContextFactory& operator=(RequestContextFactory&& other) = delete; + RequestContextFactory(const RequestContextFactory&) = delete; + RequestContextFactory& operator=(const RequestContextFactory&) = delete; + + private: + std::shared_ptr request_context_; }; } // namespace kv_server diff --git a/components/util/safe_path_log_context.h b/components/util/safe_path_log_context.h new file mode 100644 index 00000000..52a7f811 --- /dev/null +++ b/components/util/safe_path_log_context.h @@ -0,0 +1,36 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_UTIL_SAFE_PATH_LOG_CONTEXT_H_ +#define COMPONENTS_UTIL_SAFE_PATH_LOG_CONTEXT_H_ + +#include "src/logger/request_context_impl.h" + +namespace kv_server { + +// Token that allows otel logging for safe code execution path +class KVServerSafeLogContext + : public privacy_sandbox::server_common::log::SafePathContext { + ~KVServerSafeLogContext() override = default; + + private: + KVServerSafeLogContext() = default; + friend class Server; +}; + +} // namespace kv_server + +#endif // COMPONENTS_UTIL_SAFE_PATH_LOG_CONTEXT_H_ diff --git a/docs/AWS_Terraform_vars.md b/docs/AWS_Terraform_vars.md index cdb80321..7892310c 100644 --- a/docs/AWS_Terraform_vars.md +++ b/docs/AWS_Terraform_vars.md @@ -28,6 +28,12 @@ If you want to import an existing public certificate into ACM, follow these steps to [import the certificate](https://docs.aws.amazon.com/acm/latest/userguide/import-certificate.html). +- **consented_debug_token** + + Consented debug token to enable the otel collection of consented logs. Empty token means no-op + and no logs will be collected for consented requests. The token in the request's consented debug + configuration needs to match this debug token to make the server treat the request as consented. + - **data_loading_blob_prefix_allowlist** A comma separated list of prefixes (i.e., directories) where data is loaded from. @@ -42,6 +48,16 @@ the number of concurrent threads used to read and load a single delta or snapshot file from blob storage. +- **enable_consented_log** + + Enable the logging of consented requests. If it is set to true, the consented debug token + parameter value must not be an empty string. + +- **enable_external_traffic** + + Whether to serve external traffic. If disabled, only internal traffic under existing VPC will be + served. + - **enclave_cpu_count** Set how many CPUs the server will use. @@ -63,6 +79,19 @@ strings like `staging` and `prod` can be used to represent the environment that the Key/Value server will run in. +- **existing_vpc_environment** + + Environment of the existing VPC. Ingored if use_existing_vpc is false. + +- **existing_vpc_operator** + + Operator of the existing VPC. Ingored if use_existing_vpc is false. + +- **healthcheck_grace_period_sec** + + Amount of time to wait for service inside enclave to start up before starting health checks, in + seconds. + - **healthcheck_healthy_threshold** Consecutive health check successes required to be considered healthy @@ -71,6 +100,10 @@ Amount of time between health check intervals in seconds. +- **healthcheck_timeout_sec** + + Amount of time to wait for a health check response in seconds. + - **healthcheck_unhealthy_threshold** Consecutive health check failures required to be considered unhealthy. @@ -233,6 +266,15 @@ Total number of workers for UDF execution +- **udf_update_timeout_millis** + + UDF update timeout in milliseconds. Default is 30000. + +- **use_existing_vpc** + + Whether to use existing VPC. If true, only internal traffic via mesh will be served; variable + vpc_operator and vpc_environment will be requried. + - **use_external_metrics_collector_endpoint** Whether to use external metrics collector endpoint. For AWS it is false because KV instance diff --git a/docs/GCP_Terraform_vars.md b/docs/GCP_Terraform_vars.md index 3ad47098..c98ac831 100644 --- a/docs/GCP_Terraform_vars.md +++ b/docs/GCP_Terraform_vars.md @@ -28,6 +28,12 @@ The grpc port that receives traffic destined for the OpenTelemetry collector +- **consented_debug_token** + + Consented debug token to enable the otel collection of consented logs. Empty token means no-op + and no logs will be collected for consented requests. The token in the request's consented debug + configuration needs to match this debug token to make the server treat the request as consented. + - **cpu_utilization_percent** CPU utilization percentage across an instance group required for autoscaler to add instances. @@ -44,6 +50,11 @@ Number of parallel threads for reading and loading data files. +- **enable_consented_log** + + Enable the logging of consented requests. If it is set to true, the consented debug token + parameter value must not be an empty string. + - **enable_external_traffic** Whether to serve external traffic. If disabled, only internal traffic via service mesh will be @@ -215,6 +226,10 @@ Number of workers for UDF execution. +- **udf_update_timeout_millis** + + UDF update timeout in milliseconds. Default is 30000. + - **use_confidential_space_debug_image** If true, use the Confidential space debug image. Else use the prod image, which does not allow diff --git a/docs/assets/aws_nonsharded_init.png b/docs/assets/aws_nonsharded_init.png deleted file mode 100644 index bf178389..00000000 Binary files a/docs/assets/aws_nonsharded_init.png and /dev/null differ diff --git a/docs/assets/aws_nonsharded_init.svg b/docs/assets/aws_nonsharded_init.svg new file mode 100644 index 00000000..fe1b2858 --- /dev/null +++ b/docs/assets/aws_nonsharded_init.svg @@ -0,0 +1 @@ +Created with Raphaël 2.2.0KV server AWS Non sharded initializationLoad balancerLoad balancerAuto scaling groupAuto scaling groupEC2EC2KV ServerKV ServerLifecycle heartbeatLifecycle heartbeatGRPC serverGRPC serverStartNew EC2 addedEC2 is in Pending state because of the launch hook(Periodic) Are you InService?(Periodic) Are you InService?StartStartPeriodic heartbeatKeeps EC2 in Pending stateStartHealth checks availableLoad data Can take a few minutesPeriodic heartbeatDoneDoneTransition to InService statePeriodic healthchecksStart serving traffic from EC2Periodic healthchecks diff --git a/docs/assets/aws_sharded_init.png b/docs/assets/aws_sharded_init.png deleted file mode 100644 index 9b1c1dfb..00000000 Binary files a/docs/assets/aws_sharded_init.png and /dev/null differ diff --git a/docs/assets/aws_sharded_init.svg b/docs/assets/aws_sharded_init.svg new file mode 100644 index 00000000..700f0639 --- /dev/null +++ b/docs/assets/aws_sharded_init.svg @@ -0,0 +1 @@ +Created with Raphaël 2.2.0KV server AWS Sharded initializationLoad balancerLoad balancerAutoscaling groupAutoscaling groupEC2EC2KV ServerKV ServerLifecycle heartbeatLifecycle heartbeatGRPC serverGRPC serverAutoscaling groups ManagerAutoscaling groups ManagerStartNew EC2 addedEC2 is in Pending state because of the launch hook(Periodic) Are you InService?(Periodic) Are you InService?StartStartPeriodic heartbeatKeeps EC2 in Pending stateStartHealth checks availableLoad data Can take a few minutesPeriodic heartbeatDoneDoneTransition to InService statePeriodic healthchecksStart serving traffic from EC2Periodic healthchecksServed requets return Unitialized(periodic) Get Shards MappingWaiting on at least 1 replica marked as InService for each shard.Proper Shards MappingServe requests diff --git a/docs/assets/gcp_nonsharded_init.png b/docs/assets/gcp_nonsharded_init.png deleted file mode 100644 index 5e62b893..00000000 Binary files a/docs/assets/gcp_nonsharded_init.png and /dev/null differ diff --git a/docs/assets/gcp_nonsharded_init.svg b/docs/assets/gcp_nonsharded_init.svg new file mode 100644 index 00000000..c9b433e4 --- /dev/null +++ b/docs/assets/gcp_nonsharded_init.svg @@ -0,0 +1 @@ +Created with Raphaël 2.2.0KV server Gcp Non Sharded initializationMeshMeshInstance GroupInstance GroupVMVMKV ServerKV ServerGRPC serverGRPC serverStartNew VM addedVM is in Running stateStartStartHealth checks available. IG health check - Serving LB health check - Not ServingPeriodic healthchecksServingVM is healthy I won't kill it.Periodic healthchecksNot servingLoad data Can take a few minutesDoneSet LB health check as ServingPeriodic healthchecksServingStart serving traffic from VM diff --git a/docs/assets/gcp_sharded_init.png b/docs/assets/gcp_sharded_init.png deleted file mode 100644 index 2553552f..00000000 Binary files a/docs/assets/gcp_sharded_init.png and /dev/null differ diff --git a/docs/assets/gcp_sharded_init.svg b/docs/assets/gcp_sharded_init.svg new file mode 100644 index 00000000..56584f3a --- /dev/null +++ b/docs/assets/gcp_sharded_init.svg @@ -0,0 +1 @@ +Created with Raphaël 2.2.0KV server Gcp Sharded initializationMeshMeshInstance GroupInstance GroupVMVMKV ServerKV ServerGRPC serverGRPC serverInstance Groups ManagerInstance Groups ManagerStartNew VM addedVM is in Running stateStartStartHealth checks available. IG health check - Serving LB health check - Not ServingPeriodic healthchecksServingVM is healthy I won't kill it.Periodic healthchecksNot servingLoad data Can take a few minutesTag as Initialized(periodic) Get Running VMs tagged as InitializedWaiting for other VMs Need at least 1 initialized replica per shardReturn VM list that satisfies the above requirementDoneSet LB health check as ServingPeriodic healthchecksServingStart serving traffic from VM diff --git a/docs/assets/query_visualization.png b/docs/assets/query_visualization.png new file mode 100644 index 00000000..093ab975 Binary files /dev/null and b/docs/assets/query_visualization.png differ diff --git a/docs/build_flavor.md b/docs/build_flavor.md new file mode 100644 index 00000000..1812069b --- /dev/null +++ b/docs/build_flavor.md @@ -0,0 +1,24 @@ +# Build flavors + +## prod_mode vs nonprod_mode + +KV server now has two build flavors: prod_mode for running server in TEE against production traffic; +nonprod_mode for running server in local machine and TEE for testing and debugging purposes against +non-production traffic. The default build mode is prod_mode. To build nonprod_mode, the flag +"--config=nonprod_mode" needs to be provided to the bazel build command. + +### prod_mode: + +- Can attest with production key management system thus can decrypt production traffic +- Console logs will be disabled. Logs in safe code path(non-request related) such as data loading + and logs for consented requests will be exported by open-telemetry logger and published to AWS + Cloudwatch or GCP Logs Explorer. +- Metrics will be published as unnoised for consented requests and noised for non-consented + requests. + +### nonprod_mode + +- Cannot attest with production key management system thus cannot decrypt production traffic +- Console logs will be enabled, and additionally all logs will be published to AWS Cloudwatch or + GCP Logs Explorer if telemetry is configured on from server parameter. +- All metrics will be published as unnoised. diff --git a/docs/data_loading/data_loading_capabilities.md b/docs/data_loading/data_loading_capabilities.md index e0d9450d..bcbd37f6 100644 --- a/docs/data_loading/data_loading_capabilities.md +++ b/docs/data_loading/data_loading_capabilities.md @@ -20,10 +20,10 @@ parameters: The data loading benchmark tool can be used to search for optimal [tuning parameters](#tuning-parameters) that are best suited to specific hardware, memory and network specs. To build the benchmarking tool for AWS use the following command (note the -`--//:platform=aws` build flag): +`--config=aws_platform` build flag): ```sh -builders/tools/bazel-debian run //production/packaging/tools:copy_to_dist --//:instance=local --//:platform=aws +builders/tools/bazel-debian run //production/packaging/tools:copy_to_dist --config=local_instance --config=aws_platform ``` After building, load the tool into docker as follows: diff --git a/docs/data_loading/loading_data.md b/docs/data_loading/loading_data.md index deebff18..867cfee6 100644 --- a/docs/data_loading/loading_data.md +++ b/docs/data_loading/loading_data.md @@ -45,7 +45,7 @@ Confirm that the sample data file `DELTA_\d{16}` has been generated. The data CLI is located under: `//tools/data_cli`. First build the cli using the following command: ```sh --$ builders/tools/bazel-debian run //production/packaging/tools:copy_to_dist --//:instance=local --//:platform=local +-$ builders/tools/bazel-debian run //production/packaging/tools:copy_to_dist --config=local_instance --config=local_platform ``` After building, the cli will be packaged into a docker image tar file under diff --git a/docs/deployment/deploying_locally.md b/docs/deployment/deploying_locally.md index 486090e9..d0e1ce03 100644 --- a/docs/deployment/deploying_locally.md +++ b/docs/deployment/deploying_locally.md @@ -43,8 +43,9 @@ From the Key/Value server repo folder, execute the following command: ```sh ./builders/tools/bazel-debian build //components/data_server/server:server \ - --//:platform=local \ - --//:instance=local + --config=local_instance \ + --config=local_platform \ + --config=nonprod_mode ``` ## Generate UDF delta file @@ -75,12 +76,39 @@ their contents on startup and continue to watch them while it is running. ```sh ./bazel-bin/components/data_server/server/server \ --delta_directory=/tmp/deltas \ - --realtime_directory=/tmp/realtime --v=4 --stderrthreshold=0 + --realtime_directory=/tmp/realtime \ + --logging_verbosity_level=4 --stderrthreshold=0 ``` The server will start up and begin listening for new delta and realtime files in the directories provided. +# Build and run Key/Value server locally in docker + +You can also build and run KV server locally in docker. + +## Build the docker image + +```sh +builders/tools/bazel-debian run //production/packaging/local/data_server:copy_to_dist \ + --config=local_instance --config=local_platform --config=nonprod_mode +``` + +## Load the image + +```sh +docker load -i dist/server_docker_image.tar +``` + +## Run the server in docker + +```sh +docker run -it --network=host -entrypoint=/server --init --rm \ +--volume=/tmp/deltas:/tmp/deltas --volume=/tmp/realtime:/tmp/realtime \ +--security-opt=seccomp=unconfined bazel/production/packaging/local/data_server:server_docker_image \ +--port 50051 -stderrthreshold=0 -delta_directory=/tmp/deltas -realtime_directory=/tmp/realtime +``` + # Common operations ## Query the server diff --git a/docs/deployment/deploying_on_aws.md b/docs/deployment/deploying_on_aws.md index 83d0e38b..f7f4818a 100644 --- a/docs/deployment/deploying_on_aws.md +++ b/docs/deployment/deploying_on_aws.md @@ -10,7 +10,7 @@ To learn more about FLEDGE and the Key/Value server, take a look at the followin - [FLEDGE Key/Value server explainer](https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md) - [FLEDGE Key/Value server trust model](https://github.com/privacysandbox/fledge-docs/blob/main/key_value_service_trust_model.md) -- [FLEDGE explainer](https://developer.chrome.com/en/docs/privacy-sandbox/fledge/) +- [FLEDGE explainer](https://developer.chrome.com/en/docs/privacy-sandbox/protected-audience/) - [FLEDGE API developer guide](https://developer.chrome.com/blog/fledge-api/) > The instructions written in this document are for running a test Key/Value server that does @@ -90,19 +90,27 @@ run into any Docker access errors, follow the instructions for ## Get the source code from GitHub The code for the FLEDGE Key/Value server is released on -[GitHub](https://github.com/privacysandbox/fledge-key-value-service). +[GitHub](https://github.com/privacysandbox/protected-auction-key-value-service). The main branch is under active development. For a more stable experience, please use the -[latest release branch](https://github.com/privacysandbox/fledge-key-value-service/releases). +[latest release branch](https://github.com/privacysandbox/protected-auction-key-value-service/releases). ## Build the Amazon Machine Image (AMI) From the Key/Value server repo folder, execute the following command: +prod_mode (default mode) + ```sh production/packaging/aws/build_and_test --with-ami us-east-1 --with-ami us-west-1 ``` +nonprod_mode + +```sh +production/packaging/aws/build_and_test --with-ami us-east-1 --with-ami us-west-1 --mode nonprod +``` + The script will build the Enclave Image File (EIF), store it in an AMI, and upload the AMI. If the build is successful, you will see an output similar to: @@ -388,12 +396,16 @@ You should see an output similar to the following: ## Read the server log -Most recent server logs can be read by executing the following command: +Most recent server (`nonprod_mode`) console logs can be read by executing the following command: ```sh ENCLAVE_ID=$(nitro-cli describe-enclaves | jq -r ".[0].EnclaveID"); [ "$ENCLAVE_ID" != "null" ] && nitro-cli console --enclave-id ${ENCLAVE_ID} ``` +If `enable_otel_logger` parameter is set to true, KV server also exports server logs to Cloudwatch +via otel collector, located at Cloudwatch log group `kv-server-log-group` More details about logging +in `prod mode` and `nonprod mode` in ![developing the server](/docs/developing_the_server.md). + ## Start the server If you have shutdown your server for any reason, you can start the Key/Value server by executing the diff --git a/docs/deployment/deploying_on_gcp.md b/docs/deployment/deploying_on_gcp.md index e258f70c..f3eff5c0 100644 --- a/docs/deployment/deploying_on_gcp.md +++ b/docs/deployment/deploying_on_gcp.md @@ -105,10 +105,18 @@ The main branch is under active development. For a more stable experience, pleas From the Key/Value server repo folder, execute the following command: +prod_mode(default mode) + ```sh ./production/packaging/gcp/build_and_test ``` +nonprod_mode + +```sh +./production/packaging/gcp/build_and_test --mode nonprod +``` + This script will build a Docker image and store it locally under `/dist`. Before uploading the image to your docker image repo, please specify an image tag for the docker @@ -239,7 +247,13 @@ listed to the right. Instances associated with your Kv-server have the name star In the instance details page, under `Logs`, you can access server logs in both `Logging` and `Serial port (console)`. The `Logging` option is more powerful with better filtering and query -support. +support on `Logs Explorer`. + +The console log is located under resource type `VM Instance`. When server is running in prod mode, +the console log will not be available. However, if parameter `enable_otel_logger` is set to true, KV +server will export selective server logs to `Logs Explorer` under resource type `Generic Task`. More +details about logging in `prod mode` and `nonprod mode` in +![developing the server](/docs/developing_the_server.md). ![how to access GCP instance logs](../assets/gcp_instance_logs.png) diff --git a/docs/developing_the_server.md b/docs/developing_the_server.md index ad636737..ac8ad0b1 100644 --- a/docs/developing_the_server.md +++ b/docs/developing_the_server.md @@ -3,6 +3,27 @@ # FLEDGE K/V Server developer guide +## Build flavors: prod_mode vs nonprod_mode + +KV server now has two build flavors: more details in this [build flavor doc](/docs/build_flavor.md). + +## Otel logging for consented requests + +To turn on otel logging for consented requests in the server prod mode, the following parameters +(terraform variables) should be set: + +`enable_consented_log`: true\ +`consented_debug_token`: non-empty string + +To make a request consented, the ConsentedDebugConfiguration proto in the +[V2 request API](/public/query/v2/get_values_v2.proto) need to be set to: +`"consented_debug_config": {"is_consented": true, "token": }` + +Example consented V2 requests in json can be found in [here](/public/test_util/request_example.h); + +More background information about consented debugging can be found in +[Debugging Protected Audience API Services](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/debugging_protected_audience_api_services.md) + ## Develop and run the server for AWS platform in your local machine The data server provides the read API for the KV service. @@ -64,7 +85,7 @@ The data server provides the read API for the KV service. 1. Build the server artifacts and copy them into the `dist/debian/` directory. ```sh - builders/tools/bazel-debian run //production/packaging/aws/data_server:copy_to_dist --config local_instance --//:platform=aws + builders/tools/bazel-debian run //production/packaging/aws/data_server:copy_to_dist --config=local_instance --config=aws_platform --config=nonprod_mode ``` 1. Load the image into docker @@ -107,7 +128,7 @@ docker run -it --rm --network host bazel/testing/run_local:envoy_image For example: ```sh -builders/tools/bazel-debian run //components/data_server/server:server --config local_instance --//:platform=aws -- --environment="dev" +builders/tools/bazel-debian run //components/data_server/server:server --config=local_instance --config=aws_platform --config=nonprod_mode -- --environment="dev" ``` We are currently developing this server for local testing and for use on AWS Nitro instances @@ -150,7 +171,7 @@ as parameters, GCS data bucket) are still required and please follow From the kv-server repo folder, execute the following command ```sh -builders/tools/bazel-debian run //production/packaging/gcp/data_server:copy_to_dist --config=local_instance --config=gcp_platform +builders/tools/bazel-debian run //production/packaging/gcp/data_server:copy_to_dist --config=local_instance --config=gcp_platform --config=nonprod_mode ``` #### Load the image into docker diff --git a/docs/playbook/index.md b/docs/playbook/index.md index 69deee93..53dd577c 100644 --- a/docs/playbook/index.md +++ b/docs/playbook/index.md @@ -1,8 +1,109 @@ -# FLEDGE Key/Value Server Alert Playbooks +# Key Value Server Playbook -For each of the available alerts that this server can cause, there is a playbook entry that has -background and what actions to take. +This article lists important metrics ad techs should monitor and set up alerts for. -## Alerts +There are other metrics that you should probably monitor and have alerts for, based on your use case +and service level indicator guarantees (SLIs). Further guidance will be available when Key Value +Server (KV server) moves to general availability (GA). -- [Template](template.md) +## Important metrics + +For each metric, review our recommendations to get set up, verify an issue, and troubleshoot common +errors: + +- [Uptime](./server_is_unhealthy.md) +- [Read latency](./read_latency_too_high.md) +- [Total error rate](./total_error_rate_too_high.md) + +## Debugging + +[Debugging Protected Audience API Services](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/debugging_protected_audience_api_services.md) +provides an overview of ways to debug a key-value server. Important points are made about trade-offs +between privacy and ease of debugging. + +Note that KV Server does not currently support all the features mentioned in this document, but +these features on are on our roadmap and will be available when we move to GA. + +## Monitoring + +[Monitoring Protected Audience API Services](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/monitoring_protected_audience_api_services.md) +has high-level suggestions for Protected Audience monitoring. All metrics available to KV server are +defined in [server_definition.h](../../components/telemetry/server_definition.h). + +[Privacy safe telemetry](https://github.com/privacysandbox/protected-auction-services-docs/blob/main/monitoring_protected_audience_api_services.md#privacy-safe-telemetry) +outlines metrics noising. + +### Logs + +Our default terraform parameters set the variable `enable_otel_logger` to true. Our default set up +starts up oTel collector. + +#### AWS logs + +[AWS](../../production/terraform/aws/environments/kv_server_variables.tf#L271) + +Find logs in Cloud Watch with the group name and log stream specified under +[awscloudwatchlogs](../../production/packaging/aws/otel_collector/otel_collector_config.yaml#L45) + +#### GCP logs + +[GCP](../../production/terraform/gcp/environments/kv_server_variables.tf#L290) + +You can find your logs in `Logs Explorer`. + +### Dashboard + +[AWS dashboard](../../production/terraform/aws/services/dashboard/main.tf) + +[GCP dashboard](../../production/terraform/gcp/services/dashboards/main.tf) + +## Typical problems + +### How do I know if my system is healthy? + +The two most important things are [Uptime](./server_is_unhealthy.md) and +[Read latency](./read_latency_too_high.md). + +You should also monitor logs to make sure you don't see any errors in there. + +Beyond that, you can verify that you can [load data](../data_loading/loading_data.md) and then +[read](../testing_the_query_protocol.md) it. + +### How do I know if my read latency is meeting my SLI? + +Please see [Read latency](read_latency_too_high.md). + +### My realtime messages don't go through. How to fix it? + +Please review this [Realime updates](../data_loading/loading_data.md#realtime-updates) section. + +This document provides sample code to send updates to different cloud providers, from the console as +well as from code. + +Try sending an update using the steps described in the document, and see if it goes through. + +[AWS Realtime update capabilities explains](../data_loading/realtime_updates_capabilities.md) +explains how to further tune and test real-time updates. + +For real-time updates, pay attention to the +[kReceivedLowLatencyNotificationsE2E](../../components/telemetry/server_definition.h#L311) metric. + +## Escalation + +### Cloud specific + +#### AWS + +[Support](https://aws.amazon.com/contact-us/) + +#### GCP + +[Support](https://cloud.google.com/support?hl=en) + +### Privacy Sandbox + +#### KV server + +Our oncall is responsible for monitoring the issues created for this repo. + +Alternatively, you can email us at diff --git a/docs/playbook/read_latency_too_high.md b/docs/playbook/read_latency_too_high.md new file mode 100644 index 00000000..44b1cebc --- /dev/null +++ b/docs/playbook/read_latency_too_high.md @@ -0,0 +1,80 @@ +# ReadLatencyTooHigh + +## Overview + +Read requests are taking too long. + +Most likely the volume of requests is too high and the autoscaling group was not able to add more +capacity. + +Alternatively, the UDF logic is taking too long and its logic should be improved. + +## Recommended alert level + +p50 for 5 mins is over 500 ms, but this is highly dependent on your use case. + +Note that this metric is noised for privacy purposes, and as such might actually show a slightly +different number from the actual measurement. + +## Alert Severity + +High. This affects the server's core flow of serving read requests in time. + +## Verification + +Check out the +[metric](https://github.com/privacysandbox/data-plane-shared-libraries/blob/5753af60e8cfae76ef2bb35c4cc105d0ac24481d/src/metric/definition.h#L300), +and see if it's outside of the expected range. + +[AWS Request Latency](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/aws/services/dashboard/main.tf#L110) +[GCP Request Latency](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/gcp/services/dashboards/main.tf#L148) + +## Troubleshooting and solution + +### Not enough capacity + +Check out the CPU metrics for your machines. They are available in the dashboard linked +[here](index.md). + +Consider updating the max capacity limits for your auto scaling groups, and how quickly those kick +in + +[AWS](../../production/terraform/aws/services/autoscaling/main.tf#L70) + +[GCP](../../production/terraform/gcp/services/autoscaling/main.tf#L144) + +### Too many read requests + +Check out the +[metric](https://github.com/privacysandbox/data-plane-shared-libraries/blob/5753af60e8cfae76ef2bb35c4cc105d0ac24481d/src/metric/definition.h#L293) +for total requests count, and see if it's outside of the expected range. + +[AWS Request Count](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/aws/services/dashboard/main.tf#L44) + +[GCP Request Count](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/gcp/services/dashboards/main.tf#L25) + +Note that this metric is noised for privacy purposes, and as such might actually show a slightly +different number from the actual measurement. + +Consider escalating to the upstream component, if you believe you're getting more requests than you +reasonably should. + +### Investigate the UDF + +If the number of requests is reasonable and so are CPU numbers, it probably means that the UDF you +have might not be optimally written. + +Consider improving the logic, or updating the alert threshold above accordingly. + +### Upgrade hardware + +You can pick more performant hardware. + +## Escalation + +If you believe that the number of requests you're getting is too high, consider escalating to +upstream components. + +## Related Links + +[UDF explainer](https://github.com/privacysandbox/fledge-docs/blob/main/key_value_service_user_defined_functions.md#keyvalue-service-user-defined-functions-udfs) diff --git a/docs/playbook/server_is_unhealthy.md b/docs/playbook/server_is_unhealthy.md new file mode 100644 index 00000000..50393925 --- /dev/null +++ b/docs/playbook/server_is_unhealthy.md @@ -0,0 +1,127 @@ +# ServerIsUnhelathy + +## Overview + +This alert means that all kv servers (ec2/vm instances) are down and cannot serve traffic. + +If you see this alert, there is a big chance you'll see most, if not all, other alerts firing too. + +## Recommended alert level + +Fire an alert if over 90% of response did not return OK over 5 mins. Probe interval = 300 ms. + +## Alert Severity + +Critical. + +The service is down, and is fully unavailable to serve read requests. + +This condition directly affects the uptime SLO. + +## Verification + +_Http_ + +```sh +curl $YOURSERVERURL/healthcheck +``` + +should return + +```json +{ + "status": "SERVING" +} +``` + +_GRPC_ + +Run from the repo root, since you need access to the \*.pb file + +```sh +grpcurl --protoset dist/query_api_descriptor_set.pb $YOURSERVERURL:8443 grpc.health.v1.Health/Check +``` + +should return the same response above, as the http call. + +Additionally, any read requests will fail. You can run try + +```sh +curl $YOURSERVERURL/v1/getvalues?keys=hi +``` + +A healthy response is + +```json +{ + "keys": { + "hi": { + "value": "Hello, world! If you are seeing this, it means you can query me successfully" + } + }, + "renderUrls": {}, + "adComponentRenderUrls": {}, + "kvInternal": {} +} +``` + +## Troubleshooting & Solution + +The fact that there are no healthy machines in kv's autoscaling group(s) means that the autoscaling +group manager tried to rotate in new machines, but failed. + +You can check out the autoscaling events to see when that started happening. + +For AWS: pay particular attention to "TerminateInstances" events that you can query on CloudTrail. + +## Solution + +If this alert is firing, it means that something went wrong in a big way. + +You should check the metrics dashboard and logs, that are linked for you cloud [here](index.md). + +Metrics noising and other privacy enhancing observability frameworks should not interfere with +troubleshooting too much for this alert. + +### Out of memory + +In the dashboard check the memory consumption and see if it comes close to the threshold value +(total available memory), and then the machine disappears. + +The solution here is to remove the _excess_ data loaded through the standard and fast path. And then +add sharding, if necessary. + +### Out of cpu + +Similarly, check the cpu and how close it is to 100%. + +The solution here is to add more machines to the autoscaling group or change your hardware to be +more powerful. Additionally, it might be necessary to speed up how quickly the autoscaling group is +perfroming machines rotation. + +You could turn off read traffic for your server, as part of debugging. You can check if the spike in +traffic looks like a DDOS attack, and act accordingly. + +Lastly, you need to analyze if you can optimize CPU consuming tasks, e.g. your UDFs. + +### Out of disc + +Similarly, check if you're out of disk. If you are -- you can bump up the amount of disk you're +using by updating the hardware, and also how much you allocate to the enclave. You should analyze +which part of your disc usage is growing, e.g. mb your logs are stored on the disc and are bound to +run out space. In that case you need to figure out a proper log rotation strategy. + +### An implementation bug + +It might be that some incorrect implementation hit an edge case. It might be helpful to turn up log +verbosity and analyze the last few entries before the machine crashes. + +A common technique to address this bug is to revert to a previous more stable build. + +If you believe that this is a KV server issue, you should escalate using the info from +[here](index.md) + +## Related Links + +[Server initialization](../server_initialization.md) -- provides extra details on the server +initialization lifecycle and how it affects the health check. diff --git a/docs/playbook/template.md b/docs/playbook/template.md index 596a4e34..94297920 100644 --- a/docs/playbook/template.md +++ b/docs/playbook/template.md @@ -5,7 +5,6 @@ _Address the following:_ - _What does this alert mean?_ -- _Does this default to paging, emailing, filing a ticket or something else?_ - _What factors contributed to the alert?_ - _What parts of the service are affected?_ - _What other alerts accompany this alert?_ @@ -13,8 +12,8 @@ _Address the following:_ ## Alert Severity -_Indicate the reason for the severity (email or paging) of the alert and the impact of the alerted -condition on the site._ +_Indicate the reason for the severity of the alert and the impact of the alerted condition on the +site._ _For example, is the serice as a whole down? Is it running in a degraded manner? How long do you likely have until the service goes out of SLA?_ @@ -48,12 +47,6 @@ _List and describe possible solutions for addressing this alert. Address the fol _List and describe paths of escalation. Identify whom to notify (person or team) and when. If there is no need to escalate, indicate that._ -_TODO(b/266432861): This should link to the page in this directory about generic cloud-provider -escalation paths once that is added._ - -_TODO(b/266432861): This should link to the page in this directory about generic -server-author/support escalation paths once that is added._ - ## Related Links _Provide links to relevant related alerts, procedures, and component documentation._ diff --git a/docs/playbook/total_error_rate_too_high.md b/docs/playbook/total_error_rate_too_high.md new file mode 100644 index 00000000..612ac75a --- /dev/null +++ b/docs/playbook/total_error_rate_too_high.md @@ -0,0 +1,52 @@ +# TotalErrorRateTooHigh + +## Overview + +The number of total errors has crossed an alert threshold. + +## Recommended alert level + +Number of errors in the past 5 minutes is over 100. + +Note that this metric is noised for privacy purposes, and as such might actually show a slightly +different number from the actual measurement. + +## Alert Severity + +High. While it does not necessarily mean that the server is down or that any critical flows are +affected, there is a high chance that they are. + +## Verification + +Check out the [metric](../../components/telemetry/server_definition.h#100), +[metric](../../components/telemetry/server_definition.h#154), +[metric](../../components/telemetry/server_definition.h#301) and see if it's outside of the expected +range. + +Dashboards: + +AWS: + +[Request Errors](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/aws/services/dashboard/main.tf#L176), +[Internal Request Errors](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/aws/services/dashboard/main.tf#L198), +[Server Non-request Errors](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/aws/services/dashboard/main.tf#L220) + +GCP: + +[Request Errors](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/gcp/services/dashboards/main.tf#L259), +[Internal Request Errors](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/gcp/services/dashboards/main.tf#L301), +[Server Non-request Errors](https://github.com/privacysandbox/protected-auction-key-value-service/blob/552934a1e1e8d1a8beed4474408127104cdf3207/production/terraform/gcp/services/dashboards/main.tf#L343) + +## Troubleshooting and solution + +Figure out what specific error is firing. You can group by the error by dimensions, and most likely +only one of those if firing. Having that knowledge, check the [logs](index.md). + +Can it be that this is a transient error due to some temporary outage? Can this be cause by a recent +regression in the codebase? Beyond that it's hard to give a specific recommendation other than +analyzing the code and trying to understand what might have changed in the recent usage patterns. + +## Escalation + +You might want to [escalate](index.md) to your cloud provider or to privacy sandbox based on the +nature of the error. diff --git a/docs/profiling_the_server.md b/docs/profiling_the_server.md index 797f349e..72f11e90 100644 --- a/docs/profiling_the_server.md +++ b/docs/profiling_the_server.md @@ -20,7 +20,7 @@ following command from workspace root: --dynamic_mode=off -c opt --copt=-gmlt \ --copt=-fno-omit-frame-pointer \ production/packaging/local/data_server:server_profiling_docker_image.tar \ - --//:instance=local --//:platform=local + --config=local_instance --config=local_platform --config=nonprod_mode ``` The `--dynamic_mode=off -c opt --copt=-gmlt` flags are needed to generate a server binary that is diff --git a/docs/protected_app_signals/ad_retrieval_overview.md b/docs/protected_app_signals/ad_retrieval_overview.md index ddfef42f..2dba3935 100644 --- a/docs/protected_app_signals/ad_retrieval_overview.md +++ b/docs/protected_app_signals/ad_retrieval_overview.md @@ -1,3 +1,7 @@ +!! This workflow works only for +[Protected App Signals (PAS)](https://developers.google.com/privacy-sandbox/relevance/protected-audience/android/protected-app-signals). +It does not support Protected Audience (PA), since the UDF and Server API is different.!! + ## Background This document provides a detailed overview of the Ad Retrieval server, which is a server-side diff --git a/getting_started/examples/sample_word2vec/README.md b/docs/protected_app_signals/advanced_onboarding_dev_guide.md similarity index 66% rename from getting_started/examples/sample_word2vec/README.md rename to docs/protected_app_signals/advanced_onboarding_dev_guide.md index 3b1c56df..5cd50448 100644 --- a/getting_started/examples/sample_word2vec/README.md +++ b/docs/protected_app_signals/advanced_onboarding_dev_guide.md @@ -1,13 +1,25 @@ -# The word2vec sample +# Protected App Signals, Advanced Ad Retrieval developer guide + +This is an advanced part of the +[Ad Retreival dev guide](/docs/protected_app_signals/onboarding_dev_guide.md) + +It illustrates the points made in the +[use case overview](/docs/protected_app_signals/ad_retrieval_overview.md#use-case-overview). +![alt_text](../assets/ad_retrieval_filter_funnel.png 'image_tooltip') This sample demonstrates how key->set(set queries) and key->value(value lookups) data can be loaded into a server and used together. In this case the key->set data is categorized groups of words. The -key->value data is a word to embedding mapping. An embedding is a vector of numbers. vectors for a -set of words. +key->value data is a word to embedding mapping. An embedding is a vector of numbers. The sample will demonstrate how you can query for a set of words, and sort them based on scoring criteria defined by word similarities, using embeddings. +## Word2vec + +We are using [word2vec](https://en.wikipedia.org/wiki/Word2vec) technique here. It is an NLP +technique for obtaining vector representations of words. These vectors capture information about the +meaning of the word based on the surrounding words. + ## Generating DELTA files There are 2 categories of DELTA files we build, data and udf. @@ -23,8 +35,8 @@ BUILD rules take care of generating the csv and piping them to the `data_cli` fo commands will build DELTA files for both embeddings and category DELTA files. ```sh -builders/tools/bazel-debian build tools/udf/sample_word2vec:generate_categories_delta -builders/tools/bazel-debian build tools/udf/sample_word2vec:generate_embeddings_delta +builders/tools/bazel-debian build //docs/protected_app_signals/examples/advanced:generate_categories_delta +builders/tools/bazel-debian build //docs/protected_app_signals/examples/advanced:generate_embeddings_delta ``` generated csv data for categories looks like (key="catalyst"): @@ -46,7 +58,7 @@ In this example you can see that the embedding is stored as a JSON string. Build the udf: ```sh -builders/tools/bazel-debian build tools/udf/sample_word2vec:udf_delta +builders/tools/bazel-debian build //docs/protected_app_signals/examples/advanced:udf_delta ``` At this point there are 3 DELTA files: @@ -61,7 +73,7 @@ Set up the data: ```sh mkdir /tmp/deltas -cp $(builders/tools/bazel-debian aquery 'tools/udf/sample_word2vec:udf_delta' | +cp $(builders/tools/bazel-debian aquery '//docs/protected_app_signals/examples/advanced:udf_delta' | sed -n 's/Outputs: \[\(.*\)\]/\1/p' | xargs dirname)/DELTA* /tmp/deltas ``` @@ -69,7 +81,7 @@ cp $(builders/tools/bazel-debian aquery 'tools/udf/sample_word2vec:udf_delta' | Build the local server: ```sh -./builders/tools/bazel-debian build //components/data_server/server:server --//:platform=local --//:instance=local +./builders/tools/bazel-debian build //components/data_server/server:server --config=local_instance --config=local_platform --config=nonprod_mode ``` Run the local server: @@ -92,6 +104,6 @@ The UDF returns the top 5 results and their scores. ```sh grpc_cli call localhost:50051 kv_server.v2.KeyValueService/GetValuesHttp \ - "raw_body: {data: $(tr -d '\n' < tools/udf/sample_word2vec/body.txt)}" \ + "raw_body: {data: $(tr -d '\n' < docs/protected_app_signals/examples/advanced/body.txt)}" \ --channel_creds_type=insecure ``` diff --git a/getting_started/examples/sample_word2vec/BUILD.bazel b/docs/protected_app_signals/examples/advanced/BUILD.bazel similarity index 100% rename from getting_started/examples/sample_word2vec/BUILD.bazel rename to docs/protected_app_signals/examples/advanced/BUILD.bazel diff --git a/docs/protected_app_signals/examples/advanced/body.txt b/docs/protected_app_signals/examples/advanced/body.txt new file mode 100644 index 00000000..179338df --- /dev/null +++ b/docs/protected_app_signals/examples/advanced/body.txt @@ -0,0 +1,12 @@ +'{ + "partitions":[ + { + "id":0, + "arguments":[ + { + "data": "{ \\"metadataKeys\\":[ \\"animals\\", \\"garden\\", \\"dinner\\", \\"hotels\\" ], \\"signal\\" : \\"food\\" }" + } + ] + } + ] +}' diff --git a/getting_started/examples/sample_word2vec/data_generator.py b/docs/protected_app_signals/examples/advanced/data_generator.py similarity index 100% rename from getting_started/examples/sample_word2vec/data_generator.py rename to docs/protected_app_signals/examples/advanced/data_generator.py diff --git a/getting_started/examples/sample_word2vec/udf.js b/docs/protected_app_signals/examples/advanced/udf.js similarity index 85% rename from getting_started/examples/sample_word2vec/udf.js rename to docs/protected_app_signals/examples/advanced/udf.js index 2e197b2d..0e25d30c 100644 --- a/getting_started/examples/sample_word2vec/udf.js +++ b/docs/protected_app_signals/examples/advanced/udf.js @@ -79,21 +79,22 @@ function associateCosineSimilarity(wordEmbeddings, embedding) { } /** - * Computes the set union of all `metadata` keys and scores their similarity agains the `signal` word. - * The words and scores of the top 10 most similar words are returned, in order of similarity. + * Computes the set union of all `metadata` keys and scores their similarity against the `signal` word. + * The words and scores of the top 5 most similar words are returned, in order of similarity. * * @param metadataKeys Keys into the set data which do a UNION of all entries. * @param signal Orders unioned data by similarity to signal word. - * @returns A sorted list of top 10 words and their scores. + * @returns A sorted list of top 5 words and their scores. */ -function HandleRequest(executionMetadata, metadataKeys, signal) { +function HandleRequest(requestMetadata, protectedSignals, deviceMetadata, contextualSignals, contextualAdIds) { + const parsedProtectedSignals = JSON.parse(protectedSignals); results = []; - if (metadataKeys.length) { + if (parsedProtectedSignals.metadataKeys.length) { // Union all of the sets of the given metadata category - results = runQuery(metadataKeys.join('|')); + results = runQuery(parsedProtectedSignals.metadataKeys.join('|')); } wordSimilarity = {}; - embedding = getWordEmbedding(signal); + embedding = getWordEmbedding(parsedProtectedSignals.signal); if (embedding != null) { wordSimilarity = associateCosineSimilarity(associateEmbeddings(results), embedding); } diff --git a/docs/protected_app_signals/onboarding_dev_guide.md b/docs/protected_app_signals/onboarding_dev_guide.md index a0b3fe2b..bf060c8a 100644 --- a/docs/protected_app_signals/onboarding_dev_guide.md +++ b/docs/protected_app_signals/onboarding_dev_guide.md @@ -71,9 +71,9 @@ More [details](../data_loading/loading_data.md#upload-data-files-to-gcp) Note that this is a simplistic example created for an illustrative purpose. The retrieval case can get more complicated. -See this [example](/getting_started/examples/sample_word2vec/). The sample demonstrates how you can -query for a set of words, taking advantage of the native set query support, and sort them based on -scoring criteria defined by word similarities, using embeddings. +[Advanced_onboarding_dev_guide](/docs/protected_app_signals/advanced_onboarding_dev_guide.md) +demonstrates how you can query for a set of words, taking advantage of the native set query support, +and sort them based on scoring criteria defined by word similarities, using embeddings. ### UDF diff --git a/docs/server_initialization.md b/docs/server_initialization.md index 4255cdaa..fd03de91 100644 --- a/docs/server_initialization.md +++ b/docs/server_initialization.md @@ -6,18 +6,18 @@ Sequence diagrams below outline initilization steps for AWS/GCP Sharded/Non shar ### Non sharded -![GCP Non sharded initialization](assets/gcp_sharded_init.png) +![GCP Non sharded initialization](assets/gcp_nonsharded_init.svg) ### Sharded -![GCP Sharded initialization](assets/gcp_sharded_init.png) +![GCP Sharded initialization](assets/gcp_sharded_init.svg) ## AWS ### Non sharded -![AWS Non sharded initialization](assets/aws_nonsharded_init.png) +![AWS Non sharded initialization](assets/aws_nonsharded_init.svg) ### Sharded -![AWS Sharded initialization](assets/aws_sharded_init.png) +![AWS Sharded initialization](assets/aws_sharded_init.svg) diff --git a/docs/working_with_set_queries.md b/docs/working_with_set_queries.md new file mode 100644 index 00000000..efcff5af --- /dev/null +++ b/docs/working_with_set_queries.md @@ -0,0 +1,126 @@ +# Set query language + +The K/V server supports a simple set query language that can be invoked using two UDF reap APIs (1) +`runQuery("query")` and `runSetQueryInt("query")`. The query language implements threee set +operations `union` denoted as `|`, `difference` denoted as `-` and `intersection` denoted as `&` and +queries use sets of strings and numbers as input. + +As an example, suppose that we have indexed two sets of 32 bit integer ad ids targeting `games` and +`news` to the K/V server memory store. We can find the sets of ad ids targeting both `games` and +`news` by intersecting `games` and `news` sets using the following query in a UDF: +`games_and_news_ads = runSetQueryInt("games & news")`. + +# Constructing queries + +## Query operators + +The set query language [grammar](/components/query/parser.yy) supports three binary operators: + +- `union` or `|` operator - e.g., given `A = [1, 2, 3]` and `B = [3, 4, 5]`, then + `A | B = [1, 2, 3, 4, 5]`. +- `difference` or `-` operator - e.g., given `A = [1, 2, 3]` and `B = [3, 4, 5]`, then + `A - B = [1, 2]` +- `intersection` or `&` operator - e.g., given `A = [1, 2, 3]` and `B = [3, 4, 5]`, then + `A & B = [ 3 ]` + +## Query operands + +Queries support two types of sets as operands: + +- set of numbers - only 32 bit unsigned numbers are supported. The data loading type is + `UInt32Set` in [data_loading.fbs](/public/data_loading/data_loading.fbs). +- set of strings - arbitrary length UTF-8 or 7-bit ASCII strings are supported. The data loading + type is `StringSet` in [data_loading.fbs](/public/data_loading/data_loading.fbs). + +## Query syntax + +Queries can be constructed using long form operator names (`union`, `difference`, `intersection`) or +short form operator symbols (`|`, `-`, `&`). By default, queries are evaluated left to right and one +can use paretheses to override default precedence, e.g., (1) `A & B - C` is semantically different +from (2) `A & (B - C)` because in (1), `A` is intersected with `B` and then `C` is subtracted from +the result whereas in (2) `A` is intersected with the result of finding the difference between `B` +and `C`. + +Note that valid queries always have an operator between two operands. For example, given three sets +`A`, `B` and `C`, valid queries include `A | B & C`, `A - B - C`, `B & C - A`, but `A | B | C &` is +invalid because of the last `&`. + +## Visualizing ASTs for queries + +We provide a tool [query_toy.cc](/components/tools/query_toy.cc) that can be used to visualize AST +trees for queries. For example, run the following commands to evaluate and visualize the AST for +`A & B | (C - D)`: + +```bash +./builders/tools/bazel-debian build -c opt //components/tools:query_toy && +bazel-bin/components/tools/query_toy --query="A & C | (B - D)" --dot_path="$PWD/query.dot" && +dot -Tpng query.dot > query.png +``` + +This produces the following image: ![query visualization image](assets/query_visualization.png) + +# Running queries in UDFs + +The K/V server supports two APIs for running queries inside UDFs. + +- `runQuery("query")` + - Takes a valid `query` as input and returns a set of strings. + - See [run_query_udf.js](/tools/udf/sample_udf/run_query_udf.js) for a JavaScript example. +- `runSetQueryInt("query")` + - Takes a valid `query` as input and returns a byte array of serialized 32 bit usinged + integers. + - See [run_set_query_int_udf.js](/tools/udf/sample_udf/run_set_query_int_udf.js) for a + JavaScript example. + +## Which API should I use, `runQuery` vs. `runSetQueryInt`? + +`runSetQueryInt` implements a much more perfomant version of query evaluation based on bitmaps. So +if your sets can be represented as 32 bit unsigned integer sets and are relatively dense compared to +the range of numbers, then use `runSetQueryInt`. We also provide a benchmarking tool +[query_evaluation_benchmark](/components/tools/benchmarks/query_evaluation_benchmark.cc) that can be +used to determine whether using integer sets vs. string sets would be much more performant for a +given scenario. For example: + +```bash +./builders/tools/bazel-debian run -c opt //components/tools/benchmarks:query_evaluation_benchmark -- \ + --benchmark_counters_tabular=true \ + --benchmark_time_unit=us \ + --set_size=500000 \ + --range_min=1000000 \ + --range_max=2000000 \ + --set_names="A,B,C,D" \ + --query="(A - B) | (C & D)" +``` + +benchamrks evaluating set operations and the query `(A - B) | (C & D)` using sets with 500,000 +elements randomly selected from the range `[1,000,000 - 2,000,000]`. On my machine, the benchmark +produces the following output which shows superior performance for integer sets (Note that output +for integer sets is denoted using `.*kv_server::RoaringBitSet.*`): + +```bash +Run on (128 X 2450 MHz CPU s) +CPU Caches: + L1 Data 32 KiB (x64) + L1 Instruction 32 KiB (x64) + L2 Unified 512 KiB (x64) + L3 Unified 32768 KiB (x8) +Load Average: 4.73, 5.96, 6.31 +--------------------------------------------------------------------------------------------------------------- +Benchmark Time CPU Iterations Ops/s +--------------------------------------------------------------------------------------------------------------- +kv_server::BM_SetUnion 15.5 us 15.5 us 45888 64.4784k/s +kv_server::BM_SetUnion 60104 us 60104 us 10 16.6378/s +kv_server::BM_SetDifference 15.3 us 15.3 us 43768 65.4571k/s +kv_server::BM_SetDifference 21194 us 21192 us 32 47.1883/s +kv_server::BM_SetIntersection 16.6 us 16.6 us 43286 60.1952k/s +kv_server::BM_SetIntersection 20601 us 20599 us 32 48.5452/s +----------------------------------------------------------------------------------------------------------------- +Benchmark Time CPU Iterations QueryEvals/s +----------------------------------------------------------------------------------------------------------------- +kv_server::BM_AstTreeEvaluation 43.4 us 43.4 us 15977 23.0672k/s +kv_server::BM_AstTreeEvaluation 60834 us 60826 us 10 16.4404/s +``` + +# Loading sets into the K/V server + +See [data loading guide](data_loading/loading_data.md) on how to load sets into the K/V server. diff --git a/getting_started/examples/sample_word2vec/body.txt b/getting_started/examples/sample_word2vec/body.txt deleted file mode 100644 index 809c4311..00000000 --- a/getting_started/examples/sample_word2vec/body.txt +++ /dev/null @@ -1,23 +0,0 @@ -'{ - "metadata":{ - "hostname":"example.com" - }, - "partitions":[ - { - "id":0, - "arguments":[ - { - "data":[ - "animals", - "garden", - "dinner", - "hotels" - ] - }, - { - "data": "food" - } - ] - } - ] -}' diff --git a/getting_started/onboarding.md b/getting_started/onboarding.md new file mode 100644 index 00000000..1915d5e2 --- /dev/null +++ b/getting_started/onboarding.md @@ -0,0 +1,206 @@ +# KV server onboarding and self-serve guide + +This document provides guidance to adtechs to onboard to Key Value server. + +You may refer to KV services timeline and roadmap [here][1]. You may refer to the high level +architecture and design [here][2]. You may also refer to [privacy considerations][8], [security +goals][9] and the [trust model][3]. + +If you have any questions, please submit your feedback by filing a github issue for this repo. +Alternatively, you can email us at + +Following are the steps that adtechs need to follow to onboard, integrate, deploy and run KV +services in non-production and production environments. + +## Step 1: Enroll with Privacy Sandbox + +Refer to the [guide][13] to enroll with Privacy Sandbox. This is also a prerequisite for [enrolling +with Coordinators][14]. + +## Step 2: Cloud set up + +Adtechs need to choose one of the currently [supported cloud platforms][27] to run services. + +Refer to the corresponding cloud support explainer for details: + +- [AWS support][4] + - Adtechs must set up an [AWS account][34], create IAM users and security credentials. +- [GCP support][5] + - Create a GCP Project and generate a cloud service account. Refer to the [GCP project + setup][6] section for more details. + +## Step 3: Enroll with Coordinators + +Adtechs must enroll with [Coordinators][37] for the specific [cloud platform][27] where they plan to +run KV services in production environments ([Beta testing][7] and [GA][10]). + +The Coordinators run [key management systems (key services)][43] that provision keys to KV services +running in a [trusted execution environment (TEE)][44] after service attestation. Integration of KV +server workloads with Coordinators would enable TEE server attestation and allow fetching live +encryption / decryption keys from [public or private key service endpoints][43] in KV services. + +To integrate with Coordinators, adtechs will have to follow the steps below: + +- Set up a cloud account on a preferred [cloud platform][27] that is supported. +- Provide the Coordinators with specific information related to their cloud account. +- Coordinators will provide url endpoints of key services and other information that must be + incorporated in KV server configurations. + +Adtechs would have to enroll with Coordinators running key management systems that provision keys to +KV services after server attestion. + +Adtechs should only enroll with the Coordinators for the specific cloud platform where they plan to +run KV services. + +### `use_real_coordinators` + +KV supports [cryptographic protection][45] with hardcoded public-private key pairs, while disabling +TEE server attestation. During initial phases of onboarding, this would allow adtechs to test KV +server workloads even before integration with Coordinators. + +- Set `use_real_coordinators` flag to `false` ([AWS][12], [GCP][11]). +- _Note: `use_real_coordinators` should be set to `true` in production_. + +### Amazon Web Services (AWS) + +An adtech should provide their AWS Account Id. The Coordinators wil create IAM roles. They would +attach that AWS Account Id to the IAM roles and include in an allowlist. Then the Coordinators would +let adtechs know about the IAM roles and that should be included in the KV server Terraform configs +that fetch cryptographic keys from key management systems. + +Following config parameters in KV server configs would include the IAM roles information provided by +the Coordinators. + +PRIMARY_COORDINATOR_ACCOUNT_IDENTITY SECONDARY_COORDINATOR_ACCOUNT_IDENTITY + +### Google Cloud Platform (GCP) + +An adtech should provide [**IAM service account email**][107] to both the Coordinators. + +The Coordinators would create IAM roles. After adtechs provide their service account email, the +Coordinators would attach that information to the IAM roles and include in an allowlist. Then the +Coordinators would let adtechs know about the IAM roles and that should be included in the B&A +server Terraform configs that fetch cryptographic keys from key management systems. + +Enroll with Coordinators via the Protected Auction [intake form][51]. You will need to provide the +service account or aws account id created in the previous step. + +_Note:_ + +- _Adtechs can only run images attested by the key management systems_ _(Coordinators) in + production._ +- _Without successfully completing the Coordinator enrollment process, adtechs_ _will not be able + to run attestable services in TEE and therefore will not be_ _able to process production data + using KV services._ +- _Key management systems are not in the critical path of KV services. The_ _cryptographic keys + are fetched from the key services in the non critical_ _path at service startup and periodically + every few hours, and cached in-memory_ _server side. Refer here for more information._ + +## Step 4: Build, deploy services + +Follow the steps only for your preferred cloud platform to build and deploy B&A services. + +### KV code repository + +KV server code and configurations are open sourced in this repo. + +Starting with 0.17 KV, releases hashes are allowlisted with coordinators on GCP and AWS. + +### Prerequisites + +The following prerequisites are required before building and packaging KV services. + +- AWS: Refer to the prerequisite steps [here][15]. +- GCP: Refer to the prerequisite steps [here][16]. + +### Build service images + +To run KV services locally, refer [here][17]. + +- AWS: Refer to the detailed instructions [here][18]. +- GCP: Refer to the detailed instructions [here][19]. + +### Deploy services + +_Note: The configurations set the default parameter values. The values that are_ _not set must be +filled in by adtechs before deployment to the cloud._ + +[AWS][28] + +[GCP][29] + +## Step 5: Pick your usecase + +KV supports the following usecases: + +- [Protected Audience (PA)][20] +- [Protected App Signals (PAS)][21] + +## Step 6: Enable debugging and monitor services + +Refer to the [debugging explainer][80] to understand how user consented debugging can be used in +production. + +Refer to the [monitoring explainer][81] to understand how cloud monitoring is integrated and service +[metrics][82] that are exported for monitoring by adtechs. + +## Github issues + +Adtechs can file issues and feature requests on this Github project. + +[1]: + https://github.com/privacysandbox/protected-auction-key-value-service/tree/main?tab=readme-ov-file#timeline-and-roadmap +[2]: /docs/APIs.md +[3]: + https://github.com/privacysandbox/protected-auction-services-docs/blob/main/key_value_service_trust_model.md +[4]: /docs/deployment/deploying_on_aws.md +[5]: /docs/deployment/deploying_on_gcp.md +[6]: + https://github.com/privacysandbox/protected-auction-key-value-service/blob/main/docs/deployment/deploying_on_gcp.md#set-up-your-gcp-project +[7]: + https://github.com/privacysandbox/protected-auction-key-value-service/tree/main?tab=readme-ov-file#june-2024-beta-release +[8]: + https://github.com/privacysandbox/fledge-docs/blob/main/trusted_services_overview.md#privacy-considerations +[9]: + https://github.com/privacysandbox/fledge-docs/blob/main/trusted_services_overview.md#security-goals +[10]: + https://github.com/privacysandbox/protected-auction-key-value-service/tree/main?tab=readme-ov-file#h2-2024-android-pa-ga-pas-ga +[11]: + https://github.com/privacysandbox/protected-auction-key-value-service/blob/9a60180f9d6f52a4ca805e5463ecc9e5e80e88f9/production/terraform/gcp/environments/kv_server_variables.tf#L193 +[12]: + https://github.com/privacysandbox/protected-auction-key-value-service/blob/9a60180f9d6f52a4ca805e5463ecc9e5e80e88f9/production/terraform/aws/environments/kv_server_variables.tf#L216 +[13]: https://developers.google.com/privacy-sandbox/relevance/enrollment +[14]: #step-3-enroll-with-coordinators +[15]: /docs/deployment/deploying_on_aws.md#set-up-your-aws-account +[16]: /docs/deployment/deploying_on_gcp.md#set-up-your-gcp-project +[17]: /docs/deployment/deploying_locally.md +[18]: /docs/deployment/deploying_on_aws.md#build-the-keyvalue-server-artifacts +[19]: /docs/deployment/deploying_on_gcp.md#build-the-keyvalue-server-artifacts +[20]: /docs/protected_audience/integrating_with_fledge.md +[21]: /docs/protected_app_signals/ad_retrieval_overview.md +[27]: + https://github.com/privacysandbox/fledge-docs/blob/main/bidding_auction_services_api.md#supported-public-cloud-platforms +[28]: + https://github.com/privacysandbox/protected-auction-key-value-service/blob/release-0.16/docs/deployment/deploying_on_aws.md +[29]: + https://github.com/privacysandbox/protected-auction-key-value-service/blob/release-0.16/docs/deployment/deploying_on_gcp.md#deployment +[34]: + https://docs.aws.amazon.com/signin/latest/userguide/introduction-to-iam-user-sign-in-tutorial.html +[37]: + https://github.com/privacysandbox/fledge-docs/blob/main/trusted_services_overview.md#deployment-by-coordinators +[43]: + https://github.com/privacysandbox/fledge-docs/blob/main/trusted_services_overview.md#key-management-systems +[44]: + https://github.com/privacysandbox/fledge-docs/blob/main/trusted_services_overview.md#trusted-execution-environment +[45]: + https://github.com/privacysandbox/fledge-docs/blob/main/bidding_auction_services_api.md#client--server-and-server--server-communication +[51]: + https://docs.google.com/forms/d/e/1FAIpQLSduotEEI9h_Y8uEvSGdFoL-SqHAD--NVNaX1X1UTBeCeEM-Og/viewform +[80]: + https://github.com/privacysandbox/fledge-docs/blob/main/debugging_protected_audience_api_services.md +[81]: + https://github.com/privacysandbox/fledge-docs/blob/main/monitoring_protected_audience_api_services.md +[82]: + https://github.com/privacysandbox/fledge-docs/blob/main/monitoring_protected_audience_api_services.md#proposed-metrics +[107]: + https://github.com/privacysandbox/bidding-auction-servers/blob/main/production/deploy/aws/terraform/environment/demo/README.md#using-the-demo-configuration diff --git a/getting_started/quick_start.md b/getting_started/quick_start.md index befbbafc..1347b16d 100644 --- a/getting_started/quick_start.md +++ b/getting_started/quick_start.md @@ -33,17 +33,18 @@ From the Key Value server repo folder, execute the following command: ```sh ./builders/tools/bazel-debian build //components/data_server/server:server \ - --//:platform=local \ - --//:instance=local + --config=local_instance \ + --config=local_platform \ + --config=nonprod_mode ``` This will take a while for the first time. Subsequent builds can reuse cached progress. This command starts a build environment docker container and performs build from within. -- The `--//:instance=local` means the server itself runs as a local binary instead of running on a - specific cloud. -- The `--//:platform=local` means the server will integrate with local version of auxiliary +- The `--config=local_instance` means the server itself runs as a local binary instead of running + on a specific cloud. +- The `--config=local_platform` means the server will integrate with local version of auxiliary systems such as blob storage, parameter, etc. Other possible values are cloud-specific, in which case the server will use the corresponding cloud APIs to interact. diff --git a/production/packaging/aws/build_and_test b/production/packaging/aws/build_and_test index 6be00c9b..9c99fcd1 100755 --- a/production/packaging/aws/build_and_test +++ b/production/packaging/aws/build_and_test @@ -209,7 +209,13 @@ if [[ -n ${AMI_REGIONS[0]} ]]; then regions="$(arr_to_string_list AMI_REGIONS)" builder::cbuild_al2 " set -o errexit -packer build -var=regions='${regions}' -var=commit_version=$(git rev-parse HEAD) -var=distribution_dir=dist/aws -var=workspace=/src/workspace production/packaging/aws/data_server/ami/image.pkr.hcl +packer build \ + -var=regions='${regions}' \ + -var=commit_version=$(git rev-parse HEAD) \ + -var=build_mode=${MODE} \ + -var=distribution_dir=dist/aws \ + -var=workspace=/src/workspace \ + production/packaging/aws/data_server/ami/image.pkr.hcl " fi diff --git a/production/packaging/aws/data_server/BUILD.bazel b/production/packaging/aws/data_server/BUILD.bazel index b60c3563..be177338 100644 --- a/production/packaging/aws/data_server/BUILD.bazel +++ b/production/packaging/aws/data_server/BUILD.bazel @@ -34,24 +34,6 @@ pkg_files( prefix = "/", ) -pkg_files( - name = "kmstool_enclave_executables", - srcs = [ - "@google_privacysandbox_servers_common//src/cpio/client_providers/kms_client_provider/aws:kms_cli", - ], - attributes = pkg_attributes(mode = "0555"), - prefix = "/cpio/bin", -) - -pkg_files( - name = "kmstool_enclave_libs", - srcs = [ - "@google_privacysandbox_servers_common//src/cpio/client_providers/kms_client_provider/aws:libnsm_so", - ], - attributes = pkg_attributes(mode = "0444"), - prefix = "/cpio/lib", -) - # Create a symlink between where kmstool_enclave_cli expects shell to be # (/bin/sh) and where it actually is on our image (/busybox/sh). pkg_mklink( @@ -61,8 +43,6 @@ pkg_mklink( ) server_binaries = [ - ":kmstool_enclave_executables", - ":kmstool_enclave_libs", ":server_executables", ":busybox_sh_symlink", ] @@ -77,17 +57,6 @@ pkg_tar( srcs = server_binaries, ) -# Ensure libnsm ends up in the load path. -pkg_tar( - name = "libnsm-tar", - srcs = [ - "@google_privacysandbox_servers_common//src/cpio/client_providers/kms_client_provider/aws:libnsm_so", - ], - mode = "0444", - package_dir = "/cpio/lib", - visibility = ["//visibility:public"], -) - # This image target is meant for testing running the server in an enclave using. # # See project README.md on how to run the image. @@ -106,13 +75,13 @@ oci_image( "--port=50051", # These affect PCR0, so changing these would result in the loss of ability to communicate with # the downstream components - "--public_key_endpoint=https://publickeyservice.pa.aws.privacysandboxservices.com/.well-known/protected-auction/v1/public-keys", + "--public_key_endpoint=https://publickeyservice.pa-3.aws.privacysandboxservices.com/v1alpha/publicKeys", "--stderrthreshold=0", ], tars = [ "@google_privacysandbox_servers_common//src/aws/proxy:libnsm_and_proxify_tar", + "@google_privacysandbox_servers_common//src/cpio/client_providers/kms_client_provider/aws:kms_binaries", "//production/packaging/aws/resolv:resolv_config_layer", - ":libnsm-tar", ":server_binaries_tar", ], ) diff --git a/production/packaging/aws/data_server/ami/envoy_networking.sh b/production/packaging/aws/data_server/ami/envoy_networking.sh new file mode 100644 index 00000000..26e48a4f --- /dev/null +++ b/production/packaging/aws/data_server/ami/envoy_networking.sh @@ -0,0 +1,249 @@ +#!/bin/bash -e +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configure the iptables for envoy for the appmesh. +APPMESH_IGNORE_UID="1337" +APPMESH_APP_PORTS="50051" +APPMESH_EGRESS_IGNORED_PORTS="443" + +APPMESH_ENVOY_EGRESS_PORT="15001" +APPMESH_ENVOY_INGRESS_PORT="15000" +APPMESH_EGRESS_IGNORED_IP="169.254.169.254,169.254.170.2" + +# Enable IPv6. +[ -z "$APPMESH_ENABLE_IPV6" ] && APPMESH_ENABLE_IPV6="0" + +# Egress traffic from the processess owned by the following UID/GID will be ignored. +if [ -z "$APPMESH_IGNORE_UID" ] && [ -z "$APPMESH_IGNORE_GID" ]; then + echo "Variables APPMESH_IGNORE_UID and/or APPMESH_IGNORE_GID must be set." + echo "Envoy must run under those IDs to be able to properly route it's egress traffic." + exit 1 +fi + +# Port numbers Application and Envoy are listening on. +if [ -z "$APPMESH_ENVOY_EGRESS_PORT" ]; then + echo "APPMESH_ENVOY_EGRESS_PORT must be defined to forward traffic from the application to the proxy." + exit 1 +fi + +# If an app port was specified, then we also need to enforce the proxies ingress port so we know where to forward traffic. +if [ -n "$APPMESH_APP_PORTS" ] && [ -z "$APPMESH_ENVOY_INGRESS_PORT" ]; then + echo "APPMESH_ENVOY_INGRESS_PORT must be defined to forward traffic from the APPMESH_APP_PORTS to the proxy." + exit 1 +fi + +# Comma separated list of ports for which egress traffic will be ignored, we always refuse to route SSH traffic. +if [ -z "$APPMESH_EGRESS_IGNORED_PORTS" ]; then + APPMESH_EGRESS_IGNORED_PORTS="22" +else + APPMESH_EGRESS_IGNORED_PORTS="$APPMESH_EGRESS_IGNORED_PORTS,22" +fi + +# +# End of configurable options +# + +function initialize() { + echo "=== Initializing ===" + if [ -n "$APPMESH_APP_PORTS" ]; then + iptables -t nat -N APPMESH_INGRESS + if [ "$APPMESH_ENABLE_IPV6" == "1" ]; then + ip6tables -t nat -N APPMESH_INGRESS + fi + fi + iptables -t nat -N APPMESH_EGRESS + if [ "$APPMESH_ENABLE_IPV6" == "1" ]; then + ip6tables -t nat -N APPMESH_EGRESS + fi +} + +function enable_egress_routing() { + # Stuff to ignore + [ -n "$APPMESH_IGNORE_UID" ] && \ + iptables -t nat -A APPMESH_EGRESS \ + -m owner --uid-owner $APPMESH_IGNORE_UID \ + -j RETURN + + [ -n "$APPMESH_IGNORE_GID" ] && \ + iptables -t nat -A APPMESH_EGRESS \ + -m owner --gid-owner "$APPMESH_IGNORE_GID" \ + -j RETURN + + [ -n "$APPMESH_EGRESS_IGNORED_PORTS" ] && \ + for IGNORED_PORT in $(echo "$APPMESH_EGRESS_IGNORED_PORTS" | tr "," "\n"); do + iptables -t nat -A APPMESH_EGRESS \ + -p tcp \ + -m multiport --dports "$IGNORED_PORT" \ + -j RETURN + done + + if [ "$APPMESH_ENABLE_IPV6" == "1" ]; then + # Stuff to ignore ipv6 + [ -n "$APPMESH_IGNORE_UID" ] && \ + ip6tables -t nat -A APPMESH_EGRESS \ + -m owner --uid-owner $APPMESH_IGNORE_UID \ + -j RETURN + + [ -n "$APPMESH_IGNORE_GID" ] && \ + ip6tables -t nat -A APPMESH_EGRESS \ + -m owner --gid-owner "$APPMESH_IGNORE_GID" \ + -j RETURN + + [ -n "$APPMESH_EGRESS_IGNORED_PORTS" ] && \ + for IGNORED_PORT in $(echo "$APPMESH_EGRESS_IGNORED_PORTS" | tr "," "\n"); do + ip6tables -t nat -A APPMESH_EGRESS \ + -p tcp \ + -m multiport --dports "$IGNORED_PORT" \ + -j RETURN + done + fi + + # The list can contain both IPv4 and IPv6 addresses. We will loop over this list + # to add every IPv4 address into `iptables` and every IPv6 address into `ip6tables`. + [ -n "$APPMESH_EGRESS_IGNORED_IP" ] && \ + for IP_ADDR in $(echo "$APPMESH_EGRESS_IGNORED_IP" | tr "," "\n"); do + if [[ $IP_ADDR =~ .*:.* ]] + then + [ "$APPMESH_ENABLE_IPV6" == "1" ] && \ + ip6tables -t nat -A APPMESH_EGRESS \ + -p tcp \ + -d "$IP_ADDR" \ + -j RETURN + else + iptables -t nat -A APPMESH_EGRESS \ + -p tcp \ + -d "$IP_ADDR" \ + -j RETURN + fi + done + + # Redirect egress traffic destined to application port to envoy. + sudo iptables -t nat -A APPMESH_EGRESS -p tcp --dport 50051 -j REDIRECT \ + --to $APPMESH_ENVOY_EGRESS_PORT + + # Apply APPMESH_EGRESS chain to non local traffic + iptables -t nat -A OUTPUT \ + -p tcp \ + -m addrtype ! --dst-type LOCAL \ + -j APPMESH_EGRESS + + if [ "$APPMESH_ENABLE_IPV6" == "1" ]; then + # Redirect everything that is not ignored ipv6 + ip6tables -t nat -A APPMESH_EGRESS \ + -p tcp \ + -j REDIRECT --to $APPMESH_ENVOY_EGRESS_PORT + # Apply APPMESH_EGRESS chain to non local traffic ipv6 + ip6tables -t nat -A OUTPUT \ + -p tcp \ + -m addrtype ! --dst-type LOCAL \ + -j APPMESH_EGRESS + fi + +} + +function enable_ingress_redirect_routing() { + # Route everything arriving at the application port to Envoy + iptables -t nat -A APPMESH_INGRESS \ + -p tcp \ + -m multiport --dports "$APPMESH_APP_PORTS" \ + -j REDIRECT --to-port "$APPMESH_ENVOY_INGRESS_PORT" + + # Apply AppMesh ingress chain to everything non-local + iptables -t nat -A PREROUTING \ + -p tcp \ + -m addrtype ! --src-type LOCAL \ + -j APPMESH_INGRESS + + if [ "$APPMESH_ENABLE_IPV6" == "1" ]; then + # Route everything arriving at the application port to Envoy ipv6 + ip6tables -t nat -A APPMESH_INGRESS \ + -p tcp \ + -m multiport --dports "$APPMESH_APP_PORTS" \ + -j REDIRECT --to-port "$APPMESH_ENVOY_INGRESS_PORT" + + # Apply AppMesh ingress chain to everything non-local ipv6 + ip6tables -t nat -A PREROUTING \ + -p tcp \ + -m addrtype ! --src-type LOCAL \ + -j APPMESH_INGRESS + fi +} + +function enable_routing() { + echo "=== Enabling routing ===" + enable_egress_routing + if [ -n "$APPMESH_APP_PORTS" ]; then + echo "=== Enabling ingress routing ===" + enable_ingress_redirect_routing + fi +} + +function disable_routing() { + echo "=== Disabling routing ===" + iptables -t nat -F APPMESH_INGRESS + iptables -t nat -F APPMESH_EGRESS + + if [ "$APPMESH_ENABLE_IPV6" == "1" ]; then + ip6tables -t nat -F APPMESH_INGRESS + ip6tables -t nat -F APPMESH_EGRESS + fi +} + +function dump_status() { + echo "=== iptables FORWARD table ===" + iptables -L -v -n + echo "=== iptables NAT table ===" + iptables -t nat -L -v -n + + if [ "$APPMESH_ENABLE_IPV6" == "1" ]; then + echo "=== ip6tables FORWARD table ===" + ip6tables -L -v -n + echo "=== ip6tables NAT table ===" + ip6tables -t nat -L -v -n + fi +} + +function clean_up() { + disable_routing + ruleNum=$(iptables -L PREROUTING -t nat --line-numbers | grep APPMESH_INGRESS | cut -d " " -f 1) + iptables -t nat -D PREROUTING "$ruleNum" + + ruleNum=$(iptables -L OUTPUT -t nat --line-numbers | grep APPMESH_EGRESS | cut -d " " -f 1) + iptables -t nat -D OUTPUT "$ruleNum" + + iptables -t nat -X APPMESH_INGRESS + iptables -t nat -X APPMESH_EGRESS + + if [ "$APPMESH_ENABLE_IPV6" == "1" ]; then + ruleNum=$(ip6tables -L PREROUTING -t nat --line-numbers | grep APPMESH_INGRESS | cut -d " " -f 1) + ip6tables -t nat -D PREROUTING "$ruleNum" + + ruleNum=$(ip6tables -L OUTPUT -t nat --line-numbers | grep APPMESH_EGRESS | cut -d " " -f 1) + ip6tables -t nat -D OUTPUT "$ruleNum" + + ip6tables -t nat -X APPMESH_INGRESS + ip6tables -t nat -X APPMESH_EGRESS + fi +} + +function print_config() { + echo "=== Input configuration ===" + env | grep APPMESH_ || true +} + +print_config + +initialize +enable_routing diff --git a/production/packaging/aws/data_server/ami/hc.bash b/production/packaging/aws/data_server/ami/hc.bash new file mode 100755 index 00000000..1af0b259 --- /dev/null +++ b/production/packaging/aws/data_server/ami/hc.bash @@ -0,0 +1,209 @@ +#!/usr/bin/env bash +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example usage: +# bash ./hc.bash -p /usr/local/google/home/akundla/bidding-auction-server/production/packaging/aws/common/ami -n health.proto -a localhost:50051 -i 10 -t 5 -h 2 -u 10 -e i-08756c3a64a78711a -g 2 -r us-west-1 -s srv-ajxdkksp7d5wpmou + +# https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-retries.html +export AWS_RETRY_MODE="standard" +export AWS_MAX_ATTEMPTS=4 + +GRPC_STATUS_SERVING="SERVING" + +# Input parameter validation values. +healthcheck_min_frequency_sec=1 +healthcheck_max_frequency_sec=90 +min_checks_threshold=2 +max_checks_threshold=10 + +while getopts "p:n:a:i:t:h:u:e:g:r:s:" flag; do + case $flag in + p) + path_to_folder_of_healthcheck_proto_file=$OPTARG + ;; + n) + healthcheck_proto_file_name=$OPTARG + ;; + a) + address_and_port_of_service=$OPTARG + ;; + i) + if [[ $OPTARG -ge $healthcheck_min_frequency_sec ]] && [[ $OPTARG -le $healthcheck_max_frequency_sec ]] + then + interval_between_hc_sec=$OPTARG + else + echo "Invalid value for flag -i (interval_between_hc_sec): value must be between ${healthcheck_min_frequency_sec} and ${healthcheck_max_frequency_sec}, inclusive." + exit 1 + fi + ;; + t) + if [[ $OPTARG -ge $healthcheck_min_frequency_sec ]] && [[ $OPTARG -le $healthcheck_max_frequency_sec ]] + then + hc_timeout_sec=$OPTARG + else + echo "Invalid value for flag -t (hc_timeout_sec): value must be between ${healthcheck_min_frequency_sec} and ${healthcheck_max_frequency_sec}, inclusive." + exit 1 + fi + ;; + h) + if [[ $OPTARG -ge $min_checks_threshold ]] && [[ $OPTARG -le $max_checks_threshold ]] + then + healthy_threshold=$OPTARG + else + echo "Invalid value for flag -h (healthy_threshold): value must be between ${min_checks_threshold} and ${max_checks_threshold}, inclusive." + exit 1 + fi + ;; + u) + if [[ $OPTARG -ge $min_checks_threshold ]] && [[ $OPTARG -le $max_checks_threshold ]] + then + unhealthy_threshold=$OPTARG + else + echo "Invalid value for flag -u (unhealthy_threshold): value must be between ${min_checks_threshold} and ${max_checks_threshold}, inclusive." + exit 1 + fi + ;; + e) + instance_id=$OPTARG + ;; + g) + startup_grace_period=$OPTARG + ;; + r) + region=$OPTARG + ;; + s) + cloud_map_service_id=$OPTARG + ;; + \?) + # Handle invalid options + echo "Invalid option detected" + exit 1 + ;; + esac +done + +input_param_names=("path_to_folder_of_healthcheck_proto_file" "healthcheck_proto_file_name" "address_and_port_of_service" "interval_between_hc_sec" "hc_timeout_sec" "healthy_threshold" "unhealthy_threshold" "instance_id" "startup_grace_period" "region" "cloud_map_service_id") + +input_param_values=("${path_to_folder_of_healthcheck_proto_file}" "${healthcheck_proto_file_name}" "${address_and_port_of_service}" "${interval_between_hc_sec}" "${hc_timeout_sec}" "${healthy_threshold}" "${unhealthy_threshold}" "${instance_id}" "${startup_grace_period}" "${region}" "${cloud_map_service_id}") + +any_params_missing=false + +for index in "${!input_param_names[@]}"; +do + if [[ -z "${input_param_values[$index]}" ]] + then + echo "${input_param_names[$index]} is missing!" + any_params_missing=true + fi +done + +if [[ ${any_params_missing} =~ "true" ]] +then + echo "All params required, exiting." + exit 1 +fi + +# Holds the last n statuses, where n = max(healthy_threshold, unhealthy_threshold) +healthcheck_status_queue=() +hc_stat_queue_max_len=$(( healthy_threshold > unhealthy_threshold ? healthy_threshold : unhealthy_threshold )) + +last_set_status_healthy=true + +echo "Custom health checking script initialized, waiting for grace period before beginning health checks." + +sleep "${startup_grace_period}" + +echo "Custom health checking script beginning healthchecks now." + +while true +do + current_hc_response=$(grpcurl --plaintext -connect-timeout="${hc_timeout_sec}" -max-time="${hc_timeout_sec}" -import-path="${path_to_folder_of_healthcheck_proto_file}" -proto="${healthcheck_proto_file_name}" "${address_and_port_of_service}" grpc.health.v1.Health/Check) + + if [[ $current_hc_response == *$GRPC_STATUS_SERVING* ]] + then + healthcheck_status_queue+=(true) + # if the server has been set to UNEAHTLHY in the cloud map, it can be set to HEALTHY again. But if it has been condemned in the ASG, it is shutting down; even if the server recovers it will still shut down. + if [[ "${last_set_status_healthy}" =~ true ]] + then + aws servicediscovery update-instance-custom-health-status --service-id "$cloud_map_service_id" --instance-id "$instance_id" --region "$region" --status HEALTHY + fi + else + healthcheck_status_queue+=(false) + echo "Server is not serving; fails custom health check running on machine!" + echo "Last status between the quotes if present: '$($current_hc_response)'" + aws servicediscovery update-instance-custom-health-status --service-id "$cloud_map_service_id" --instance-id "$instance_id" --region "$region" --status UNHEALTHY + echo "Set to UNHEALTHY in cloud map" + fi + + # Keep queue at max length and no larger. + if [[ "${#healthcheck_status_queue[@]}" -gt hc_stat_queue_max_len ]] + then + # Remove first element + healthcheck_status_queue=("${healthcheck_status_queue[@]:1}") + fi + + current_hc_stat_queue_len=${#healthcheck_status_queue[@]} + + if [[ ${healthcheck_status_queue[-1]} =~ "true" ]] && [[ current_hc_stat_queue_len -ge $healthy_threshold ]] + then + start_i="$(( current_hc_stat_queue_len - healthy_threshold ))" + end_i="$(( current_hc_stat_queue_len - 1 ))" + can_set_server_as_healthy=true + for i in $(seq $start_i $end_i) + do + if [[ ${healthcheck_status_queue[$i]} =~ "false" ]] + then + can_set_server_as_healthy=false + fi + done + if [[ ${can_set_server_as_healthy} =~ "true" ]] && [[ "${last_set_status_healthy}" =~ "false" ]] + then + aws autoscaling set-instance-health --instance-id "$instance_id" --region "$region" --health-status Healthy + last_set_status_healthy=true + echo "Just made ASG HC Call to set server to healthy" + fi + fi + + if [[ ${healthcheck_status_queue[-1]} =~ "false" ]] && [[ current_hc_stat_queue_len -ge $unhealthy_threshold ]] + then + echo "Can check if server is un-healthy." + start_i="$(( current_hc_stat_queue_len - unhealthy_threshold ))" + end_i="$(( current_hc_stat_queue_len - 1 ))" + can_set_server_as_unhealthy=true + for i in $(seq $start_i $end_i) + do + if [[ ${healthcheck_status_queue[$i]} =~ "true" ]] + then + can_set_server_as_unhealthy=false + fi + echo "healthcheck_status_queue[$i]: ${healthcheck_status_queue[$i]}" + done + echo "Can set server as unhealthy: ${can_set_server_as_unhealthy}" + if [[ ${can_set_server_as_unhealthy} =~ "true" ]] && [[ "${last_set_status_healthy}" =~ "true" ]] + then + # We're about to flag this instance to be killed by ASG, so it must be de-registered from the cloud map - and we need to de-register it NOW, before the machine is shut down. + aws servicediscovery deregister-instance --instance-id "$instance_id" --service-id "$cloud_map_service_id" --region "$region" + echo "Just made Cloud Map Call to de-register instance" + # Now set the instance to be grought down and replaced by ASG. + aws autoscaling set-instance-health --instance-id "$instance_id" --region "$region" --health-status Unhealthy + last_set_status_healthy=false + echo "Just made ASG HC Call to set server to Unhealthy" + fi + fi + + # Sleep no matter what. + sleep "${interval_between_hc_sec}s" +done diff --git a/production/packaging/aws/data_server/ami/image.pkr.hcl b/production/packaging/aws/data_server/ami/image.pkr.hcl index fe0f4b41..d22b995c 100644 --- a/production/packaging/aws/data_server/ami/image.pkr.hcl +++ b/production/packaging/aws/data_server/ami/image.pkr.hcl @@ -26,6 +26,10 @@ variable "commit_version" { type = string } +variable "build_mode" { + type = string +} + # Directory path where the built artifacts appear variable "distribution_dir" { type = string @@ -76,6 +80,7 @@ source "amazon-ebs" "dataserver" { } tags = { commit_version = var.commit_version + build_mode = var.build_mode } ssh_username = "ec2-user" } @@ -116,6 +121,18 @@ build { source = join("/", [var.distribution_dir, "otel_collector_config.yaml"]) destination = "/home/ec2-user/otel_collector_config.yaml" } + provisioner "file" { + source = join("/", [var.workspace, "production/packaging/aws/data_server/ami/envoy_networking.sh"]) + destination = "/home/ec2-user/envoy_networking.sh" + } + provisioner "file" { + source = join("/", [var.workspace, "production/packaging/aws/data_server/ami/hc.bash"]) + destination = "/home/ec2-user/hc.bash" + } + provisioner "file" { + source = join("/", [var.workspace, "components/health_check/health.proto"]) + destination = "/home/ec2-user/health.proto" + } provisioner "shell" { script = join("/", [var.workspace, "production/packaging/aws/data_server/ami/setup.sh"]) } diff --git a/production/packaging/aws/data_server/ami/setup.sh b/production/packaging/aws/data_server/ami/setup.sh index 0c2a5636..74d6d994 100644 --- a/production/packaging/aws/data_server/ami/setup.sh +++ b/production/packaging/aws/data_server/ami/setup.sh @@ -28,6 +28,12 @@ sudo cp /home/ec2-user/server_enclave_image.eif /opt/privacysandbox/server_encla OTEL_COL_CONF=/opt/aws/aws-otel-collector/etc/otel_collector_config.yaml sudo mkdir -p "$(dirname "${OTEL_COL_CONF}")" sudo cp /home/ec2-user/otel_collector_config.yaml "${OTEL_COL_CONF}" +sudo cp /home/ec2-user/envoy_networking.sh /opt/privacysandbox/envoy_networking.sh +sudo cp /home/ec2-user/hc.bash /opt/privacysandbox/hc.bash +sudo cp /home/ec2-user/health.proto /opt/privacysandbox/health.proto +sudo chmod 555 /opt/privacysandbox/envoy_networking.sh +sudo chmod 555 /opt/privacysandbox/hc.bash +sudo chmod 555 /opt/privacysandbox/health.proto # Install necessary dependencies sudo yum update -y @@ -43,3 +49,8 @@ sudo docker pull envoyproxy/envoy-distroless:v1.24.1 sudo mkdir /etc/envoy sudo chown ec2-user:ec2-user /etc/envoy + +# Install grpcurl +cd /tmp +wget -q https://github.com/fullstorydev/grpcurl/releases/download/v1.9.1/grpcurl_1.9.1_linux_amd64.rpm +sudo rpm -i grpcurl_1.9.1_linux_amd64.rpm diff --git a/production/packaging/aws/data_server/test/structure.yaml b/production/packaging/aws/data_server/test/structure.yaml index 3ac2d784..40ef88f1 100644 --- a/production/packaging/aws/data_server/test/structure.yaml +++ b/production/packaging/aws/data_server/test/structure.yaml @@ -17,21 +17,6 @@ schemaVersion: 2.0.0 fileExistenceTests: - - name: config - path: /etc/envoy/envoy.yaml - shouldExist: true - permissions: "-r--r--r--" - - - name: proto descriptor - path: /etc/envoy/config/query_api_descriptor_set.pb - shouldExist: true - permissions: "-r--r--r--" - - - name: envoy - path: /usr/local/bin/envoy - shouldExist: true - isExecutableBy: any - - name: init_server_basic path: /init_server_basic shouldExist: true diff --git a/production/packaging/build_and_test_all_in_docker b/production/packaging/build_and_test_all_in_docker index 4cca3b85..efcb1be6 100755 --- a/production/packaging/build_and_test_all_in_docker +++ b/production/packaging/build_and_test_all_in_docker @@ -154,7 +154,7 @@ bazel ${BAZEL_STARTUP_ARGS} info ${BAZEL_EXTRA_ARGS} bazel-testlogs 2>/dev/null bazel ${BAZEL_STARTUP_ARGS} build ${BAZEL_EXTRA_ARGS} //components/... //public/... //tools/... if [[ ${RUN_TESTS} -ne 0 ]]; then printf 'Tests enabled. Running tests...' - bazel ${BAZEL_STARTUP_ARGS} test ${BAZEL_EXTRA_ARGS} --build_tests_only --test_size_filters=small //... + bazel ${BAZEL_STARTUP_ARGS} test ${BAZEL_EXTRA_ARGS} --build_tests_only //... fi if [[ ${PLATFORM} == gcp ]]; then bazel ${BAZEL_STARTUP_ARGS} run ${BAZEL_EXTRA_ARGS} //production/packaging/gcp/data_server:copy_to_dist diff --git a/production/packaging/gcp/data_server/BUILD.bazel b/production/packaging/gcp/data_server/BUILD.bazel index 6e9e90ff..eb077cc8 100644 --- a/production/packaging/gcp/data_server/BUILD.bazel +++ b/production/packaging/gcp/data_server/BUILD.bazel @@ -136,7 +136,7 @@ container_image( env = { "GRPC_DNS_RESOLVER": "native", }, - labels = {"tee.launch_policy.log_redirect": "always"}, + labels = {"tee.launch_policy.log_redirect": "debugonly"}, layers = [ ":server_binary_layer", ":envoy_distroless_layer", diff --git a/production/packaging/gcp/data_server/envoy/envoy.yaml b/production/packaging/gcp/data_server/envoy/envoy.yaml index 11dba3ac..c1ea1b38 100644 --- a/production/packaging/gcp/data_server/envoy/envoy.yaml +++ b/production/packaging/gcp/data_server/envoy/envoy.yaml @@ -62,8 +62,6 @@ static_resources: services: - kv_server.v1.KeyValueService - kv_server.v2.KeyValueService - ignored_query_parameters: - - "interestGroupNames" print_options: add_whitespace: true always_print_primitive_fields: true diff --git a/production/packaging/tools/request_simulation/metrics_dashboard/cloudwatch.json b/production/packaging/tools/request_simulation/metrics_dashboard/cloudwatch.json index a10c5e5c..78563a82 100644 --- a/production/packaging/tools/request_simulation/metrics_dashboard/cloudwatch.json +++ b/production/packaging/tools/request_simulation/metrics_dashboard/cloudwatch.json @@ -7,18 +7,19 @@ "x": 9, "type": "metric", "properties": { - "view": "timeSeries", - "stacked": false, - "region": "us-east-1", "metrics": [ [ { - "expression": "SEARCH('{Request-simulation,OTelLib,event,host.arch,service.name,service.version,telemetry.sdk.language,telemetry.sdk.name,telemetry.sdk.version, testing.server} EstimatedQPS', 'Average', 60)", + "expression": "SEARCH('service.name=\"request-simulation\" MetricName=\"EstimatedQPS\"', 'Maximum', 60)", "id": "e1", - "period": 60 + "label": "$${PROP('Dim.testing.server')}", + "region": "us-east-1" } ] ], + "view": "timeSeries", + "stacked": false, + "region": "us-east-1", "title": "Estimated QPS", "period": 60, "yAxis": { @@ -26,7 +27,8 @@ "showUnits": false, "min": 0 } - } + }, + "stat": "Average" } }, { @@ -36,31 +38,22 @@ "x": 0, "type": "metric", "properties": { + "sparkline": false, "metrics": [ [ { - "expression": "SELECT SUM(EventStatus) FROM SCHEMA(\"Request-simulation\", OTelLib,event,\"host.arch\",\"service.name\",\"service.version\",status,\"telemetry.sdk.language\",\"telemetry.sdk.name\",\"telemetry.sdk.version\", \"testing.server\") WHERE event = 'ServerResponseStatus' GROUP BY status, \"testing.server\"", - "region": "us-east-1", - "period": 604800, - "stat": "Sum" + "expression": "SEARCH('service.name=\"request-simulation\" MetricName=\"ServerResponseStatus\"', 'Average', 300)", + "id": "e1", + "label": "$${PROP('Dim.testing.server')} $${PROP('Dim.status')}", + "period": 300 } ] ], "view": "singleValue", "stacked": false, "region": "us-east-1", - "title": "Server response status", - "yAxis": { - "left": { - "showUnits": false, - "min": 0 - } - }, - "period": 604800, - "setPeriodToTimeRange": true, - "sparkline": false, - "trend": false, - "stat": "Sum" + "stat": "Average", + "period": 300 } }, { @@ -73,23 +66,10 @@ "metrics": [ [ { - "expression": "SEARCH('{Request-simulation,OTelLib,event,host.arch,service.name,service.version,telemetry.sdk.language,telemetry.sdk.name,telemetry.sdk.version, testing.server} P50GrpcLatency', 'Average', 60)", + "expression": "SEARCH('service.name=\"request-simulation\" MetricName=\"P50GrpcLatency\" OR \"P90GrpcLatency\" OR \"P99GrpcLatency\"', 'Maximum', 60)", "id": "e1", - "period": 60 - } - ], - [ - { - "expression": "SEARCH('{Request-simulation,OTelLib,event,host.arch,service.name,service.version,telemetry.sdk.language,telemetry.sdk.name,telemetry.sdk.version, testing.server} P90GrpcLatency', 'Average', 60)", - "id": "e2", - "period": 60 - } - ], - [ - { - "expression": "SEARCH('{Request-simulation,OTelLib,event,host.arch,service.name,service.version,telemetry.sdk.language,telemetry.sdk.name,telemetry.sdk.version, testing.server} P99GrpcLatency', 'Average', 60)", - "id": "e3", - "period": 60 + "label": "$${PROP('Dim.testing.server')}", + "region": "us-east-1" } ] ], @@ -98,7 +78,7 @@ "region": "us-east-1", "stat": "Average", "period": 60, - "title": "Server response latency(microseconds)", + "title": "Server response latency ms", "yAxis": { "left": { "showUnits": true, @@ -120,9 +100,10 @@ "metrics": [ [ { - "expression": "SEARCH('{Request-simulation,OTelLib,event,host.arch,service.name,service.version,telemetry.sdk.language,telemetry.sdk.name,telemetry.sdk.version, testing.server} RequestsSent', 'Average', 60)", + "expression": "SEARCH('service.name=\"request-simulation\" MetricName=\"RequestsSent\"', 'Maximum', 60)", "id": "e1", - "period": 60 + "label": "$${PROP('Dim.testing.server')}", + "region": "us-east-1" } ] ], diff --git a/production/terraform/aws/environments/demo/us-east-1.tfvars.json b/production/terraform/aws/environments/demo/us-east-1.tfvars.json index aa0cb895..2f13cb21 100644 --- a/production/terraform/aws/environments/demo/us-east-1.tfvars.json +++ b/production/terraform/aws/environments/demo/us-east-1.tfvars.json @@ -5,15 +5,22 @@ "autoscaling_min_size": 4, "backup_poll_frequency_secs": 300, "certificate_arn": "cert-arn", + "consented_debug_token": "", "data_loading_blob_prefix_allowlist": ",", "data_loading_file_format": "riegeli", "data_loading_num_threads": 16, + "enable_consented_log": false, + "enable_external_traffic": true, "enclave_cpu_count": 2, "enclave_enable_debug_mode": true, "enclave_memory_mib": 3072, "environment": "demo", + "existing_vpc_environment": "", + "existing_vpc_operator": "", + "healthcheck_grace_period_sec": 60, "healthcheck_healthy_threshold": 3, "healthcheck_interval_sec": 30, + "healthcheck_timeout_sec": 5, "healthcheck_unhealthy_threshold": 3, "http_api_paths": ["/v1/*", "/v2/*", "/healthcheck"], "instance_ami_id": "ami-0000000", @@ -48,6 +55,8 @@ "telemetry_config": "mode: PROD", "udf_min_log_level": 0, "udf_num_workers": 2, + "udf_update_timeout_millis": 30000, + "use_existing_vpc": false, "use_external_metrics_collector_endpoint": false, "use_real_coordinators": false, "vpc_cidr_block": "10.0.0.0/16" diff --git a/production/terraform/aws/environments/kv_server.tf b/production/terraform/aws/environments/kv_server.tf index 2cc19747..c5089963 100644 --- a/production/terraform/aws/environments/kv_server.tf +++ b/production/terraform/aws/environments/kv_server.tf @@ -22,10 +22,14 @@ module "kv_server" { region = var.region # Variables related to network, dns and certs configuration. - vpc_cidr_block = var.vpc_cidr_block - root_domain = var.root_domain - root_domain_zone_id = var.root_domain_zone_id - certificate_arn = var.certificate_arn + vpc_cidr_block = var.vpc_cidr_block + root_domain = var.root_domain + root_domain_zone_id = var.root_domain_zone_id + certificate_arn = var.certificate_arn + use_existing_vpc = var.use_existing_vpc + existing_vpc_operator = var.existing_vpc_operator + existing_vpc_environment = var.existing_vpc_environment + enable_external_traffic = var.enable_external_traffic # Variables related to EC2 instances. instance_type = var.instance_type @@ -58,6 +62,8 @@ module "kv_server" { healthcheck_healthy_threshold = var.healthcheck_healthy_threshold healthcheck_interval_sec = var.healthcheck_interval_sec healthcheck_unhealthy_threshold = var.healthcheck_unhealthy_threshold + healthcheck_timeout_sec = var.healthcheck_timeout_sec + healthcheck_grace_period_sec = var.healthcheck_grace_period_sec # Variables related to SSH ssh_source_cidr_blocks = var.ssh_source_cidr_blocks @@ -86,9 +92,10 @@ module "kv_server" { sharding_key_regex = var.sharding_key_regex # Variables related to UDF execution. - udf_num_workers = var.udf_num_workers - udf_timeout_millis = var.udf_timeout_millis - udf_min_log_level = var.udf_min_log_level + udf_num_workers = var.udf_num_workers + udf_timeout_millis = var.udf_timeout_millis + udf_update_timeout_millis = var.udf_update_timeout_millis + udf_min_log_level = var.udf_min_log_level # Variables related to coordinators use_real_coordinators = var.use_real_coordinators @@ -103,6 +110,8 @@ module "kv_server" { # Variables related to logging logging_verbosity_level = var.logging_verbosity_level enable_otel_logger = var.enable_otel_logger + consented_debug_token = var.consented_debug_token + enable_consented_log = var.enable_consented_log } output "kv_server_url" { diff --git a/production/terraform/aws/environments/kv_server_variables.tf b/production/terraform/aws/environments/kv_server_variables.tf index 48a68fc9..f857e0ac 100644 --- a/production/terraform/aws/environments/kv_server_variables.tf +++ b/production/terraform/aws/environments/kv_server_variables.tf @@ -49,7 +49,7 @@ variable "autoscaling_min_size" { } variable "certificate_arn" { - description = "ARN for a certificate to be attached to the ALB listener." + description = "ARN for a certificate to be attached to the ALB listener. Ingored if enable_external_traffic is false." type = string } @@ -126,6 +126,16 @@ variable "healthcheck_unhealthy_threshold" { type = number } +variable "healthcheck_timeout_sec" { + description = "Amount of time to wait for a health check response in seconds." + type = number +} + +variable "healthcheck_grace_period_sec" { + description = "Amount of time to wait for service inside enclave to start up before starting health checks, in seconds." + type = number +} + variable "ssh_source_cidr_blocks" { description = "Source ips allowed to send ssh traffic to the ssh instance." type = set(string) @@ -273,6 +283,12 @@ variable "udf_timeout_millis" { type = number } +variable "udf_update_timeout_millis" { + description = "UDF update timeout in milliseconds. Default is 30000." + default = 30000 + type = number +} + variable "udf_min_log_level" { description = "Minimum log level for UDFs. Info = 0, Warn = 1, Error = 2. The UDF will only attempt to log for min_log_level and above. Default is 0(info)." default = 0 @@ -315,3 +331,33 @@ variable "public_key_endpoint" { description = "Public key endpoint. Can only be overriden in non-prod mode." type = string } + +variable "consented_debug_token" { + description = "Consented debug token to enable the otel collection of consented logs. Empty token means no-op and no logs will be collected for consented requests. The token in the request's consented debug configuration needs to match this debug token to make the server treat the request as consented." + type = string +} + +variable "enable_consented_log" { + description = "Enable the logging of consented requests. If it is set to true, the consented debug token parameter value must not be an empty string." + type = bool +} + +variable "use_existing_vpc" { + description = "Whether to use existing VPC. If true, only internal traffic via mesh will be served; variable vpc_operator and vpc_environment will be requried." + type = bool +} + +variable "existing_vpc_operator" { + description = "Operator of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "existing_vpc_environment" { + description = "Environment of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "enable_external_traffic" { + description = "Whether to serve external traffic. If disabled, only internal traffic under existing VPC will be served." + type = bool +} diff --git a/production/terraform/aws/modules/kv_server/main.tf b/production/terraform/aws/modules/kv_server/main.tf index 63043ecb..90989a36 100644 --- a/production/terraform/aws/modules/kv_server/main.tf +++ b/production/terraform/aws/modules/kv_server/main.tf @@ -19,15 +19,21 @@ locals { } module "iam_roles" { - source = "../../services/iam_roles" - environment = var.environment - service = local.service + source = "../../services/iam_roles" + environment = var.environment + service = local.service + use_existing_vpc = var.use_existing_vpc + existing_vpc_operator = var.existing_vpc_operator + existing_vpc_environment = var.existing_vpc_environment } module "iam_groups" { - source = "../../services/iam_groups" - environment = var.environment - service = local.service + source = "../../services/iam_groups" + environment = var.environment + service = local.service + use_existing_vpc = var.use_existing_vpc + existing_vpc_operator = var.existing_vpc_operator + existing_vpc_environment = var.existing_vpc_environment } module "data_storage" { @@ -51,17 +57,23 @@ module "sqs_cleanup" { } module "networking" { - source = "../../services/networking" - service = local.service - environment = var.environment - vpc_cidr_block = var.vpc_cidr_block + source = "../../services/networking" + service = local.service + environment = var.environment + vpc_cidr_block = var.vpc_cidr_block + use_existing_vpc = var.use_existing_vpc + existing_vpc_operator = var.existing_vpc_operator + existing_vpc_environment = var.existing_vpc_environment } module "security_groups" { - source = "../../services/security_groups" - environment = var.environment - service = local.service - vpc_id = module.networking.vpc_id + source = "../../services/security_groups" + environment = var.environment + service = local.service + vpc_id = module.networking.vpc_id + use_existing_vpc = var.use_existing_vpc + existing_vpc_operator = var.existing_vpc_operator + existing_vpc_environment = var.existing_vpc_environment } module "backend_services" { @@ -76,6 +88,9 @@ module "backend_services" { server_instance_role_arn = module.iam_roles.instance_role_arn ssh_instance_role_arn = module.iam_roles.ssh_instance_role_arn prometheus_service_region = var.prometheus_service_region + use_existing_vpc = var.use_existing_vpc + existing_vpc_operator = var.existing_vpc_operator + existing_vpc_environment = var.existing_vpc_environment } module "telemetry" { @@ -86,7 +101,25 @@ module "telemetry" { prometheus_service_region = var.prometheus_service_region } +module "mesh_service" { + count = var.use_existing_vpc ? 1 : 0 + source = "../../services/mesh_service" + environment = var.environment + service = local.service + root_domain = var.root_domain + service_port = 50051 + server_instance_role_name = module.iam_roles.instance_role_name + root_domain_zone_id = var.root_domain_zone_id + existing_vpc_operator = var.existing_vpc_operator + existing_vpc_environment = var.existing_vpc_environment + healthcheck_interval_sec = var.healthcheck_interval_sec + healthcheck_timeout_sec = var.healthcheck_timeout_sec + healthcheck_healthy_threshold = var.healthcheck_healthy_threshold + healthcheck_unhealthy_threshold = var.healthcheck_unhealthy_threshold +} + module "load_balancing" { + count = var.enable_external_traffic ? 1 : 0 source = "../../services/load_balancing" environment = var.environment service = local.service @@ -104,31 +137,39 @@ module "load_balancing" { } module "autoscaling" { - count = var.num_shards - source = "../../services/autoscaling" - environment = var.environment - region = var.region - service = local.service - autoscaling_subnet_ids = module.networking.private_subnet_ids - instance_ami_id = var.instance_ami_id - instance_security_group_id = module.security_groups.instance_security_group_id - instance_type = var.instance_type - target_group_arns = module.load_balancing.target_group_arns - autoscaling_desired_capacity = var.autoscaling_desired_capacity - autoscaling_max_size = var.autoscaling_max_size - autoscaling_min_size = var.autoscaling_min_size - wait_for_capacity_timeout = var.autoscaling_wait_for_capacity_timeout - instance_profile_arn = module.iam_roles.instance_profile_arn - enclave_cpu_count = var.enclave_cpu_count - enclave_memory_mib = var.enclave_memory_mib - enclave_enable_debug_mode = var.enclave_enable_debug_mode - server_port = var.server_port - launch_hook_name = module.parameter.launch_hook_parameter_value - depends_on = [module.iam_role_policies.instance_role_policy_attachment] - prometheus_service_region = var.prometheus_service_region - prometheus_workspace_id = var.prometheus_workspace_id != "" ? var.prometheus_workspace_id : module.telemetry.prometheus_workspace_id - shard_num = count.index - run_server_outside_tee = var.run_server_outside_tee + count = var.num_shards + source = "../../services/autoscaling" + environment = var.environment + region = var.region + service = local.service + autoscaling_subnet_ids = module.networking.private_subnet_ids + instance_ami_id = var.instance_ami_id + instance_security_group_id = module.security_groups.instance_security_group_id + instance_type = var.instance_type + target_group_arns = var.enable_external_traffic ? module.load_balancing[0].target_group_arns : [] + autoscaling_desired_capacity = var.autoscaling_desired_capacity + autoscaling_max_size = var.autoscaling_max_size + autoscaling_min_size = var.autoscaling_min_size + wait_for_capacity_timeout = var.autoscaling_wait_for_capacity_timeout + instance_profile_arn = module.iam_roles.instance_profile_arn + enclave_cpu_count = var.enclave_cpu_count + enclave_memory_mib = var.enclave_memory_mib + enclave_enable_debug_mode = var.enclave_enable_debug_mode + server_port = var.server_port + launch_hook_name = module.parameter.launch_hook_parameter_value + depends_on = [module.iam_role_policies.instance_role_policy_attachment] + prometheus_service_region = var.prometheus_service_region + prometheus_workspace_id = var.prometheus_workspace_id != "" ? var.prometheus_workspace_id : module.telemetry.prometheus_workspace_id + shard_num = count.index + run_server_outside_tee = var.run_server_outside_tee + cloud_map_service_id = var.use_existing_vpc ? module.mesh_service[0].cloud_map_service_id : "" + app_mesh_name = var.use_existing_vpc ? module.mesh_service[0].app_mesh_name : "" + virtual_node_name = var.use_existing_vpc ? module.mesh_service[0].virtual_node_name : "" + healthcheck_interval_sec = var.healthcheck_interval_sec + healthcheck_timeout_sec = var.healthcheck_timeout_sec + healthcheck_healthy_threshold = var.healthcheck_healthy_threshold + healthcheck_unhealthy_threshold = var.healthcheck_unhealthy_threshold + healthcheck_grace_period_sec = var.healthcheck_grace_period_sec } module "ssh" { @@ -160,6 +201,7 @@ module "parameter" { num_shards_parameter_value = var.num_shards udf_num_workers_parameter_value = var.udf_num_workers udf_timeout_millis_parameter_value = var.udf_timeout_millis + udf_update_timeout_millis_parameter_value = var.udf_update_timeout_millis udf_min_log_level_parameter_value = var.udf_min_log_level route_v1_requests_to_v2_parameter_value = var.route_v1_requests_to_v2 add_missing_keys_v1_parameter_value = var.add_missing_keys_v1 @@ -171,6 +213,8 @@ module "parameter" { primary_coordinator_region_parameter_value = var.primary_coordinator_region secondary_coordinator_region_parameter_value = var.secondary_coordinator_region public_key_endpoint_parameter_value = var.public_key_endpoint + consented_debug_token_parameter_value = var.consented_debug_token + enable_consented_log_parameter_value = var.enable_consented_log data_loading_file_format_parameter_value = var.data_loading_file_format @@ -194,6 +238,7 @@ module "security_group_rules" { vpce_security_group_id = module.security_groups.vpc_endpoint_security_group_id gateway_endpoints_prefix_list_ids = module.backend_services.gateway_endpoints_prefix_list_ids ssh_source_cidr_blocks = var.ssh_source_cidr_blocks + use_existing_vpc = var.use_existing_vpc } module "iam_role_policies" { @@ -229,8 +274,10 @@ module "iam_role_policies" { module.parameter.use_real_coordinators_parameter_arn, module.parameter.use_sharding_key_regex_parameter_arn, module.parameter.udf_timeout_millis_parameter_arn, + module.parameter.udf_update_timeout_millis_parameter_arn, module.parameter.udf_min_log_level_parameter_arn, module.parameter.enable_otel_logger_parameter_arn, + module.parameter.enable_consented_log_parameter_arn, module.parameter.data_loading_blob_prefix_allowlist_parameter_arn] coordinator_parameter_arns = ( var.use_real_coordinators ? [ @@ -253,6 +300,9 @@ module "iam_role_policies" { module.parameter.sharding_key_regex_parameter_arn ] : [] ) + consented_debug_token_arns = (var.enable_consented_log ? [ + module.parameter.consented_debug_token_parameter_arn] : [] + ) } module "iam_group_policies" { diff --git a/production/terraform/aws/modules/kv_server/outputs.tf b/production/terraform/aws/modules/kv_server/outputs.tf index 7518bba7..b6a671b4 100644 --- a/production/terraform/aws/modules/kv_server/outputs.tf +++ b/production/terraform/aws/modules/kv_server/outputs.tf @@ -15,5 +15,5 @@ */ output "kv_server_url" { - value = module.load_balancing.kv_server_url + value = "${var.enable_external_traffic ? "External url: ${module.load_balancing[0].kv_server_url}\n" : ""}${var.use_existing_vpc ? "Mesh virtual service name: ${module.mesh_service[0].virtual_service_name}" : ""}" } diff --git a/production/terraform/aws/modules/kv_server/variables.tf b/production/terraform/aws/modules/kv_server/variables.tf index 39364af9..21345044 100644 --- a/production/terraform/aws/modules/kv_server/variables.tf +++ b/production/terraform/aws/modules/kv_server/variables.tf @@ -42,7 +42,7 @@ variable "server_port" { } variable "certificate_arn" { - description = "ARN for an ACM managed certificate." + description = "ARN for an ACM managed certificate. Ingored if enable_external_traffic is false." type = string } @@ -127,6 +127,16 @@ variable "healthcheck_unhealthy_threshold" { type = number } +variable "healthcheck_timeout_sec" { + description = "Amount of time to wait for a health check response in seconds." + type = number +} + +variable "healthcheck_grace_period_sec" { + description = "Amount of time to wait for service inside enclave to start up before starting health checks, in seconds." + type = number +} + variable "ssh_source_cidr_blocks" { description = "Source ips allowed to send ssh traffic to the ssh instance." type = set(string) @@ -270,6 +280,12 @@ variable "udf_timeout_millis" { type = number } +variable "udf_update_timeout_millis" { + description = "UDF update timeout in milliseconds. Default is 30000." + default = 30000 + type = number +} + variable "udf_min_log_level" { description = "Minimum log level for UDFs. Info = 0, Warn = 1, Error = 2. The UDF will only attempt to log for min_log_level and above. Default is 0(info)." type = number @@ -310,3 +326,33 @@ variable "public_key_endpoint" { description = "Public key endpoint. Can only be overriden in non-prod mode." type = string } + +variable "consented_debug_token" { + description = "Consented debug token to enable the otel collection of consented logs. Empty token means no-op and no logs will be collected for consented requests. The token in the request's consented debug configuration needs to match this debug token to make the server treat the request as consented." + type = string +} + +variable "enable_consented_log" { + description = "Enable the logging of consented requests. If it is set to true, the consented debug token parameter value must not be an empty string." + type = bool +} + +variable "use_existing_vpc" { + description = "Whether to use existing VPC. If true, only internal traffic via mesh will be served; variable vpc_operator and vpc_environment will be requried." + type = bool +} + +variable "existing_vpc_operator" { + description = "Operator of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "existing_vpc_environment" { + description = "Environment of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "enable_external_traffic" { + description = "Whether to serve external traffic. If disabled, only internal traffic under existing VPC will be served. " + type = bool +} diff --git a/production/terraform/aws/services/app_mesh/main.tf b/production/terraform/aws/services/app_mesh/main.tf new file mode 100644 index 00000000..22d29995 --- /dev/null +++ b/production/terraform/aws/services/app_mesh/main.tf @@ -0,0 +1,24 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +data "aws_appmesh_mesh" "existing_app_mesh" { + name = "${var.existing_vpc_operator}-${var.existing_vpc_environment}-app-mesh" +} + +data "aws_service_discovery_dns_namespace" "existing_cloud_map_private_dns_namespace" { + name = "${var.existing_vpc_operator}-${var.existing_vpc_environment}-cloud-map-private-dns-namespace" + type = "DNS_PRIVATE" +} diff --git a/production/terraform/aws/services/app_mesh/output.tf b/production/terraform/aws/services/app_mesh/output.tf new file mode 100644 index 00000000..d406d162 --- /dev/null +++ b/production/terraform/aws/services/app_mesh/output.tf @@ -0,0 +1,36 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +output "app_mesh_name" { + description = "The name of the app mesh." + value = data.aws_appmesh_mesh.existing_app_mesh.name +} + +output "app_mesh_id" { + description = "The ID of the app mesh." + value = data.aws_appmesh_mesh.existing_app_mesh.id +} + +output "cloud_map_private_dns_namespace_id" { + description = "ID of the cloud map namespace" + value = data.aws_service_discovery_dns_namespace.existing_cloud_map_private_dns_namespace.id +} + +output "cloud_map_private_dns_namespace_name" { + description = "Name of the cloud map namespace" + value = data.aws_service_discovery_dns_namespace.existing_cloud_map_private_dns_namespace.name +} diff --git a/production/terraform/aws/services/app_mesh/variables.tf b/production/terraform/aws/services/app_mesh/variables.tf new file mode 100644 index 00000000..553ef1be --- /dev/null +++ b/production/terraform/aws/services/app_mesh/variables.tf @@ -0,0 +1,25 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +variable "existing_vpc_operator" { + description = "Operator of the existing VPC. Required if use_existing_vpc is true." + type = string +} + +variable "existing_vpc_environment" { + description = "Environment of the existing VPC. Required if use_existing_vpc is true." + type = string +} diff --git a/production/terraform/aws/services/autoscaling/instance_init_script.tftpl b/production/terraform/aws/services/autoscaling/instance_init_script.tftpl index 18dcbc0d..8a50b7da 100644 --- a/production/terraform/aws/services/autoscaling/instance_init_script.tftpl +++ b/production/terraform/aws/services/autoscaling/instance_init_script.tftpl @@ -34,6 +34,55 @@ sed -i -e 's/$REGION/'${prometheus_service_region}'/g' -e 's/$WORKSPACE_ID/'${pr # Start the otel collector sudo /opt/aws/aws-otel-collector/bin/aws-otel-collector-ctl -c /opt/aws/aws-otel-collector/etc/otel_collector_config.yaml -a start +if [[ -n "${app_mesh_name}" && -n "${virtual_node_name}" ]]; then +# Authenticate with the Envoy Amazon ECR repository in the Region that you want +# your Docker client to pull the image from. +aws ecr get-login-password \ + --region ${region} \ +| docker login \ + --username AWS \ + --password-stdin 840364872350.dkr.ecr.${region}.amazonaws.com + +# Start the App Mesh Envoy container. +sudo docker run --detach --env APPMESH_RESOURCE_ARN=mesh/${app_mesh_name}/virtualNode/${virtual_node_name} \ +-v /tmp:/tmp \ +-u 1337 --network host public.ecr.aws/appmesh/aws-appmesh-envoy:v1.29.4.0-prod +fi + +if [[ -n "${cloud_map_service_id}" && -n "${region}" ]]; then +# Grab the metadata needed for registering instance. +TOKEN=`curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 21600"` \ +&& IP_ADDRESS=`curl -s -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169.254/latest/meta-data/local-ipv4` + +TOKEN=`curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 21600"` \ +&& INSTANCE_ID=`curl -s -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169.254/latest/meta-data/instance-id` + +# Actually register the present EC2 with the cloud map. +register_instance_out=$(aws servicediscovery register-instance \ + --service-id ${cloud_map_service_id} \ + --instance-id $INSTANCE_ID \ + --attributes "AWS_INSTANCE_IPV4="$IP_ADDRESS \ + --region ${region} 2>&1) +while [[ "$?" -gt 0 ]] && [[ "$register_instance_out" =~ "not authorized to perform" ]]; do + echo "Registering service instance failed ... This can be transient and thus trying again in 2 seconds" + echo "Observed failure: $register_instance_out" + sleep 2 + register_instance_out=$(aws servicediscovery register-instance \ + --service-id ${cloud_map_service_id} \ + --instance-id $INSTANCE_ID \ + --attributes "AWS_INSTANCE_IPV4="$IP_ADDRESS \ + --region ${region} 2>&1) +done +fi + +if [[ -n "${app_mesh_name}" && -n "${virtual_node_name}" && -n "${cloud_map_service_id}" && -n "${region}" ]]; then +echo "Will wait for service mesh envoy proxy to come up" +while [ "$(curl localhost:9901/ready)" != "LIVE" ] ; do + echo "Service mesh envoy proxy is not ready.. will check again in 1 second" + sleep 1 +done +fi + if [ ${run_server_outside_tee} = false ]; then # Make sure nitro enclave allocator service is stopped @@ -61,11 +110,31 @@ then --cpu-count ${enclave_cpu_count} \ --memory ${enclave_memory_mib} \ --eif-path /opt/privacysandbox/server_enclave_image.eif \ - --enclave-cid 16 ${enclave_enable_debug_mode} + --enclave-cid 16 ${enclave_enable_debug_mode} & else # Load the docker image docker load -i /home/ec2-user/server_docker_image.tar # Run the docker image docker run --detach --rm --network host --security-opt=seccomp=unconfined \ - --entrypoint=/init_server_basic bazel/production/packaging/aws/data_server:server_docker_image + --entrypoint=/init_server_basic bazel/production/packaging/aws/data_server:server_docker_image & +fi + +SECONDS_TRIED=0 + +echo "Will wait for ${healthcheck_grace_period_sec} seconds for the service to come up" +while ! grpcurl --plaintext localhost:50051 list; do + echo "Service/Vsock proxy is not reachable.. will retry in 1 second" + ((SECONDS_TRIED++)) + if (( SECONDS_TRIED > ${healthcheck_grace_period_sec} )) + then + echo "Timing out: tried for ${healthcheck_grace_period_sec} seconds and the service and its vsock proxy are still not reachable." + break + fi + sleep 1 +done + +if [[ -n "${app_mesh_name}" && -n "${virtual_node_name}" && -n "${cloud_map_service_id}" && -n "${region}" ]]; then + bash /opt/privacysandbox/hc.bash -p /opt/privacysandbox -n health.proto -a localhost:50051 -i ${healthcheck_interval_sec} -t ${healthcheck_timeout_sec} -h ${healthcheck_healthy_threshold} -u ${healthcheck_unhealthy_threshold} -e $INSTANCE_ID -g 0 -r ${region} -s ${cloud_map_service_id} & + echo "Setting up iptables to route traffic via service mesh / envoy" + sudo bash -x /opt/privacysandbox/envoy_networking.sh fi diff --git a/production/terraform/aws/services/autoscaling/main.tf b/production/terraform/aws/services/autoscaling/main.tf index f6aa04f6..4965059b 100644 --- a/production/terraform/aws/services/autoscaling/main.tf +++ b/production/terraform/aws/services/autoscaling/main.tf @@ -36,14 +36,23 @@ resource "aws_launch_template" "instance_launch_template" { user_data = base64encode(templatefile( "${path.module}/instance_init_script.tftpl", { - enclave_memory_mib = var.enclave_memory_mib, - enclave_cpu_count = var.enclave_cpu_count, - enclave_enable_debug_mode = "${var.enclave_enable_debug_mode ? "--debug-mode" : " "}" - server_port = var.server_port, - region = var.region, - prometheus_service_region = var.prometheus_service_region - prometheus_workspace_id = var.prometheus_workspace_id - run_server_outside_tee = var.run_server_outside_tee + enclave_memory_mib = var.enclave_memory_mib, + enclave_cpu_count = var.enclave_cpu_count, + enclave_enable_debug_mode = "${var.enclave_enable_debug_mode ? "--debug-mode" : " "}" + server_port = var.server_port, + region = var.region, + prometheus_service_region = var.prometheus_service_region + prometheus_workspace_id = var.prometheus_workspace_id + run_server_outside_tee = var.run_server_outside_tee + cloud_map_service_id = var.cloud_map_service_id + app_mesh_name = var.app_mesh_name + virtual_node_name = var.virtual_node_name + healthcheck_interval_sec = var.healthcheck_interval_sec + healthcheck_timeout_sec = var.healthcheck_timeout_sec + healthcheck_healthy_threshold = var.healthcheck_healthy_threshold + healthcheck_unhealthy_threshold = var.healthcheck_unhealthy_threshold + healthcheck_grace_period_sec = var.healthcheck_grace_period_sec + })) # Enforce IMDSv2. diff --git a/production/terraform/aws/services/autoscaling/variables.tf b/production/terraform/aws/services/autoscaling/variables.tf index 428a8f87..d5cf5b49 100644 --- a/production/terraform/aws/services/autoscaling/variables.tf +++ b/production/terraform/aws/services/autoscaling/variables.tf @@ -125,3 +125,43 @@ locals { validate_prometheus_workspace_id = ( var.prometheus_service_region == var.region || var.prometheus_workspace_id != null) ? true : tobool("If Prometheus service runs in a different region, please create the workspace first and specify the workspace id in var file.") } + +variable "cloud_map_service_id" { + description = "The ID of the service discovery service" + type = string +} + +variable "app_mesh_name" { + description = "Name of the AWS App Mesh in which this service will communicate." + type = string +} + +variable "virtual_node_name" { + description = "Name of the App Mesh Virtual Node of which instance in this ASG will be a part." + type = string +} + +variable "healthcheck_interval_sec" { + description = "Amount of time between health check intervals in seconds." + type = number +} + +variable "healthcheck_healthy_threshold" { + description = "Consecutive health check successes required to be considered healthy." + type = number +} + +variable "healthcheck_unhealthy_threshold" { + description = "Consecutive health check failures required to be considered unhealthy." + type = number +} + +variable "healthcheck_timeout_sec" { + description = "Amount of time to wait for a health check response in seconds." + type = number +} + +variable "healthcheck_grace_period_sec" { + description = "Amount of time to wait for service inside enclave to start up before starting health checks, in seconds." + type = number +} diff --git a/production/terraform/aws/services/backend_services/main.tf b/production/terraform/aws/services/backend_services/main.tf index e283dcb2..a95aa4ea 100644 --- a/production/terraform/aws/services/backend_services/main.tf +++ b/production/terraform/aws/services/backend_services/main.tf @@ -14,6 +14,20 @@ * limitations under the License. */ +# Existing aws_vpc_endpoint from existing VPC +data "aws_vpc_endpoint" "existing_vpc_gateway_endpoint" { + count = var.use_existing_vpc ? 1 : 0 + service_name = "com.amazonaws.${var.region}.s3" + filter { + name = "tag:operator" + values = [var.existing_vpc_operator] + } + filter { + name = "tag:environment" + values = [var.existing_vpc_environment] + } +} + # Restrict VPC endpoint access only to instances in this environment and service. data "aws_iam_policy_document" "vpce_policy_doc" { statement { @@ -35,31 +49,31 @@ data "aws_iam_policy_document" "vpce_policy_doc" { # Create gateway endpoints for accessing AWS services. resource "aws_vpc_endpoint" "vpc_gateway_endpoint" { - for_each = toset([ - "s3" - ]) - service_name = "com.amazonaws.${var.region}.${each.key}" + count = var.use_existing_vpc ? 0 : 1 + service_name = "com.amazonaws.${var.region}.s3" vpc_id = var.vpc_id vpc_endpoint_type = "Gateway" route_table_ids = var.vpc_endpoint_route_table_ids policy = data.aws_iam_policy_document.vpce_policy_doc.json tags = { - Name = "${var.service}-${var.environment}-${each.key}-endpoint" + Name = "${var.service}-${var.environment}-s3-endpoint" } } # Create interface endpoints for accessing AWS services. resource "aws_vpc_endpoint" "vpc_interface_endpoint" { for_each = toset(concat([ - "ec2", - "ssm", "sns", "sqs", - "autoscaling", - "xray", "logs", - ], var.prometheus_service_region == var.region ? ["aps-workspaces"] : [])) + ], var.prometheus_service_region == var.region ? ["aps-workspaces"] : [] + , var.use_existing_vpc ? [] : [ + "ec2", + "ssm", + "autoscaling", + "xray", + ])) service_name = "com.amazonaws.${var.region}.${each.key}" vpc_id = var.vpc_id vpc_endpoint_type = "Interface" diff --git a/production/terraform/aws/services/backend_services/outputs.tf b/production/terraform/aws/services/backend_services/outputs.tf index 1ed7bc8a..971f32dd 100644 --- a/production/terraform/aws/services/backend_services/outputs.tf +++ b/production/terraform/aws/services/backend_services/outputs.tf @@ -15,5 +15,5 @@ */ output "gateway_endpoints_prefix_list_ids" { - value = [for ep in aws_vpc_endpoint.vpc_gateway_endpoint : ep.prefix_list_id] + value = var.use_existing_vpc ? [data.aws_vpc_endpoint.existing_vpc_gateway_endpoint[0].prefix_list_id] : [aws_vpc_endpoint.vpc_gateway_endpoint[0].prefix_list_id] } diff --git a/production/terraform/aws/services/backend_services/variables.tf b/production/terraform/aws/services/backend_services/variables.tf index 98f099e8..c167af5c 100644 --- a/production/terraform/aws/services/backend_services/variables.tf +++ b/production/terraform/aws/services/backend_services/variables.tf @@ -63,3 +63,18 @@ variable "prometheus_service_region" { description = "Region where prometheus service runs that other services deployed by this file should interact with." type = string } + +variable "use_existing_vpc" { + description = "Whether to use existing VPC. If true, only internal traffic via mesh will be served; variable vpc_operator and vpc_environment will be requried." + type = bool +} + +variable "existing_vpc_operator" { + description = "Operator of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "existing_vpc_environment" { + description = "Environment of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} diff --git a/production/terraform/aws/services/dashboard/main.tf b/production/terraform/aws/services/dashboard/main.tf index 76dcf871..994b2c3b 100644 --- a/production/terraform/aws/services/dashboard/main.tf +++ b/production/terraform/aws/services/dashboard/main.tf @@ -30,7 +30,9 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"request.count\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"request.count\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Sum', 60))", "id": "e1", "visible": false, "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ], + [ { "expression": "e1 / 60"} ] + ], "region": "${var.region}", "view": "timeSeries", @@ -41,7 +43,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "showUnits": false } }, - "title": "request.count [MEAN]" + "title": "request.count per second [MEAN]" } }, { @@ -52,7 +54,8 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"SecureLookupRequestCount\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"SecureLookupRequestCount\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Sum', 60))", "id": "e1", "visible": false, "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ], + [ { "expression": "e1 / 60"} ] ], "region": "${var.region}", "view": "timeSeries", @@ -63,7 +66,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "showUnits": false } }, - "title": "Secure lookup request count [MEAN]" + "title": "Secure lookup request count per second [MEAN]" } }, { @@ -74,7 +77,8 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"request.failed_count_by_status\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.error_code')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"request.failed_count_by_status\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Sum', 60))", "id": "e1", "visible": false, "label": "$${PROP('Dim.Noise')} $${PROP('Dim.error_code')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ], + [ { "expression": "e1 / 60"} ] ], "region": "${var.region}", "view": "timeSeries", @@ -85,7 +89,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "showUnits": false } }, - "title": "request.failed_count_by_status [MEAN]" + "title": "request.failed_count_by_status per second [MEAN]" } }, { @@ -96,7 +100,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"request.duration_ms\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"request.duration_ms\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ] ], "region": "${var.region}", "view": "timeSeries", @@ -118,7 +122,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"request.size_bytes\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"request.size_bytes\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ] ], "region": "${var.region}", "view": "timeSeries", @@ -140,7 +144,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"response.size_bytes\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"response.size_bytes\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ] ], "region": "${var.region}", "view": "timeSeries", @@ -162,7 +166,8 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"KVUdfRequestError\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.error_code')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"KVUdfRequestError\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Sum', 60))", "id": "e1", "visible": false, "label": "$${PROP('Dim.Noise')} $${PROP('Dim.error_code')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ], + [ { "expression": "e1 / 60"} ] ], "region": "${var.region}", "view": "timeSeries", @@ -173,7 +178,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "showUnits": false } }, - "title": "Request Errors [MEAN]" + "title": "Request Errors Per Second [MEAN]" } }, { @@ -184,7 +189,8 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"InternalLookupRequestError\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.error_code')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"InternalLookupRequestError\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Sum', 60))", "id": "e1", "visible": false, "label": "$${PROP('Dim.Noise')} $${PROP('Dim.error_code')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ], + [ { "expression": "e1 / 60"} ] ], "region": "${var.region}", "view": "timeSeries", @@ -195,7 +201,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "showUnits": false } }, - "title": "Internal Request Errors [MEAN]" + "title": "Internal Request Errors Per Second [MEAN]" } }, { @@ -228,7 +234,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=(\"ShardedLookupGetKeyValuesLatencyInMicros\" OR \"ShardedLookupGetKeyValueSetLatencyInMicros\" OR \"ShardedLookupRunQueryLatencyInMicros\") Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('MetricName')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=(\"ShardedLookupGetKeyValuesLatencyInMicros\" OR \"ShardedLookupGetKeyValueSetLatencyInMicros\" OR \"ShardedLookupRunQueryLatencyInMicros\") Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('MetricName')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ] ], "region": "${var.region}", "view": "timeSeries", @@ -250,7 +256,8 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"ShardedLookupKeyCountByShard\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.key_shard_num')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"ShardedLookupKeyCountByShard\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Sum', 60))", "id": "e1", "visible": false, "label": "$${PROP('Dim.Noise')} $${PROP('Dim.key_shard_num')}" } ], + [ { "expression": "e1 / 60"} ] ], "region": "${var.region}", "view": "timeSeries", @@ -261,7 +268,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "showUnits": false } }, - "title": "Sharded Lookup Key Count By Shard [MEAN]" + "title": "Sharded Lookup Key Count By Shard Per Second [MEAN]" } }, { @@ -272,7 +279,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=(\"InternalGetKeyValuesLatencyInMicros\" OR \"InternalGetKeyValueSetLatencyInMicros\" OR \"InternalRunQueryLatencyInMicros\") Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('MetricName')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=(\"InternalGetKeyValuesLatencyInMicros\" OR \"InternalGetKeyValueSetLatencyInMicros\" OR \"InternalRunQueryLatencyInMicros\") Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('MetricName')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ] ], "region": "${var.region}", "view": "timeSeries", @@ -294,7 +301,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=(\"GetValuePairsLatencyInMicros\" OR \"GetKeyValueSetLatencyInMicros\") Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('MetricName')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=(\"GetValuePairsLatencyInMicros\" OR \"GetKeyValueSetLatencyInMicros\") Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('MetricName')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ] ], "region": "${var.region}", "view": "timeSeries", @@ -316,7 +323,8 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "type": "metric", "properties": { "metrics": [ - [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"CacheAccessEventCount\" Noise=(\"Raw\" OR \"Noised\")', 'Average', 60))", "id": "e1", "label": "$${PROP('Dim.Noise')} $${PROP('Dim.cache_access')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')}" } ] + [ { "expression": "REMOVE_EMPTY(SEARCH('service.name=\"kv-server\" deployment.environment=${var.environment} MetricName=\"CacheAccessEventCount\" Noise=(\"Raw\" OR \"Noised\") generation_id=(\"consented\" OR \"not_consented\")', 'Sum', 60))", "id": "e1", "visible": false, "label": "$${PROP('Dim.Noise')} $${PROP('Dim.cache_access')} $${PROP('Dim.service.instance.id')} $${PROP('Dim.shard_number')} $${PROP('Dim.generation_id')}" } ], + [ { "expression": "e1 / 60"} ] ], "region": "${var.region}", "view": "timeSeries", @@ -327,7 +335,7 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { "showUnits": false } }, - "title": "Cache Access Event Count [MEAN]" + "title": "Cache Access Event Count Per Second [MEAN]" } }, { @@ -527,6 +535,29 @@ resource "aws_cloudwatch_dashboard" "environment_dashboard" { }, "title": "system.cpu.total_cores [MEAN]" } + }, + { + "height": 10, + "width": 12, + "y": 110, + "x": 12, + "type": "metric", + "properties": { + "metrics": [ + [ "AWS/Lambda", "Errors", "FunctionName", "kv-server-${var.environment}-sqs-cleanup", { "id": "errors", "stat": "Sum", "color": "#d13212" } ], + [ ".", "Invocations", ".", ".", { "id": "invocations", "stat": "Sum", "visible": false } ], + [ { "expression": "100 - 100 * errors / MAX([errors, invocations])", "label": "Success rate (%)", "id": "availability", "yAxis": "right" } ] + ], + "view": "timeSeries", + "stacked": false, + "region": "${var.region}", + "yAxis": { + "right": { + "max": 100 + } + }, + "title": "Sqs cleanup job error count and success rate (%)" + } } ] } diff --git a/production/terraform/aws/services/iam_groups/main.tf b/production/terraform/aws/services/iam_groups/main.tf index 8b3d78aa..a0845d0d 100644 --- a/production/terraform/aws/services/iam_groups/main.tf +++ b/production/terraform/aws/services/iam_groups/main.tf @@ -14,6 +14,12 @@ * limitations under the License. */ +data "aws_iam_group" "existing_ssh_users_group" { + count = var.use_existing_vpc ? 1 : 0 + group_name = format("%s-%s-ssh-users", var.existing_vpc_operator, var.existing_vpc_environment) +} + resource "aws_iam_group" "ssh_users_group" { - name = format("%s-%s-ssh-users", var.service, var.environment) + count = var.use_existing_vpc ? 0 : 1 + name = format("%s-%s-ssh-users", var.service, var.environment) } diff --git a/production/terraform/aws/services/iam_groups/outputs.tf b/production/terraform/aws/services/iam_groups/outputs.tf index 9eef87fe..3eca22b9 100644 --- a/production/terraform/aws/services/iam_groups/outputs.tf +++ b/production/terraform/aws/services/iam_groups/outputs.tf @@ -15,5 +15,5 @@ */ output "ssh_users_group_name" { - value = aws_iam_group.ssh_users_group.name + value = var.use_existing_vpc ? data.aws_iam_group.existing_ssh_users_group[0].group_name : aws_iam_group.ssh_users_group[0].name } diff --git a/production/terraform/aws/services/iam_groups/variables.tf b/production/terraform/aws/services/iam_groups/variables.tf index 03b995fc..9b44fad2 100644 --- a/production/terraform/aws/services/iam_groups/variables.tf +++ b/production/terraform/aws/services/iam_groups/variables.tf @@ -23,3 +23,18 @@ variable "environment" { description = "Assigned environment name to group related resources." type = string } + +variable "use_existing_vpc" { + description = "Whether to use existing VPC. If true, only internal traffic via mesh will be served; variable vpc_operator and vpc_environment will be requried." + type = bool +} + +variable "existing_vpc_operator" { + description = "Operator of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "existing_vpc_environment" { + description = "Environment of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} diff --git a/production/terraform/aws/services/iam_role_policies/main.tf b/production/terraform/aws/services/iam_role_policies/main.tf index 3d1d95e1..ce1bc9be 100644 --- a/production/terraform/aws/services/iam_role_policies/main.tf +++ b/production/terraform/aws/services/iam_role_policies/main.tf @@ -61,7 +61,7 @@ data "aws_iam_policy_document" "instance_policy_doc" { sid = "AllowInstancesToReadParameters" actions = ["ssm:GetParameter"] effect = "Allow" - resources = setunion(var.server_parameter_arns, var.coordinator_parameter_arns, var.metrics_collector_endpoint_arns, var.sharding_key_regex_arns) + resources = setunion(var.server_parameter_arns, var.coordinator_parameter_arns, var.metrics_collector_endpoint_arns, var.sharding_key_regex_arns, var.consented_debug_token_arns) } statement { sid = "AllowInstancesToAssumeRole" @@ -126,6 +126,16 @@ data "aws_iam_policy_document" "instance_policy_doc" { ] resources = ["*"] } + statement { + sid = "AllowInstancesToSetInstanceHealthForASGandCloudMap" + actions = [ + "autoscaling:SetInstanceHealth", + "servicediscovery:UpdateInstanceCustomHealthStatus", + "servicediscovery:DeregisterInstance", + ] + effect = "Allow" + resources = ["*"] + } } resource "aws_iam_policy" "instance_policy" { diff --git a/production/terraform/aws/services/iam_role_policies/variables.tf b/production/terraform/aws/services/iam_role_policies/variables.tf index 0b190ad3..191701b3 100644 --- a/production/terraform/aws/services/iam_role_policies/variables.tf +++ b/production/terraform/aws/services/iam_role_policies/variables.tf @@ -59,6 +59,12 @@ variable "sharding_key_regex_arns" { type = set(string) } +variable "consented_debug_token_arns" { + description = "A set of arns for consented debug token" + type = set(string) +} + + variable "sns_data_updates_topic_arn" { description = "ARN for the sns topic that receives s3 delta file updates." type = string diff --git a/production/terraform/aws/services/iam_roles/main.tf b/production/terraform/aws/services/iam_roles/main.tf index 293f4fb6..fea4390f 100644 --- a/production/terraform/aws/services/iam_roles/main.tf +++ b/production/terraform/aws/services/iam_roles/main.tf @@ -14,10 +14,36 @@ * limitations under the License. */ -#################################################### -# Create EC2 instance profile. -#################################################### + +################################################################################ +# If use_existing_vpc is true, we need to use existing iam roles. +################################################################################ + +data "aws_iam_role" "existing_instance_role" { + count = var.use_existing_vpc ? 1 : 0 + name = format("%s-%s-InstanceRole", var.existing_vpc_operator, var.existing_vpc_environment) +} + +data "aws_iam_instance_profile" "existing_instance_profile" { + count = var.use_existing_vpc ? 1 : 0 + name = format("%s-%s-InstanceProfile", var.existing_vpc_operator, var.existing_vpc_environment) +} + +data "aws_iam_role" "existing_ssh_instance_role" { + count = var.use_existing_vpc ? 1 : 0 + name = format("%s-%s-sshInstanceRole", var.existing_vpc_operator, var.existing_vpc_environment) +} + +data "aws_iam_instance_profile" "existing_ssh_instance_profile" { + count = var.use_existing_vpc ? 1 : 0 + name = format("%s-%s-sshInstanceProfile", var.existing_vpc_operator, var.existing_vpc_environment) +} + +################################################################################ +# If use_existing_vpc is false, create EC2 instance profile. +################################################################################ data "aws_iam_policy_document" "ec2_assume_role_policy" { + count = var.use_existing_vpc ? 0 : 1 statement { actions = [ "sts:AssumeRole" @@ -32,8 +58,9 @@ data "aws_iam_policy_document" "ec2_assume_role_policy" { } resource "aws_iam_role" "instance_role" { + count = var.use_existing_vpc ? 0 : 1 name = format("%s-%s-InstanceRole", var.service, var.environment) - assume_role_policy = data.aws_iam_policy_document.ec2_assume_role_policy.json + assume_role_policy = data.aws_iam_policy_document.ec2_assume_role_policy[0].json tags = { Name = format("%s-%s-InstanceRole", var.service, var.environment) @@ -41,20 +68,22 @@ resource "aws_iam_role" "instance_role" { } resource "aws_iam_instance_profile" "instance_profile" { - name = format("%s-%s-InstanceProfile", var.service, var.environment) - role = aws_iam_role.instance_role.name + count = var.use_existing_vpc ? 0 : 1 + name = format("%s-%s-InstanceProfile", var.service, var.environment) + role = aws_iam_role.instance_role[0].name tags = { Name = format("%s-%s-InstanceProfile", var.service, var.environment) } } -#################################################### -# Create SSH role for using EC2 instance connect. -#################################################### +################################################################################ +# If use_existing_vpc is false, create SSH role for using EC2 instance connect. +################################################################################ resource "aws_iam_role" "ssh_instance_role" { + count = var.use_existing_vpc ? 0 : 1 name = format("%s-%s-sshInstanceRole", var.service, var.environment) - assume_role_policy = data.aws_iam_policy_document.ec2_assume_role_policy.json + assume_role_policy = data.aws_iam_policy_document.ec2_assume_role_policy[0].json tags = { Name = format("%s-%s-sshInstanceRole", var.service, var.environment) @@ -62,8 +91,9 @@ resource "aws_iam_role" "ssh_instance_role" { } resource "aws_iam_instance_profile" "ssh_instance_profile" { - name = format("%s-%s-sshInstanceProfile", var.service, var.environment) - role = aws_iam_role.ssh_instance_role.name + count = var.use_existing_vpc ? 0 : 1 + name = format("%s-%s-sshInstanceProfile", var.service, var.environment) + role = aws_iam_role.ssh_instance_role[0].name tags = { Name = format("%s-%s-sshInstanceProfile", var.service, var.environment) diff --git a/production/terraform/aws/services/iam_roles/outputs.tf b/production/terraform/aws/services/iam_roles/outputs.tf index 32dcff16..b2974d4d 100644 --- a/production/terraform/aws/services/iam_roles/outputs.tf +++ b/production/terraform/aws/services/iam_roles/outputs.tf @@ -15,15 +15,15 @@ */ output "instance_profile_arn" { - value = aws_iam_instance_profile.instance_profile.arn + value = var.use_existing_vpc ? data.aws_iam_instance_profile.existing_instance_profile[0].arn : aws_iam_instance_profile.instance_profile[0].arn } output "instance_role_name" { - value = aws_iam_role.instance_role.name + value = var.use_existing_vpc ? data.aws_iam_role.existing_instance_role[0].name : aws_iam_role.instance_role[0].name } output "instance_role_arn" { - value = aws_iam_role.instance_role.arn + value = var.use_existing_vpc ? data.aws_iam_role.existing_instance_role[0].arn : aws_iam_role.instance_role[0].arn } output "lambda_role_arn" { @@ -35,13 +35,13 @@ output "lambda_role_name" { } output "ssh_instance_role_arn" { - value = aws_iam_role.ssh_instance_role.arn + value = var.use_existing_vpc ? data.aws_iam_role.existing_ssh_instance_role[0].arn : aws_iam_role.ssh_instance_role[0].arn } output "ssh_instance_role_name" { - value = aws_iam_role.ssh_instance_role.name + value = var.use_existing_vpc ? data.aws_iam_role.existing_ssh_instance_role[0].name : aws_iam_role.ssh_instance_role[0].name } output "ssh_instance_profile_name" { - value = aws_iam_instance_profile.ssh_instance_profile.name + value = var.use_existing_vpc ? data.aws_iam_instance_profile.existing_ssh_instance_profile[0].name : aws_iam_instance_profile.ssh_instance_profile[0].name } diff --git a/production/terraform/aws/services/iam_roles/variables.tf b/production/terraform/aws/services/iam_roles/variables.tf index 03b995fc..9b44fad2 100644 --- a/production/terraform/aws/services/iam_roles/variables.tf +++ b/production/terraform/aws/services/iam_roles/variables.tf @@ -23,3 +23,18 @@ variable "environment" { description = "Assigned environment name to group related resources." type = string } + +variable "use_existing_vpc" { + description = "Whether to use existing VPC. If true, only internal traffic via mesh will be served; variable vpc_operator and vpc_environment will be requried." + type = bool +} + +variable "existing_vpc_operator" { + description = "Operator of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "existing_vpc_environment" { + description = "Environment of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} diff --git a/production/terraform/aws/services/load_balancing/variables.tf b/production/terraform/aws/services/load_balancing/variables.tf index 51ef211e..8c73006d 100644 --- a/production/terraform/aws/services/load_balancing/variables.tf +++ b/production/terraform/aws/services/load_balancing/variables.tf @@ -35,7 +35,7 @@ variable "elb_subnet_ids" { } variable "certificate_arn" { - description = "ARN for a certificate to be attached to the NLB listener." + description = "ARN for a certificate to be attached to the NLB listener. Ingored if enable_external_traffic is false." type = string } diff --git a/production/terraform/aws/services/mesh_service/main.tf b/production/terraform/aws/services/mesh_service/main.tf new file mode 100644 index 00000000..890718cc --- /dev/null +++ b/production/terraform/aws/services/mesh_service/main.tf @@ -0,0 +1,117 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +data "aws_appmesh_mesh" "existing_app_mesh" { + name = "${var.existing_vpc_operator}-${var.existing_vpc_environment}-app-mesh" +} + +data "aws_service_discovery_dns_namespace" "existing_cloud_map_private_dns_namespace" { + name = "${var.existing_vpc_operator}-${var.existing_vpc_environment}-cloud-map-private-dns-namespace" + type = "DNS_PRIVATE" +} + + +resource "aws_service_discovery_service" "cloud_map_service" { + name = "${var.service}-${var.environment}-cloud-map-service.${var.root_domain}" + + dns_config { + namespace_id = data.aws_service_discovery_dns_namespace.existing_cloud_map_private_dns_namespace.id + + dns_records { + ttl = 10 + type = "A" + } + } + health_check_custom_config { + failure_threshold = 1 + } +} + +resource "aws_appmesh_virtual_node" "appmesh_virtual_node" { + name = "${var.service}-${var.environment}-appmesh-virtual-node" + mesh_name = data.aws_appmesh_mesh.existing_app_mesh.id + spec { + listener { + port_mapping { + port = var.service_port + protocol = "grpc" + } + + health_check { + protocol = "grpc" + healthy_threshold = var.healthcheck_healthy_threshold + unhealthy_threshold = var.healthcheck_unhealthy_threshold + timeout_millis = var.healthcheck_timeout_sec * 1000 + interval_millis = var.healthcheck_interval_sec * 1000 + } + } + + service_discovery { + aws_cloud_map { + service_name = aws_service_discovery_service.cloud_map_service.name + namespace_name = data.aws_service_discovery_dns_namespace.existing_cloud_map_private_dns_namespace.name + } + } + } +} + +resource "aws_appmesh_virtual_service" "appmesh_virtual_service" { + name = "${var.service}-${var.environment}-appmesh-virtual-service.${var.root_domain}" + mesh_name = data.aws_appmesh_mesh.existing_app_mesh.name + spec { + provider { + virtual_node { + virtual_node_name = aws_appmesh_virtual_node.appmesh_virtual_node.name + } + } + } +} + +resource "aws_route53_record" "mesh_node_record" { + name = aws_appmesh_virtual_service.appmesh_virtual_service.name + type = "A" + // In seconds + ttl = 300 + zone_id = var.root_domain_zone_id + // Any non-loopback IP will do, this record just needs to exist, not go anywhere (should be overrifed by appmesh). + records = ["10.10.10.10"] +} + +data "aws_iam_policy_document" "virtual_node_policy_document" { + statement { + actions = [ + "appmesh:StreamAggregatedResources" + ] + resources = [ + aws_appmesh_virtual_node.appmesh_virtual_node.arn + ] + } +} + +resource "aws_iam_policy" "app_mesh_node_policy" { + name = format("%s-%s-virtualNodePolicy", var.service, var.environment) + policy = data.aws_iam_policy_document.virtual_node_policy_document.json +} + +resource "aws_iam_role_policy_attachment" "app_mesh_node_policy_to_ec2_attachment" { + role = var.server_instance_role_name + policy_arn = aws_iam_policy.app_mesh_node_policy.arn +} + +resource "aws_iam_role_policy_attachment" "amazon_ec2_container_registry_read_only_to_ec2_attachment" { + role = var.server_instance_role_name + policy_arn = "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryReadOnly" +} diff --git a/production/terraform/aws/services/mesh_service/output.tf b/production/terraform/aws/services/mesh_service/output.tf new file mode 100644 index 00000000..443e1f54 --- /dev/null +++ b/production/terraform/aws/services/mesh_service/output.tf @@ -0,0 +1,41 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +output "app_mesh_name" { + description = "The name of the app mesh." + value = data.aws_appmesh_mesh.existing_app_mesh.name +} + +output "cloud_map_service_id" { + description = "The ID of the service discovery service" + value = aws_service_discovery_service.cloud_map_service.id +} + +output "virtual_node_name" { + description = "The name of the virtual node." + value = aws_appmesh_virtual_node.appmesh_virtual_node.name +} + +output "virtual_service_name" { + description = "The name of the virtual service." + value = aws_appmesh_virtual_service.appmesh_virtual_service.name +} + +output "cloud_map_service_name" { + description = "The name of the service discovery service" + value = aws_service_discovery_service.cloud_map_service.name +} diff --git a/production/terraform/aws/services/mesh_service/variables.tf b/production/terraform/aws/services/mesh_service/variables.tf new file mode 100644 index 00000000..84845597 --- /dev/null +++ b/production/terraform/aws/services/mesh_service/variables.tf @@ -0,0 +1,75 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +variable "environment" { + description = "Assigned environment name to group related resources." + type = string +} + +variable "service" { + description = "One of: bidding, auction, bfe, sfe" + type = string +} + +variable "existing_vpc_operator" { + description = "Operator of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "existing_vpc_environment" { + description = "Environment of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "root_domain" { + description = "Root domain for APIs." + type = string +} + +variable "service_port" { + description = "Port on which this service recieves outbound traffic" + type = number +} + +variable "root_domain_zone_id" { + description = "Zone id for the root domain." + type = string +} + +variable "server_instance_role_name" { + description = "Role for server EC2 instance profile." + type = string +} + +variable "healthcheck_interval_sec" { + description = "Amount of time between health check intervals in seconds." + type = number +} + +variable "healthcheck_healthy_threshold" { + description = "Consecutive health check successes required to be considered healthy." + type = number +} + +variable "healthcheck_unhealthy_threshold" { + description = "Consecutive health check failures required to be considered unhealthy." + type = number +} + +variable "healthcheck_timeout_sec" { + description = "Amount of time to wait for a health check response in seconds." + type = number +} diff --git a/production/terraform/aws/services/networking/main.tf b/production/terraform/aws/services/networking/main.tf index 5aa05531..0931e0f9 100644 --- a/production/terraform/aws/services/networking/main.tf +++ b/production/terraform/aws/services/networking/main.tf @@ -14,12 +14,77 @@ * limitations under the License. */ +################################################################################ +# If use_existing_vpc is true, we need to use an existing network. +################################################################################ + +data "aws_vpc" "existing_vpc" { + count = var.use_existing_vpc ? 1 : 0 + filter { + name = "tag:operator" + values = [var.existing_vpc_operator] + } + filter { + name = "tag:environment" + values = [var.existing_vpc_environment] + } +} + +data "aws_subnets" "existing_public_subnet" { + depends_on = [data.aws_vpc.existing_vpc] + filter { + name = "tag:Name" + values = ["*public*"] + } + filter { + name = "tag:operator" + values = [var.existing_vpc_operator] + } + filter { + name = "tag:environment" + values = [var.existing_vpc_environment] + } +} + +data "aws_subnets" "existing_private_subnet" { + depends_on = [data.aws_vpc.existing_vpc] + filter { + name = "tag:Name" + values = ["*private*"] + } + filter { + name = "tag:operator" + values = [var.existing_vpc_operator] + } + filter { + name = "tag:environment" + values = [var.existing_vpc_environment] + } +} + +data "aws_route_tables" "existing_private_rt" { + depends_on = [data.aws_vpc.existing_vpc] + filter { + name = "tag:Name" + values = ["*private-rt*"] + } + filter { + name = "tag:operator" + values = [var.existing_vpc_operator] + } + filter { + name = "tag:environment" + values = [var.existing_vpc_environment] + } +} + ################################################################################ # Setup VPC, and networking for private subnets and public subnets. ################################################################################ # Create the VPC where server instances will be launched. resource "aws_vpc" "vpc" { + count = var.use_existing_vpc ? 0 : 1 cidr_block = var.vpc_cidr_block enable_dns_support = true enable_dns_hostnames = true @@ -36,9 +101,9 @@ data "aws_availability_zones" "azs" { # Create public subnets used to connect to instances in private subnets. resource "aws_subnet" "public_subnet" { - count = length(data.aws_availability_zones.azs.names) - cidr_block = cidrsubnet(aws_vpc.vpc.cidr_block, 4, count.index) - vpc_id = aws_vpc.vpc.id + count = var.use_existing_vpc ? 0 : length(data.aws_availability_zones.azs.names) + cidr_block = cidrsubnet(aws_vpc.vpc[0].cidr_block, 4, count.index) + vpc_id = aws_vpc.vpc[0].id availability_zone = data.aws_availability_zones.azs.names[count.index] map_public_ip_on_launch = true @@ -49,9 +114,9 @@ resource "aws_subnet" "public_subnet" { # Create private subnets where instances will be launched. resource "aws_subnet" "private_subnet" { - count = length(data.aws_availability_zones.azs.names) - cidr_block = cidrsubnet(aws_vpc.vpc.cidr_block, 4, 15 - count.index) - vpc_id = aws_vpc.vpc.id + count = var.use_existing_vpc ? 0 : length(data.aws_availability_zones.azs.names) + cidr_block = cidrsubnet(aws_vpc.vpc[0].cidr_block, 4, 15 - count.index) + vpc_id = aws_vpc.vpc[0].id availability_zone = data.aws_availability_zones.azs.names[count.index] map_public_ip_on_launch = false @@ -62,7 +127,8 @@ resource "aws_subnet" "private_subnet" { # Create networking components for public subnets. resource "aws_internet_gateway" "igw" { - vpc_id = aws_vpc.vpc.id + count = var.use_existing_vpc ? 0 : 1 + vpc_id = aws_vpc.vpc[0].id tags = { Name = "${var.service}-${var.environment}-igw" @@ -70,7 +136,8 @@ resource "aws_internet_gateway" "igw" { } resource "aws_route_table" "public_rt" { - vpc_id = aws_vpc.vpc.id + count = var.use_existing_vpc ? 0 : 1 + vpc_id = aws_vpc.vpc[0].id tags = { Name = "${var.service}-${var.environment}-public-rt" @@ -78,8 +145,9 @@ resource "aws_route_table" "public_rt" { } resource "aws_route" "public_route" { - route_table_id = aws_route_table.public_rt.id - gateway_id = aws_internet_gateway.igw.id + count = var.use_existing_vpc ? 0 : 1 + route_table_id = aws_route_table.public_rt[0].id + gateway_id = aws_internet_gateway.igw[0].id destination_cidr_block = "0.0.0.0/0" depends_on = [ @@ -88,14 +156,14 @@ resource "aws_route" "public_route" { } resource "aws_route_table_association" "public_rt_assoc" { - count = length(aws_subnet.public_subnet) + count = var.use_existing_vpc ? 0 : length(aws_subnet.public_subnet) subnet_id = aws_subnet.public_subnet[count.index].id - route_table_id = aws_route_table.public_rt.id + route_table_id = aws_route_table.public_rt[0].id } # Create private route tables required for gateway endpoints. resource "aws_eip" "private_subnet_eip" { - count = length(aws_subnet.private_subnet) + count = var.use_existing_vpc ? 0 : length(aws_subnet.private_subnet) vpc = true depends_on = [ aws_internet_gateway.igw @@ -103,7 +171,7 @@ resource "aws_eip" "private_subnet_eip" { } resource "aws_nat_gateway" "nat_gateway" { - count = length(aws_subnet.private_subnet) + count = var.use_existing_vpc ? 0 : length(aws_subnet.private_subnet) subnet_id = aws_subnet.public_subnet[count.index].id allocation_id = aws_eip.private_subnet_eip[count.index].id @@ -117,8 +185,8 @@ resource "aws_nat_gateway" "nat_gateway" { } resource "aws_route_table" "private_rt" { - count = length(aws_subnet.private_subnet) - vpc_id = aws_vpc.vpc.id + count = var.use_existing_vpc ? 0 : length(aws_subnet.private_subnet) + vpc_id = aws_vpc.vpc[0].id tags = { Name = "${var.service}-${var.environment}-private-rt${count.index}" @@ -126,14 +194,14 @@ resource "aws_route_table" "private_rt" { } resource "aws_route" "private_route" { - count = length(aws_subnet.private_subnet) + count = var.use_existing_vpc ? 0 : length(aws_subnet.private_subnet) route_table_id = aws_route_table.private_rt[count.index].id nat_gateway_id = aws_nat_gateway.nat_gateway[count.index].id destination_cidr_block = "0.0.0.0/0" } resource "aws_route_table_association" "private_rt_assoc" { - count = length(aws_subnet.private_subnet) + count = var.use_existing_vpc ? 0 : length(aws_subnet.private_subnet) route_table_id = aws_route_table.private_rt[count.index].id subnet_id = aws_subnet.private_subnet[count.index].id } diff --git a/production/terraform/aws/services/networking/outputs.tf b/production/terraform/aws/services/networking/outputs.tf index c12a351e..cc082ddc 100644 --- a/production/terraform/aws/services/networking/outputs.tf +++ b/production/terraform/aws/services/networking/outputs.tf @@ -15,17 +15,17 @@ */ output "vpc_id" { - value = aws_vpc.vpc.id + value = var.use_existing_vpc ? data.aws_vpc.existing_vpc[0].id : aws_vpc.vpc[0].id } output "public_subnet_ids" { - value = [for subnet in aws_subnet.public_subnet : subnet.id] + value = var.use_existing_vpc ? data.aws_subnets.existing_public_subnet.ids : [for subnet in aws_subnet.public_subnet : subnet.id] } output "private_subnet_ids" { - value = [for subnet in aws_subnet.private_subnet : subnet.id] + value = var.use_existing_vpc ? data.aws_subnets.existing_private_subnet.ids : [for subnet in aws_subnet.private_subnet : subnet.id] } output "private_route_table_ids" { - value = [for rt in aws_route_table.private_rt : rt.id] + value = var.use_existing_vpc ? data.aws_route_tables.existing_private_rt.ids : [for rt in aws_route_table.private_rt : rt.id] } diff --git a/production/terraform/aws/services/networking/variables.tf b/production/terraform/aws/services/networking/variables.tf index 452bc7c6..63bb6cc6 100644 --- a/production/terraform/aws/services/networking/variables.tf +++ b/production/terraform/aws/services/networking/variables.tf @@ -27,3 +27,18 @@ variable "vpc_cidr_block" { description = "CIDR range for the VPC where KV server will be deployed." type = string } + +variable "use_existing_vpc" { + description = "Whether to use existing VPC. If true, only internal traffic via mesh will be served; variable vpc_operator and vpc_environment will be requried." + type = bool +} + +variable "existing_vpc_operator" { + description = "Operator of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "existing_vpc_environment" { + description = "Environment of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} diff --git a/production/terraform/aws/services/parameter/main.tf b/production/terraform/aws/services/parameter/main.tf index 7d6d6450..eecf0fa8 100644 --- a/production/terraform/aws/services/parameter/main.tf +++ b/production/terraform/aws/services/parameter/main.tf @@ -175,7 +175,7 @@ resource "aws_ssm_parameter" "primary_coordinator_private_key_endpoint_parameter resource "aws_ssm_parameter" "secondary_coordinator_private_key_endpoint_parameter" { count = (var.use_real_coordinators_parameter_value) ? 1 : 0 - name = "${var.service}-${var.environment}-primary-coordinator-region" + name = "${var.service}-${var.environment}-secondary-coordinator-private-key-endpoint" type = "String" value = var.secondary_coordinator_private_key_endpoint_parameter_value overwrite = true @@ -183,7 +183,7 @@ resource "aws_ssm_parameter" "secondary_coordinator_private_key_endpoint_paramet resource "aws_ssm_parameter" "primary_coordinator_region_parameter" { count = (var.use_real_coordinators_parameter_value) ? 1 : 0 - name = "${var.service}-${var.environment}-secondary-coordinator-private-key-endpoint" + name = "${var.service}-${var.environment}-primary-coordinator-region" type = "String" value = var.primary_coordinator_region_parameter_value overwrite = true @@ -241,6 +241,13 @@ resource "aws_ssm_parameter" "udf_timeout_millis_parameter" { overwrite = true } +resource "aws_ssm_parameter" "udf_update_timeout_millis_parameter" { + name = "${var.service}-${var.environment}-udf-update-timeout-millis" + type = "String" + value = var.udf_update_timeout_millis_parameter_value + overwrite = true +} + resource "aws_ssm_parameter" "udf_min_log_level_parameter" { name = "${var.service}-${var.environment}-udf-min-log-level" type = "String" @@ -261,3 +268,18 @@ resource "aws_ssm_parameter" "data_loading_blob_prefix_allowlist" { value = var.data_loading_blob_prefix_allowlist overwrite = true } + +resource "aws_ssm_parameter" "consented_debug_token_parameter" { + count = (var.enable_consented_log_parameter_value) ? 1 : 0 + name = "${var.service}-${var.environment}-consented-debug-token" + type = "String" + value = var.consented_debug_token_parameter_value + overwrite = true +} + +resource "aws_ssm_parameter" "enable_consented_log_parameter" { + name = "${var.service}-${var.environment}-enable-consented-log" + type = "String" + value = var.enable_consented_log_parameter_value + overwrite = true +} diff --git a/production/terraform/aws/services/parameter/outputs.tf b/production/terraform/aws/services/parameter/outputs.tf index 32e9aacf..9b6fd2ca 100644 --- a/production/terraform/aws/services/parameter/outputs.tf +++ b/production/terraform/aws/services/parameter/outputs.tf @@ -142,6 +142,10 @@ output "udf_timeout_millis_parameter_arn" { value = aws_ssm_parameter.udf_timeout_millis_parameter.arn } +output "udf_update_timeout_millis_parameter_arn" { + value = aws_ssm_parameter.udf_update_timeout_millis_parameter.arn +} + output "udf_min_log_level_parameter_arn" { value = aws_ssm_parameter.udf_min_log_level_parameter.arn } @@ -153,3 +157,11 @@ output "enable_otel_logger_parameter_arn" { output "data_loading_blob_prefix_allowlist_parameter_arn" { value = aws_ssm_parameter.data_loading_blob_prefix_allowlist.arn } + +output "consented_debug_token_parameter_arn" { + value = (var.enable_consented_log_parameter_value) ? aws_ssm_parameter.consented_debug_token_parameter[0].arn : "" +} + +output "enable_consented_log_parameter_arn" { + value = aws_ssm_parameter.enable_consented_log_parameter.arn +} diff --git a/production/terraform/aws/services/parameter/variables.tf b/production/terraform/aws/services/parameter/variables.tf index 32228f3b..3ed6e121 100644 --- a/production/terraform/aws/services/parameter/variables.tf +++ b/production/terraform/aws/services/parameter/variables.tf @@ -149,6 +149,11 @@ variable "udf_timeout_millis_parameter_value" { type = number } +variable "udf_update_timeout_millis_parameter_value" { + description = "UDF update timeout in milliseconds." + type = number +} + variable "udf_min_log_level_parameter_value" { description = "Minimum log level for UDFs. Info = 0, Warn = 1, Error = 2. The UDF will only attempt to log for min_log_level and above. Default is 0(info)." type = number @@ -188,3 +193,13 @@ variable "public_key_endpoint_parameter_value" { description = "Public key endpoint. Can only be overriden in non-prod mode." type = string } + +variable "consented_debug_token_parameter_value" { + description = "Consented debug token to enable the otel collection of consented logs. Empty token means no-op and no logs will be collected for consented requests. The token in the request's consented debug configuration needs to match this debug token to make the server treat the request as consented." + type = string +} + +variable "enable_consented_log_parameter_value" { + description = "Enable the logging of consented requests. If it is set to true, the consented debug token parameter value must not be an empty string." + type = bool +} diff --git a/production/terraform/aws/services/security_group_rules/main.tf b/production/terraform/aws/services/security_group_rules/main.tf index efe01002..6b144a9d 100644 --- a/production/terraform/aws/services/security_group_rules/main.tf +++ b/production/terraform/aws/services/security_group_rules/main.tf @@ -16,6 +16,7 @@ # Ingress and egress rules for the load balancer listener. resource "aws_security_group_rule" "allow_all_elb_ingress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 443 protocol = "TCP" security_group_id = var.elb_security_group_id @@ -44,6 +45,7 @@ resource "aws_security_group_rule" "allow_elb_to_ec2_egress" { # Ingress and egress rules for SSH. resource "aws_security_group_rule" "allow_all_ssh_ingress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 22 protocol = "TCP" security_group_id = var.ssh_security_group_id @@ -53,6 +55,7 @@ resource "aws_security_group_rule" "allow_all_ssh_ingress" { } resource "aws_security_group_rule" "allow_ssh_to_ec2_egress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 22 protocol = "TCP" security_group_id = var.ssh_security_group_id @@ -62,6 +65,7 @@ resource "aws_security_group_rule" "allow_ssh_to_ec2_egress" { } resource "aws_security_group_rule" "allow_ssh_secure_tcp_egress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 443 protocol = "TCP" security_group_id = var.ssh_security_group_id @@ -81,6 +85,7 @@ resource "aws_security_group_rule" "allow_elb_to_ec2_ingress" { } resource "aws_security_group_rule" "allow_ssh_to_ec2_ingress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 22 protocol = "TCP" security_group_id = var.instances_security_group_id @@ -90,6 +95,7 @@ resource "aws_security_group_rule" "allow_ssh_to_ec2_ingress" { } resource "aws_security_group_rule" "allow_ec2_to_vpc_endpoint_egress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 443 protocol = "TCP" security_group_id = var.instances_security_group_id @@ -99,6 +105,7 @@ resource "aws_security_group_rule" "allow_ec2_to_vpc_endpoint_egress" { } resource "aws_security_group_rule" "allow_ec2_to_vpc_ge_egress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 443 protocol = "TCP" security_group_id = var.instances_security_group_id @@ -113,6 +120,7 @@ data "aws_ip_ranges" "ec2_instance_connect_ip_ranges" { } resource "aws_security_group_rule" "allow_ec2_instance_connect_ingress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 22 protocol = "TCP" security_group_id = var.instances_security_group_id @@ -123,6 +131,7 @@ resource "aws_security_group_rule" "allow_ec2_instance_connect_ingress" { # Ingress and egress rules for backend vpc interface endpoints. resource "aws_security_group_rule" "allow_ec2_to_vpce_ingress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 443 protocol = "TCP" security_group_id = var.vpce_security_group_id @@ -132,6 +141,7 @@ resource "aws_security_group_rule" "allow_ec2_to_vpce_ingress" { } resource "aws_security_group_rule" "allow_ssh_instance_to_vpce_ingress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 443 protocol = "TCP" security_group_id = var.vpce_security_group_id @@ -159,6 +169,7 @@ resource "aws_security_group_rule" "allow_ec2_to_ec2_endpoint_ingress" { } resource "aws_security_group_rule" "allow_ec2_secure_tcp_egress" { + count = var.use_existing_vpc ? 0 : 1 from_port = 443 protocol = "TCP" security_group_id = var.instances_security_group_id diff --git a/production/terraform/aws/services/security_group_rules/variables.tf b/production/terraform/aws/services/security_group_rules/variables.tf index 4a03b5a6..60213277 100644 --- a/production/terraform/aws/services/security_group_rules/variables.tf +++ b/production/terraform/aws/services/security_group_rules/variables.tf @@ -66,3 +66,8 @@ variable "ssh_source_cidr_blocks" { description = "Source ips allowed to send ssh traffic to the ssh instance." type = set(string) } + +variable "use_existing_vpc" { + description = "Whether to use existing VPC. If true, only internal traffic via mesh will be served; variable vpc_operator and vpc_environment will be requried." + type = bool +} diff --git a/production/terraform/aws/services/security_groups/main.tf b/production/terraform/aws/services/security_groups/main.tf index c9c39e74..3381ea47 100644 --- a/production/terraform/aws/services/security_groups/main.tf +++ b/production/terraform/aws/services/security_groups/main.tf @@ -18,8 +18,37 @@ # # NOTE that security group rules are managed in "../security_group_rules" module. +################################################################################ +# If use_existing_vpc is true, we need to use existing security groups. +################################################################################ + +data "aws_security_group" "existing_elb_security_group" { + count = var.use_existing_vpc ? 1 : 0 + name = "${var.existing_vpc_operator}-${var.existing_vpc_environment}-elb-sg" +} + +data "aws_security_group" "existing_ssh_security_group" { + count = var.use_existing_vpc ? 1 : 0 + name = "${var.existing_vpc_operator}-${var.existing_vpc_environment}-ssh-sg" +} + +data "aws_security_group" "existing_instance_security_group" { + count = var.use_existing_vpc ? 1 : 0 + name = "${var.existing_vpc_operator}-${var.existing_vpc_environment}-instance-sg" +} + +data "aws_security_group" "existing_vpce_security_group" { + count = var.use_existing_vpc ? 1 : 0 + name = "${var.existing_vpc_operator}-${var.existing_vpc_environment}-vpce-sg" +} + +################################################################################ +# If use_existing_vpc is false, create security groups. +################################################################################ + # Security group to control ingress and egress traffic for the load balancer. resource "aws_security_group" "elb_security_group" { + count = var.use_existing_vpc ? 0 : 1 name = "${var.service}-${var.environment}-elb-sg" vpc_id = var.vpc_id @@ -30,6 +59,7 @@ resource "aws_security_group" "elb_security_group" { # Security group to control ingress and egress traffic for the ssh ec2 instance. resource "aws_security_group" "ssh_security_group" { + count = var.use_existing_vpc ? 0 : 1 name = "${var.service}-${var.environment}-ssh-sg" vpc_id = var.vpc_id @@ -40,6 +70,7 @@ resource "aws_security_group" "ssh_security_group" { # Security group to control ingress and egress traffic for the server ec2 instances. resource "aws_security_group" "instance_security_group" { + count = var.use_existing_vpc ? 0 : 1 name = "${var.service}-${var.environment}-instance-sg" vpc_id = var.vpc_id @@ -50,6 +81,7 @@ resource "aws_security_group" "instance_security_group" { # Security group to control ingress and egress traffic to backend vpc endpoints. resource "aws_security_group" "vpce_security_group" { + count = var.use_existing_vpc ? 0 : 1 name = "${var.service}-${var.environment}-vpce-sg" vpc_id = var.vpc_id diff --git a/production/terraform/aws/services/security_groups/outputs.tf b/production/terraform/aws/services/security_groups/outputs.tf index 68d82520..6be2e38f 100644 --- a/production/terraform/aws/services/security_groups/outputs.tf +++ b/production/terraform/aws/services/security_groups/outputs.tf @@ -15,17 +15,17 @@ */ output "instance_security_group_id" { - value = aws_security_group.instance_security_group.id + value = var.use_existing_vpc ? data.aws_security_group.existing_instance_security_group[0].id : aws_security_group.instance_security_group[0].id } output "elb_security_group_id" { - value = aws_security_group.elb_security_group.id + value = var.use_existing_vpc ? data.aws_security_group.existing_elb_security_group[0].id : aws_security_group.elb_security_group[0].id } output "ssh_security_group_id" { - value = aws_security_group.ssh_security_group.id + value = var.use_existing_vpc ? data.aws_security_group.existing_ssh_security_group[0].id : aws_security_group.ssh_security_group[0].id } output "vpc_endpoint_security_group_id" { - value = aws_security_group.vpce_security_group.id + value = var.use_existing_vpc ? data.aws_security_group.existing_vpce_security_group[0].id : aws_security_group.vpce_security_group[0].id } diff --git a/production/terraform/aws/services/security_groups/variables.tf b/production/terraform/aws/services/security_groups/variables.tf index 36e10631..0c50ab40 100644 --- a/production/terraform/aws/services/security_groups/variables.tf +++ b/production/terraform/aws/services/security_groups/variables.tf @@ -27,3 +27,18 @@ variable "vpc_id" { description = "VPC id where security groups will be created." type = string } + +variable "use_existing_vpc" { + description = "Whether to use existing VPC. If true, only internal traffic via mesh will be served; variable vpc_operator and vpc_environment will be requried." + type = bool +} + +variable "existing_vpc_operator" { + description = "Operator of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} + +variable "existing_vpc_environment" { + description = "Environment of the existing VPC. Ingored if use_existing_vpc is false." + type = string +} diff --git a/production/terraform/gcp/environments/demo/us-east1.tfvars.json b/production/terraform/gcp/environments/demo/us-east1.tfvars.json index b0cf074d..212f64ff 100644 --- a/production/terraform/gcp/environments/demo/us-east1.tfvars.json +++ b/production/terraform/gcp/environments/demo/us-east1.tfvars.json @@ -6,10 +6,12 @@ "collector_machine_type": "e2-micro", "collector_service_name": "otel-collector", "collector_service_port": 4317, + "consented_debug_token": "EMPTY_STRING", "cpu_utilization_percent": 0.9, "data_bucket_id": "your-delta-file-bucket", "data_loading_blob_prefix_allowlist": ",", "data_loading_num_threads": 16, + "enable_consented_log": false, "enable_external_traffic": true, "environment": "demo", "envoy_port": 51052, @@ -51,6 +53,7 @@ "tee_impersonate_service_accounts": "", "telemetry_config": "mode: EXPERIMENT", "udf_num_workers": 2, + "udf_update_timeout_millis": 30000, "use_confidential_space_debug_image": false, "use_existing_service_mesh": false, "use_existing_vpc": false, diff --git a/production/terraform/gcp/environments/kv_server.tf b/production/terraform/gcp/environments/kv_server.tf index 332507de..4fb509c2 100644 --- a/production/terraform/gcp/environments/kv_server.tf +++ b/production/terraform/gcp/environments/kv_server.tf @@ -56,6 +56,7 @@ module "kv_server" { collector_service_name = var.collector_service_name collector_machine_type = var.collector_machine_type collector_service_port = var.collector_service_port + collector_startup_script_path = var.collector_startup_script_path collector_domain_name = var.collector_domain_name collector_dns_zone = var.collector_dns_zone data_bucket_id = var.data_bucket_id @@ -81,6 +82,7 @@ module "kv_server" { num-shards = var.num_shards udf-num-workers = var.udf_num_workers udf-timeout-millis = var.udf_timeout_millis + udf-update-timeout-millis = var.udf_update_timeout_millis udf-min-log-level = var.udf_min_log_level route-v1-to-v2 = var.route_v1_to_v2 add-missing-keys-v1 = var.add_missing_keys_v1 @@ -107,5 +109,7 @@ module "kv_server" { enable-external-traffic = var.enable_external_traffic telemetry-config = var.telemetry_config data-loading-blob-prefix-allowlist = var.data_loading_blob_prefix_allowlist + consented-debug-token = var.consented_debug_token + enable-consented-log = var.enable_consented_log } } diff --git a/production/terraform/gcp/environments/kv_server_variables.tf b/production/terraform/gcp/environments/kv_server_variables.tf index 8c0a8ba0..94247d2a 100644 --- a/production/terraform/gcp/environments/kv_server_variables.tf +++ b/production/terraform/gcp/environments/kv_server_variables.tf @@ -174,6 +174,12 @@ variable "udf_timeout_millis" { description = "UDF execution timeout in milliseconds." } +variable "udf_update_timeout_millis" { + type = number + default = 30000 + description = "UDF update timeout in milliseconds." +} + variable "udf_min_log_level" { type = number default = 0 @@ -210,6 +216,12 @@ variable "collector_service_port" { type = number } +variable "collector_startup_script_path" { + description = "Relative path from main.tf to collector service startup script." + type = string + default = "../../services/metrics_collector_autoscaling/collector_startup.sh" +} + variable "collector_domain_name" { description = "Google Cloud domain name for OpenTelemetry collector" type = string @@ -352,3 +364,13 @@ variable "public_key_endpoint" { description = "Public key endpoint. Can only be overriden in non-prod mode." type = string } + +variable "consented_debug_token" { + description = "Consented debug token to enable the otel collection of consented logs. Empty token means no-op and no logs will be collected for consented requests. The token in the request's consented debug configuration needs to match this debug token to make the server treat the request as consented." + type = string +} + +variable "enable_consented_log" { + description = "Enable the logging of consented requests. If it is set to true, the consented debug token parameter value must not be an empty string." + type = bool +} diff --git a/production/terraform/gcp/modules/kv_server/main.tf b/production/terraform/gcp/modules/kv_server/main.tf index a43531d9..168032dd 100644 --- a/production/terraform/gcp/modules/kv_server/main.tf +++ b/production/terraform/gcp/modules/kv_server/main.tf @@ -33,6 +33,7 @@ module "security" { environment = var.environment network_id = module.networking.network_id subnets = module.networking.subnets + proxy_subnets = module.networking.proxy_subnets collector_service_port = var.collector_service_port } @@ -72,13 +73,15 @@ module "metrics_collector_autoscaling" { collector_machine_type = var.collector_machine_type collector_service_name = var.collector_service_name collector_service_port = var.collector_service_port + collector_startup_script_path = "${path.module}/${var.collector_startup_script_path}" max_replicas_per_service_region = var.max_replicas_per_service_region } module "metrics_collector" { source = "../../services/metrics_collector" environment = var.environment - collector_ip_address = module.networking.collector_ip_address + subnets = module.networking.subnets + proxy_subnets = module.networking.proxy_subnets collector_instance_groups = module.metrics_collector_autoscaling.collector_instance_groups collector_service_name = var.collector_service_name collector_service_port = var.collector_service_port diff --git a/production/terraform/gcp/modules/kv_server/variables.tf b/production/terraform/gcp/modules/kv_server/variables.tf index bdcd0fc7..3b4b14a9 100644 --- a/production/terraform/gcp/modules/kv_server/variables.tf +++ b/production/terraform/gcp/modules/kv_server/variables.tf @@ -121,6 +121,11 @@ variable "collector_service_port" { type = number } +variable "collector_startup_script_path" { + description = "Relative path from main.tf to collector service startup script." + type = string +} + variable "collector_machine_type" { description = "Machine type for the collector service." type = string diff --git a/production/terraform/gcp/services/autoscaling/main.tf b/production/terraform/gcp/services/autoscaling/main.tf index 3b10e1bc..44308e8d 100644 --- a/production/terraform/gcp/services/autoscaling/main.tf +++ b/production/terraform/gcp/services/autoscaling/main.tf @@ -53,7 +53,7 @@ resource "google_compute_instance_template" "kv_server" { metadata = { tee-image-reference = "${var.gcp_image_repo}:${var.gcp_image_tag}" - tee-container-log-redirect = true, + tee-container-log-redirect = var.use_confidential_space_debug_image ? true : false tee-impersonate-service-accounts = "${var.tee_impersonate_service_accounts}" environment = var.environment } diff --git a/production/terraform/gcp/services/metrics_collector/main.tf b/production/terraform/gcp/services/metrics_collector/main.tf index 3f872e6c..7d2ed8cb 100644 --- a/production/terraform/gcp/services/metrics_collector/main.tf +++ b/production/terraform/gcp/services/metrics_collector/main.tf @@ -14,13 +14,21 @@ * limitations under the License. */ -resource "google_compute_backend_service" "mesh_collector" { + +############################################################### +# +# Collector LB +# +# The internal lb uses HTTP/2 (gRPC) with no TLS. +############################################################### + +resource "google_compute_backend_service" "collector" { name = "${var.environment}-${var.collector_service_name}-service" provider = google-beta port_name = "otlp" protocol = "TCP" - load_balancing_scheme = "EXTERNAL" + load_balancing_scheme = "INTERNAL_MANAGED" timeout_sec = 10 health_checks = [google_compute_health_check.collector.id] @@ -36,22 +44,26 @@ resource "google_compute_backend_service" "mesh_collector" { resource "google_compute_target_tcp_proxy" "collector" { name = "${var.environment}-${var.collector_service_name}-lb-proxy" - backend_service = google_compute_backend_service.mesh_collector.id + backend_service = google_compute_backend_service.collector.id } -resource "google_compute_global_forwarding_rule" "collector" { - name = "${var.environment}-${var.collector_service_name}-forwarding-rule" - provider = google-beta +resource "google_compute_global_forwarding_rule" "collectors" { + for_each = var.subnets + + name = "${var.environment}-${var.collector_service_name}-${each.value.region}-ilb-rule" ip_protocol = "TCP" port_range = var.collector_service_port - load_balancing_scheme = "EXTERNAL" + load_balancing_scheme = "INTERNAL_MANAGED" target = google_compute_target_tcp_proxy.collector.id - ip_address = var.collector_ip_address + subnetwork = each.value.id labels = { service = var.collector_service_name + region = each.value.region } + + depends_on = [var.proxy_subnets] } resource "google_dns_record_set" "collector" { @@ -59,17 +71,22 @@ resource "google_dns_record_set" "collector" { managed_zone = var.collector_dns_zone type = "A" ttl = 10 - rrdatas = [ - var.collector_ip_address - ] + routing_policy { + dynamic "geo" { + for_each = google_compute_global_forwarding_rule.collectors + content { + location = geo.value.labels.region + rrdatas = [geo.value.ip_address] + } + } + } } resource "google_compute_health_check" "collector" { name = "${var.environment}-${var.collector_service_name}-lb-hc" - tcp_health_check { - port_name = "otlp" - port = var.collector_service_port + grpc_health_check { + port = var.collector_service_port } timeout_sec = 3 diff --git a/production/terraform/gcp/services/metrics_collector/outputs.tf b/production/terraform/gcp/services/metrics_collector/outputs.tf index a12503c3..ca958892 100644 --- a/production/terraform/gcp/services/metrics_collector/outputs.tf +++ b/production/terraform/gcp/services/metrics_collector/outputs.tf @@ -15,7 +15,7 @@ */ output "collector_forwarding_rule" { - value = google_compute_global_forwarding_rule.collector + value = google_compute_global_forwarding_rule.collectors } output "collector_tcp_proxy" { diff --git a/production/terraform/gcp/services/metrics_collector/variables.tf b/production/terraform/gcp/services/metrics_collector/variables.tf index 23ef7178..ef41c9df 100644 --- a/production/terraform/gcp/services/metrics_collector/variables.tf +++ b/production/terraform/gcp/services/metrics_collector/variables.tf @@ -19,9 +19,9 @@ variable "environment" { type = string } -variable "collector_ip_address" { - description = "Collector IP address" - type = string +variable "subnets" { + description = "All service subnets." + type = any } variable "collector_instance_groups" { @@ -47,3 +47,8 @@ variable "collector_dns_zone" { description = "Google Cloud DNS zone name for collector." type = string } + +variable "proxy_subnets" { + description = "A list of all envoy proxy subnets. Used to allow ingress into the collectors." + type = any +} diff --git a/production/terraform/gcp/services/metrics_collector_autoscaling/collector_startup.sh b/production/terraform/gcp/services/metrics_collector_autoscaling/collector_startup.sh index 64647ac1..33c24b4c 100644 --- a/production/terraform/gcp/services/metrics_collector_autoscaling/collector_startup.sh +++ b/production/terraform/gcp/services/metrics_collector_autoscaling/collector_startup.sh @@ -42,7 +42,7 @@ exporters: # https://github.com/open-telemetry/opentelemetry-collector-contrib/blob/main/exporter/googlecloudexporter/README.md - regex: .* log: - default_log_name: kv-server-metrics + default_log_name: kv-server-logs service: pipelines: diff --git a/production/terraform/gcp/services/metrics_collector_autoscaling/main.tf b/production/terraform/gcp/services/metrics_collector_autoscaling/main.tf index f930e588..1041e078 100644 --- a/production/terraform/gcp/services/metrics_collector_autoscaling/main.tf +++ b/production/terraform/gcp/services/metrics_collector_autoscaling/main.tf @@ -57,7 +57,7 @@ resource "google_compute_instance_template" "collector" { scopes = ["https://www.googleapis.com/auth/cloud-platform"] } metadata = { - startup-script = templatefile("${path.module}/collector_startup.sh", { + startup-script = templatefile(var.collector_startup_script_path, { collector_port = var.collector_service_port, }) } diff --git a/production/terraform/gcp/services/metrics_collector_autoscaling/variables.tf b/production/terraform/gcp/services/metrics_collector_autoscaling/variables.tf index 96da849f..4c43748f 100644 --- a/production/terraform/gcp/services/metrics_collector_autoscaling/variables.tf +++ b/production/terraform/gcp/services/metrics_collector_autoscaling/variables.tf @@ -65,6 +65,11 @@ variable "collector_service_port" { type = number } +variable "collector_startup_script_path" { + description = "Path to collector service startup script." + type = string +} + variable "max_replicas_per_service_region" { description = "Maximum amount of replicas per each service region (a single managed instance group)." type = number diff --git a/production/terraform/gcp/services/networking/main.tf b/production/terraform/gcp/services/networking/main.tf index 87c8fb07..754e45cf 100644 --- a/production/terraform/gcp/services/networking/main.tf +++ b/production/terraform/gcp/services/networking/main.tf @@ -30,6 +30,20 @@ resource "google_compute_subnetwork" "kv_server" { ip_cidr_range = tolist(var.regions_cidr_blocks)[each.key] } +resource "google_compute_subnetwork" "proxy_subnets" { + for_each = { for index, region in tolist(var.regions) : index => region } + + ip_cidr_range = "10.${139 + each.key}.0.0/23" + name = "${var.service}-${var.environment}-${each.value}-collector-proxy-subnet" + network = var.use_existing_vpc ? var.existing_vpc_id : google_compute_network.kv_server[0].id + purpose = "GLOBAL_MANAGED_PROXY" + region = each.value + role = "ACTIVE" + lifecycle { + ignore_changes = [ipv6_access_type] + } +} + resource "google_compute_router" "kv_server" { for_each = var.regions @@ -56,11 +70,6 @@ resource "google_compute_router_nat" "kv_server" { } } -resource "google_compute_global_address" "collector" { - name = "${var.collector_service_name}-${var.environment}-lb" - ip_version = "IPV4" -} - resource "google_compute_global_address" "kv_server" { count = var.enable_external_traffic ? 1 : 0 name = "${var.service}-${var.environment}-xlb-ip" diff --git a/production/terraform/gcp/services/networking/outputs.tf b/production/terraform/gcp/services/networking/outputs.tf index 3bb29175..e8ac484d 100644 --- a/production/terraform/gcp/services/networking/outputs.tf +++ b/production/terraform/gcp/services/networking/outputs.tf @@ -23,8 +23,9 @@ output "subnets" { value = google_compute_subnetwork.kv_server } -output "collector_ip_address" { - value = google_compute_global_address.collector.address +output "proxy_subnets" { + description = "All service proxy subnets." + value = google_compute_subnetwork.proxy_subnets } output "server_ip_address" { diff --git a/production/terraform/gcp/services/security/main.tf b/production/terraform/gcp/services/security/main.tf index a5c3c376..6d6da042 100644 --- a/production/terraform/gcp/services/security/main.tf +++ b/production/terraform/gcp/services/security/main.tf @@ -73,5 +73,5 @@ resource "google_compute_firewall" "fw_allow_otlp" { ports = [var.collector_service_port] } target_tags = ["allow-otlp"] - source_ranges = [for subnet in var.subnets : subnet.ip_cidr_range] + source_ranges = [for subnet in var.proxy_subnets : subnet.ip_cidr_range] } diff --git a/production/terraform/gcp/services/security/variables.tf b/production/terraform/gcp/services/security/variables.tf index 2f171765..a09424fb 100644 --- a/production/terraform/gcp/services/security/variables.tf +++ b/production/terraform/gcp/services/security/variables.tf @@ -38,3 +38,8 @@ variable "collector_service_port" { description = "The grpc port that receives traffic destined for the OpenTelemetry collector." type = number } + +variable "proxy_subnets" { + description = "A list of all envoy proxy subnets. Used to allow ingress into the collectors." + type = any +} diff --git a/public/applications/pas/retrieval_request_builder.cc b/public/applications/pas/retrieval_request_builder.cc index 45f2d420..1ac5b5bf 100644 --- a/public/applications/pas/retrieval_request_builder.cc +++ b/public/applications/pas/retrieval_request_builder.cc @@ -26,6 +26,9 @@ v2::GetValuesRequest GetRequest() { } v2::GetValuesRequest BuildRetrievalRequest( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config, std::string protected_signals, absl::flat_hash_map device_metadata, std::string contextual_signals, std::vector optional_ad_ids) { @@ -59,10 +62,16 @@ v2::GetValuesRequest BuildRetrievalRequest( ->set_string_value(std::move(item)); } } + { *req.mutable_consented_debug_config() = consented_debug_config; } + { *req.mutable_log_context() = log_context; } return req; } -v2::GetValuesRequest BuildLookupRequest(std::vector ad_ids) { +v2::GetValuesRequest BuildLookupRequest( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config, + std::vector ad_ids) { v2::GetValuesRequest req = GetRequest(); v2::RequestPartition* partition = req.add_partitions(); auto* ad_id_arg = partition->add_arguments(); @@ -72,6 +81,8 @@ v2::GetValuesRequest BuildLookupRequest(std::vector ad_ids) { ->add_values() ->set_string_value(std::move(item)); } + { *req.mutable_consented_debug_config() = consented_debug_config; } + { *req.mutable_log_context() = log_context; } return req; } diff --git a/public/applications/pas/retrieval_request_builder.h b/public/applications/pas/retrieval_request_builder.h index 99e2f7ce..f622cfc8 100644 --- a/public/applications/pas/retrieval_request_builder.h +++ b/public/applications/pas/retrieval_request_builder.h @@ -30,13 +30,20 @@ namespace kv_server::application_pas { // Input strings must be JSON-compliant, i.e., for binary strings, they must be // base64 encoded. v2::GetValuesRequest BuildRetrievalRequest( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config, std::string protected_signals, absl::flat_hash_map device_metadata, std::string contextual_signals, std::vector optional_ad_ids = {}); // Builds a GetValuesRequest. Stores the input arguments into the request. -v2::GetValuesRequest BuildLookupRequest(std::vector ad_ids); +v2::GetValuesRequest BuildLookupRequest( + const privacy_sandbox::server_common::LogContext& log_context, + const privacy_sandbox::server_common::ConsentedDebugConfiguration& + consented_debug_config, + std::vector ad_ids); } // namespace kv_server::application_pas diff --git a/public/applications/pas/retrieval_request_builder_test.cc b/public/applications/pas/retrieval_request_builder_test.cc index 43c4aaa5..9e831c5b 100644 --- a/public/applications/pas/retrieval_request_builder_test.cc +++ b/public/applications/pas/retrieval_request_builder_test.cc @@ -27,10 +27,24 @@ namespace { using google::protobuf::TextFormat; TEST(RequestBuilder, Build) { + privacy_sandbox::server_common::ConsentedDebugConfiguration + consented_debug_configuration; + consented_debug_configuration.set_is_consented(true); + consented_debug_configuration.set_token("test_token"); + privacy_sandbox::server_common::LogContext log_context; + log_context.set_generation_id("generation_id"); + log_context.set_adtech_debug_id("debug_id"); + v2::GetValuesRequest expected; TextFormat::ParseFromString( R"pb( client_version: "Retrieval.20231018" + metadata { + fields { + key: "is_pas" + value { string_value: "true" } + } + } partitions { id: 0 arguments { data { string_value: "protected signals" } } @@ -62,11 +76,16 @@ TEST(RequestBuilder, Build) { } } } - - })pb", + } + log_context { + generation_id: "generation_id" + adtech_debug_id: "debug_id" + } + consented_debug_config { is_consented: true token: "test_token" })pb", &expected); EXPECT_THAT( - BuildRetrievalRequest("protected signals", + BuildRetrievalRequest(log_context, consented_debug_configuration, + "protected signals", {{"m1", "v1"}, {"m2", "v2"}, {"m3", "v3"}}, "contextual signals", {"item1", "item2", "item3"}), EqualsProto(expected)); diff --git a/public/data_loading/BUILD.bazel b/public/data_loading/BUILD.bazel index 3286748e..8504076b 100644 --- a/public/data_loading/BUILD.bazel +++ b/public/data_loading/BUILD.bazel @@ -91,6 +91,7 @@ cc_test( size = "small", srcs = ["records_utils_test.cc"], deps = [ + ":record_utils", ":records_utils", "@com_google_absl//absl/hash:hash_testing", "@com_google_googletest//:gtest_main", diff --git a/public/data_loading/csv/constants.h b/public/data_loading/csv/constants.h index 15dcea9a..49af6ba0 100644 --- a/public/data_loading/csv/constants.h +++ b/public/data_loading/csv/constants.h @@ -33,6 +33,7 @@ inline constexpr std::string_view kValueColumn = "value"; inline constexpr std::string_view kValueTypeColumn = "value_type"; inline constexpr std::string_view kValueTypeString = "string"; inline constexpr std::string_view kValueTypeStringSet = "string_set"; +inline constexpr std::string_view kValueTypeUInt32Set = "uint32_set"; inline constexpr std::string_view kRecordTypeColumn = "record_type"; inline constexpr std::string_view kRecordTypeKVMutation = "key_value_mutation"; diff --git a/public/data_loading/csv/csv_delta_record_stream_reader.cc b/public/data_loading/csv/csv_delta_record_stream_reader.cc index 433c34ff..a7b94f9b 100644 --- a/public/data_loading/csv/csv_delta_record_stream_reader.cc +++ b/public/data_loading/csv/csv_delta_record_stream_reader.cc @@ -27,6 +27,7 @@ namespace kv_server { namespace { + absl::StatusOr GetInt64Column(const riegeli::CsvRecord& csv_record, std::string_view column_name) { if (int64_t result; absl::SimpleAtoi(csv_record[column_name], &result)) { @@ -51,23 +52,45 @@ absl::StatusOr GetValue(const riegeli::CsvRecord& csv_record, return csv_record[kValueColumn]; } -absl::StatusOr> GetSetValue( +template +absl::StatusOr> BuildSetValue( + const std::vector& set) { + std::vector result; + for (auto&& set_value : set) { + if constexpr (std::is_same_v) { + result.push_back(std::string(set_value)); + } + if constexpr (std::is_same_v) { + if (uint32_t number; absl::SimpleAtoi(set_value, &number)) { + result.push_back(number); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot convert: ", set_value, " to a uint32 number.")); + } + } + } + return result; +} + +template +absl::StatusOr> GetSetValue( const riegeli::CsvRecord& csv_record, char value_separator, const CsvEncoding& csv_encoding) { if (csv_encoding == CsvEncoding::kBase64) { - std::vector result; + std::vector decoded_values; for (auto&& set_value : absl::StrSplit(csv_record[kValueColumn], value_separator)) { if (std::string dest; absl::Base64Unescape(set_value, &dest)) { - result.push_back(std::move(dest)); + decoded_values.push_back(std::move(dest)); } else { return absl::InvalidArgumentError(absl::StrCat( "base64 decode failed for value: ", csv_record[kValueColumn])); } } - return result; + return BuildSetValue(decoded_values); } - return absl::StrSplit(csv_record[kValueColumn], value_separator); + return BuildSetValue( + absl::StrSplit(csv_record[kValueColumn], value_separator)); } absl::StatusOr GetDeltaMutationType( @@ -106,7 +129,8 @@ absl::Status SetRecordValue(char value_separator, return absl::OkStatus(); } if (absl::EqualsIgnoreCase(type, kValueTypeStringSet)) { - auto maybe_value = GetSetValue(csv_record, value_separator, csv_encoding); + auto maybe_value = + GetSetValue(csv_record, value_separator, csv_encoding); if (!maybe_value.ok()) { return maybe_value.status(); } @@ -115,6 +139,17 @@ absl::Status SetRecordValue(char value_separator, mutation_record.value.Set(std::move(set_value)); return absl::OkStatus(); } + if (absl::EqualsIgnoreCase(type, kValueTypeUInt32Set)) { + auto maybe_value = + GetSetValue(csv_record, value_separator, csv_encoding); + if (!maybe_value.ok()) { + return maybe_value.status(); + } + UInt32SetT set_value; + set_value.value = std::move(*maybe_value); + mutation_record.value.Set(std::move(set_value)); + return absl::OkStatus(); + } return absl::InvalidArgumentError( absl::StrCat("Value type: ", type, " is not supported")); } diff --git a/public/data_loading/csv/csv_delta_record_stream_reader_test.cc b/public/data_loading/csv/csv_delta_record_stream_reader_test.cc index 5ec95cb0..9add4b92 100644 --- a/public/data_loading/csv/csv_delta_record_stream_reader_test.cc +++ b/public/data_loading/csv/csv_delta_record_stream_reader_test.cc @@ -39,6 +39,12 @@ StringSetT GetStringSetValue(const std::vector& values) { return string_set; } +UInt32SetT GetUInt32SetValue(const std::vector& values) { + UInt32SetT uint32_set; + uint32_set.value = {values.begin(), values.end()}; + return uint32_set; +} + template std::pair GetKVMutationRecord(ValueT&& value, @@ -213,6 +219,32 @@ TEST(CsvDeltaRecordStreamReaderTest, EXPECT_TRUE(status.ok()) << status; } +TEST(CsvDeltaRecordStreamReaderTest, + ValidateReadingAndWriting_KVMutation_UInt32SetValues_Success) { + const std::vector values{ + 1000, + 1001, + 1002, + }; + std::stringstream string_stream; + CsvDeltaRecordStreamWriter record_writer(string_stream); + auto [legacy_mutation, mutation] = + GetKVMutationRecord(GetUInt32SetValue(values), values); + auto input = GetDataRecord(legacy_mutation); + auto expected = GetNativeDataRecord(mutation); + EXPECT_TRUE(record_writer.WriteRecord(input).ok()); + EXPECT_TRUE(record_writer.Flush().ok()); + LOG(INFO) << string_stream.str(); + CsvDeltaRecordStreamReader record_reader(string_stream); + auto status = + record_reader.ReadRecords([&expected](const DataRecord& record) { + std::unique_ptr native_type_record(record.UnPack()); + EXPECT_EQ(*native_type_record, expected); + return absl::OkStatus(); + }); + EXPECT_TRUE(status.ok()) << status; +} + TEST(CsvDeltaRecordStreamReaderTest, ValidateReadingAndWriting_KVMutation_SetValues_Base64_Success) { const std::vector values{ diff --git a/public/data_loading/csv/csv_delta_record_stream_writer.cc b/public/data_loading/csv/csv_delta_record_stream_writer.cc index 1fc1d970..5f807d07 100644 --- a/public/data_loading/csv/csv_delta_record_stream_writer.cc +++ b/public/data_loading/csv/csv_delta_record_stream_writer.cc @@ -52,16 +52,19 @@ std::string MaybeEncode(std::string_view value, return std::string(value); } -std::vector MaybeEncode(std::vector values, +template +std::vector MaybeEncode(std::vector values, const CsvEncoding& csv_encoding) { + std::vector result; if (csv_encoding == CsvEncoding::kBase64) { - std::vector result; for (auto&& value : values) { - result.push_back(absl::Base64Escape(value)); + result.push_back(absl::Base64Escape(absl::StrCat(value))); } return result; } - return std::vector(values.begin(), values.end()); + std::transform(values.begin(), values.end(), std::back_inserter(result), + [](const auto elem) { return absl::StrCat(elem); }); + return result; } absl::StatusOr GetRecordValue( @@ -84,6 +87,13 @@ absl::StatusOr GetRecordValue( value_separator), }; } + if constexpr (std::is_same_v>) { + return ValueStruct{ + .value_type = std::string(kValueTypeUInt32Set), + .value = absl::StrJoin(MaybeEncode(arg, csv_encoding), + value_separator), + }; + } return absl::InvalidArgumentError("Value must be set."); }, value); @@ -100,7 +110,8 @@ absl::StatusOr MakeCsvRecordWithKVMutation( const auto record = std::get(data_record.record); - riegeli::CsvRecord csv_record(kKeyValueMutationRecordHeader); + riegeli::CsvHeader header(kKeyValueMutationRecordHeader); + riegeli::CsvRecord csv_record(header); csv_record[kKeyColumn] = record.key; absl::StatusOr value = GetRecordValue( record.value, std::string(1, value_separator), csv_encoding); @@ -141,7 +152,8 @@ absl::StatusOr MakeCsvRecordWithUdfConfig( const auto udf_config = std::get(data_record.record); - riegeli::CsvRecord csv_record(kUserDefinedFunctionsConfigHeader); + riegeli::CsvHeader header(kUserDefinedFunctionsConfigHeader); + riegeli::CsvRecord csv_record(header); csv_record[kCodeSnippetColumn] = udf_config.code_snippet; csv_record[kHandlerNameColumn] = udf_config.handler_name; csv_record[kLogicalCommitTimeColumn] = @@ -164,7 +176,8 @@ absl::StatusOr MakeCsvRecordWithShardMapping( } const auto shard_mapping_struct = std::get(data_record.record); - riegeli::CsvRecord csv_record(kShardMappingRecordHeader); + riegeli::CsvHeader header(kShardMappingRecordHeader); + riegeli::CsvRecord csv_record(header); csv_record[kLogicalShardColumn] = absl::StrCat(shard_mapping_struct.logical_shard); csv_record[kPhysicalShardColumn] = diff --git a/public/data_loading/csv/csv_delta_record_stream_writer.h b/public/data_loading/csv/csv_delta_record_stream_writer.h index 1c828286..849e7fc6 100644 --- a/public/data_loading/csv/csv_delta_record_stream_writer.h +++ b/public/data_loading/csv/csv_delta_record_stream_writer.h @@ -124,7 +124,8 @@ riegeli::CsvWriterBase::Options GetRecordWriterOptions( kShardMappingRecordHeader.end()); break; } - writer_options.set_header(std::move(header)); + riegeli::CsvHeader header_opt(std::move(header)); + writer_options.set_header(std::move(header_opt)); return writer_options; } } // namespace internal @@ -154,7 +155,7 @@ absl::Status CsvDeltaRecordStreamWriter::WriteRecord( template absl::Status CsvDeltaRecordStreamWriter::Flush() { - record_writer_.dest_writer()->Flush(); + record_writer_.dest().Flush(); return record_writer_.status(); } diff --git a/public/data_loading/data_loading.fbs b/public/data_loading/data_loading.fbs index a2519e2a..955fdcb4 100644 --- a/public/data_loading/data_loading.fbs +++ b/public/data_loading/data_loading.fbs @@ -10,7 +10,8 @@ table StringValue { value:string; } // otherwise inserts the elements into the existing set. // (2) `Delete` mutation removes the elements from existing set. table StringSet { value:[string]; } -union Value { StringValue, StringSet } +table UInt32Set { value:[uint]; } +union Value { StringValue, StringSet, UInt32Set } table KeyValueMutationRecord { // Required. For updates, the value will overwrite the previous value, if any. diff --git a/public/data_loading/data_loading_generated.h b/public/data_loading/data_loading_generated.h index 85f78f10..6d68243c 100644 --- a/public/data_loading/data_loading_generated.h +++ b/public/data_loading/data_loading_generated.h @@ -30,14 +30,18 @@ static_assert(FLATBUFFERS_VERSION_MAJOR == 2 && namespace kv_server { -struct String; -struct StringBuilder; -struct StringT; +struct StringValue; +struct StringValueBuilder; +struct StringValueT; struct StringSet; struct StringSetBuilder; struct StringSetT; +struct UInt32Set; +struct UInt32SetBuilder; +struct UInt32SetT; + struct KeyValueMutationRecord; struct KeyValueMutationRecordBuilder; struct KeyValueMutationRecordT; @@ -54,6 +58,25 @@ struct DataRecord; struct DataRecordBuilder; struct DataRecordT; +bool operator==(const StringValueT& lhs, const StringValueT& rhs); +bool operator!=(const StringValueT& lhs, const StringValueT& rhs); +bool operator==(const StringSetT& lhs, const StringSetT& rhs); +bool operator!=(const StringSetT& lhs, const StringSetT& rhs); +bool operator==(const UInt32SetT& lhs, const UInt32SetT& rhs); +bool operator!=(const UInt32SetT& lhs, const UInt32SetT& rhs); +bool operator==(const KeyValueMutationRecordT& lhs, + const KeyValueMutationRecordT& rhs); +bool operator!=(const KeyValueMutationRecordT& lhs, + const KeyValueMutationRecordT& rhs); +bool operator==(const UserDefinedFunctionsConfigT& lhs, + const UserDefinedFunctionsConfigT& rhs); +bool operator!=(const UserDefinedFunctionsConfigT& lhs, + const UserDefinedFunctionsConfigT& rhs); +bool operator==(const ShardMappingRecordT& lhs, const ShardMappingRecordT& rhs); +bool operator!=(const ShardMappingRecordT& lhs, const ShardMappingRecordT& rhs); +bool operator==(const DataRecordT& lhs, const DataRecordT& rhs); +bool operator!=(const DataRecordT& lhs, const DataRecordT& rhs); + enum class KeyValueMutationType : int8_t { Update = 0, Delete = 1, @@ -82,24 +105,27 @@ inline const char* EnumNameKeyValueMutationType(KeyValueMutationType e) { enum class Value : uint8_t { NONE = 0, - String = 1, + StringValue = 1, StringSet = 2, + UInt32Set = 3, MIN = NONE, - MAX = StringSet + MAX = UInt32Set }; -inline const Value (&EnumValuesValue())[3] { - static const Value values[] = {Value::NONE, Value::String, Value::StringSet}; +inline const Value (&EnumValuesValue())[4] { + static const Value values[] = {Value::NONE, Value::StringValue, + Value::StringSet, Value::UInt32Set}; return values; } inline const char* const* EnumNamesValue() { - static const char* const names[4] = {"NONE", "String", "StringSet", nullptr}; + static const char* const names[5] = {"NONE", "StringValue", "StringSet", + "UInt32Set", nullptr}; return names; } inline const char* EnumNameValue(Value e) { - if (flatbuffers::IsOutRange(e, Value::NONE, Value::StringSet)) return ""; + if (flatbuffers::IsOutRange(e, Value::NONE, Value::UInt32Set)) return ""; const size_t index = static_cast(e); return EnumNamesValue()[index]; } @@ -110,8 +136,8 @@ struct ValueTraits { }; template <> -struct ValueTraits { - static const Value enum_value = Value::String; +struct ValueTraits { + static const Value enum_value = Value::StringValue; }; template <> @@ -119,14 +145,19 @@ struct ValueTraits { static const Value enum_value = Value::StringSet; }; +template <> +struct ValueTraits { + static const Value enum_value = Value::UInt32Set; +}; + template struct ValueUnionTraits { static const Value enum_value = Value::NONE; }; template <> -struct ValueUnionTraits { - static const Value enum_value = Value::String; +struct ValueUnionTraits { + static const Value enum_value = Value::StringValue; }; template <> @@ -134,6 +165,11 @@ struct ValueUnionTraits { static const Value enum_value = Value::StringSet; }; +template <> +struct ValueUnionTraits { + static const Value enum_value = Value::UInt32Set; +}; + struct ValueUnion { Value type; void* value; @@ -176,13 +212,14 @@ struct ValueUnion { flatbuffers::FlatBufferBuilder& _fbb, const flatbuffers::rehasher_function_t* _rehasher = nullptr) const; - kv_server::StringT* AsString() { - return type == Value::String ? reinterpret_cast(value) - : nullptr; + kv_server::StringValueT* AsStringValue() { + return type == Value::StringValue + ? reinterpret_cast(value) + : nullptr; } - const kv_server::StringT* AsString() const { - return type == Value::String - ? reinterpret_cast(value) + const kv_server::StringValueT* AsStringValue() const { + return type == Value::StringValue + ? reinterpret_cast(value) : nullptr; } kv_server::StringSetT* AsStringSet() { @@ -195,8 +232,46 @@ struct ValueUnion { ? reinterpret_cast(value) : nullptr; } + kv_server::UInt32SetT* AsUInt32Set() { + return type == Value::UInt32Set + ? reinterpret_cast(value) + : nullptr; + } + const kv_server::UInt32SetT* AsUInt32Set() const { + return type == Value::UInt32Set + ? reinterpret_cast(value) + : nullptr; + } }; +inline bool operator==(const ValueUnion& lhs, const ValueUnion& rhs) { + if (lhs.type != rhs.type) return false; + switch (lhs.type) { + case Value::NONE: { + return true; + } + case Value::StringValue: { + return *(reinterpret_cast(lhs.value)) == + *(reinterpret_cast(rhs.value)); + } + case Value::StringSet: { + return *(reinterpret_cast(lhs.value)) == + *(reinterpret_cast(rhs.value)); + } + case Value::UInt32Set: { + return *(reinterpret_cast(lhs.value)) == + *(reinterpret_cast(rhs.value)); + } + default: { + return false; + } + } +} + +inline bool operator!=(const ValueUnion& lhs, const ValueUnion& rhs) { + return !(lhs == rhs); +} + bool VerifyValue(flatbuffers::Verifier& verifier, const void* obj, Value type); bool VerifyValueVector( flatbuffers::Verifier& verifier, @@ -378,6 +453,40 @@ struct RecordUnion { } }; +inline bool operator==(const RecordUnion& lhs, const RecordUnion& rhs) { + if (lhs.type != rhs.type) return false; + switch (lhs.type) { + case Record::NONE: { + return true; + } + case Record::KeyValueMutationRecord: { + return *(reinterpret_cast( + lhs.value)) == + *(reinterpret_cast( + rhs.value)); + } + case Record::UserDefinedFunctionsConfig: { + return *(reinterpret_cast( + lhs.value)) == + *(reinterpret_cast( + rhs.value)); + } + case Record::ShardMappingRecord: { + return *(reinterpret_cast( + lhs.value)) == + *(reinterpret_cast( + rhs.value)); + } + default: { + return false; + } + } +} + +inline bool operator!=(const RecordUnion& lhs, const RecordUnion& rhs) { + return !(lhs == rhs); +} + bool VerifyRecord(flatbuffers::Verifier& verifier, const void* obj, Record type); bool VerifyRecordVector( @@ -385,14 +494,14 @@ bool VerifyRecordVector( const flatbuffers::Vector>* values, const flatbuffers::Vector* types); -struct StringT : public flatbuffers::NativeTable { - typedef String TableType; +struct StringValueT : public flatbuffers::NativeTable { + typedef StringValue TableType; std::string value{}; }; -struct String FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef StringT NativeTableType; - typedef StringBuilder Builder; +struct StringValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef StringValueT NativeTableType; + typedef StringValueBuilder Builder; struct Traits; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_VALUE = 4 @@ -404,53 +513,55 @@ struct String FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VALUE) && verifier.VerifyString(value()) && verifier.EndTable(); } - StringT* UnPack( + StringValueT* UnPack( + const flatbuffers::resolver_function_t* _resolver = nullptr) const; + void UnPackTo( + StringValueT* _o, const flatbuffers::resolver_function_t* _resolver = nullptr) const; - void UnPackTo(StringT* _o, const flatbuffers::resolver_function_t* _resolver = - nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder& _fbb, const StringT* _o, + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder& _fbb, const StringValueT* _o, const flatbuffers::rehasher_function_t* _rehasher = nullptr); }; -struct StringBuilder { - typedef String Table; +struct StringValueBuilder { + typedef StringValue Table; flatbuffers::FlatBufferBuilder& fbb_; flatbuffers::uoffset_t start_; void add_value(flatbuffers::Offset value) { - fbb_.AddOffset(String::VT_VALUE, value); + fbb_.AddOffset(StringValue::VT_VALUE, value); } - explicit StringBuilder(flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { + explicit StringValueBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - flatbuffers::Offset Finish() { + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateString( +inline flatbuffers::Offset CreateStringValue( flatbuffers::FlatBufferBuilder& _fbb, flatbuffers::Offset value = 0) { - StringBuilder builder_(_fbb); + StringValueBuilder builder_(_fbb); builder_.add_value(value); return builder_.Finish(); } -struct String::Traits { - using type = String; - static auto constexpr Create = CreateString; +struct StringValue::Traits { + using type = StringValue; + static auto constexpr Create = CreateStringValue; }; -inline flatbuffers::Offset CreateStringDirect( +inline flatbuffers::Offset CreateStringValueDirect( flatbuffers::FlatBufferBuilder& _fbb, const char* value = nullptr) { auto value__ = value ? _fbb.CreateString(value) : 0; - return kv_server::CreateString(_fbb, value__); + return kv_server::CreateStringValue(_fbb, value__); } -flatbuffers::Offset CreateString( - flatbuffers::FlatBufferBuilder& _fbb, const StringT* _o, +flatbuffers::Offset CreateStringValue( + flatbuffers::FlatBufferBuilder& _fbb, const StringValueT* _o, const flatbuffers::rehasher_function_t* _rehasher = nullptr); struct StringSetT : public flatbuffers::NativeTable { @@ -535,6 +646,76 @@ flatbuffers::Offset CreateStringSet( flatbuffers::FlatBufferBuilder& _fbb, const StringSetT* _o, const flatbuffers::rehasher_function_t* _rehasher = nullptr); +struct UInt32SetT : public flatbuffers::NativeTable { + typedef UInt32Set TableType; + std::vector value{}; +}; + +struct UInt32Set FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UInt32SetT NativeTableType; + typedef UInt32SetBuilder Builder; + struct Traits; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUE = 4 + }; + const flatbuffers::Vector* value() const { + return GetPointer*>(VT_VALUE); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_VALUE) && + verifier.VerifyVector(value()) && verifier.EndTable(); + } + UInt32SetT* UnPack( + const flatbuffers::resolver_function_t* _resolver = nullptr) const; + void UnPackTo( + UInt32SetT* _o, + const flatbuffers::resolver_function_t* _resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder& _fbb, const UInt32SetT* _o, + const flatbuffers::rehasher_function_t* _rehasher = nullptr); +}; + +struct UInt32SetBuilder { + typedef UInt32Set Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_value(flatbuffers::Offset> value) { + fbb_.AddOffset(UInt32Set::VT_VALUE, value); + } + explicit UInt32SetBuilder(flatbuffers::FlatBufferBuilder& _fbb) : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUInt32Set( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset> value = 0) { + UInt32SetBuilder builder_(_fbb); + builder_.add_value(value); + return builder_.Finish(); +} + +struct UInt32Set::Traits { + using type = UInt32Set; + static auto constexpr Create = CreateUInt32Set; +}; + +inline flatbuffers::Offset CreateUInt32SetDirect( + flatbuffers::FlatBufferBuilder& _fbb, + const std::vector* value = nullptr) { + auto value__ = value ? _fbb.CreateVector(*value) : 0; + return kv_server::CreateUInt32Set(_fbb, value__); +} + +flatbuffers::Offset CreateUInt32Set( + flatbuffers::FlatBufferBuilder& _fbb, const UInt32SetT* _o, + const flatbuffers::rehasher_function_t* _rehasher = nullptr); + struct KeyValueMutationRecordT : public flatbuffers::NativeTable { typedef KeyValueMutationRecord TableType; kv_server::KeyValueMutationType mutation_type = @@ -572,9 +753,9 @@ struct KeyValueMutationRecord FLATBUFFERS_FINAL_CLASS const void* value() const { return GetPointer(VT_VALUE); } template const T* value_as() const; - const kv_server::String* value_as_String() const { - return value_type() == kv_server::Value::String - ? static_cast(value()) + const kv_server::StringValue* value_as_StringValue() const { + return value_type() == kv_server::Value::StringValue + ? static_cast(value()) : nullptr; } const kv_server::StringSet* value_as_StringSet() const { @@ -582,6 +763,11 @@ struct KeyValueMutationRecord FLATBUFFERS_FINAL_CLASS ? static_cast(value()) : nullptr; } + const kv_server::UInt32Set* value_as_UInt32Set() const { + return value_type() == kv_server::Value::UInt32Set + ? static_cast(value()) + : nullptr; + } bool Verify(flatbuffers::Verifier& verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_MUTATION_TYPE, 1) && @@ -602,9 +788,9 @@ struct KeyValueMutationRecord FLATBUFFERS_FINAL_CLASS }; template <> -inline const kv_server::String* -KeyValueMutationRecord::value_as() const { - return value_as_String(); +inline const kv_server::StringValue* +KeyValueMutationRecord::value_as() const { + return value_as_StringValue(); } template <> @@ -613,6 +799,12 @@ KeyValueMutationRecord::value_as() const { return value_as_StringSet(); } +template <> +inline const kv_server::UInt32Set* +KeyValueMutationRecord::value_as() const { + return value_as_UInt32Set(); +} + struct KeyValueMutationRecordBuilder { typedef KeyValueMutationRecord Table; flatbuffers::FlatBufferBuilder& fbb_; @@ -1007,15 +1199,23 @@ flatbuffers::Offset CreateDataRecord( flatbuffers::FlatBufferBuilder& _fbb, const DataRecordT* _o, const flatbuffers::rehasher_function_t* _rehasher = nullptr); -inline StringT* String::UnPack( +inline bool operator==(const StringValueT& lhs, const StringValueT& rhs) { + return (lhs.value == rhs.value); +} + +inline bool operator!=(const StringValueT& lhs, const StringValueT& rhs) { + return !(lhs == rhs); +} + +inline StringValueT* StringValue::UnPack( const flatbuffers::resolver_function_t* _resolver) const { - auto _o = std::make_unique(); + auto _o = std::make_unique(); UnPackTo(_o.get(), _resolver); return _o.release(); } -inline void String::UnPackTo( - StringT* _o, const flatbuffers::resolver_function_t* _resolver) const { +inline void StringValue::UnPackTo( + StringValueT* _o, const flatbuffers::resolver_function_t* _resolver) const { (void)_o; (void)_resolver; { @@ -1024,25 +1224,33 @@ inline void String::UnPackTo( } } -inline flatbuffers::Offset String::Pack( - flatbuffers::FlatBufferBuilder& _fbb, const StringT* _o, +inline flatbuffers::Offset StringValue::Pack( + flatbuffers::FlatBufferBuilder& _fbb, const StringValueT* _o, const flatbuffers::rehasher_function_t* _rehasher) { - return CreateString(_fbb, _o, _rehasher); + return CreateStringValue(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateString( - flatbuffers::FlatBufferBuilder& _fbb, const StringT* _o, +inline flatbuffers::Offset CreateStringValue( + flatbuffers::FlatBufferBuilder& _fbb, const StringValueT* _o, const flatbuffers::rehasher_function_t* _rehasher) { (void)_rehasher; (void)_o; struct _VectorArgs { flatbuffers::FlatBufferBuilder* __fbb; - const StringT* __o; + const StringValueT* __o; const flatbuffers::rehasher_function_t* __rehasher; } _va = {&_fbb, _o, _rehasher}; (void)_va; auto _value = _o->value.empty() ? 0 : _fbb.CreateString(_o->value); - return kv_server::CreateString(_fbb, _value); + return kv_server::CreateStringValue(_fbb, _value); +} + +inline bool operator==(const StringSetT& lhs, const StringSetT& rhs) { + return (lhs.value == rhs.value); +} + +inline bool operator!=(const StringSetT& lhs, const StringSetT& rhs) { + return !(lhs == rhs); } inline StringSetT* StringSet::UnPack( @@ -1088,6 +1296,69 @@ inline flatbuffers::Offset CreateStringSet( return kv_server::CreateStringSet(_fbb, _value); } +inline bool operator==(const UInt32SetT& lhs, const UInt32SetT& rhs) { + return (lhs.value == rhs.value); +} + +inline bool operator!=(const UInt32SetT& lhs, const UInt32SetT& rhs) { + return !(lhs == rhs); +} + +inline UInt32SetT* UInt32Set::UnPack( + const flatbuffers::resolver_function_t* _resolver) const { + auto _o = std::make_unique(); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void UInt32Set::UnPackTo( + UInt32SetT* _o, const flatbuffers::resolver_function_t* _resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = value(); + if (_e) { + _o->value.resize(_e->size()); + for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { + _o->value[_i] = _e->Get(_i); + } + } + } +} + +inline flatbuffers::Offset UInt32Set::Pack( + flatbuffers::FlatBufferBuilder& _fbb, const UInt32SetT* _o, + const flatbuffers::rehasher_function_t* _rehasher) { + return CreateUInt32Set(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateUInt32Set( + flatbuffers::FlatBufferBuilder& _fbb, const UInt32SetT* _o, + const flatbuffers::rehasher_function_t* _rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder* __fbb; + const UInt32SetT* __o; + const flatbuffers::rehasher_function_t* __rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _value = _o->value.size() ? _fbb.CreateVector(_o->value) : 0; + return kv_server::CreateUInt32Set(_fbb, _value); +} + +inline bool operator==(const KeyValueMutationRecordT& lhs, + const KeyValueMutationRecordT& rhs) { + return (lhs.mutation_type == rhs.mutation_type) && + (lhs.logical_commit_time == rhs.logical_commit_time) && + (lhs.key == rhs.key) && (lhs.value == rhs.value); +} + +inline bool operator!=(const KeyValueMutationRecordT& lhs, + const KeyValueMutationRecordT& rhs) { + return !(lhs == rhs); +} + inline KeyValueMutationRecordT* KeyValueMutationRecord::UnPack( const flatbuffers::resolver_function_t* _resolver) const { auto _o = std::make_unique(); @@ -1150,6 +1421,20 @@ inline flatbuffers::Offset CreateKeyValueMutationRecord( _fbb, _mutation_type, _logical_commit_time, _key, _value_type, _value); } +inline bool operator==(const UserDefinedFunctionsConfigT& lhs, + const UserDefinedFunctionsConfigT& rhs) { + return (lhs.language == rhs.language) && + (lhs.code_snippet == rhs.code_snippet) && + (lhs.handler_name == rhs.handler_name) && + (lhs.logical_commit_time == rhs.logical_commit_time) && + (lhs.version == rhs.version); +} + +inline bool operator!=(const UserDefinedFunctionsConfigT& lhs, + const UserDefinedFunctionsConfigT& rhs) { + return !(lhs == rhs); +} + inline UserDefinedFunctionsConfigT* UserDefinedFunctionsConfig::UnPack( const flatbuffers::resolver_function_t* _resolver) const { auto _o = std::make_unique(); @@ -1215,6 +1500,17 @@ CreateUserDefinedFunctionsConfig( _version); } +inline bool operator==(const ShardMappingRecordT& lhs, + const ShardMappingRecordT& rhs) { + return (lhs.logical_shard == rhs.logical_shard) && + (lhs.physical_shard == rhs.physical_shard); +} + +inline bool operator!=(const ShardMappingRecordT& lhs, + const ShardMappingRecordT& rhs) { + return !(lhs == rhs); +} + inline ShardMappingRecordT* ShardMappingRecord::UnPack( const flatbuffers::resolver_function_t* _resolver) const { auto _o = std::make_unique(); @@ -1260,6 +1556,14 @@ inline flatbuffers::Offset CreateShardMappingRecord( _physical_shard); } +inline bool operator==(const DataRecordT& lhs, const DataRecordT& rhs) { + return (lhs.record == rhs.record); +} + +inline bool operator!=(const DataRecordT& lhs, const DataRecordT& rhs) { + return !(lhs == rhs); +} + inline DataRecordT* DataRecord::UnPack( const flatbuffers::resolver_function_t* _resolver) const { auto _o = std::make_unique(); @@ -1311,14 +1615,18 @@ inline bool VerifyValue(flatbuffers::Verifier& verifier, const void* obj, case Value::NONE: { return true; } - case Value::String: { - auto ptr = reinterpret_cast(obj); + case Value::StringValue: { + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case Value::StringSet: { auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case Value::UInt32Set: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return true; } @@ -1343,14 +1651,18 @@ inline void* ValueUnion::UnPack( const flatbuffers::resolver_function_t* resolver) { (void)resolver; switch (type) { - case Value::String: { - auto ptr = reinterpret_cast(obj); + case Value::StringValue: { + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case Value::StringSet: { auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case Value::UInt32Set: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } @@ -1361,14 +1673,18 @@ inline flatbuffers::Offset ValueUnion::Pack( const flatbuffers::rehasher_function_t* _rehasher) const { (void)_rehasher; switch (type) { - case Value::String: { - auto ptr = reinterpret_cast(value); - return CreateString(_fbb, ptr, _rehasher).Union(); + case Value::StringValue: { + auto ptr = reinterpret_cast(value); + return CreateStringValue(_fbb, ptr, _rehasher).Union(); } case Value::StringSet: { auto ptr = reinterpret_cast(value); return CreateStringSet(_fbb, ptr, _rehasher).Union(); } + case Value::UInt32Set: { + auto ptr = reinterpret_cast(value); + return CreateUInt32Set(_fbb, ptr, _rehasher).Union(); + } default: return 0; } @@ -1377,9 +1693,9 @@ inline flatbuffers::Offset ValueUnion::Pack( inline ValueUnion::ValueUnion(const ValueUnion& u) : type(u.type), value(nullptr) { switch (type) { - case Value::String: { - value = new kv_server::StringT( - *reinterpret_cast(u.value)); + case Value::StringValue: { + value = new kv_server::StringValueT( + *reinterpret_cast(u.value)); break; } case Value::StringSet: { @@ -1387,6 +1703,11 @@ inline ValueUnion::ValueUnion(const ValueUnion& u) *reinterpret_cast(u.value)); break; } + case Value::UInt32Set: { + value = new kv_server::UInt32SetT( + *reinterpret_cast(u.value)); + break; + } default: break; } @@ -1394,8 +1715,8 @@ inline ValueUnion::ValueUnion(const ValueUnion& u) inline void ValueUnion::Reset() { switch (type) { - case Value::String: { - auto ptr = reinterpret_cast(value); + case Value::StringValue: { + auto ptr = reinterpret_cast(value); delete ptr; break; } @@ -1404,6 +1725,11 @@ inline void ValueUnion::Reset() { delete ptr; break; } + case Value::UInt32Set: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } diff --git a/public/data_loading/readers/BUILD.bazel b/public/data_loading/readers/BUILD.bazel index 6cb118e1..51f671e5 100644 --- a/public/data_loading/readers/BUILD.bazel +++ b/public/data_loading/readers/BUILD.bazel @@ -108,6 +108,7 @@ cc_library( "//public/data_loading:riegeli_metadata_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@google_privacysandbox_servers_common//src/logger:request_context_logger", ], ) diff --git a/public/data_loading/readers/avro_stream_io.cc b/public/data_loading/readers/avro_stream_io.cc index cc648df0..64eb9fa8 100644 --- a/public/data_loading/readers/avro_stream_io.cc +++ b/public/data_loading/readers/avro_stream_io.cc @@ -130,9 +130,9 @@ absl::Status AvroConcurrentStreamRecordReader::ReadStreamRecords( } total_records_read += curr_byte_range_result->num_records_read; } - VLOG(2) << "Done reading " << total_records_read << " records in " - << absl::ToDoubleMilliseconds(latency_recorder.GetLatency()) - << " ms."; + PS_VLOG(2, options_.log_context) + << "Done reading " << total_records_read << " records in " + << absl::ToDoubleMilliseconds(latency_recorder.GetLatency()) << " ms."; return absl::OkStatus(); } @@ -153,20 +153,20 @@ AvroConcurrentStreamRecordReader::ReadByteRange( const ByteRange& byte_range, const std::function& record_callback) { - VLOG(2) << "Reading byte_range: " - << "[" << byte_range.begin_offset << "," << byte_range.end_offset - << "]"; + PS_VLOG(2, options_.log_context) + << "Reading byte_range: " << "[" << byte_range.begin_offset << "," + << byte_range.end_offset << "]"; ScopeLatencyMetricsRecorder latency_recorder(KVServerContextMap()->SafeMetric()); auto record_stream = stream_factory_(); - VLOG(9) << "creating input stream"; + PS_VLOG(9, options_.log_context) << "creating input stream"; avro::InputStreamPtr input_stream = avro::istreamInputStream(record_stream->Stream()); - VLOG(9) << "creating reader"; + PS_VLOG(9, options_.log_context) << "creating reader"; auto record_reader = std::make_unique>( std::move(input_stream)); - VLOG(9) << "syncing to block"; + PS_VLOG(9, options_.log_context) << "syncing to block"; if (record_stream->Stream().bad()) { return absl::InternalError("Avro stream is bad"); } @@ -183,14 +183,15 @@ AvroConcurrentStreamRecordReader::ReadByteRange( // TODO: b/269119466 - Figure out how to handle this better. Maybe add // metrics to track callback failures (??). if (!overall_status.ok()) { - LOG(ERROR) << "Record callback failed to process some records with: " - << overall_status; + PS_LOG(ERROR, options_.log_context) + << "Record callback failed to process some records with: " + << overall_status; return overall_status; } - VLOG(2) << "Done reading " << num_records_read << " records in byte_range: [" - << byte_range.begin_offset << "," << byte_range.end_offset << "] in " - << absl::ToDoubleMilliseconds(latency_recorder.GetLatency()) - << " ms."; + PS_VLOG(2, options_.log_context) + << "Done reading " << num_records_read << " records in byte_range: [" + << byte_range.begin_offset << "," << byte_range.end_offset << "] in " + << absl::ToDoubleMilliseconds(latency_recorder.GetLatency()) << " ms."; ByteRangeResult result; result.num_records_read = num_records_read; return result; diff --git a/public/data_loading/readers/avro_stream_io.h b/public/data_loading/readers/avro_stream_io.h index de7b8c35..818c62e4 100644 --- a/public/data_loading/readers/avro_stream_io.h +++ b/public/data_loading/readers/avro_stream_io.h @@ -34,6 +34,7 @@ #include "components/telemetry/server_definition.h" #include "public/data_loading/readers/stream_record_reader.h" #include "public/data_loading/riegeli_metadata.pb.h" +#include "src/logger/request_context_logger.h" #include "src/telemetry/telemetry_provider.h" namespace kv_server { @@ -93,6 +94,9 @@ class AvroConcurrentStreamRecordReader : public StreamRecordReader { struct Options { int64_t num_worker_threads = std::thread::hardware_concurrency(); int64_t min_byte_range_size_bytes = 8 * 1024 * 1024; // 8MB + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext); Options() {} }; AvroConcurrentStreamRecordReader( diff --git a/public/data_loading/readers/delta_record_stream_reader.h b/public/data_loading/readers/delta_record_stream_reader.h index 2e935e78..f4d04424 100644 --- a/public/data_loading/readers/delta_record_stream_reader.h +++ b/public/data_loading/readers/delta_record_stream_reader.h @@ -43,12 +43,20 @@ namespace kv_server { template class DeltaRecordStreamReader : public DeltaRecordReader { public: - explicit DeltaRecordStreamReader(SrcStreamT& src_stream) + explicit DeltaRecordStreamReader( + SrcStreamT& src_stream, + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext)) : stream_reader_(RiegeliStreamReader( - src_stream, [](const riegeli::SkippedRegion& region) { - LOG(ERROR) << "Failed to read region: " << region; + src_stream, + [&log_context](const riegeli::SkippedRegion& region, + riegeli::RecordReaderBase& record_reader) { + PS_LOG(ERROR, log_context) << "Failed to read region: " << region; return true; - })) {} + }, + log_context)), + log_context_(log_context) {} DeltaRecordStreamReader(const DeltaRecordStreamReader&) = delete; DeltaRecordStreamReader& operator=(const DeltaRecordStreamReader&) = delete; @@ -64,6 +72,7 @@ class DeltaRecordStreamReader : public DeltaRecordReader { private: RiegeliStreamReader stream_reader_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; template diff --git a/public/data_loading/readers/riegeli_stream_io.h b/public/data_loading/readers/riegeli_stream_io.h index b15ea78f..8359bfc3 100644 --- a/public/data_loading/readers/riegeli_stream_io.h +++ b/public/data_loading/readers/riegeli_stream_io.h @@ -37,6 +37,7 @@ #include "public/data_loading/riegeli_metadata.pb.h" #include "riegeli/bytes/istream_reader.h" #include "riegeli/records/record_reader.h" +#include "src/logger/request_context_logger.h" #include "src/telemetry/telemetry_provider.h" namespace kv_server { @@ -48,11 +49,15 @@ class RiegeliStreamReader : public StreamRecordReader { // `data_input` must be at the file beginning when passed in. explicit RiegeliStreamReader( std::istream& data_input, - std::function recover) + std::function + recover, + privacy_sandbox::server_common::log::PSLogContext& log_context) : reader_(riegeli::RecordReader( riegeli::IStreamReader(&data_input), riegeli::RecordReaderBase::Options().set_recovery( - std::move(recover)))) {} + std::move(recover)))), + log_context_(log_context) {} absl::StatusOr GetKVFileMetadata() override { riegeli::RecordsMetadata metadata; @@ -65,7 +70,8 @@ class RiegeliStreamReader : public StreamRecordReader { } auto file_metadata = metadata.GetExtension(kv_file_metadata); - VLOG(2) << "File metadata: " << file_metadata.DebugString(); + PS_VLOG(2, log_context_) + << "File metadata: " << file_metadata.DebugString(); return file_metadata; } @@ -81,7 +87,7 @@ class RiegeliStreamReader : public StreamRecordReader { overall_status.Update(callback_status); } if (!overall_status.ok()) { - LOG(ERROR) << overall_status; + PS_LOG(ERROR, log_context_) << overall_status; } return reader_.status(); } @@ -91,6 +97,7 @@ class RiegeliStreamReader : public StreamRecordReader { private: riegeli::RecordReader> reader_; + privacy_sandbox::server_common::log::PSLogContext& log_context_; }; const int64_t kDefaultNumWorkerThreads = std::thread::hardware_concurrency(); @@ -134,9 +141,16 @@ class ConcurrentStreamRecordReader : public StreamRecordReader { struct Options { int64_t num_worker_threads = kDefaultNumWorkerThreads; int64_t min_shard_size_bytes = kDefaultMinShardSize; - std::function recovery_callback = - [](const riegeli::SkippedRegion& region) { - LOG(WARNING) << "Skipping over corrupted region: " << region; + privacy_sandbox::server_common::log::PSLogContext& log_context = + const_cast( + privacy_sandbox::server_common::log::kNoOpContext); + std::function + recovery_callback = [log_context = &this->log_context]( + const riegeli::SkippedRegion& region, + riegeli::RecordReaderBase& record_reader) { + PS_LOG(WARNING, *log_context) + << "Skipping over corrupted region: " << region; return true; }; }; @@ -188,10 +202,15 @@ absl::StatusOr ConcurrentStreamRecordReader::GetKVFileMetadata() { auto record_stream = stream_factory_(); RiegeliStreamReader metadata_reader( - record_stream->Stream(), [](const riegeli::SkippedRegion& region) { - LOG(WARNING) << "Skipping over corrupted region: " << region; + record_stream->Stream(), + [log_context = &options_.log_context]( + const riegeli::SkippedRegion& region, + riegeli::RecordReaderBase& record_reader) { + PS_LOG(WARNING, *log_context) + << "Skipping over corrupted region: " << region; return true; - }); + }, + options_.log_context); return metadata_reader.GetKVFileMetadata(); } template @@ -292,9 +311,9 @@ absl::Status ConcurrentStreamRecordReader::ReadStreamRecords( total_records_read += curr_shard_result->num_records_read; prev_shard_result = curr_shard_result; } - VLOG(2) << "Done reading " << total_records_read << " records in " - << absl::ToDoubleMilliseconds(latency_recorder.GetLatency()) - << " ms."; + PS_VLOG(2, options_.log_context) + << "Done reading " << total_records_read << " records in " + << absl::ToDoubleMilliseconds(latency_recorder.GetLatency()) << " ms."; return absl::OkStatus(); } @@ -303,8 +322,9 @@ absl::StatusOr::ShardResult> ConcurrentStreamRecordReader::ReadShardRecords( const ShardRange& shard, const std::function& record_callback) { - VLOG(2) << "Reading shard: " - << "[" << shard.start_pos << "," << shard.end_pos << "]"; + PS_VLOG(2, options_.log_context) + << "Reading shard: " << "[" << shard.start_pos << "," << shard.end_pos + << "]"; ScopeLatencyMetricsRecorder< ServerSafeMetricsContext, kConcurrentStreamRecordReaderReadShardRecordsLatency> @@ -331,18 +351,19 @@ ConcurrentStreamRecordReader::ReadShardRecords( // TODO: b/269119466 - Figure out how to handle this better. Maybe add // metrics to track callback failures (??). if (!overall_status.ok()) { - LOG(ERROR) << "Record callback failed to process some records with: " - << overall_status; + PS_LOG(ERROR, options_.log_context) + << "Record callback failed to process some records with: " + << overall_status; } if (!record_reader.ok()) { return record_reader.status(); } shard_result.next_shard_first_record_pos = next_record_pos; shard_result.num_records_read = num_records_read; - VLOG(2) << "Done reading " << num_records_read << " records in shard: [" - << shard.start_pos << "," << shard.end_pos << "] in " - << absl::ToDoubleMilliseconds(latency_recorder.GetLatency()) - << " ms."; + PS_VLOG(2, options_.log_context) + << "Done reading " << num_records_read << " records in shard: [" + << shard.start_pos << "," << shard.end_pos << "] in " + << absl::ToDoubleMilliseconds(latency_recorder.GetLatency()) << " ms."; return shard_result; } diff --git a/public/data_loading/readers/riegeli_stream_record_reader_factory.cc b/public/data_loading/readers/riegeli_stream_record_reader_factory.cc index d1b909fb..20133058 100644 --- a/public/data_loading/readers/riegeli_stream_record_reader_factory.cc +++ b/public/data_loading/readers/riegeli_stream_record_reader_factory.cc @@ -21,10 +21,15 @@ namespace kv_server { std::unique_ptr RiegeliStreamRecordReaderFactory::CreateReader(std::istream& data_input) const { return std::make_unique>( - data_input, [](const riegeli::SkippedRegion& skipped_region) { - LOG(WARNING) << "Skipping over corrupted region: " << skipped_region; + data_input, + [log_context = &options_.log_context]( + const riegeli::SkippedRegion& skipped_region, + riegeli::RecordReaderBase& record_reader) { + PS_LOG(WARNING, *log_context) + << "Skipping over corrupted region: " << skipped_region; return true; - }); + }, + options_.log_context); } std::unique_ptr diff --git a/public/data_loading/record_utils.cc b/public/data_loading/record_utils.cc index 75d1ee4c..65126a74 100644 --- a/public/data_loading/record_utils.cc +++ b/public/data_loading/record_utils.cc @@ -49,6 +49,11 @@ absl::Status ValidateValue(const KeyValueMutationRecord& kv_mutation_record) { kv_mutation_record.value_as_StringSet()->value() == nullptr)) { return absl::InvalidArgumentError("StringSet value not set."); } + if (kv_mutation_record.value_type() == Value::UInt32Set && + (kv_mutation_record.value_as_UInt32Set() == nullptr || + kv_mutation_record.value_as_UInt32Set()->value() == nullptr)) { + return absl::InvalidArgumentError("UInt32Set value not set."); + } return absl::OkStatus(); } @@ -159,4 +164,18 @@ absl::StatusOr> MaybeGetRecordValue( return values; } +template <> +absl::StatusOr> MaybeGetRecordValue( + const KeyValueMutationRecord& record) { + const kv_server::UInt32Set* maybe_value = record.value_as_UInt32Set(); + if (!maybe_value) { + return absl::InvalidArgumentError(absl::StrCat( + "KeyValueMutationRecord does not contain expected value type. " + "Expected: UInt32Set", + ". Actual: ", EnumNameValue(record.value_type()))); + } + return std::vector(maybe_value->value()->begin(), + maybe_value->value()->end()); +} + } // namespace kv_server diff --git a/public/data_loading/record_utils.h b/public/data_loading/record_utils.h index f56fa6e6..30a752a7 100644 --- a/public/data_loading/record_utils.h +++ b/public/data_loading/record_utils.h @@ -37,6 +37,7 @@ inline std::ostream& operator<<(std::ostream& os, os << string_value.value; return os; } + inline std::ostream& operator<<(std::ostream& os, const StringSetT& string_set_value) { for (const auto& string_value : string_set_value.value) { @@ -45,6 +46,13 @@ inline std::ostream& operator<<(std::ostream& os, return os; } +inline std::ostream& operator<<(std::ostream& os, const UInt32SetT& set_value) { + for (const auto& value : set_value.value) { + os << value << ", "; + } + return os; +} + inline std::ostream& operator<<(std::ostream& os, const ValueUnion& value_union) { switch (value_union.type) { @@ -56,6 +64,10 @@ inline std::ostream& operator<<(std::ostream& os, os << *(reinterpret_cast(value_union.value)); break; } + case Value::UInt32Set: { + os << *(reinterpret_cast(value_union.value)); + break; + } case Value::NONE: { break; } @@ -180,6 +192,12 @@ template <> absl::StatusOr> MaybeGetRecordValue( const KeyValueMutationRecord& record); +// Returns the vector of uint32_t stored in `record.value`. Returns error if the +// record.value is not a uint32_t set. +template <> +absl::StatusOr> MaybeGetRecordValue( + const KeyValueMutationRecord& record); + } // namespace kv_server #endif // PUBLIC_DATA_LOADING_RECORD_UTILS_H_ diff --git a/public/data_loading/record_utils_test.cc b/public/data_loading/record_utils_test.cc index d25fc614..e8e3064c 100644 --- a/public/data_loading/record_utils_test.cc +++ b/public/data_loading/record_utils_test.cc @@ -183,6 +183,38 @@ TEST(RecordUtilsTest, DataRecordWithKeyValueMutationRecordWithStringSetValue) { EXPECT_TRUE(status.ok()) << status; } +TEST(RecordUtilsTest, DataRecordWithKeyValueMutationRecordWithUInt32SetValue) { + // Serialize + KeyValueMutationRecordT kv_mutation_record_native; + kv_mutation_record_native.key = "key"; + kv_mutation_record_native.logical_commit_time = 5; + UInt32SetT value_native; + value_native.value = {1000, 1001, 1002}; + kv_mutation_record_native.value.Set(std::move(value_native)); + DataRecordT data_record_native; + data_record_native.record.Set(std::move(kv_mutation_record_native)); + auto [fbs_buffer, serialized_string_view] = Serialize(data_record_native); + // Deserialize + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .Times(1) + .WillOnce([](const DataRecord& fbs_record) { + const KeyValueMutationRecord& kv_mutation_record = + *fbs_record.record_as_KeyValueMutationRecord(); + EXPECT_EQ(kv_mutation_record.key()->string_view(), "key"); + EXPECT_EQ(kv_mutation_record.logical_commit_time(), 5); + absl::StatusOr> maybe_record_value = + MaybeGetRecordValue>(kv_mutation_record); + EXPECT_TRUE(maybe_record_value.ok()) << maybe_record_value.status(); + EXPECT_THAT(*maybe_record_value, + testing::UnorderedElementsAre(1000, 1001, 1002)); + return absl::OkStatus(); + }); + auto status = DeserializeRecord(serialized_string_view, + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + TEST(DataRecordTest, DeserializeDataRecordEmptyRecordFailure) { DataRecordT data_record_native; auto [fbs_buffer, serialized_string_view] = Serialize(data_record_native); diff --git a/public/data_loading/records_utils.cc b/public/data_loading/records_utils.cc index 73c42b99..a16f293c 100644 --- a/public/data_loading/records_utils.cc +++ b/public/data_loading/records_utils.cc @@ -14,10 +14,11 @@ #include "public/data_loading/records_utils.h" -#include - #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "flatbuffers/flatbuffer_builder.h" +#include "public/data_loading/record_utils.h" namespace kv_server { namespace { @@ -50,6 +51,13 @@ ValueUnion BuildValueUnion(const KeyValueMutationRecordValueT& value, .value = CreateStringSet(builder, values_offset).Union(), }; } + if constexpr (std::is_same_v>) { + auto values_offset = builder.CreateVector(arg); + return ValueUnion{ + .value_type = Value::UInt32Set, + .value = CreateUInt32Set(builder, values_offset).Union(), + }; + } if constexpr (std::is_same_v) { return ValueUnion{ .value_type = Value::NONE, @@ -148,6 +156,11 @@ absl::Status ValidateValue(const KeyValueMutationRecord& kv_mutation_record) { kv_mutation_record.value_as_StringSet()->value() == nullptr)) { return absl::InvalidArgumentError("StringSet value not set."); } + if (kv_mutation_record.value_type() == Value::UInt32Set && + (kv_mutation_record.value_as_UInt32Set() == nullptr || + kv_mutation_record.value_as_UInt32Set()->value() == nullptr)) { + return absl::InvalidArgumentError("UInt32Set value not set."); + } return absl::OkStatus(); } @@ -202,6 +215,9 @@ KeyValueMutationRecordValueT GetRecordStructValue( if (fbs_record.value_type() == Value::StringSet) { value = GetRecordValue>(fbs_record); } + if (fbs_record.value_type() == Value::UInt32Set) { + value = GetRecordValue>(fbs_record); + } return value; } @@ -351,6 +367,12 @@ std::vector GetRecordValue( return values; } +template <> +std::vector GetRecordValue(const KeyValueMutationRecord& record) { + return std::vector(record.value_as_UInt32Set()->value()->begin(), + record.value_as_UInt32Set()->value()->end()); +} + template <> KeyValueMutationRecordStruct GetTypedRecordStruct( const DataRecord& data_record) { diff --git a/public/data_loading/records_utils.h b/public/data_loading/records_utils.h index fbe25f48..a1532079 100644 --- a/public/data_loading/records_utils.h +++ b/public/data_loading/records_utils.h @@ -17,14 +17,11 @@ #ifndef PUBLIC_DATA_LOADING_RECORDS_UTILS_H_ #define PUBLIC_DATA_LOADING_RECORDS_UTILS_H_ -#include #include -#include #include #include #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "public/data_loading/data_loading_generated.h" #include "public/data_loading/record_utils.h" @@ -38,7 +35,7 @@ enum class DataRecordType : int { using KeyValueMutationRecordValueT = std::variant>; + std::vector, std::vector>; struct KeyValueMutationRecordStruct { KeyValueMutationType mutation_type; @@ -141,6 +138,8 @@ std::string_view GetRecordValue(const KeyValueMutationRecord& record); template <> std::vector GetRecordValue( const KeyValueMutationRecord& record); +template <> +std::vector GetRecordValue(const KeyValueMutationRecord& record); // Utility function to get the union record set on the `data_record`. Must // be called after checking the type of the union record using diff --git a/public/data_loading/records_utils_test.cc b/public/data_loading/records_utils_test.cc index 320d184c..03c5f90f 100644 --- a/public/data_loading/records_utils_test.cc +++ b/public/data_loading/records_utils_test.cc @@ -14,9 +14,9 @@ #include "public/data_loading/records_utils.h" -#include "absl/hash/hash_testing.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "public/data_loading/record_utils.h" namespace kv_server { namespace { @@ -42,10 +42,6 @@ UserDefinedFunctionsConfigStruct GetUdfConfigStruct( return udf_config_struct; } -ShardMappingRecordStruct GetShardMappingRecordStruct() { - return ShardMappingRecordStruct{.logical_shard = 0, .physical_shard = 0}; -} - DataRecordStruct GetDataRecord(RecordT record) { DataRecordStruct data_record_struct; data_record_struct.record = record; @@ -129,6 +125,11 @@ void ExpectEqual(const KeyValueMutationRecordStruct& record, testing::ContainerEq( GetRecordValue>(fbs_record))); } + if (fbs_record.value_type() == Value::UInt32Set) { + EXPECT_THAT(std::get>(record.value), + testing::ContainerEq( + GetRecordValue>(fbs_record))); + } } void ExpectEqual(const UserDefinedFunctionsConfigStruct& record, @@ -237,6 +238,22 @@ TEST(DataRecordTest, EXPECT_TRUE(status.ok()) << status; } +TEST(DataRecordTest, + DeserializeDataRecord_ToFbsRecord_KVMutation_UInt32VectorValue_Success) { + std::vector values({1000, 1001, 1002}); + auto data_record_struct = GetDataRecord(GetKeyValueMutationRecord(values)); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecord& data_record_fbs) { + ExpectEqual(data_record_struct, data_record_fbs); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + TEST(DataRecordTest, DeserializeDataRecord_ToFbsRecord_KVMutation_KeyNotSet_Failure) { flatbuffers::FlatBufferBuilder builder; @@ -330,6 +347,29 @@ TEST( EXPECT_EQ(status.message(), "StringSet value not set."); } +TEST( + DataRecordTest, + DeserializeDataRecord_ToFbsRecord_KVMutation_UInt32SetValueNotSet_Failure) { + flatbuffers::FlatBufferBuilder builder; + const auto kv_mutation_fbs = CreateKeyValueMutationRecordDirect( + builder, + /*mutation_type=*/KeyValueMutationType::Update, + /*logical_commit_time=*/0, + /*key=*/"key", + /*value_type=*/Value::UInt32Set, + /*value=*/CreateUInt32Set(builder).Union()); + const auto data_record_fbs = + CreateDataRecord(builder, /*record_type=*/Record::KeyValueMutationRecord, + kv_mutation_fbs.Union()); + builder.Finish(data_record_fbs); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call).Times(0); + auto status = DeserializeDataRecord(ToStringView(builder), + record_callback.AsStdFunction()); + ASSERT_FALSE(status.ok()) << status; + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + TEST(DataRecordTest, DeserializeDataRecord_ToFbsRecord_UdfConfig_Success) { auto data_record_struct = GetDataRecord(GetUdfConfigStruct()); testing::MockFunction record_callback; @@ -421,6 +461,22 @@ TEST(DataRecordTest, EXPECT_TRUE(status.ok()) << status; } +TEST(DataRecordTest, + DeserializeDataRecord_ToStruct_KVMutation_Uint32VectorValue_Success) { + std::vector values({1000, 1001, 1002}); + auto data_record_struct = GetDataRecord(GetKeyValueMutationRecord(values)); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecordStruct& actual_record) { + EXPECT_EQ(data_record_struct, actual_record); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + TEST(DataRecordTest, DeserializeDataRecord_ToStruct_UdfConfig_Success) { auto data_record_struct = GetDataRecord(GetUdfConfigStruct()); testing::MockFunction record_callback; diff --git a/public/query/get_values.proto b/public/query/get_values.proto index b375ac2e..444de154 100644 --- a/public/query/get_values.proto +++ b/public/query/get_values.proto @@ -36,6 +36,9 @@ service KeyValueService { message GetValuesRequest { // [DSP] List of keys to query values for, under the namespace keys. repeated string keys = 1; + // [DSP] List of keys to query values for, under the namespace interest_group_names. + // Results will be stored in per_interest_group_data field of GetValuesResponse. + repeated string interest_group_names = 7; // [DSP] The browser sets the hostname of the publisher page to be the value. // If no specific value is available in the system for this subkey, @@ -71,6 +74,9 @@ message GetValuesResponse { // Map of key value pairs for namespace keys. map keys = 1; + // Map of key value pairs for namespace interest_group_names. + map per_interest_group_data = 5; + // Map of key value pairs for namespace renderUrls. map render_urls = 2; diff --git a/public/test_util/BUILD.bazel b/public/test_util/BUILD.bazel index f5ee9f3f..9542d9e6 100644 --- a/public/test_util/BUILD.bazel +++ b/public/test_util/BUILD.bazel @@ -36,3 +36,9 @@ cc_library( "@com_google_googletest//:gtest", ], ) + +cc_library( + name = "request_example", + testonly = 1, + hdrs = ["request_example.h"], +) diff --git a/public/test_util/request_example.h b/public/test_util/request_example.h new file mode 100644 index 00000000..b24c9a73 --- /dev/null +++ b/public/test_util/request_example.h @@ -0,0 +1,143 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PUBLIC_TEST_UTIL_REQUEST_EXAMPLE_H_ +#define PUBLIC_TEST_UTIL_REQUEST_EXAMPLE_H_ + +#include + +namespace kv_server { + +// Non-consented V2 request example +constexpr std::string_view kExampleV2RequestInJson = R"( + { + "metadata": { + "hostname": "example.com" + }, + "partitions": [ + { + "id": 0, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "structured", + "groupNames" + ], + "data": [ + "hello" + ] + }, + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key1" + ] + } + ] + } + ] + } + )"; +// Consented V2 request example without log context +constexpr std::string_view kExampleConsentedV2RequestInJson = R"( + { + "metadata": { + "hostname": "example.com" + }, + "partitions": [ + { + "id": 0, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "structured", + "groupNames" + ], + "data": [ + "hello" + ] + }, + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key1" + ] + } + ] + } + ], + "consented_debug_config": { + "is_consented": true, + "token": "debug_token" + } + })"; + +// Consented V2 request example with log context +constexpr std::string_view kExampleConsentedV2RequestWithLogContextInJson = R"( + { + "metadata": { + "hostname": "example.com" + }, + "partitions": [ + { + "id": 0, + "compressionGroupId": 0, + "arguments": [ + { + "tags": [ + "structured", + "groupNames" + ], + "data": [ + "hello" + ] + }, + { + "tags": [ + "custom", + "keys" + ], + "data": [ + "key1" + ] + } + ] + } + ], + "consented_debug_config": { + "is_consented": true, + "token": "debug_token" + }, + "log_context": { + "generation_id": "client_UUID", + "adtech_debug_id": "adtech_debug_test" + }, + })"; + +// Example consented debug token used in the unit tests +constexpr std::string_view kExampleConsentedDebugToken = "debug_token"; + +} // namespace kv_server + +#endif // PUBLIC_TEST_UTIL_REQUEST_EXAMPLE_H_ diff --git a/testing/functionaltest/README.md b/testing/functionaltest/README.md index 00968081..1eb2c5ba 100644 --- a/testing/functionaltest/README.md +++ b/testing/functionaltest/README.md @@ -15,8 +15,8 @@ Docker Compose version v2.17.2 ```sh builders/tools/bazel-debian run //production/packaging/aws/data_server:copy_to_dist \ - --//:platform=local \ - --//:instance=local + --config=local_instance \ + --config=local_platform unzip -d dist/debian -j -u dist/debian/server_artifacts.zip server/bin/server touch dist/debian/server docker load -i dist/server_docker_image.tar @@ -33,8 +33,8 @@ Generate data files for the functional test suites, using: ```sh builders/tools/bazel-debian run //testing/functionaltest:copy_to_dist \ - --//:platform=local \ - --//:instance=local + --config=local_instance \ + --config=local_platform ``` ### Direct execution mode diff --git a/third_party_deps/cpp_repositories.bzl b/third_party_deps/cpp_repositories.bzl index 9c2b1a07..e42f7f0e 100644 --- a/third_party_deps/cpp_repositories.bzl +++ b/third_party_deps/cpp_repositories.bzl @@ -23,9 +23,9 @@ def cpp_repositories(): repo_mapping = { "@org_brotli": "@brotli", }, - sha256 = "32f303a9b0b6e07101a7a95a4cc364fb4242f0f7431de5da1a2e0ee61f5924c5", - strip_prefix = "riegeli-562f26cbb685aae10b7d32e32fb53d2e42a5d8c2", - url = "https://github.com/google/riegeli/archive/562f26cbb685aae10b7d32e32fb53d2e42a5d8c2.zip", + sha256 = "0aad9af403e5f394cf30330658a361c622a0155499d8726112b8fb1716750cf9", + strip_prefix = "riegeli-0bf809f36ae5be8a5684f63d8238b5440b42bbec", + url = "https://github.com/google/riegeli/archive/0bf809f36ae5be8a5684f63d8238b5440b42bbec.zip", ) #external deps for riegeli @@ -90,3 +90,14 @@ def cpp_repositories(): "https://github.com/apache/avro/archive/release-1.10.2.tar.gz", ], ) + + ### Roaring Bitmaps + http_archive( + name = "roaring_bitmap", + build_file = "//third_party_deps:roaring.BUILD", + sha256 = "c7b0e36dfeaca0d951b2842a747ddf6fec95355abba5970511bb68d698e10a90", + strip_prefix = "CRoaring-3.0.1", + urls = [ + "https://github.com/RoaringBitmap/CRoaring/archive/refs/tags/v3.0.1.zip", + ], + ) diff --git a/third_party_deps/roaring.BUILD b/third_party_deps/roaring.BUILD new file mode 100644 index 00000000..fa537a63 --- /dev/null +++ b/third_party_deps/roaring.BUILD @@ -0,0 +1,32 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_library( + name = "c_roaring", + srcs = glob(["src/**/*.c"]), + hdrs = glob(["include/**/*.h"] + ["cpp/**/*.hh"]), + defines = ["ROARING_EXCEPTIONS"], + includes = [ + "cpp/.", + "include/.", + ], +) diff --git a/tools/latency_benchmarking/BUILD.bazel b/tools/benchmarking/BUILD.bazel similarity index 100% rename from tools/latency_benchmarking/BUILD.bazel rename to tools/benchmarking/BUILD.bazel diff --git a/tools/latency_benchmarking/README.md b/tools/benchmarking/README.md similarity index 94% rename from tools/latency_benchmarking/README.md rename to tools/benchmarking/README.md index 6372288d..dcad8f6a 100644 --- a/tools/latency_benchmarking/README.md +++ b/tools/benchmarking/README.md @@ -1,18 +1,7 @@ -# Latency Benchmarking Tool +# Request Benchmarking Tool ## Requirements -- Install ghz - -The `run_benchmarks` script uses [ghz](https://ghz.sh/docs/intro). - -Follow instructions to [install ghz](https://ghz.sh/docs/install) and make sure it is installed -correctly: - -```sh -ghz -v -``` - - Follow deployment guides To use the tool you will need to have a complete deployment setup. @@ -67,7 +56,7 @@ minutes to complete the benchmark + 1-2 minutes for pre- and post processing. Usage: ```sh -./tools/latency_benchmarking/run_benchmarks +./tools/benchmarking/run_benchmarks ``` Flags: @@ -129,14 +118,13 @@ Start from the workspace root. SNAPSHOT_DIR=/path/to/snapshot/dir NUMBER_OF_LOOKUP_KEYS_LIST="1 10 100" SERVER_ADDRESS="demo.kv-server.your-domain.example:8443" -./tools/latency_benchmarking/run_benchmarks \ +./tools/benchmarking/run_benchmarks \ --server-address ${SERVER_ADDRESS} \ --snapshot-dir ${SNAPSHOT_DIR} \ --number-of-lookup-keys-list "${NUMBER_OF_LOOKUP_KEYS_LIST}" ``` -The summary can be found in -`dist/tools/latency_benchmarking/output//summary.csv`. +The summary can be found in `dist/tools/benchmarking/output//summary.csv`. ### deploy_and_benchmark @@ -166,7 +154,7 @@ iterate through. Usage: ```sh -./tools/latency_benchmarking/deploy_and_benchmark +./tools/benchmarking/deploy_and_benchmark ``` Flags: @@ -205,7 +193,7 @@ Flags: a comma-separated list of terraform variable overrides, e.g. `instance_type=c5.4xlarge,enclave_cpu_count=12` - Example: `/tools/latency_benchmarking/example/aws_tf_overrides.txt` + Example: `/tools/benchmarking/example/aws_tf_overrides.txt` - `--udf-delta-dir` (Optional) @@ -227,7 +215,7 @@ Flags: Path to JSON lines file to iterate through to generate requests. - Example: `/tools/latency_benchmarking/example/request_metadata.jsonl` + Example: `/tools/benchmarking/example/request_metadata.jsonl` - `--number-of-lookup-keys-list` (Optional) @@ -353,7 +341,7 @@ Start from the workspace root. - Each line is considered a set of variables to be overriden in one `terraform apply` command - - For an example, see `/latency_benchmarking/example/aws_tf_overrides.txt`. + - For an example, see `/benchmarking/example/aws_tf_overrides.txt`. ```sh TF_OVERRIDES=/path/to/tf_variable_overrides.txt @@ -371,7 +359,7 @@ Start from the workspace root. 1. Run the script and wait for the result ```sh - ./tools/latency_benchmarking/deploy_and_benchmark \ + ./tools/benchmarking/deploy_and_benchmark \ --cloud-provider ${CLOUD_PROVIDER} \ --server-url ${SERVER_URL} \ --snapshot-dir ${SNAPSHOT_DIR} \ diff --git a/tools/latency_benchmarking/create_csv_summary.py b/tools/benchmarking/create_csv_summary.py similarity index 100% rename from tools/latency_benchmarking/create_csv_summary.py rename to tools/benchmarking/create_csv_summary.py diff --git a/tools/latency_benchmarking/deploy_and_benchmark b/tools/benchmarking/deploy_and_benchmark similarity index 81% rename from tools/latency_benchmarking/deploy_and_benchmark rename to tools/benchmarking/deploy_and_benchmark index 730eacf8..edb1466e 100755 --- a/tools/latency_benchmarking/deploy_and_benchmark +++ b/tools/benchmarking/deploy_and_benchmark @@ -43,15 +43,15 @@ declare SERVER_ENDPOINT # CSV I/O declare -a BENCHMARK_CSVS -CSV_OUTPUT="${WORKSPACE}/dist/tools/latency_benchmarking/output/output.csv" -declare -r DOCKER_OUTPUT_CSV="/tmp/latency_benchmarking/output/deploy_and_benchmark/output.csv" -declare -r SNAPSHOT_CSV_DIR="${WORKSPACE}/dist/tools/latency_benchmarking/deploy_and_benchmark/snapshot_csvs/${START}" -declare -r DOCKER_SNAPSHOT_CSV_DIR="/tmp/latency_benchmarking/deploy_and_benchmark/snapshot_csvs/" -declare -r DOCKER_SNAPSHOT_DIR="/tmp/latency_benchmarking/deploy_and_benchmark/snapshots/" +CSV_OUTPUT="${WORKSPACE}/dist/tools/benchmarking/output/output.csv" +declare -r DOCKER_OUTPUT_CSV="/tmp/benchmarking/output/deploy_and_benchmark/output.csv" +declare -r SNAPSHOT_CSV_DIR="${WORKSPACE}/dist/tools/benchmarking/deploy_and_benchmark/snapshot_csvs/${START}" +declare -r DOCKER_SNAPSHOT_CSV_DIR="/tmp/benchmarking/deploy_and_benchmark/snapshot_csvs/" +declare -r DOCKER_SNAPSHOT_DIR="/tmp/benchmarking/deploy_and_benchmark/snapshots/" # run_benchmarks writes output to this directory -CSV_SUMMARY_INPUT_DIR="${WORKSPACE}/dist/tools/latency_benchmarking/output" -DOCKER_INPUT_DIR="/tmp/latency_benchmarking/output" +CSV_SUMMARY_INPUT_DIR="${WORKSPACE}/dist/tools/benchmarking/output" +DOCKER_INPUT_DIR="/tmp/benchmarking/summaries/" readonly DOCKER_INPUT_DIR DESTROY_INSTANCES=0 @@ -84,7 +84,10 @@ usage: [--tf-var-file] (Required) Full path to tfvars.json file. [--tf-backend-config] (Required) Full path to tf backend.conf file. [--server-url] (Required) URL of deployed server. - [--snapshot-dir] (Required) Full path to a directory of snapshot files. + [--snapshot-dir] OR (Required) Full path to a directory of snapshot files + [--lookup-keys-file] or a file with lookup keys. + The lookup-keys-file should have one lookup key per line + and ignores the filter-snapshot-by-sets option. [--csv-output] (Optional) Path to output file for summary of benchmarks. [--tf-overrides] (Optional) Path to file with terraform variable overrides. [--udf-delta-dir] (Optional) Full path to directory of udf delta files. @@ -108,12 +111,12 @@ function convert_snapshots_to_csvs() { for SNAPSHOT_FILE in "${SNAPSHOT_DIR}"/*; do SNAPSHOT_FILENAME=$(basename "${SNAPSHOT_FILE}") EXTRA_DOCKER_RUN_ARGS+=" --volume ${SNAPSHOT_CSV_DIR}:${DOCKER_SNAPSHOT_CSV_DIR} --volume ${SNAPSHOT_DIR}:${DOCKER_SNAPSHOT_DIR} " \ - builders/tools/bazel-debian run //tools/data_cli:data_cli format_data \ - -- \ - --input_file "${DOCKER_SNAPSHOT_DIR}/${SNAPSHOT_FILENAME}" \ - --input_format DELTA \ - --output_file "${DOCKER_SNAPSHOT_CSV_DIR}/${SNAPSHOT_FILENAME}.csv" \ - --output_format CSV + builders/tools/bazel-debian run //tools/data_cli:data_cli format_data \ + -- \ + --input_file "${DOCKER_SNAPSHOT_DIR}/${SNAPSHOT_FILENAME}" \ + --input_format DELTA \ + --output_file "${DOCKER_SNAPSHOT_CSV_DIR}/${SNAPSHOT_FILENAME}.csv" \ + --output_format CSV done } @@ -126,9 +129,17 @@ function set_benchmark_args() { --number-of-lookup-keys-list "${NUMBER_OF_LOOKUP_KEYS_LIST[@]}" --server-address "${SERVER_ADDRESS}" --ghz-tags "${RUN_BENCHMARK_GHZ_TAGS}" - --snapshot-csv-dir "${SNAPSHOT_CSV_DIR}" --benchmark-duration "${BENCHMARK_DURATION}" ) + if [[ -f "${LOOKUP_KEYS_FILE}" ]]; then + RUN_BENCHMARK_ARGS+=( + --lookup-keys-file "${LOOKUP_KEYS_FILE}" + ) + else + RUN_BENCHMARK_ARGS+=( + --snapshot-csv-dir "${SNAPSHOT_CSV_DIR}" + ) + fi if [[ -v "${REQUEST_METADATA_JSON}" ]]; then RUN_BENCHMARK_ARGS+=( --request-metadata-json "${REQUEST_METADATA_JSON}" @@ -147,7 +158,7 @@ function run_benchmarks() { printf "BENCHMARK ARGS: %s\n" "${RUN_BENCHMARK_ARGS[*]}" local BENCHMARK_OUTPUT - BENCHMARK_OUTPUT=$(./tools/latency_benchmarking/run_benchmarks "${RUN_BENCHMARK_ARGS[@]}") + BENCHMARK_OUTPUT=$(./tools/benchmarking/run_benchmarks "${RUN_BENCHMARK_ARGS[@]}") BENCHMARK_CSVS+=( "$(echo "${BENCHMARK_OUTPUT}" | tail -n 1 2>&1 | tee /dev/tty)" ) @@ -157,11 +168,11 @@ function run_benchmarks() { printf "BENCHMARK ARGS: %s\n" "${RUN_BENCHMARK_ARGS[*]}" local BENCHMARK_OUTPUT - BENCHMARK_OUTPUT=$(./tools/latency_benchmarking/run_benchmarks "${RUN_BENCHMARK_ARGS[@]}") + BENCHMARK_OUTPUT=$(./tools/benchmarking/run_benchmarks "${RUN_BENCHMARK_ARGS[@]}") BENCHMARK_CSVS+=( "$(echo "${BENCHMARK_OUTPUT}" | tail -n 1 2>&1 | tee /dev/tty)" ) - done < "${REQUEST_METADATA_JSON_FILE}" + done <"${REQUEST_METADATA_JSON_FILE}" fi } @@ -171,13 +182,20 @@ function set_server_address() { SERVER_ENDPOINT="${SERVER_URL}/v1/getvalues?keys=hi" # Build gRPC server address from tf output SERVER_HOSTNAME=$([[ "${SERVER_URL}" =~ https://(.*) ]] && echo "${BASH_REMATCH[1]}") - SERVER_ADDRESS="${SERVER_HOSTNAME}:8443" + if [[ "${CLOUD_PROVIDER}" == "aws" ]]; then + SERVER_ADDRESS="${SERVER_HOSTNAME}:8443" + elif [[ "${CLOUD_PROVIDER}" == "gcp" ]]; then + SERVER_ADDRESS="${SERVER_HOSTNAME}:443" + else + echo "Cloud provider not supported" + exit 1 + fi } function upload_file_to_bucket() { if [[ "${CLOUD_PROVIDER}" == "aws" ]]; then EXTRA_DOCKER_RUN_ARGS+=" --volume ${1}:/tmp/deltas/${1}" \ - builders/tools/aws-cli s3 cp "/tmp/deltas/${1}" "${2}" + builders/tools/aws-cli s3 cp "/tmp/deltas/${1}" "${2}" elif [[ "${CLOUD_PROVIDER}" == "gcp" ]]; then gcloud storage cp "${1}" "${2}" else @@ -257,7 +275,7 @@ function merge_benchmark_csvs() { touch "${CSV_OUTPUT}" # Run merge_csvs python script EXTRA_DOCKER_RUN_ARGS+=" --volume ${CSV_SUMMARY_INPUT_DIR}:${DOCKER_INPUT_DIR} --volume ${CSV_OUTPUT}:${DOCKER_OUTPUT_CSV} " \ - builders/tools/bazel-debian run //tools/latency_benchmarking:merge_csvs \ + builders/tools/bazel-debian run //tools/benchmarking:merge_csvs \ -- \ --csv-inputs "${DOCKER_BENCHMARK_CSVS[@]}" \ --csv-output "${DOCKER_OUTPUT_CSV}" @@ -287,6 +305,10 @@ while [[ $# -gt 0 ]]; do SNAPSHOT_DIR="$2" shift 2 ;; + --lookup-keys-file) + LOOKUP_KEYS_FILE="$2" + shift 2 + ;; --udf-delta-dir) UDF_DELTA_DIR="$2" shift 2 @@ -350,9 +372,10 @@ if [[ -z "${SERVER_URL}" ]]; then exit 1 fi -# Check for SNAPSHOT_DIR. If not available, exit. -if [[ -z "${SNAPSHOT_DIR}" || ! -d "${SNAPSHOT_DIR}" ]]; then +# Check for SNAPSHOT_DIR or LOOKUP_KEYS_FILE. If not available, exit. +if [[ (-z "${SNAPSHOT_DIR}" || ! -d "${SNAPSHOT_DIR}") && (-z "${LOOKUP_KEYS_FILE}" || ! -f "${LOOKUP_KEYS_FILE}") ]]; then printf "snapshot-dir not found:%s\n" "${SNAPSHOT_DIR}" + printf "lookup-keys-file not found:%s\n" "${LOOKUP_KEYS_FILE}" exit 1 fi @@ -367,7 +390,9 @@ if [[ -v "${REQUEST_METADATA_JSON_FILE}" && ! -r "${REQUEST_METADATA_JSON_FILE}" exit 1 fi -convert_snapshots_to_csvs +if [[ -d "${SNAPSHOT_DIR}" ]]; then + convert_snapshots_to_csvs +fi # No terraform variable overrides, deploy and benchmark without overrides if [[ -z "${TF_OVERRIDES}" ]]; then @@ -377,16 +402,22 @@ else # Each row defines a set of overrides for terraform variables. # Pass the overrides to deploy_and_benchmark function and set # them as tags in the ghz call. + TF_OVERRIDES_LIST=() while IFS=',' read -ra VARS; do + TF_OVERRIDES_LIST+=("${VARS[*]}") + done <"${TF_OVERRIDES}" + + for TF_OVERRIDES in "${TF_OVERRIDES_LIST[@]}"; do + IFS=" " read -ra VAR_LIST <<<"${TF_OVERRIDES}" declare -a VAR_OVERRIDES=() DEPLOYMENT_GHZ_TAGS="{}" - for VAR in "${VARS[@]}"; do + for VAR in "${VAR_LIST[@]}"; do VAR_OVERRIDES+=(-var "${VAR}") OVERRIDE_VAR_GHZ_TAG=$(echo "${VAR}" | jq -s -R 'split("\n") | .[0] | split("=") | {(.[0]): .[1]}') DEPLOYMENT_GHZ_TAGS=$(echo "${DEPLOYMENT_GHZ_TAGS} ${OVERRIDE_VAR_GHZ_TAG}" | jq -s -c 'add') done deploy_and_benchmark - done <"${TF_OVERRIDES}" + done fi # Benchmarks done, merge CSVs diff --git a/tools/latency_benchmarking/example/aws_tf_overrides.txt b/tools/benchmarking/example/aws_tf_overrides.txt similarity index 100% rename from tools/latency_benchmarking/example/aws_tf_overrides.txt rename to tools/benchmarking/example/aws_tf_overrides.txt diff --git a/tools/latency_benchmarking/example/gcp_tf_overrides.txt b/tools/benchmarking/example/gcp_tf_overrides.txt similarity index 100% rename from tools/latency_benchmarking/example/gcp_tf_overrides.txt rename to tools/benchmarking/example/gcp_tf_overrides.txt diff --git a/tools/latency_benchmarking/example/kv_data/BUILD.bazel b/tools/benchmarking/example/kv_data/BUILD.bazel similarity index 100% rename from tools/latency_benchmarking/example/kv_data/BUILD.bazel rename to tools/benchmarking/example/kv_data/BUILD.bazel diff --git a/tools/latency_benchmarking/example/request_metadata.jsonl b/tools/benchmarking/example/request_metadata.jsonl similarity index 100% rename from tools/latency_benchmarking/example/request_metadata.jsonl rename to tools/benchmarking/example/request_metadata.jsonl diff --git a/tools/latency_benchmarking/example/request_metadata_run_query.jsonl b/tools/benchmarking/example/request_metadata_run_query.jsonl similarity index 100% rename from tools/latency_benchmarking/example/request_metadata_run_query.jsonl rename to tools/benchmarking/example/request_metadata_run_query.jsonl diff --git a/tools/latency_benchmarking/example/udf_code/BUILD.bazel b/tools/benchmarking/example/udf_code/BUILD.bazel similarity index 90% rename from tools/latency_benchmarking/example/udf_code/BUILD.bazel rename to tools/benchmarking/example/udf_code/BUILD.bazel index 1745c297..27671cfe 100644 --- a/tools/latency_benchmarking/example/udf_code/BUILD.bazel +++ b/tools/benchmarking/example/udf_code/BUILD.bazel @@ -19,7 +19,7 @@ closure_js_library( # Generates a UDF delta file using the given closure_js_lib target # builders/tools/bazel-debian run \ -# //tools/latency_benchmarking/example/udf_code:benchmark_udf_js_delta +# //tools/benchmarking/example/udf_code:benchmark_udf_js_delta closure_to_delta( name = "benchmark_udf_js_delta", closure_js_library_target = ":benchmark_udf_js_lib", @@ -28,7 +28,7 @@ closure_to_delta( ) # builders/tools/bazel-debian run --config=emscripten \ -# //tools/latency_benchmarking/example/udf_code:benchmark_cpp_wasm_udf_delta +# //tools/benchmarking/example/udf_code:benchmark_cpp_wasm_udf_delta cc_inline_wasm_udf_delta( name = "benchmark_cpp_wasm_udf_delta", srcs = ["benchmark_cpp_wasm_udf.cc"], diff --git a/tools/latency_benchmarking/example/udf_code/benchmark_cpp_wasm_udf.cc b/tools/benchmarking/example/udf_code/benchmark_cpp_wasm_udf.cc similarity index 99% rename from tools/latency_benchmarking/example/udf_code/benchmark_cpp_wasm_udf.cc rename to tools/benchmarking/example/udf_code/benchmark_cpp_wasm_udf.cc index b2e278bc..a6c3f9f4 100644 --- a/tools/latency_benchmarking/example/udf_code/benchmark_cpp_wasm_udf.cc +++ b/tools/benchmarking/example/udf_code/benchmark_cpp_wasm_udf.cc @@ -171,7 +171,7 @@ std::vector MaybeSplitDataByBatchSize( } // I/O processing, similar to -// tools/latency_benchmarking/example/udf_code/benchmark_udf.js +// tools/benchmarking/example/udf_code/benchmark_udf.js emscripten::val GetKeyGroupOutputs(const emscripten::val& get_values_cb, const emscripten::val& get_values_binary_cb, const emscripten::val& request_metadata, diff --git a/tools/latency_benchmarking/example/udf_code/benchmark_cpp_wasm_udf.js b/tools/benchmarking/example/udf_code/benchmark_cpp_wasm_udf.js similarity index 100% rename from tools/latency_benchmarking/example/udf_code/benchmark_cpp_wasm_udf.js rename to tools/benchmarking/example/udf_code/benchmark_cpp_wasm_udf.js diff --git a/tools/latency_benchmarking/example/udf_code/benchmark_udf.js b/tools/benchmarking/example/udf_code/benchmark_udf.js similarity index 100% rename from tools/latency_benchmarking/example/udf_code/benchmark_udf.js rename to tools/benchmarking/example/udf_code/benchmark_udf.js diff --git a/tools/latency_benchmarking/example/udf_code/externs.js b/tools/benchmarking/example/udf_code/externs.js similarity index 100% rename from tools/latency_benchmarking/example/udf_code/externs.js rename to tools/benchmarking/example/udf_code/externs.js diff --git a/tools/latency_benchmarking/generate_requests.py b/tools/benchmarking/generate_requests.py similarity index 80% rename from tools/latency_benchmarking/generate_requests.py rename to tools/benchmarking/generate_requests.py index e55b5733..d61715be 100644 --- a/tools/latency_benchmarking/generate_requests.py +++ b/tools/benchmarking/generate_requests.py @@ -16,9 +16,12 @@ import csv import argparse import json + import base64 from pathlib import Path +from itertools import islice + from typing import Any, Iterator """ @@ -82,7 +85,7 @@ def _ReadCsv(snapshot_csv_file: str) -> Iterator[str]: yield row -def ReadKeys( +def ReadKeysFromDelta( snapshot_csv_file: str, max_number_of_keys: int, filter_by_sets: bool ) -> list[str]: """Read keys from CSV file. Only include update and string type mutations. @@ -108,6 +111,27 @@ def ReadKeys( return list(keys) +def ReadKeysFromFile(lookup_keys_file: str, max_number_of_keys: int) -> list[str]: + """Read keys from a file. + + Args: + lookup_keys_file: Path to file with keys. + max_number_of_keys: Maximum number of keys to read. + + Returns: + List of unique set of keys. + """ + keys = [] + with open(lookup_keys_file, "r") as f: + for _ in range(max_number_of_keys): + try: + key = next(f).rstrip("\n") + keys.append(key) + except StopIteration: + break + return keys + + def Main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -121,7 +145,7 @@ def Main(): "--output-dir", dest="output_dir", type=str, - default="/tmp/latency_benchmarking", + default="/tmp/benchmarking", help="Output directory for benchmarks", ) parser.add_argument( @@ -143,13 +167,26 @@ def Main(): action="store_true", help="Whether to only use keys of sets from the input to build the requests", ) + parser.add_argument( + "--lookup-keys-file", + dest="lookup_keys_file", + type=str, + help="Path to file with keys to use in request. If set, snapshot-csv-dir is ignored.", + ) args = parser.parse_args() metadata = json.loads(args.metadata) if not isinstance(metadata, dict): raise ValueError("metadata is not a JSON object") + if args.lookup_keys_file is not None: + keys = ReadKeysFromFile(args.lookup_keys_file, max(args.number_of_keys_list)) + keys_filename = os.path.basename(args.lookup_keys_file) + output_dir = os.path.join(args.output_dir, keys_filename) + WriteRequests(keys, args.number_of_keys_list, output_dir, metadata) + return + for filename in os.listdir(args.snapshot_csv_dir): snapshot_csv_file = os.path.join(args.snapshot_csv_dir, filename) - keys = ReadKeys( + keys = ReadKeysFromDelta( snapshot_csv_file, max(args.number_of_keys_list), args.filter_by_sets ) output_dir_for_snapshot = os.path.join(args.output_dir, filename) diff --git a/tools/latency_benchmarking/merge_csvs.py b/tools/benchmarking/merge_csvs.py similarity index 100% rename from tools/latency_benchmarking/merge_csvs.py rename to tools/benchmarking/merge_csvs.py diff --git a/tools/latency_benchmarking/run_benchmarks b/tools/benchmarking/run_benchmarks similarity index 56% rename from tools/latency_benchmarking/run_benchmarks rename to tools/benchmarking/run_benchmarks index c67d11fb..9305d675 100755 --- a/tools/latency_benchmarking/run_benchmarks +++ b/tools/benchmarking/run_benchmarks @@ -21,7 +21,7 @@ readonly START WORKSPACE="$(git rev-parse --show-toplevel)" -BASE_OUTPUT_DIR="${WORKSPACE}/dist/tools/latency_benchmarking/output/${START}" +BASE_OUTPUT_DIR="${WORKSPACE}/dist/tools/benchmarking/output/${START}" NUMBER_OF_LOOKUP_KEYS="1 5 10" REQUEST_METADATA="{}" @@ -29,11 +29,12 @@ FILTER_BY_SETS=0 FILTER_BY_SETS_JSON='{"filter_snapshot_by_sets": "false"}' BENCHMARK_DURATION="5s" -DOCKER_OUTPUT_DIR="/tmp/latency_benchmarking/output" +DOCKER_OUTPUT_DIR="/tmp/benchmarking/output" readonly DOCKER_OUTPUT_DIR -DOCKER_SNAPSHOT_DIR="/tmp/latency_benchmarking/snapshots" -DOCKER_SNAPSHOT_CSV_DIR="/tmp/latency_benchmarking/snapshot_csvs" +DOCKER_SNAPSHOT_DIR="/tmp/benchmarking/snapshots" +DOCKER_SNAPSHOT_CSV_DIR="/tmp/benchmarking/snapshot_csvs" +DOCKER_LOOKUP_KEYS_DIR="/tmp/benchmarking/lookup_keys/" function usage() { local -r -i exitval=${1-1} @@ -41,8 +42,10 @@ function usage() { usage: ${BASH_SOURCE[0]} [--server-address] (Required) gRPC host and port. - [--snapshot-dir] or (Required) Full path to either snapshot-dir or snapshot-csv-dir - [--snapshot-csv-dir] is required. + [--snapshot-dir] OR (Required) Full path to either snapshot-dir, snapshot-csv-dir, + [--snapshot-csv-dir] OR or lookup-keys-file is required. + [--lookup-keys-file] The lookup-keys-file should have one lookup key per line + and ignores the filter-snapshot-by-sets option. [--number-of-lookup-keys-list] (Optional) List of number of keys to include in a request. [--benchmark-duration] (Optional) Duration of each benchmark. Default "5s". [--ghz-tags] (Optional) Tags to include in the ghz run. @@ -53,7 +56,36 @@ USAGE exit ${exitval} } -function generate_requests() { +function generate_requests_with_file() { + # Create output dirs before calling bazel-debian + # If the dir is created in the docker container, we lose write permission + LOOKUP_KEYS_FILENAME=$(basename "${LOOKUP_KEYS_FILE}") + LOOKUP_KEYS_FILE_DIR=$(dirname "${LOOKUP_KEYS_FILE}") + for N in "${NUMBER_OF_LOOKUP_KEYS_LIST[@]}"; do + mkdir -p "${BASE_OUTPUT_DIR}/${LOOKUP_KEYS_FILENAME}/n=${N}" + done + + local -a GENERATE_REQUESTS_ARGS=( + --output-dir "${DOCKER_OUTPUT_DIR}" + --number-of-keys-list "${NUMBER_OF_LOOKUP_KEYS_LIST[@]}" + --metadata "${REQUEST_METADATA}" + --lookup-keys-file "${DOCKER_LOOKUP_KEYS_DIR}/${LOOKUP_KEYS_FILENAME}" + ) + if [[ ${FILTER_BY_SETS} -eq 1 ]]; then + GENERATE_REQUESTS_ARGS+=(--filter-by-sets) + FILTER_BY_SETS_JSON='{"filter_snapshot_by_sets": "true"}' + fi + + # Mount the output dir to docker and write requests to output dir for each item in + # `NUMBER_OF_LOOKUP_KEYS_LIST`. + # This will write a json request for each NUMBER_OF_LOOKUP_KEY=N to + # dist/tools/benchmarking/output/${START}/${LOOKUP_KEYS_FILENAME}/n=${N}/request.json + EXTRA_DOCKER_RUN_ARGS+=" --volume ${BASE_OUTPUT_DIR}:${DOCKER_OUTPUT_DIR} --volume ${LOOKUP_KEYS_FILE_DIR}:${DOCKER_LOOKUP_KEYS_DIR} " \ + builders/tools/bazel-debian run //tools/benchmarking:generate_requests \ + -- "${GENERATE_REQUESTS_ARGS[@]}" +} + +function generate_requests_with_snapshots() { # Create output dirs before calling bazel-debian # If the dir is created in the docker container, we lose write permission for SNAPSHOT_CSV in "${SNAPSHOT_CSV_DIR}"/*; do @@ -76,28 +108,36 @@ function generate_requests() { # Mount the output dir to docker and write requests to output dir for each item in # `NUMBER_OF_LOOKUP_KEYS_LIST`. # This will write a json request for each NUMBER_OF_LOOKUP_KEY=N to - # dist/tools/latency_benchmarking/output/${START}/${SNAPSHOT_FILENAME}/n=${N}/request.json + # dist/tools/benchmarking/output/${START}/${SNAPSHOT_FILENAME}/n=${N}/request.json EXTRA_DOCKER_RUN_ARGS+=" --volume ${BASE_OUTPUT_DIR}:${DOCKER_OUTPUT_DIR} --volume ${SNAPSHOT_CSV_DIR}:${DOCKER_SNAPSHOT_CSV_DIR} " \ - builders/tools/bazel-debian run //tools/latency_benchmarking:generate_requests \ + builders/tools/bazel-debian run //tools/benchmarking:generate_requests \ -- "${GENERATE_REQUESTS_ARGS[@]}" } -function run_ghz_for_requests() { - # Iterate through the generated request.json files and call `ghz` to benchmark server at ${SERVER_ADDRESS} +function run_ghz_for_requests_from_file() { + FILENAME=$(basename "${1}") + REQUEST_JSON_DOCKER_OUTPUT_DIR="${DOCKER_OUTPUT_DIR}/${FILENAME}" + REQUEST_JSON_BASE_OUTPUT_DIR="${BASE_OUTPUT_DIR}/${FILENAME}" + for N in "${NUMBER_OF_LOOKUP_KEYS_LIST[@]}"; do - DIR="${OUTPUT_DIR}/n=${N}" - REQUEST_JSON="${DIR}"/request.json - if [[ ! -f "${REQUEST_JSON}" ]]; then + DIR="${REQUEST_JSON_DOCKER_OUTPUT_DIR}/n=${N}" + REQUEST_JSON_DOCKER="${DIR}"/request.json + REQUEST_JSON_BASE="${REQUEST_JSON_BASE_OUTPUT_DIR}/n=${N}"/request.json + if [[ ! -f "${REQUEST_JSON_BASE}" ]]; then continue fi printf "Running ghz for number of keys %s\n" "${N}" - BASE_GHZ_TAGS='{"number_of_lookup_keys": "'"${N}"'", "keys_from_file": "'"${SNAPSHOT_CSV_FILENAME}"'"}' + BASE_GHZ_TAGS=$( + jq -n --arg n "${N}" --arg f "${FILENAME}" \ + '{"number_of_lookup_keys": $n, "keys_from_file": $f}' + ) REQUEST_METADATA_TAGS=$(echo "${REQUEST_METADATA}" | jq 'with_entries(.key |= "request_metadata."+.) | with_entries( .value |= @json)') TAGS=$(echo "${GHZ_TAGS} ${BASE_GHZ_TAGS} ${REQUEST_METADATA_TAGS} ${FILTER_BY_SETS_JSON}" | jq -s -c 'add') GHZ_OUTPUT_JSON_FILE="${DIR}/ghz_output.json" - ghz --protoset "${WORKSPACE}/dist/query_api_descriptor_set.pb" \ - -D "${REQUEST_JSON}" \ + EXTRA_DOCKER_RUN_ARGS+=" --volume ${BASE_OUTPUT_DIR}:${DOCKER_OUTPUT_DIR} --volume ${WORKSPACE}:/src/workspace" \ + builders/tools/ghz --protoset /src/workspace/dist/query_api_descriptor_set.pb \ + -D "${REQUEST_JSON_DOCKER}" \ --duration="${BENCHMARK_DURATION}" \ --skipFirst=100 \ --concurrency=100 \ @@ -111,6 +151,18 @@ function run_ghz_for_requests() { done } +function run_ghz_for_requests() { + # Iterate through the generated request.json files and + # call `ghz` to benchmark server at ${SERVER_ADDRESS} + if [[ -f "${LOOKUP_KEYS_FILE}" ]]; then + run_ghz_for_requests_from_file "${LOOKUP_KEYS_FILE}" + else + for SNAPSHOT_CSV_FILE in "${SNAPSHOT_CSV_DIR}"/*; do + run_ghz_for_requests_from_file "${SNAPSHOT_CSV_FILE}" + done + fi +} + while [[ $# -gt 0 ]]; do case "$1" in --server-address) @@ -137,6 +189,10 @@ while [[ $# -gt 0 ]]; do SNAPSHOT_CSV_DIR="$2" shift 2 ;; + --lookup-keys-file) + LOOKUP_KEYS_FILE="$2" + shift 2 + ;; --request-metadata-json) REQUEST_METADATA="$2" shift 2 @@ -150,15 +206,20 @@ while [[ $# -gt 0 ]]; do esac done -# Check for SNAPSHOT_DIR or SNAPSHOT_CSV_DIR. If not available, exit. -if [[ (-z "${SNAPSHOT_DIR}" || ! -d "${SNAPSHOT_DIR}") && (-z "${SNAPSHOT_CSV_DIR}" || ! -d "${SNAPSHOT_CSV_DIR}") ]]; then +# Check for SNAPSHOT_DIR or SNAPSHOT_CSV_DIR or LOOKUP_KEYS_FILE. If not available, exit. +if [[ (-z "${SNAPSHOT_DIR}" || ! -d "${SNAPSHOT_DIR}") && (-z "${SNAPSHOT_CSV_DIR}" || ! -d "${SNAPSHOT_CSV_DIR}") && (-z "${LOOKUP_KEYS_FILE}" || ! -f "${LOOKUP_KEYS_FILE}") ]]; then printf "snapshot-dir not found:%s\n" "${SNAPSHOT_DIR}" + printf "snapshot-csv-dir not found:%s\n" "${SNAPSHOT_CSV_DIR}" + printf "lookup-keys-file not found:%s\n" "${LOOKUP_KEYS_FILE}" + printf "Exiting...\n" exit 1 fi IFS=' ' read -ra NUMBER_OF_LOOKUP_KEYS_LIST <<<"${NUMBER_OF_LOOKUP_KEYS}" -if [[ -d "${SNAPSHOT_CSV_DIR}" ]]; then - generate_requests +if [[ -f ${LOOKUP_KEYS_FILE} ]]; then + generate_requests_with_file +elif [[ -d ${SNAPSHOT_CSV_DIR} ]]; then + generate_requests_with_snapshots else # No snapshot csv given, iterate through snapshot dir to create csvs SNAPSHOT_CSV_DIR="${BASE_OUTPUT_DIR}/snapshot_csvs/" @@ -174,22 +235,15 @@ else --output_file "${DOCKER_SNAPSHOT_CSV_DIR}/${SNAPSHOT_FILENAME}.csv" \ --output_format CSV done - generate_requests + generate_requests_with_snapshots fi -# Iterate through each snapshot file and -# 1. create requests under dist/tools/latency_benchmarking/output/${START}/${SNAPSHOT_FILENAME} -# 2. for each request, run benchmarks with ghz -for SNAPSHOT_CSV_FILE in "${SNAPSHOT_CSV_DIR}"/*; do - SNAPSHOT_CSV_FILENAME=$(basename "${SNAPSHOT_CSV_FILE}") - OUTPUT_DIR="${BASE_OUTPUT_DIR}/${SNAPSHOT_CSV_FILENAME}" - run_ghz_for_requests -done +run_ghz_for_requests -# Go through all the ghz results in dist/tools/latency_benchmarking/output/${START} +# Go through all the ghz results in dist/tools/benchmarking/output/${START} # and collect the summary in a csv EXTRA_DOCKER_RUN_ARGS+=" --volume ${BASE_OUTPUT_DIR}:${DOCKER_OUTPUT_DIR} " \ - builders/tools/bazel-debian run //tools/latency_benchmarking:create_csv_summary \ + builders/tools/bazel-debian run //tools/benchmarking:create_csv_summary \ -- \ --ghz-result-dir ${DOCKER_OUTPUT_DIR} diff --git a/tools/data_cli/data_cli.cc b/tools/data_cli/data_cli.cc index f7ec1c91..9db3c576 100644 --- a/tools/data_cli/data_cli.cc +++ b/tools/data_cli/data_cli.cc @@ -138,7 +138,7 @@ bool IsSupportedCommand(std::string_view command) { // // bazel run \ // //tools/data_cli:data_cli \ -// --//:instance=local --//:platform=local -- \ +// --config=local_instance --config=local_platform -- \ // format_data \ // --input_file=/data/DELTA_1689344645643610 \ // --input_format=DELTA \ diff --git a/tools/request_simulation/BUILD.bazel b/tools/request_simulation/BUILD.bazel index f09a4173..ffd6eb7d 100644 --- a/tools/request_simulation/BUILD.bazel +++ b/tools/request_simulation/BUILD.bazel @@ -82,7 +82,7 @@ cc_library( "//components/util:sleepfor", "@com_github_grpc_grpc//test/core/util:grpc_test_util_base", "@com_google_absl//absl/log", - "@google_privacysandbox_servers_common//src/telemetry:metrics_recorder", + "@google_privacysandbox_servers_common//src/metric:context_map", ], ) @@ -120,7 +120,6 @@ cc_library( "@com_github_google_flatbuffers//:flatbuffers", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", - "@google_privacysandbox_servers_common//src/telemetry:metrics_recorder", "@google_privacysandbox_servers_common//src/telemetry:tracing", ], ) @@ -159,6 +158,7 @@ cc_library( "//components/data/blob_storage:delta_file_notifier", "//components/data/common:change_notifier", "//components/data/common:thread_manager", + "//components/tools/util:configure_telemetry_tools", "//components/util:platform_initializer", "//components/util:version_linkstamp", "//public/data_loading/readers:riegeli_stream_io", diff --git a/tools/request_simulation/client_worker_test.cc b/tools/request_simulation/client_worker_test.cc index a9c45ea1..ee2ecd76 100644 --- a/tools/request_simulation/client_worker_test.cc +++ b/tools/request_simulation/client_worker_test.cc @@ -33,7 +33,6 @@ namespace kv_server { -using privacy_sandbox::server_common::MockMetricsRecorder; using privacy_sandbox::server_common::SimulatedSteadyClock; using privacy_sandbox::server_common::SteadyTime; using testing::_; @@ -70,7 +69,6 @@ class ClientWorkerTest : public ::testing::Test { SimulatedSteadyClock sim_clock_; std::unique_ptr sleep_for_metrics_collector_; std::unique_ptr sleep_for_; - MockMetricsRecorder metrics_recorder_; }; TEST_F(ClientWorkerTest, SingleClientWorkerTest) { @@ -92,10 +90,9 @@ TEST_F(ClientWorkerTest, SingleClientWorkerTest) { std::move(sleep_for_), absl::Seconds(0)); EXPECT_CALL(*sleep_for_metrics_collector_, Duration(_)) .WillRepeatedly(Return(true)); - EXPECT_CALL(metrics_recorder_, RegisterHistogram(_, _, _, _)).Times(5); std::unique_ptr metrics_collector = std::make_unique( - metrics_recorder_, std::move(sleep_for_metrics_collector_)); + std::move(sleep_for_metrics_collector_)); EXPECT_CALL(*metrics_collector, IncrementServerResponseStatusEvent(_)) .Times(requests_per_second); EXPECT_CALL(*metrics_collector, IncrementRequestSentPerInterval()) @@ -136,10 +133,9 @@ TEST_F(ClientWorkerTest, MultipleClientWorkersTest) { std::move(sleep_for_), absl::Seconds(0)); EXPECT_CALL(*sleep_for_metrics_collector_, Duration(_)) .WillRepeatedly(Return(true)); - EXPECT_CALL(metrics_recorder_, RegisterHistogram(_, _, _, _)).Times(5); std::unique_ptr metrics_collector = std::make_unique( - metrics_recorder_, std::move(sleep_for_metrics_collector_)); + std::move(sleep_for_metrics_collector_)); EXPECT_CALL(*metrics_collector, IncrementServerResponseStatusEvent(_)) .Times(requests_per_second); EXPECT_CALL(*metrics_collector, IncrementRequestSentPerInterval()) diff --git a/tools/request_simulation/delta_based_request_generator.h b/tools/request_simulation/delta_based_request_generator.h index ada28ca2..d2fbe0d2 100644 --- a/tools/request_simulation/delta_based_request_generator.h +++ b/tools/request_simulation/delta_based_request_generator.h @@ -30,7 +30,6 @@ #include "components/data/realtime/realtime_notifier.h" #include "public/data_loading/readers/riegeli_stream_io.h" #include "public/data_loading/readers/stream_record_reader_factory.h" -#include "src/telemetry/metrics_recorder.h" #include "tools/request_simulation/message_queue.h" #include "tools/request_simulation/request_generation_util.h" @@ -53,13 +52,11 @@ class DeltaBasedRequestGenerator { }; DeltaBasedRequestGenerator( Options options, - absl::AnyInvocable request_generation_fn, - privacy_sandbox::server_common::MetricsRecorder& metrics_recorder) + absl::AnyInvocable request_generation_fn) : options_(std::move(options)), data_load_thread_manager_( ThreadManager::Create("Delta file loading thread")), - request_generation_fn_(std::move(request_generation_fn)), - metrics_recorder_(metrics_recorder) {} + request_generation_fn_(std::move(request_generation_fn)) {} ~DeltaBasedRequestGenerator() = default; // DeltaBasedRequestGenerator is neither copyable nor movable. @@ -95,7 +92,6 @@ class DeltaBasedRequestGenerator { std::unique_ptr data_load_thread_manager_; // Callback function to generate KV request from a given key absl::AnyInvocable request_generation_fn_; - privacy_sandbox::server_common::MetricsRecorder& metrics_recorder_; }; } // namespace kv_server diff --git a/tools/request_simulation/delta_based_request_generator_test.cc b/tools/request_simulation/delta_based_request_generator_test.cc index ecc92de7..b117503b 100644 --- a/tools/request_simulation/delta_based_request_generator_test.cc +++ b/tools/request_simulation/delta_based_request_generator_test.cc @@ -74,7 +74,8 @@ BlobStorageClient::DataLocation GetTestLocation( absl::AnyInvocable GetRequestGenFn() { return [](std::string_view key) { - return kv_server::CreateKVDSPRequestBodyInJson({std::string(key)}); + return kv_server::CreateKVDSPRequestBodyInJson( + {std::string(key)}, "debug_token", "generation_id"); }; } @@ -94,7 +95,6 @@ class GenerateRequestsFromDeltaFilesTest : public ::testing::Test { MockDeltaFileNotifier notifier_; MockBlobStorageChangeNotifier change_notifier_; MockStreamRecordReaderFactory delta_stream_reader_factory_; - MockMetricsRecorder metrics_recorder_; MessageQueue message_queue_; DeltaBasedRequestGenerator::Options options_; }; @@ -102,8 +102,8 @@ class GenerateRequestsFromDeltaFilesTest : public ::testing::Test { TEST_F(GenerateRequestsFromDeltaFilesTest, LoadingDataFromDeltaFiles) { ON_CALL(blob_client_, ListBlobs) .WillByDefault(Return(std::vector({}))); - DeltaBasedRequestGenerator request_generator( - std::move(options_), std::move(GetRequestGenFn()), metrics_recorder_); + DeltaBasedRequestGenerator request_generator(std::move(options_), + std::move(GetRequestGenFn())); const std::string last_basename = ""; EXPECT_CALL(notifier_, Start(_, GetTestLocation(), @@ -147,7 +147,8 @@ TEST_F(GenerateRequestsFromDeltaFilesTest, LoadingDataFromDeltaFiles) { auto message_in_the_queue = message_queue_.Pop(); EXPECT_TRUE(message_in_the_queue.ok()); EXPECT_EQ(message_in_the_queue.value(), - kv_server::CreateKVDSPRequestBodyInJson({std::string("key")})); + kv_server::CreateKVDSPRequestBodyInJson( + {std::string("key")}, "debug_token", "generation_id")); } } // namespace diff --git a/tools/request_simulation/main.cc b/tools/request_simulation/main.cc index 0690a362..a289c155 100644 --- a/tools/request_simulation/main.cc +++ b/tools/request_simulation/main.cc @@ -21,7 +21,6 @@ #include "absl/log/initialize.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" -#include "src/telemetry/metrics_recorder.h" #include "src/telemetry/telemetry_provider.h" #include "tools/request_simulation/grpc_client.h" #include "tools/request_simulation/request_simulation_system.h" @@ -40,10 +39,7 @@ int main(int argc, char** argv) { absl::SetProgramUsageMessage(absl::StrCat( "Key Value Server Request Simulation System. Sample usage:\n", argv[0])); kv_server::RequestSimulationSystem::InitializeTelemetry(); - auto metric_recorder = - TelemetryProvider::GetInstance().CreateMetricsRecorder(); kv_server::RequestSimulationSystem system( - *metric_recorder, privacy_sandbox::server_common::SteadyClock::RealClock(), [](const std::string& server_address, const kv_server::GrpcAuthenticationMode& mode) { diff --git a/tools/request_simulation/metrics_collector.cc b/tools/request_simulation/metrics_collector.cc index fdfae0ae..3de1ffbe 100644 --- a/tools/request_simulation/metrics_collector.cc +++ b/tools/request_simulation/metrics_collector.cc @@ -23,37 +23,21 @@ ABSL_FLAG(absl::Duration, metrics_report_interval, absl::Minutes(1), namespace kv_server { -constexpr char* kP50GrpcLatency = "P50GrpcLatency"; -constexpr char* kP90GrpcLatency = "P90GrpcLatency"; -constexpr char* kP99GrpcLatency = "P99GrpcLatency"; -constexpr char* kEstimatedQPS = "EstimatedQPS"; -constexpr char* kRequestsSent = "RequestsSent"; -constexpr char* KServerResponseStatus = "ServerResponseStatus"; - +namespace { constexpr double kDefaultHistogramResolution = 0.1; constexpr double kDefaultHistogramMaxBucket = 60e9; +} // namespace -MetricsCollector::MetricsCollector( - privacy_sandbox::server_common::MetricsRecorder& metrics_recorder, - std::unique_ptr sleep_for) +MetricsCollector::MetricsCollector(std::unique_ptr sleep_for) : requests_sent_per_interval_(0), requests_with_ok_response_per_interval_(0), requests_with_error_response_per_interval_(0), report_interval_(std::move(absl::GetFlag(FLAGS_metrics_report_interval))), report_thread_manager_( ThreadManager::Create("Metrics periodic report thread")), - metrics_recorder_(metrics_recorder), sleep_for_(std::move(sleep_for)) { histogram_per_interval_ = grpc_histogram_create(kDefaultHistogramResolution, kDefaultHistogramMaxBucket); - metrics_recorder_.RegisterHistogram(kRequestsSent, "Requests sent", ""); - metrics_recorder_.RegisterHistogram(kEstimatedQPS, "Estimated QPS", ""); - metrics_recorder_.RegisterHistogram(kP50GrpcLatency, "P50 Latency", - "microsecond"); - metrics_recorder_.RegisterHistogram(kP90GrpcLatency, "P90 Latency", - "microsecond"); - metrics_recorder_.RegisterHistogram(kP99GrpcLatency, "P99 Latency", - "microsecond"); } void MetricsCollector::AddLatencyToHistogram(absl::Duration latency) { @@ -98,14 +82,16 @@ void MetricsCollector::PublishMetrics() { auto p50_latency = GetPercentileLatency(0.5); auto p90_latency = GetPercentileLatency(0.9); auto p99_latency = GetPercentileLatency(0.99); - metrics_recorder_.RecordHistogramEvent(kRequestsSent, requests_sent); - metrics_recorder_.RecordHistogramEvent(kEstimatedQPS, estimated_qps); - metrics_recorder_.RecordHistogramEvent( - kP50GrpcLatency, absl::ToInt64Microseconds(p50_latency)); - metrics_recorder_.RecordHistogramEvent( - kP90GrpcLatency, absl::ToInt64Microseconds(p90_latency)); - metrics_recorder_.RecordHistogramEvent( - kP99GrpcLatency, absl::ToInt64Microseconds(p99_latency)); + RequestSimulationContextMap()->SafeMetric().LogUpDownCounter( + (int)requests_sent); + RequestSimulationContextMap()->SafeMetric().LogUpDownCounter( + (int)estimated_qps); + RequestSimulationContextMap()->SafeMetric().LogHistogram( + ((int)absl::ToInt64Milliseconds(p99_latency))); + RequestSimulationContextMap()->SafeMetric().LogHistogram( + ((int)absl::ToInt64Milliseconds(p90_latency))); + RequestSimulationContextMap()->SafeMetric().LogHistogram( + ((int)absl::ToInt64Milliseconds(p50_latency))); LOG(INFO) << "Metrics Summary: "; LOG(INFO) << "Number of requests sent:" << requests_sent; LOG(INFO) << "Number of requests with ok responses:" @@ -142,7 +128,10 @@ int64_t MetricsCollector::GetQPS() { } void MetricsCollector::IncrementServerResponseStatusEvent( const absl::Status& status) { - metrics_recorder_.IncrementEventStatus(KServerResponseStatus, status); + RequestSimulationContextMap() + ->SafeMetric() + .LogUpDownCounter( + {{absl::StatusCodeToString(status.code()), 1}}); } } // namespace kv_server diff --git a/tools/request_simulation/metrics_collector.h b/tools/request_simulation/metrics_collector.h index 49721a15..5b552643 100644 --- a/tools/request_simulation/metrics_collector.h +++ b/tools/request_simulation/metrics_collector.h @@ -18,22 +18,75 @@ #define TOOLS_REQUEST_SIMULATION_METRICS_COLLECTOR_H_ #include +#include +#include #include "absl/flags/flag.h" #include "components/data/common/thread_manager.h" #include "components/util/sleepfor.h" -#include "src/telemetry/metrics_recorder.h" #include "test/core/util/histogram.h" namespace kv_server { +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kPartitionedCounter> + kServerResponseStatus( + "ServerResponseStatus", "Server responses partitioned by status", + "status", + privacy_sandbox::server_common::metrics::kEmptyPublicPartition); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kUpDownCounter> + kRequestsSent("RequestsSent", + "Total number of requests sent to the server"); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kUpDownCounter> + kEstimatedQPS("EstimatedQPS", "Estimated QPS"); + +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kP50GrpcLatencyMs("P50GrpcLatency", "P50 Grpc request latency", + privacy_sandbox::server_common::metrics::kTimeHistogram); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kP90GrpcLatencyMs("P90GrpcLatency", "P50 Grpc request latency", + privacy_sandbox::server_common::metrics::kTimeHistogram); +inline constexpr privacy_sandbox::server_common::metrics::Definition< + int, privacy_sandbox::server_common::metrics::Privacy::kNonImpacting, + privacy_sandbox::server_common::metrics::Instrument::kHistogram> + kP99GrpcLatencyMs("P99GrpcLatency", "P50 Grpc request latency", + privacy_sandbox::server_common::metrics::kTimeHistogram); + +inline constexpr const privacy_sandbox::server_common::metrics::DefinitionName* + kRequestSimulationMetricsList[] = { + &kRequestsSent, &kServerResponseStatus, &kEstimatedQPS, + &kP50GrpcLatencyMs, &kP90GrpcLatencyMs, &kP99GrpcLatencyMs}; +inline constexpr absl::Span< + const privacy_sandbox::server_common::metrics::DefinitionName* const> + kRequestSimulationMetricsSpan = kRequestSimulationMetricsList; + +inline auto* RequestSimulationContextMap( + std::optional< + privacy_sandbox::server_common::telemetry::BuildDependentConfig> + config = std::nullopt, + std::unique_ptr provider = nullptr, + absl::string_view service = "Request-simulation", + absl::string_view version = "") { + return privacy_sandbox::server_common::metrics::GetContextMap< + const std::string, kRequestSimulationMetricsSpan>( + std::move(config), std::move(provider), service, version, {}); +} // MetricsCollector periodically collects metrics // periodically prints metrics to the log and publishes metrics to -// MetricsRecorder +// Otel class MetricsCollector { public: - MetricsCollector( - privacy_sandbox::server_common::MetricsRecorder& metrics_recorder, - std::unique_ptr sleep_for); + explicit MetricsCollector(std::unique_ptr sleep_for); // Increments server response status event virtual void IncrementServerResponseStatusEvent(const absl::Status& status); // Increments requests sent counter for the current interval @@ -78,7 +131,6 @@ class MetricsCollector { mutable std::atomic requests_with_error_response_per_interval_; absl::Duration report_interval_; std::unique_ptr report_thread_manager_; - privacy_sandbox::server_common::MetricsRecorder& metrics_recorder_; std::unique_ptr sleep_for_; grpc_histogram* histogram_per_interval_ ABSL_GUARDED_BY(mutex_); friend class MetricsCollectorPeer; diff --git a/tools/request_simulation/metrics_collector_test.cc b/tools/request_simulation/metrics_collector_test.cc index 014e6708..58e895fb 100644 --- a/tools/request_simulation/metrics_collector_test.cc +++ b/tools/request_simulation/metrics_collector_test.cc @@ -59,11 +59,10 @@ namespace { class MetricsCollectorTest : public ::testing::Test { protected: MetricsCollectorTest() { - metrics_collector_ = std::make_unique( - metrics_recorder_, std::make_unique()); + metrics_collector_ = + std::make_unique(std::make_unique()); } SimulatedSteadyClock sim_clock_; - MockMetricsRecorder metrics_recorder_; std::unique_ptr metrics_collector_; }; diff --git a/tools/request_simulation/mocks.h b/tools/request_simulation/mocks.h index fc6fc3aa..a863eebe 100644 --- a/tools/request_simulation/mocks.h +++ b/tools/request_simulation/mocks.h @@ -27,10 +27,8 @@ namespace kv_server { class MockMetricsCollector : public MetricsCollector { public: - MockMetricsCollector( - privacy_sandbox::server_common::MetricsRecorder& metrics_recorder, - std::unique_ptr sleep_for) - : MetricsCollector(metrics_recorder, std::move(sleep_for)) {} + explicit MockMetricsCollector(std::unique_ptr sleep_for) + : MetricsCollector(std::move(sleep_for)) {} MOCK_METHOD(void, IncrementServerResponseStatusEvent, (const absl::Status& status), (override)); MOCK_METHOD(void, IncrementRequestSentPerInterval, (), (override)); diff --git a/tools/request_simulation/request_generation_util.cc b/tools/request_simulation/request_generation_util.cc index 8cd8f91a..0d909578 100644 --- a/tools/request_simulation/request_generation_util.cc +++ b/tools/request_simulation/request_generation_util.cc @@ -24,7 +24,7 @@ namespace kv_server { // external configure file constexpr std::string_view kKVV2KeyValueDSPRequestBodyFormat = R"json( -{"metadata": {},"partitions": [{ "id": 0, "compressionGroupId": 0,"arguments": [{ "tags": [ "custom", "keys" ],"data": [ %s ] }] }] })json"; +{"metadata": {}, "log_context": {"generation_id": "%s", "adtech_debug_id": "debug_id"}, "consented_debug_config": {"is_consented": true, "token": "%s"}, "partitions": [{ "id": 0, "compressionGroupId": 0,"arguments": [{ "tags": [ "custom", "keys" ],"data": [ %s ] }] }] })json"; std::vector GenerateRandomKeys(int number_of_keys, int key_size) { std::vector result; @@ -34,12 +34,19 @@ std::vector GenerateRandomKeys(int number_of_keys, int key_size) { return result; } -std::string CreateKVDSPRequestBodyInJson(const std::vector& keys) { +std::string CreateKVDSPRequestBodyInJson( + const std::vector& keys, + std::string_view consented_debug_token, + std::optional generation_id_override) { const std::string comma_seperated_keys = absl::StrJoin(keys, ",", [](std::string* out, const std::string& key) { absl::StrAppend(out, "\"", key, "\""); }); - return absl::StrFormat(kKVV2KeyValueDSPRequestBodyFormat, + const std::string generation_id = generation_id_override.has_value() + ? generation_id_override.value() + : std::to_string(std::rand()); + return absl::StrFormat(kKVV2KeyValueDSPRequestBodyFormat, generation_id, + std::string(consented_debug_token), comma_seperated_keys); } diff --git a/tools/request_simulation/request_generation_util.h b/tools/request_simulation/request_generation_util.h index 2dbf0cae..6e55ca7d 100644 --- a/tools/request_simulation/request_generation_util.h +++ b/tools/request_simulation/request_generation_util.h @@ -28,7 +28,10 @@ namespace kv_server { std::vector GenerateRandomKeys(int number_of_keys, int key_size); // Creates KV DSP request body in json -std::string CreateKVDSPRequestBodyInJson(const std::vector& keys); +std::string CreateKVDSPRequestBodyInJson( + const std::vector& keys, + std::string_view consented_debug_token, + std::optional generation_id_override = std::nullopt); // Creates proto message from request body in json kv_server::RawRequest CreatePlainTextRequest( diff --git a/tools/request_simulation/request_generation_util_test.cc b/tools/request_simulation/request_generation_util_test.cc index 208ced1b..da492681 100644 --- a/tools/request_simulation/request_generation_util_test.cc +++ b/tools/request_simulation/request_generation_util_test.cc @@ -25,7 +25,7 @@ namespace { TEST(TestCreateMessage, ProtoMessageMatchJson) { const auto keys = kv_server::GenerateRandomKeys(10, 3); const std::string request_in_json = - kv_server::CreateKVDSPRequestBodyInJson(keys); + kv_server::CreateKVDSPRequestBodyInJson(keys, "debug_token"); const auto request = kv_server::CreatePlainTextRequest(request_in_json); EXPECT_EQ(request_in_json, request.raw_body().data()); std::string encoded_request_body; diff --git a/tools/request_simulation/request_simulation_system.cc b/tools/request_simulation/request_simulation_system.cc index 83a85095..039e104e 100644 --- a/tools/request_simulation/request_simulation_system.cc +++ b/tools/request_simulation/request_simulation_system.cc @@ -22,6 +22,7 @@ #include "absl/log/log.h" #include "components/tools/concurrent_publishing_engine.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "grpcpp/grpcpp.h" #include "opentelemetry/sdk/resource/resource.h" #include "opentelemetry/sdk/resource/semantic_conventions.h" @@ -92,6 +93,14 @@ ABSL_FLAG(int32_t, realtime_publisher_insertion_num_threads, 1, "Number of threads used to write to pubsub in parallel."); ABSL_FLAG(int32_t, realtime_publisher_files_insertion_rate, 15, "Number of messages sent per insertion thread to pubsub per second"); +ABSL_FLAG(std::string, consented_debug_token, "", + "Consented debug token, if non-empty value is provided," + "consented requests will be sent to the test server"); +ABSL_FLAG(bool, use_default_generation_id, true, + "Whether to send consented requests with default generation_id," + "which is a constant for all consented requests. This value should " + "be set to true if sending high volume of consented traffic to the " + "server."); namespace kv_server { @@ -99,10 +108,11 @@ constexpr char* kServiceName = "request-simulation"; constexpr char* kTestingServer = "testing.server"; constexpr int kMetricsExportIntervalInMs = 5000; constexpr int kMetricsExportTimeoutInMs = 500; +constexpr char* kDefaultGenerationIdForConsentedRequests = "consented"; using opentelemetry::sdk::resource::Resource; using opentelemetry::sdk::resource::ResourceAttributes; -using privacy_sandbox::server_common::ConfigureMetrics; +using privacy_sandbox::server_common::ConfigurePrivateMetrics; using privacy_sandbox::server_common::InitTelemetry; using privacy_sandbox::server_common::SteadyClock; namespace semantic_conventions = @@ -133,6 +143,10 @@ absl::Status RequestSimulationSystem::Init( std::unique_ptr metrics_collector) { server_address_ = absl::GetFlag(FLAGS_server_address); server_method_ = absl::GetFlag(FLAGS_server_method); + consented_debug_token_ = absl::GetFlag(FLAGS_consented_debug_token); + if (absl::GetFlag(FLAGS_use_default_generation_id)) { + generation_id_override_ = kDefaultGenerationIdForConsentedRequests; + } concurrent_number_of_requests_ = absl::GetFlag(FLAGS_concurrency); synthetic_request_gen_option_.number_of_keys_per_request = absl::GetFlag(FLAGS_number_of_keys_per_request); @@ -167,14 +181,14 @@ absl::Status RequestSimulationSystem::Init( const auto keys = kv_server::GenerateRandomKeys( synthetic_request_gen_option_.number_of_keys_per_request, synthetic_request_gen_option_.key_size_in_bytes); - return kv_server::CreateKVDSPRequestBodyInJson(keys); + return kv_server::CreateKVDSPRequestBodyInJson( + keys, consented_debug_token_, generation_id_override_); }); // Telemetry must be initialized before initializing metrics collector metrics_collector_ = metrics_collector == nullptr - ? std::make_unique(metrics_recorder_, - std::make_unique()) + ? std::make_unique(std::make_unique()) : std::move(metrics_collector); // Initialize no-op telemetry for the new Telemetry API // TODO(b/304306398): deprecate metric recorder and use new telemetry API to @@ -225,7 +239,7 @@ absl::Status RequestSimulationSystem::Init( .delta_notifier = *delta_file_notifier_, .change_notifier = *blob_change_notifier_, .delta_stream_reader_factory = *delta_stream_reader_factory_}, - CreateRequestFromKeyFn(), metrics_recorder_); + CreateRequestFromKeyFn()); PS_ASSIGN_OR_RETURN(realtime_message_batcher_, RealtimeMessageBatcher::Create( realtime_messages_, realtime_messages_mutex_, @@ -381,8 +395,9 @@ RequestSimulationSystem::CreateStreamRecordReaderFactory() { } absl::AnyInvocable RequestSimulationSystem::CreateRequestFromKeyFn() { - return [](std::string_view key) { - return kv_server::CreateKVDSPRequestBodyInJson({std::string(key)}); + return [this](std::string_view key) { + return kv_server::CreateKVDSPRequestBodyInJson( + {std::string(key)}, consented_debug_token_, generation_id_override_); }; } void RequestSimulationSystem::InitializeTelemetry() { @@ -400,8 +415,14 @@ void RequestSimulationSystem::InitializeTelemetry() { {semantic_conventions::kHostArch, std::string(BuildPlatform())}, {kTestingServer, server_address}}; auto resource = Resource::Create(attributes); - ConfigureMetrics(resource, metrics_options); - kv_server::InitMetricsContextMap(); + kv_server::ConfigureTelemetryForTools(); + privacy_sandbox::server_common::telemetry::TelemetryConfig config_proto; + config_proto.set_mode( + privacy_sandbox::server_common::telemetry::TelemetryConfig::EXPERIMENT); + auto* context_map = RequestSimulationContextMap( + privacy_sandbox::server_common::telemetry::BuildDependentConfig( + config_proto), + ConfigurePrivateMetrics(resource, metrics_options)); } } // namespace kv_server diff --git a/tools/request_simulation/request_simulation_system.h b/tools/request_simulation/request_simulation_system.h index 0d7579d6..5e856c47 100644 --- a/tools/request_simulation/request_simulation_system.h +++ b/tools/request_simulation/request_simulation_system.h @@ -34,7 +34,6 @@ #include "grpcpp/grpcpp.h" #include "public/data_loading/readers/riegeli_stream_io.h" #include "public/query/get_values.grpc.pb.h" -#include "src/telemetry/metrics_recorder.h" #include "test/core/util/histogram.h" #include "tools/request_simulation/client_worker.h" #include "tools/request_simulation/delta_based_request_generator.h" @@ -83,7 +82,6 @@ namespace kv_server { class RequestSimulationSystem { public: RequestSimulationSystem( - privacy_sandbox::server_common::MetricsRecorder& metrics_recorder, privacy_sandbox::server_common::SteadyClock& steady_clock, absl::AnyInvocable( const std::string& server_address, @@ -91,8 +89,7 @@ class RequestSimulationSystem { channel_creation_fn, std::unique_ptr parameter_fetcher_for_unit_testing = nullptr) - : metrics_recorder_(metrics_recorder), - steady_clock_(steady_clock), + : steady_clock_(steady_clock), channel_creation_fn_(std::move(channel_creation_fn)) { if (parameter_fetcher_for_unit_testing != nullptr) { parameter_fetcher_ = std::move(parameter_fetcher_for_unit_testing); @@ -136,7 +133,6 @@ class RequestSimulationSystem { absl::AnyInvocable CreateRequestFromKeyFn(); // This must be first, otherwise the AWS SDK will crash when it's called: PlatformInitializer platform_initializer_; - privacy_sandbox::server_common::MetricsRecorder& metrics_recorder_; std::unique_ptr metrics_collector_; privacy_sandbox::server_common::SteadyClock& steady_clock_; absl::AnyInvocable( @@ -145,6 +141,8 @@ class RequestSimulationSystem { channel_creation_fn_; std::string server_address_; std::string server_method_; + std::string consented_debug_token_; + std::optional generation_id_override_; int concurrent_number_of_requests_; int64_t synthetic_requests_fill_qps_; SyntheticRequestGenOption synthetic_request_gen_option_; diff --git a/tools/request_simulation/request_simulation_system_local_test.cc b/tools/request_simulation/request_simulation_system_local_test.cc index 30b3eaf7..99c0cd2b 100644 --- a/tools/request_simulation/request_simulation_system_local_test.cc +++ b/tools/request_simulation/request_simulation_system_local_test.cc @@ -99,7 +99,6 @@ class SimulationSystemTest : public ::testing::Test { } std::unique_ptr fake_get_value_service_; std::unique_ptr server_; - MockMetricsRecorder metrics_recorder_; // std::unique_ptr metrics_collector_; SimulatedSteadyClock sim_clock_; std::unique_ptr sleep_for_request_generator_; @@ -144,7 +143,7 @@ TEST_F(SimulationSystemTest, TestSimulationSystemRunning) { .WillRepeatedly(Return(true)); auto metrics_collector = std::make_unique( - metrics_recorder_, std::move(sleep_for_metrics_collector_)); + std::move(sleep_for_metrics_collector_)); EXPECT_CALL(*metrics_collector, Start()) .WillRepeatedly(Return(absl::OkStatus())); @@ -170,7 +169,7 @@ TEST_F(SimulationSystemTest, TestSimulationSystemRunning) { .WillRepeatedly(Return( LocalNotifierMetadata{.local_directory = ::testing::TempDir()})); RequestSimulationSystem system( - metrics_recorder_, sim_clock_, channel_creation_fn, + sim_clock_, channel_creation_fn, std::move(mock_request_simulation_parameter_fetcher_)); EXPECT_TRUE(system .Init(std::move(sleep_for_request_generator_), diff --git a/tools/request_simulation/synthetic_request_generator_test.cc b/tools/request_simulation/synthetic_request_generator_test.cc index c987dc9a..477299d1 100644 --- a/tools/request_simulation/synthetic_request_generator_test.cc +++ b/tools/request_simulation/synthetic_request_generator_test.cc @@ -55,7 +55,7 @@ TEST_F(TestSyntheticRequestGeneratorTest, TestGenerateRequestsAtFixedRate) { requests_per_second, [option]() { const auto keys = kv_server::GenerateRandomKeys( option.number_of_keys_per_request, option.key_size_in_bytes); - return kv_server::CreateKVDSPRequestBodyInJson(keys); + return kv_server::CreateKVDSPRequestBodyInJson(keys, "debug_token"); }); sim_clock_.AdvanceTime(absl::Seconds(1)); EXPECT_TRUE(request_generator.Start().ok()); @@ -83,7 +83,7 @@ TEST_F(TestSyntheticRequestGeneratorTest, message_queue, rate_limiter, std::move(sleep_for_request_generator_), requests_per_second, [num_of_keys, key_size]() { const auto keys = kv_server::GenerateRandomKeys(num_of_keys, key_size); - return kv_server::CreateKVDSPRequestBodyInJson(keys); + return kv_server::CreateKVDSPRequestBodyInJson(keys, "debug_token"); }); sim_clock_.AdvanceTime(absl::Seconds(1)); EXPECT_TRUE(request_generator.Start().ok()); diff --git a/tools/server_diagnostic/BUILD.bazel b/tools/server_diagnostic/BUILD.bazel index 2112b99b..99ad3778 100644 --- a/tools/server_diagnostic/BUILD.bazel +++ b/tools/server_diagnostic/BUILD.bazel @@ -12,13 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -load( - "@io_bazel_rules_docker//container:container.bzl", - "container_image", - "container_layer", -) -load("@io_bazel_rules_docker//go:image.bzl", "go_image") load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") +load("@rules_oci//oci:defs.bzl", "oci_image", "oci_tarball") load( "@rules_pkg//pkg:mappings.bzl", "pkg_attributes", @@ -39,13 +34,85 @@ go_binary( visibility = ["//visibility:public"], ) +pkg_files( + name = "diagnostic_cli", + srcs = [ + ":diagnostic", + ], + attributes = pkg_attributes(mode = "0555"), + prefix = "/tools/diagnostic_cli", +) + +pkg_tar( + name = "diagnostic_tar", + srcs = [":diagnostic_cli"], +) + +pkg_files( + name = "query_api_descriptor_set", + srcs = [ + "//public/query:query_api_descriptor_set", + ], + attributes = pkg_attributes(mode = "0555"), + prefix = "/tools/query", +) + +pkg_tar( + name = "query_api_descriptor_set_tar", + srcs = [":query_api_descriptor_set"], +) + +pkg_files( + name = "helloworld_server_executables", + srcs = [ + "//tools/server_diagnostic/helloworld_server", + ], + attributes = pkg_attributes(mode = "0555"), + prefix = "/tools/helloworld_server", +) + +pkg_tar( + name = "helloworld_server_binaries_tar", + srcs = [":helloworld_server_executables"], +) + [ - go_image( - name = "diagnostic_go_image_{}".format(arch), - embed = [":diagnostic_lib"], - goarch = arch, - goos = "linux", - visibility = ["//visibility:public"], + genrule( + name = "grpcurl_{}_file".format(arch), + srcs = ["@grpcurl_{}//file".format(arch)], + outs = ["grpcurl_{}".format(arch)], + cmd = "tar -xzf $(location @grpcurl_{arch}//file) --exclude=LICENSE --to-stdout >$(@D)/grpcurl_{arch}".format(arch = arch), + ) + for arch in [ + "x86_64", + "aarch64", + ] +] + +pkg_files( + name = "grpcurl_files_amd64", + srcs = [":grpcurl_x86_64"], + attributes = pkg_attributes(mode = "0555"), + prefix = "/usr/bin", + renames = { + ":grpcurl_x86_64": "grpcurl", + }, +) + +pkg_files( + name = "grpcurl_files_arm64", + srcs = [":grpcurl_aarch64"], + attributes = pkg_attributes(mode = "0555"), + prefix = "/usr/bin", + renames = { + ":grpcurl_aarch64": "grpcurl", + }, +) + +[ + pkg_tar( + name = "grpcurl_tar_{}".format(arch), + srcs = [":grpcurl_files_{}".format(arch)], ) for arch in [ "arm64", @@ -54,10 +121,15 @@ go_binary( ] [ - container_image( - name = "diagnostic_docker_image_{}".format(arch), - base = ":diagnostic_go_image_{}".format(arch), - visibility = ["//visibility:public"], + oci_image( + name = "diagnostic_tools_image_{}".format(arch), + base = "@runtime-ubuntu-fulldist-debug-root-{}//image".format(arch), + tars = [ + ":helloworld_server_binaries_tar", + ":diagnostic_tar", + ":query_api_descriptor_set_tar", + ":grpcurl_tar_{}".format(arch), + ], ) for arch in [ "arm64", @@ -65,36 +137,11 @@ go_binary( ] ] -pkg_files( - name = "server_executables", - srcs = [ - "//tools/server_diagnostic/helloworld_server", - ], - attributes = pkg_attributes(mode = "0555"), - prefix = "/", -) - -pkg_tar( - name = "server_binaries_tar", - srcs = [":server_executables"], -) - -container_layer( - name = "helloworld_server_binary_layer", - directory = "/", - tars = [ - ":server_binaries_tar", - ], -) - [ - container_image( - name = "helloworld_server_docker_image_{}".format(arch), - architecture = arch, - base = "@runtime-debian-debug-nonroot-{}//image".format(arch), - layers = [ - ":helloworld_server_binary_layer", - ], + oci_tarball( + name = "diagnostic_tools_docker_image_{}".format(arch), + image = ":diagnostic_tools_image_{}".format(arch), + repo_tags = ["bazel/tools/server_diagnostic:diagnostic_tools_docker_image"], ) for arch in [ "arm64", @@ -105,17 +152,15 @@ container_layer( genrule( name = "copy_to_dist", srcs = [ - ":diagnostic_docker_image_arm64.tar", - ":diagnostic_docker_image_amd64.tar", - ":helloworld_server_docker_image_arm64.tar", - ":helloworld_server_docker_image_amd64.tar", + ":diagnostic_tools_docker_image_arm64", + ":diagnostic_tools_docker_image_amd64", ], outs = ["copy_to_dist.bin"], cmd_bash = """cat << EOF > '$@' mkdir -p dist/tools/arm64/server_diagnostic -cp $(execpath :diagnostic_docker_image_arm64.tar) $(execpath :helloworld_server_docker_image_arm64.tar) dist/tools/arm64/server_diagnostic +cp $(execpath :diagnostic_tools_docker_image_arm64) dist/tools/arm64/server_diagnostic/diagnostic_tools_docker_image_arm64.tar mkdir -p dist/tools/amd64/server_diagnostic -cp $(execpath :diagnostic_docker_image_amd64.tar) $(execpath :helloworld_server_docker_image_amd64.tar) dist/tools/amd64/server_diagnostic +cp $(execpath :diagnostic_tools_docker_image_amd64) dist/tools/amd64/server_diagnostic/diagnostic_tools_docker_image_amd64.tar builders/tools/normalize-dist EOF""", executable = True, diff --git a/tools/serving_data_generator/test_serving_data_generator.cc b/tools/serving_data_generator/test_serving_data_generator.cc index aebae233..d3d33854 100644 --- a/tools/serving_data_generator/test_serving_data_generator.cc +++ b/tools/serving_data_generator/test_serving_data_generator.cc @@ -18,11 +18,8 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" -#include "absl/log/flags.h" #include "absl/log/initialize.h" #include "absl/log/log.h" -#include "absl/strings/substitute.h" -#include "google/protobuf/text_format.h" #include "public/data_loading/data_loading_generated.h" #include "public/data_loading/filename_utils.h" #include "public/data_loading/records_utils.h" @@ -40,13 +37,20 @@ ABSL_FLAG(int, num_shards, 1, "Number of shards"); ABSL_FLAG(int, shard_number, 0, "Shard number"); ABSL_FLAG(int64_t, timestamp, absl::ToUnixMicros(absl::Now()), "Record timestamp. Increases by 1 for each record."); -ABSL_FLAG(bool, generate_set_record, false, - "Whether to generate set record or not"); -ABSL_FLAG(std::string, set_value_key, "bar", +ABSL_FLAG(bool, generate_string_set_records, false, + "Whether to generate string set records or not"); +ABSL_FLAG(bool, generate_int_set_records, false, + "Whether to generate int set records or not"); +ABSL_FLAG(bool, use_random_elements, false, + "Whether to select set elements from a random range. If false, " + "elements are selected from generated keys."); +ABSL_FLAG(std::string, set_key_prefix, "set", "Specify the set value key prefix for lookups"); ABSL_FLAG(int, num_values_in_set, 10, "Number of values in the set to generate"); ABSL_FLAG(int, num_set_records, 5, "Number of records to generate"); +ABSL_FLAG(uint32_t, range_min, 0, "Minimum element in set records."); +ABSL_FLAG(uint32_t, range_max, 2147483647, "Maximum element in set records."); using kv_server::DataRecordStruct; using kv_server::KeyValueMutationRecordStruct; @@ -57,6 +61,8 @@ using kv_server::ToDeltaFileName; using kv_server::ToFlatBufferBuilder; using kv_server::ToStringView; +const std::array kSetOps = {" - ", " | ", " & "}; + void WriteKeyValueRecord(std::string_view key, std::string_view value, int64_t logical_commit_time, riegeli::RecordWriterBase& writer) { @@ -95,32 +101,52 @@ std::vector WriteKeyValueRecords( } void WriteKeyValueSetRecords(const std::vector& keys, - std::string_view set_value_key_prefix, - int64_t timestamp, + std::string_view set_key_prefix, int64_t timestamp, riegeli::RecordWriterBase& writer) { const int num_set_records = absl::GetFlag(FLAGS_num_set_records); const int num_values_in_set = absl::GetFlag(FLAGS_num_values_in_set); - const int keys_max_index = keys.size() - 1; std::string query(" "); for (int i = 0; i < num_set_records; ++i) { - std::vector set_copy; + std::vector uint32_set; + std::vector string_set; for (int j = 0; j < num_values_in_set; ++j) { - // Add a random element from keys - set_copy.emplace_back(keys[std::rand() % keys_max_index]); - } - std::vector set; - for (const auto& v : set_copy) { - set.emplace_back(v); + // Add a random element + std::srand(absl::GetCurrentTimeNanos()); + auto element = absl::GetFlag(FLAGS_range_min) + + (std::rand() % (absl::GetFlag(FLAGS_range_max) - + absl::GetFlag(FLAGS_range_min))); + if (absl::GetFlag(FLAGS_generate_int_set_records)) { + uint32_set.emplace_back(element); + } + if (absl::GetFlag(FLAGS_generate_string_set_records)) { + if (absl::GetFlag(FLAGS_use_random_elements)) { + string_set.emplace_back(absl::StrCat(element)); + } else { + string_set.emplace_back(keys[std::rand() % (keys.size() - 1)]); + } + } } - std::string set_value_key = absl::StrCat(set_value_key_prefix, i); - absl::StrAppend(&query, set_value_key, " | "); + auto set_value_key = absl::StrCat(set_key_prefix, i); KeyValueMutationRecordStruct record; - record.value = set; record.mutation_type = KeyValueMutationType::Update; record.logical_commit_time = timestamp++; record.key = set_value_key; - writer.WriteRecord(ToStringView( - ToFlatBufferBuilder(DataRecordStruct{.record = std::move(record)}))); + if (absl::GetFlag(FLAGS_generate_int_set_records)) { + record.value = uint32_set; + writer.WriteRecord(ToStringView( + ToFlatBufferBuilder(DataRecordStruct{.record = std::move(record)}))); + } + if (absl::GetFlag(FLAGS_generate_string_set_records)) { + std::vector string_set_view; + for (const auto& v : string_set) { + string_set_view.emplace_back(v); + } + record.value = string_set_view; + writer.WriteRecord(ToStringView( + ToFlatBufferBuilder(DataRecordStruct{.record = std::move(record)}))); + } + absl::StrAppend(&query, set_value_key, + kSetOps[std::rand() % kSetOps.size()]); } LOG(INFO) << "Example set query for all keys" << query; LOG(INFO) << "write done for set records"; @@ -147,7 +173,7 @@ int main(int argc, char** argv) { auto write_records = [](std::ostream* os) { const std::string key = absl::GetFlag(FLAGS_key); const int value_size = absl::GetFlag(FLAGS_value_size); - const std::string set_value_key_prefix = absl::GetFlag(FLAGS_set_value_key); + const std::string set_key_prefix = absl::GetFlag(FLAGS_set_key_prefix); int64_t timestamp = absl::GetFlag(FLAGS_timestamp); auto os_writer = riegeli::OStreamWriter(os); @@ -160,10 +186,9 @@ int main(int argc, char** argv) { auto record_writer = riegeli::RecordWriter(std::move(os_writer), options); const auto keys = WriteKeyValueRecords(key, value_size, timestamp, record_writer); - if (absl::GetFlag(FLAGS_generate_set_record)) { - timestamp += keys.size(); - WriteKeyValueSetRecords(keys, set_value_key_prefix, timestamp, - record_writer); + if (absl::GetFlag(FLAGS_generate_int_set_records) || + absl::GetFlag(FLAGS_generate_string_set_records)) { + WriteKeyValueSetRecords(keys, set_key_prefix, timestamp++, record_writer); } record_writer.Close(); }; diff --git a/tools/udf/inline_wasm/examples/hello_world/BUILD.bazel b/tools/udf/inline_wasm/examples/hello_world/BUILD.bazel index e1936042..6f1fb57e 100644 --- a/tools/udf/inline_wasm/examples/hello_world/BUILD.bazel +++ b/tools/udf/inline_wasm/examples/hello_world/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@google_privacysandbox_servers_common//build_defs/cc:wasm.bzl", "inline_wasm_cc_binary") +load("@google_privacysandbox_servers_common//build_defs/cc:wasm.bzl", "cc_inline_wasm_udf_js", "inline_wasm_cc_binary") load("//tools/udf/inline_wasm:wasm.bzl", "cc_inline_wasm_udf_delta", "inline_wasm_udf_delta") inline_wasm_cc_binary( @@ -45,3 +45,15 @@ cc_inline_wasm_udf_delta( output_file_name = "DELTA_0000000000000007", tags = ["manual"], ) + +# builders/tools/bazel-debian run --config=emscripten \ +# //tools/udf/inline_wasm/examples/hello_world:hello_wasm_js +# +# Creates a JS file with inline WASM + emscripten generated glue JS. +# Custom JS needs to be appended and call `const module = await getModule();` +# See tools/udf/inline_wasm/examples/hello_world/my_udf.js for a custom JS +# that loads the WASM module. +cc_inline_wasm_udf_js( + name = "hello_wasm_js", + srcs = ["hello.cc"], +) diff --git a/tools/udf/sample_udf/BUILD.bazel b/tools/udf/sample_udf/BUILD.bazel index 91022232..920ce70c 100644 --- a/tools/udf/sample_udf/BUILD.bazel +++ b/tools/udf/sample_udf/BUILD.bazel @@ -47,3 +47,20 @@ run_binary( ], tool = "//tools/udf/udf_generator:udf_delta_file_generator", ) + +run_binary( + name = "generate_run_set_query_int_delta", + srcs = [ + ":run_set_query_int_udf.js", + ], + outs = [ + "DELTA_0000000000000007", + ], + args = [ + "--udf_file_path", + "$(location :run_set_query_int_udf.js)", + "--output_path", + "$(location DELTA_0000000000000007)", + ], + tool = "//tools/udf/udf_generator:udf_delta_file_generator", +) diff --git a/tools/udf/sample_udf/run_set_query_int_udf.js b/tools/udf/sample_udf/run_set_query_int_udf.js new file mode 100644 index 00000000..ef405476 --- /dev/null +++ b/tools/udf/sample_udf/run_set_query_int_udf.js @@ -0,0 +1,42 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +function HandleRequest(executionMetadata, ...input) { + let keyGroupOutputs = []; + for (const keyGroup of input) { + let keyGroupOutput = {}; + if (!keyGroup.tags.includes('custom') || !keyGroup.tags.includes('queries')) { + continue; + } + keyGroupOutput.tags = keyGroup.tags; + if (!Array.isArray(keyGroup.data) || !keyGroup.data.length) { + continue; + } + + // Get the first key in the data. + const runQueryArray = runSetQueryInt(keyGroup.data[0]); + // runSetQueryInt returns an Uint8Array of 'uint32' ints and "code" on failure. + // Ignore failures and only add successful runQuery results to output. + if (runQueryArray instanceof Uint8Array) { + const keyValuesOutput = {}; + const value = Array.from(new Uint32Array(runQueryArray.buffer)); + keyValuesOutput['result'] = { value: value }; + keyGroupOutput.keyValues = keyValuesOutput; + keyGroupOutputs.push(keyGroupOutput); + } + } + return { keyGroupOutputs, udfOutputApiVersion: 1 }; +} diff --git a/tools/udf/udf_generator/BUILD.bazel b/tools/udf/udf_generator/BUILD.bazel index 78e56bf3..4ffd816d 100644 --- a/tools/udf/udf_generator/BUILD.bazel +++ b/tools/udf/udf_generator/BUILD.bazel @@ -15,11 +15,7 @@ load("@rules_cc//cc:defs.bzl", "cc_binary") package(default_visibility = [ - "//docs/protected_app_signals:__subpackages__", - "//getting_started:__subpackages__", - "//production/packaging/tools:__subpackages__", - "//testing:__subpackages__", - "//tools:__subpackages__", + "//visibility:public", ]) cc_binary( diff --git a/tools/udf/udf_tester/BUILD.bazel b/tools/udf/udf_tester/BUILD.bazel index de2873b8..baed738f 100644 --- a/tools/udf/udf_tester/BUILD.bazel +++ b/tools/udf/udf_tester/BUILD.bazel @@ -26,6 +26,7 @@ cc_binary( "//components/data_server/cache", "//components/data_server/cache:key_value_cache", "//components/internal_server:local_lookup", + "//components/tools/util:configure_telemetry_tools", "//components/udf:udf_client", "//components/udf:udf_config_builder", "//components/udf/hooks:get_values_hook", diff --git a/tools/udf/udf_tester/udf_delta_file_tester.cc b/tools/udf/udf_tester/udf_delta_file_tester.cc index d7e178fa..3caf6661 100644 --- a/tools/udf/udf_tester/udf_delta_file_tester.cc +++ b/tools/udf/udf_tester/udf_delta_file_tester.cc @@ -21,6 +21,7 @@ #include "components/data_server/cache/cache.h" #include "components/data_server/cache/key_value_cache.h" #include "components/internal_server/local_lookup.h" +#include "components/tools/util/configure_telemetry_tools.h" #include "components/udf/hooks/get_values_hook.h" #include "components/udf/udf_client.h" #include "components/udf/udf_config_builder.h" @@ -40,30 +41,39 @@ ABSL_FLAG(std::string, input_arguments, "", "be equivalent to a UDFArgument."); namespace kv_server { +namespace { +class UDFDeltaFileTestLogContext + : public privacy_sandbox::server_common::log::SafePathContext { + public: + UDFDeltaFileTestLogContext() = default; +}; +} // namespace using google::protobuf::util::JsonStringToMessage; // If the arg is const&, the Span construction complains about converting const // string_view to non-const string_view. Since this tool is for simple testing, // the current solution is to pass by value. -absl::Status LoadCacheFromKVMutationRecord(KeyValueMutationRecordStruct record, - Cache& cache) { +absl::Status LoadCacheFromKVMutationRecord( + UDFDeltaFileTestLogContext& log_context, + KeyValueMutationRecordStruct record, Cache& cache) { switch (record.mutation_type) { case KeyValueMutationType::Update: { LOG(INFO) << "Updating cache with key " << record.key << ", logical commit time " << record.logical_commit_time; std::visit( - [&cache, &record](auto& value) { + [&cache, &record, &log_context](auto& value) { using VariantT = std::decay_t; if constexpr (std::is_same_v) { - cache.UpdateKeyValue(record.key, value, + cache.UpdateKeyValue(log_context, record.key, value, record.logical_commit_time); return; } constexpr bool is_list = (std::is_same_v>); if constexpr (is_list) { - cache.UpdateKeyValueSet(record.key, absl::MakeSpan(value), + cache.UpdateKeyValueSet(log_context, record.key, + absl::MakeSpan(value), record.logical_commit_time); return; } @@ -72,7 +82,7 @@ absl::Status LoadCacheFromKVMutationRecord(KeyValueMutationRecordStruct record, break; } case KeyValueMutationType::Delete: { - cache.DeleteKey(record.key, record.logical_commit_time); + cache.DeleteKey(log_context, record.key, record.logical_commit_time); break; } default: @@ -83,15 +93,17 @@ absl::Status LoadCacheFromKVMutationRecord(KeyValueMutationRecordStruct record, return absl::OkStatus(); } -absl::Status LoadCacheFromFile(std::string file_path, Cache& cache) { +absl::Status LoadCacheFromFile(UDFDeltaFileTestLogContext& log_context, + std::string file_path, Cache& cache) { std::ifstream delta_file(file_path); DeltaRecordStreamReader record_reader(delta_file); - absl::Status status = - record_reader.ReadRecords([&cache](const DataRecordStruct& data_record) { + absl::Status status = record_reader.ReadRecords( + [&cache, &log_context](const DataRecordStruct& data_record) { // Only load KVMutationRecords into cache. if (std::holds_alternative( data_record.record)) { return LoadCacheFromKVMutationRecord( + log_context, std::get(data_record.record), cache); } @@ -137,10 +149,11 @@ void ShutdownUdf(UdfClient& udf_client) { absl::Status TestUdf(const std::string& kv_delta_file_path, const std::string& udf_delta_file_path, const std::string& input_arguments) { - InitMetricsContextMap(); + ConfigureTelemetryForTools(); LOG(INFO) << "Loading cache from delta file: " << kv_delta_file_path; std::unique_ptr cache = KeyValueCache::Create(); - PS_RETURN_IF_ERROR(LoadCacheFromFile(kv_delta_file_path, *cache)) + UDFDeltaFileTestLogContext log_context; + PS_RETURN_IF_ERROR(LoadCacheFromFile(log_context, kv_delta_file_path, *cache)) << "Error loading cache from file"; LOG(INFO) << "Loading udf code config from delta file: " @@ -157,20 +170,21 @@ absl::Status TestUdf(const std::string& kv_delta_file_path, auto binary_get_values_hook = GetValuesHook::Create(GetValuesHook::OutputType::kBinary); binary_get_values_hook->FinishInit(CreateLocalLookup(*cache)); - auto run_query_hook = RunQueryHook::Create(); - run_query_hook->FinishInit(CreateLocalLookup(*cache)); + auto run_set_query_string_hook = RunSetQueryStringHook::Create(); + run_set_query_string_hook->FinishInit(CreateLocalLookup(*cache)); absl::StatusOr> udf_client = UdfClient::Create(std::move( config_builder.RegisterStringGetValuesHook(*string_get_values_hook) .RegisterBinaryGetValuesHook(*binary_get_values_hook) - .RegisterRunQueryHook(*run_query_hook) + .RegisterRunSetQueryStringHook(*run_set_query_string_hook) .RegisterLoggingFunction() .SetNumberOfWorkers(1) .Config())); PS_RETURN_IF_ERROR(udf_client.status()) << "Error starting UDF execution engine"; - auto code_object_status = udf_client.value()->SetCodeObject(code_config); + auto code_object_status = + udf_client.value()->SetCodeObject(code_config, log_context); if (!code_object_status.ok()) { LOG(ERROR) << "Error setting UDF code object: " << code_object_status; ShutdownUdf(*udf_client.value()); @@ -185,9 +199,13 @@ absl::Status TestUdf(const std::string& kv_delta_file_path, JsonStringToMessage(req_partition_json, &req_partition); LOG(INFO) << "Calling UDF for partition: " << req_partition.DebugString(); - auto metrics_context = std::make_unique(); + auto request_context_factory = std::make_unique( + privacy_sandbox::server_common::LogContext(), + privacy_sandbox::server_common::ConsentedDebugConfiguration()); + ExecutionMetadata execution_metadata; auto udf_result = udf_client.value()->ExecuteCode( - RequestContext(*metrics_context), {}, req_partition.arguments()); + *request_context_factory, {}, req_partition.arguments(), + execution_metadata); if (!udf_result.ok()) { LOG(ERROR) << "UDF execution failed: " << udf_result.status(); ShutdownUdf(*udf_client.value()); @@ -205,7 +223,6 @@ absl::Status TestUdf(const std::string& kv_delta_file_path, int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); - const std::string kv_delta_file_path = absl::GetFlag(FLAGS_kv_delta_file_path); const std::string udf_delta_file_path = diff --git a/version.txt b/version.txt index d183d4ac..07feb823 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.16.0 \ No newline at end of file +0.17.0 \ No newline at end of file