diff --git a/.bazelrc b/.bazelrc index a1f59b5d..7996c2cd 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,3 +1,4 @@ +build --announce_rc build --verbose_failures build --client_env=CC=clang build --cxxopt=-std=c++17 @@ -61,3 +62,7 @@ build:ubsan --copt -O1 build:ubsan --copt -fno-omit-frame-pointer build:ubsan --linkopt -fsanitize=undefined build:ubsan --linkopt -lubsan + +# --config local_instance: builds the service to run with the instance=local flag +build:local_instance --//:instance=local +build:local_instance --@google_privacysandbox_servers_common//:instance=local diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1eff7a75..52495579 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,8 +51,13 @@ repos: - id: shellcheck exclude: '^(google_internal|builders/images)/.*$' +- repo: https://github.com/bufbuild/buf + rev: v1.23.1 + hooks: + - id: buf-format + - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.2 + rev: v16.0.6 hooks: - id: clang-format types_or: @@ -65,7 +70,7 @@ repos: name: addlicense language: golang additional_dependencies: - - github.com/google/addlicense@v1.1.0 + - github.com/google/addlicense@v1.1.1 always_run: false pass_filenames: true entry: addlicense -v -ignore google_internal/third_party/** @@ -76,7 +81,7 @@ repos: name: addlicense check language: golang additional_dependencies: - - github.com/google/addlicense@v1.1.0 + - github.com/google/addlicense@v1.1.1 always_run: false pass_filenames: true entry: addlicense -check -ignore google_internal/third_party/** @@ -105,7 +110,7 @@ repos: )$ - repo: https://github.com/DavidAnson/markdownlint-cli2 - rev: v0.7.1 + rev: v0.8.1 hooks: - id: markdownlint-cli2 name: lint markdown @@ -117,7 +122,7 @@ repos: description: Format bazel WORKSPACE, BUILD and .bzl files with a standard convention. language: golang additional_dependencies: - - github.com/bazelbuild/buildtools/buildifier@5.1.0 + - github.com/bazelbuild/buildtools/buildifier@6.1.1 always_run: true pass_filenames: true types_or: diff --git a/.versionrc.json b/.versionrc.json index a254d606..991dace6 100644 --- a/.versionrc.json +++ b/.versionrc.json @@ -15,6 +15,36 @@ ], "tagPrefix": "release-", "types": [ + { + "section": "API: Features", + "type": "feat", + "scope": "api" + }, + { + "section": "API: Fixes", + "type": "fix", + "scope": "api" + }, + { + "section": "Terraform", + "type": "feat", + "scope": "terraform" + }, + { + "section": "Terraform", + "type": "fix", + "scope": "terraform" + }, + { + "section": "Build Tools: Features", + "type": "feat", + "scope": "build" + }, + { + "section": "Build Tools: Fixes", + "type": "fix", + "scope": "build" + }, { "section": "Features", "type": "feat" @@ -27,10 +57,6 @@ "section": "Documentation", "type": "docs" }, - { - "section": "Terraform", - "type": "terraform" - }, { "hidden": true, "type": "internal" diff --git a/BUILD b/BUILD index 4d828322..024012b7 100644 --- a/BUILD +++ b/BUILD @@ -113,3 +113,7 @@ EOF""", local = True, message = "copy bazel build and test logs", ) + +exports_files([ + "buf.yaml", +]) diff --git a/CHANGELOG.md b/CHANGELOG.md index 30b7885e..b28edfa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,83 @@ All notable changes to this project will be documented in this file. See [commit-and-tag-version](https://github.com/absolute-version/commit-and-tag-version) for commit guidelines. -## Release 0.10.0 (2023-05-04) +## 0.11.0 (2023-07-11) + + +### Features + +* [Breaking change] Use UserDefinedFunctionsConfig instead of KVs for loading UDFs. +* [Sharding] Add hpke for s2s communication +* [Sharding] Allow for partial data lookups +* [Sharding] Making downstream requests in parallel +* Add bazel build flag --announce_rc +* Add bool parameter to allow routing V1 requests through V2. +* Add buf format pre-commit hook +* Add build time directive for reentrant parser. +* Add functions to retrieve instance information. +* Add internal run query client and server. +* Add JS hook for set query. +* Add lookup client and server for communication with shards +* Add MessageQueue for the request simulation system +* Add query grammar and interface for set queries. +* Add rate limiter for the request simulation system +* Add second map to store key value set and add set value update interfaces +* Add shard metadata for supporting sharded files +* Add simple microbenchmarks for key value cache +* Add UDF support for format data command. +* Add unit tests for query lexer. +* Adding cluster mappings manager +* Adding padding +* Apply custom lockings on the cache +* Connect InternalRunQuery to the parser +* Extend and simplify collect-logs to capture test outputs +* Extend use of scp deps via data-plane-shared repo +* Implement shard manager +* Move sharding function to public so it's available for file sharding +* Register a logging hook with the UDF. +* Register run query hook with udf framework. +* Sharding - realtime updates +* Sharding read flow fixes +* Simplify work done in set operations. Set operations can be passed by +* Snapshot files support UDF configs. +* Support reading and writing set queries to data files. +* Support reading and writing set values for csv files +* Support reading/writing DataRecords. Requires new DELTA format. +* Support writing sharded files +* Update data_loading.fb to support UDF code updates. +* Update pre-commit hook versions +* Update shard manager mappings continuously +* Upgrade build-system to release-0.28.0 +* Upgrade build-system to v0.30.1 +* Upgrade scp to 0.72.0 +* Use Unix domain socket for internal lookup server. +* Utilize AWS deps via data-plane-shared repo + + +### Bug Fixes + +* Add internal lookup client deadline. +* Catch error if insufficient args specified +* Fix aggregation logic for set values. +* Fix ASAN potential deadlock errors in key_value_cache_test +* Proper memory management of callback hook wrappers. +* Specify 2 workers for UDF execution. +* Upgrade pre-commit hooks +* Use shared pointer for UDF absl::Notification. + + +### Build Tools: Fixes + +* **build:** Add scope-based sections in release notes + + +### Documentation + +* Add docs for data loading capabilities +* Add explanation that access control is managed by IAM for writes. +* Point readme to a new sharding public explainer + +## 0.10.0 (2023-05-04) ### Features @@ -48,6 +124,459 @@ All notable changes to this project will be documented in this file. See [commit * Add Protected Audience API rename banner +## 0.9.0 (2023-04-10) + + +### Features + +* Add a total realtime QPS metric +* Add aws supplied e2e latency +* Add basic UDF functional tests for v2 +* Add error counters for realtime updates +* Add functional test stubs for v2 +* Add target to generate delta for sample udf.js +* Add test data artifacts to dist/test_data/deltas +* Add UDF delta file generator tool. +* Add UDF delta file upload through terraform config. +* Add udf.js delta file to test set +* Upgrade to build-system 0.22.0 and functionaltest-system 0.2.0 + + +### Bug Fixes + +* Add a dashboard for environments +* Add documentation for editing dashboards +* Change envoy log level to debug +* Check that recovery function is valid before calling it. +* Enable docker network cleanup +* Ensure changelog notes use specific version +* ignore interestGroupNames argument +* MetricsRecorder no longer a singleton. +* MetricsRecorder now optional for retry templates. +* Return missing key error status from internal lookup server. +* Upgrade gRPC and make lookup client a singleton +* Use dynamic_cast to get metric_sdk::MeterProvider provider. + + +### Terraform + +* Add us-west-1 terraform + + +### Documentation + +* Add documentation on roma child processes. +* Add instructions on realtime updates +* Add note to use grpcurl for v2 since http has a bug. +* Add v2 request JSON schema +* AWS realtime update capabilities +* Correct udf target name +* Update documentation for building data cli. +* Update realtime metrics querying docs + +## 0.8.0 (2023-03-28) + + +### Features + +* add AWS SQS ReceiveMessage latency histogram +* Add command line flags for parameters +* Add configurable thread count for realtime updater +* Add functional test +* Adding e2e latency measurement for realtime updates +* Allow specifying explicit histogram bucket boundaries +* Allow the blob_storage_util cp command to work with local files. +* Allow the DeltaFileRecordChangeNotifier to read local files as well as from S3 +* Batch delete SQS messages +* Build delta files from csv +* clean up realtime queues +* Configure AWS hosted Prometheus. +* Disable the use of exceptions +* Enhance/Simplify local export of telemetry with OTLP. +* Functional testing of local server with delta files +* make AwsSnsSqsManager thread safe +* Make the blob_storage_change_watcher tool work for local files +* Make the blob_storage_util cat and rm commands work for local files +* Make the blob_storage_util ls command work for local files and refactor out common parts from the AWS binary +* Make the delta_file_watcher tool work for local files +* Move the platform specific server configuration logic to a separate file +* multi-threaded realtime notifier +* realtime tester in a container +* Reuse SQS client +* Speed up test updates publisher +* Tools for generating and inserting realtime test data. +* Upgrade build-system to release-0.18.0 +* Upgrade build-system to release-0.20.0 +* Upgrade debian runtime images to 15 Feb 2023 +* Upgrade to build-system 0.17.0 +* Use a PlatformInitializer so the data_cli will compile for --platform=local + + +### Bug Fixes + +* Add ability to interrupt a SleepFor Duration. +* Add minimum shard size threshold for concurrent reader. +* Launch Envoy first before all other processes. +* Make MetricsRecorder a shared global instance. +* Only run functional tests against local server +* Remove functionaltest/run-server, update docs accordingly +* Remove submodule section from docs +* Run server in background, and reduce noise +* Run server using delta files in run-server-docker +* Update generate_load_test_data to use bazel-debian +* Use symlink for identical test replies +* Use VLOG for concurrent reader debugging logs. +* Wait for envoy to respond before launching enclave service. + + +### Documentation + +* Add a playbook template for new alerts +* Add description for backup_poll_frequency_secs +* Add docs about how to run the server locally +* fix /tmp/deltas discrepancy +* remove obsolete service +* Updating instructions on how to copy `eif` manually. + + +### Terraform + +* Convert tfvar files to json +* Support Prometheus service running in a different region + +## 0.7.0 (2023-02-16) + + +### Features + +* Add --platform flag to build_and_test_all_in_docker +* Add a concurrent record reader for data files. +* Add a helper service for protocol testing +* Add a stub class for reading local blob files +* add automated PCR0 updates to kokoro continuous +* Add base streambuf with seeking support for reading blobs. +* Add delta writer and custom audience data parser +* add github personal access token validation for release scripts +* Add instance id to metrics. +* Add seeking to S3 blob reader. +* Add support for Zipkin exports for local builds. +* Add terraform logic to create SNS for real time updates +* Check timestamps for cache update +* Implement BinaryHTTP version of V2 API +* Implement delta file record change notifier to retrieve high priority updates +* Implement test OHTTP V2 query handling +* Integrating high priority updates in data server +* Memory cleanup for delete timestamps in the cache +* Record metrics for all RetryUntilOk events. Export them to stdout or +* Upgrade black to 23.1.0 +* Upgrade to build-system 0.13.0 +* Upgrade to build-system 0.14.0 +* Upgrade to build-system 0.16.0 +* Use concurrent reader for reading snapshot and delta files. + + +### Bug Fixes + +* Add docker compose config for testing locally. +* add empty bug id to automated PCR0 CL +* Add unit test for delta file backup poll. Fix bug where we can't +* Don't ListBlobs to poll Delta files on notifications that don't +* Don't use default number of cores for small test files. +* fetch git remote for automated pcr0 updates +* Fix typos and remove unreachable branches. +* flaky delta_file_notifier test. +* Listing non-delta files from bucket shouldn't cause state change. +* Only read the most recent snapshot file. +* path for local envoy +* Prefer github release artifacts over archive artifacts +* remove duplicate open telemetry entry. +* remove spaces from automated PCR0 CL commit footer +* Switch jaeger over to using OTLP directly. Jaeger otel component is +* Upgrade to rules_buf 0.1.1 +* Uprev Otel to pull in semantic resource convensions. Use them +* Use shared libraries for proxy + + +### Build System + +* Hide build stdout/stderr for third_party + + +### Documentation + +* Add docs for data loading library. + +## 0.6.0 (2023-01-10) + + +### Features + +* Add --no-precommit flag to build_and_test_all_in_docker +* Add command to generate snapshots to data cli +* Add support for reading snapshots during server startup. +* Implement a snapshot writer. +* Produce PCR0.json for server EIF +* Remove VCS commit info from server --buildinfo +* Reorg build_and_test* scripts +* Store and validate arch-specific PCR0 hash +* update dev to staging copybara to include github workflows +* update GitHub presubmit workflow to trigger on pull request +* Upgrade to build-system 0.10.0 +* Upgrade to build-system 0.5.0 +* Upgrade to build-system 0.6.0 +* Upgrade to gRPC v1.51.1 + + +### Bug Fixes + +* add missing "xray" to vpc_interface_endpoint_services references. +* Adjust git global config +* Attach initial_launch_hook to autoscaling group. +* Avoid non-zero exit on PCR0 hash mismatch +* Correct documentation on endpoint to test +* Fix the region doc for local development. +* Ignore builders/ when executing pre-commit +* LifecycleHeartbeat only Finish once. Fixed unit test. +* Upgrade to addlicense v1.1 +* Use absolute path for kokoro_release.sh +* Use bazel-debian to build and run test_serving_data_generator + + +### Build System + +* Add presubmit GitHub workflow +* Upgrade to bazel 5.4.0 + + +### Documentation + +* Add a default AWS region to push command +* Correct command to run server locally +* Update ECR format and improve the AWS doc order + +## 0.5.0 (2022-11-28) + + +### Features + +* Add basic smoke test +* Add builders/utils docker image +* Add hadolint to lint Dockerfiles +* Add toolchain short hash to bazel output_user_root path +* Add tools/lib/builder.sh +* Add utils for working with snapshot files. +* Adopt build-system release-0.2.0 +* Allow AMI building to specify AWS region. +* Bump debian runtime to stable-20221004-slim +* Rename nitro_artifacts to aws_artifacts +* Set BUILD_ARCH env var in docker images +* Simplify use of --with-ami flag +* Tag small tests +* Upgrade build-debian to python3.9 +* Upgrade to build-system 0.3.1 +* Upgrade to build-system 0.4.3 +* Upgrade to build-system 0.4.4 +* Upgrade to clang v14 on bazel-debian + + +### Bug Fixes + +* Add get_workspace_mount function to encapsulate code block +* Allow server script to accept any flags +* Avoid installing recommended debian packages +* Copy get_values_descriptor_set.pb to dist dir +* Correct shell quoting +* Correct workspace volume when tools/terraform is executed in a nested container +* Execute tests prior to copy_to_dist +* Guess user/group for files when running container as root bazel +* Ignore InvalidArgument error on completing lifecycle hook. +* include a backoff on errors to long poll 'push' notifications. +* Invoke addlicense for all text files +* Invoke unzip via utils image +* Migrate duration code into KV server. +* Minor improvements to shell scripts +* Modify normalize-dist to use builder::id function +* Mount $HOME/aws in aws-cli container +* Move builder-related configs to builders/etc +* Move WORKSPACE definition to cbuild script global +* multi-region support for sqs_lambda +* Propagate AWS env vars and $HOME/.aws into terraform container +* Propagate gcloud stderr +* Reduce noise from tools/collect-logs +* Remove build timestamp to afford stability of binary +* Remove debugging statement +* Remove docker flags -i and -t +* Remove dockerfile linter ignore and correct ENTRYPOINT +* Remove pre-commit config from build-debian +* Rename bazel image name debian-slim to runtime-debian +* Set architecture in container_image declaration +* Set bazel output_base to accommodate distinct workspaces +* Set WORKSPACE variable +* Support regions outside us-east-1 +* unzip should overwrite files +* Update gazelle to v0.28.0 +* Upgrade bazel-skylib to 1.3.0 +* Upgrade rules_pkg to 0.8.0 +* Use builder library functions + + +### Documentation + +* Add error handling guidelines. +* Add submodule instructions +* Fix build command +* fix typo in aws doc +* recommend the use of native AWS CLI in documentation +* Remove an unnecessary step in server doc + +## 0.4.0 (2022-10-11) + + +### Features + +* Add //:buildifier rule as an alias to the pre-commit buildifier hook +* Add a debugging endpoint for Binary Http GetValues. +* Add aws-cli helper script +* Add csv reader and writer +* Add delta record reader based on riegeli stream io. +* Add library for reading and writing delta files +* Add utility for data generation (currently supports csv to delta and vice versa) +* Add version info to server binaries +* Determine workspace mount point from docker inspect if inside docker container +* Display pre-commit error log if it exists +* Implement API call to record lifecycle heartbeat. +* Log bazel build flags during server startup runtime. +* Overhaul building on amazonlinux2 +* Repeating timer with callback implementation +* Set working dir to current workspace-relative path in tools/terraform + + +### Bug Fixes + +* Add bazel rule to copy files to dist dir +* Add builders/tools/normalize-dist to chmod/chgrp/chown dist/ directory tree +* Add fetch git tags from remote prior to syncing repos +* Add files to subject to chown and chgrp +* Adjust chown/chgrp to be silent +* Adopt shellcheck +* Clean bazel_root for smaller docker image +* Correct the WORKSPACE path in production/packaging/aws/build_and_test +* Correct variable used to check for valid repo name +* Drop packer from build-debian image +* Fix a typo and improve some logging for DataOrchestrator's loading +* Improve cbuild help text +* Improve git push message to accommodate patch branches +* Increase SQS cleanup lamabda timeout +* Modifications as indicated by shellcheck +* Modify build and packaging for AWS SQS Lambda +* Move definition from header to cc to eliminate linker error +* Only propagate AWS env vars into amazonlinux2 build container +* pre-commit CLEANUP should default to zero +* Print pre-commit version rather than help +* Print timestamps in UTC timezone +* Remove container when get-architecture exits +* Remove duplicate text "instance:" in build flavor +* Remove shellcheck from build-debian +* Remove unused nitro_enclave_image bazel rule +* Set bazel output_user_root in image bazelrc +* Set locale in build-debian +* Strip commit hashes from CHANGELOG.md +* Switch from hardcoded arch to using dpkg --print-architecture +* Update author name +* Update pre-commit to use cbuild +* Use --with-ami flag to determine bazel flags instance/platform +* Use default health check grace priod (300s) now that we have heartbeats. +* Use git rather than bazel to determine workspace root +* Use PRE_COMMIT_TOOL env var + + +### Build System + +* Add arch to docker image tags +* Add get_builder_image_tagged tool to determine a content-based tag +* Add get-architecture helper script +* Propagate status code in exit functions + + +### Documentation + +* Correct command to load docker image locally +* Instructions to make changes to a dependency +* Sugggest use of python virtualenv +* Use concise form of passing env vars into docker container + +## 0.3.0 (2022-09-14) + + +### Features + +* Add --env flag to cbuild +* Update release image to node v18 + + +### Bug Fixes + +* Bump to latest version of bazelisk +* Consolidate bazel/ dir into third_party/ dir +* Ensure appropriate ownership of modified files +* fix local docker run command +* Improve shell string quoting +* Invoke bash via /usr/bin/env +* Propagate SKIP env var into pre-commit container + +## 0.2.0 (2022-09-07) + + +### Features + +* Add arg processing to cbuild script +* Add bazel-debian helper script +* Add black python formatter to pre-commit +* Add optional flags to cbuild tool +* Inject clang-version as a bazel action_env +* Migrate generation of builders/release container image to Dockerfile +* Remove python build dependencies from bazel +* Support installation of git pre-commit hook +* Support running the server container locally +* Use EC2 instance connect for ssh access. + + +### Bug Fixes + +* Add public/ to presubmit tests +* Add python version to action_env +* Add require-ascii pre-commit hook +* Add/remove basic pre-commit hooks +* Allow for short(er) presubmit builds +* Change ownership of dist/ to user +* Correct git instructions at the end of cut_release +* Define python3 toolchain +* Ensure /etc/gitconfig is readable by all +* Install bazel version as specified in .bazelversion +* Move bazel env vars from comments to help text +* Pin version of bazelisk +* Pin version of libc++-dev +* Pin version of python3.8 +* Remove check for uncommitted changes, tools/pre-commit exit status should suffice +* Tidy cut_release +* Upgrade zstd to v1.5.2 + + +### Build System + +* Add buildifier tool to pre-commit +* Add terraform fmt to pre-commit +* Add terraform to tools +* Add tools/pre-commit +* Adjust branch specification for mirroring to GitHub +* Move commit-and-tag-version into tools dir +* Move gh into tools dir +* Reduce redundant installation commands +* Reinstate use of cpplint, via pre-commit tool +* Remove buildifier from bazel build as it is redundant with pre-commit +* Remove release-please tool +* Rename builders/bazel to build-debian + ### 0.1.0 (2022-08-18) diff --git a/README.md b/README.md index 89532523..fe68a635 100644 --- a/README.md +++ b/README.md @@ -73,8 +73,10 @@ changes. - [Local server quickstart guide](/docs/developing_the_server.md) - [AWS server user deployment documentation](/docs/deploying_on_aws.md) - [Integrating the K/V server with FLEDGE](/docs/integrating_with_fledge.md) +- [FLEDGE K/V server sharding explainer](https://github.com/privacysandbox/fledge-docs/blob/main/key_value_sharding.md) - Operating documentation - [Data loading API and operations](/docs/loading_data.md) + - [Generating and loading UDF files](/docs/generating_udf_files.md) - Error handling explainer (_to be published_) - Developer guide - [Codebase structure](/docs/repo_layout.md) diff --git a/WORKSPACE b/WORKSPACE index a12f3dad..bceb01cc 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -5,171 +5,52 @@ local_repository( path = "testing/functionaltest-system", ) -http_archive( - name = "boringssl", - sha256 = "0cd64ecff9e5f757988b84b7685e968775de08ea9157656d0b9fee0fa62d67ec", - strip_prefix = "boringssl-c2837229f381f5fcd8894f0cca792a94b557ac52", - urls = ["https://github.com/google/boringssl/archive/c2837229f381f5fcd8894f0cca792a94b557ac52.tar.gz"], -) - -http_archive( - name = "bazel_skylib", - sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", - "https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", - ], -) - -load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") - -bazel_skylib_workspace() - -http_archive( - name = "com_google_protobuf", - sha256 = "e51cc8fc496f893e2a48beb417730ab6cbcb251142ad8b2cd1951faa5c76fe3d", # Last updated 2022-09-29 - strip_prefix = "protobuf-3.20.3", - urls = ["https://github.com/protocolbuffers/protobuf/releases/download/v3.20.3/protobuf-cpp-3.20.3.tar.gz"], -) - -http_archive( - name = "io_bazel_rules_go", - sha256 = "16e9fca53ed6bd4ff4ad76facc9b7b651a89db1689a2877d6fd7b82aa824e366", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.34.0/rules_go-v0.34.0.zip", - "https://github.com/bazelbuild/rules_go/releases/download/v0.34.0/rules_go-v0.34.0.zip", - ], -) - load("//builders/bazel:deps.bzl", "python_deps") python_deps("//builders/bazel") http_archive( - name = "bazel_gazelle", - sha256 = "448e37e0dbf61d6fa8f00aaa12d191745e14f07c31cabfa731f0c8e8a4f41b97", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.28.0/bazel-gazelle-v0.28.0.tar.gz", - "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.28.0/bazel-gazelle-v0.28.0.tar.gz", - ], -) - -load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies") -load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies") - -go_rules_dependencies() - -### go_register_toolchains will be called by grpc_extra_deps -# go_register_toolchains(go_version = "1.18") -### gRPC -http_archive( - name = "com_github_grpc_grpc", - sha256 = "ec125d7fdb77ecc25b01050a0d5d32616594834d3fe163b016768e2ae42a2df6", - strip_prefix = "grpc-1.52.1", + name = "google_privacysandbox_servers_common", + # commit 1afdb3d4e59bcc422ac769025ccde4460b48569c 2023-05-31 + sha256 = "32eeade4bad14fef6d2884e866e538c1312b8a11a8d55df0b176c925ca17fe5e", + strip_prefix = "data-plane-shared-libraries-1afdb3d4e59bcc422ac769025ccde4460b48569c", urls = [ - "https://github.com/grpc/grpc/archive/v1.52.1.tar.gz", + "https://github.com/privacysandbox/data-plane-shared-libraries/archive/1afdb3d4e59bcc422ac769025ccde4460b48569c.zip", ], ) -load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") - -grpc_deps() - -load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") - -grpc_extra_deps() - -### gazelle deps must be loaded after go toolchains registered -gazelle_dependencies() - -http_archive( - name = "rules_pkg", - sha256 = "eea0f59c28a9241156a47d7a8e32db9122f3d50b505fae0f33de6ce4d9b61834", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_pkg/releases/download/0.8.0/rules_pkg-0.8.0.tar.gz", - "https://github.com/bazelbuild/rules_pkg/releases/download/0.8.0/rules_pkg-0.8.0.tar.gz", - ], +load( + "@google_privacysandbox_servers_common//third_party:cpp_deps.bzl", + data_plane_shared_deps_cpp = "cpp_dependencies", ) -load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") - -rules_pkg_dependencies() - -### rules_buf (https://docs.buf.build/build-systems/bazel) -http_archive( - name = "rules_buf", - sha256 = "523a4e06f0746661e092d083757263a249fedca535bd6dd819a8c50de074731a", - strip_prefix = "rules_buf-0.1.1", - urls = ["https://github.com/bufbuild/rules_buf/archive/refs/tags/v0.1.1.zip"], -) +data_plane_shared_deps_cpp() -load("@rules_buf//buf:repositories.bzl", "rules_buf_dependencies", "rules_buf_toolchains") +load("@google_privacysandbox_servers_common//third_party:deps1.bzl", data_plane_shared_deps1 = "deps1") -rules_buf_dependencies() +data_plane_shared_deps1() -rules_buf_toolchains(version = "v1.7.0") +load("@google_privacysandbox_servers_common//third_party:deps2.bzl", data_plane_shared_deps2 = "deps2") -http_archive( - name = "io_bazel_rules_docker", - sha256 = "b1e80761a8a8243d03ebca8845e9cc1ba6c82ce7c5179ce2b295cd36f7e394bf", - urls = ["https://github.com/bazelbuild/rules_docker/releases/download/v0.25.0/rules_docker-v0.25.0.tar.gz"], -) +data_plane_shared_deps2(go_toolchains_version = "1.19.9") -load( - "@io_bazel_rules_docker//repositories:repositories.bzl", - container_repositories = "repositories", -) +load("@google_privacysandbox_servers_common//third_party:deps3.bzl", data_plane_shared_deps3 = "deps3") -container_repositories() +data_plane_shared_deps3() -load("@io_bazel_rules_docker//repositories:deps.bzl", docker_container_deps = "deps") +load("@google_privacysandbox_servers_common//third_party:deps4.bzl", data_plane_shared_deps4 = "deps4") -docker_container_deps() - -load("//third_party:container_deps.bzl", "container_deps") - -container_deps() - -http_archive( - name = "google_privacysandbox_servers_common", - sha256 = "4158164f52e719e5948e5b43bae01b111e5f1cc38e66516d35e37927b0316ff1", - strip_prefix = "data-plane-shared-libraries-5f9c6fc89e32f944ca208ed4ee1a2c71777cc483", - urls = [ - "https://github.com/privacysandbox/data-plane-shared-libraries/archive/5f9c6fc89e32f944ca208ed4ee1a2c71777cc483.zip", - ], -) +data_plane_shared_deps4() load("//third_party:cpp_repositories.bzl", "cpp_repositories") cpp_repositories() -load("@google_privacysandbox_servers_common//third_party:scp_deps.bzl", "scp_deps") - -scp_deps() - -load("@google_privacysandbox_servers_common//third_party:scp_deps2.bzl", "scp_deps2") - -scp_deps2() - -load("@v8_python_deps//:requirements.bzl", install_v8_python_deps = "install_deps") - -install_v8_python_deps() - -load("//third_party:quiche.bzl", "quiche_dependencies") - -quiche_dependencies() - -load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") - -protobuf_deps() - -# Load OpenTelemetry dependencies after load. -load("//third_party:open_telemetry.bzl", "open_telemetry_dependencies") +load("//third_party:container_deps.bzl", "container_deps") -open_telemetry_dependencies() +container_deps() # emscripten - http_archive( name = "emsdk", sha256 = "d55e3c73fc4f8d1fecb7aabe548de86bdb55080fe6b12ce593d63b8bade54567", @@ -193,17 +74,40 @@ http_archive( urls = ["https://github.com/googleapis/googleapis/archive/f91b6cf82e929280f6562f6110957c654bd9e2e6.tar.gz"], ) -load("@io_opentelemetry_cpp//bazel:repository.bzl", "opentelemetry_cpp_deps") - -opentelemetry_cpp_deps() - -load("@rules_foreign_cc//foreign_cc:repositories.bzl", "rules_foreign_cc_dependencies") - -rules_foreign_cc_dependencies() - http_archive( name = "distributed_point_functions", sha256 = "19cd27b36b0ceba683c02fc6c80e61339397afc3385b91d54210c5db0a254ef8", strip_prefix = "distributed_point_functions-45da5f54836c38b73a1392e846c9db999c548711", urls = ["https://github.com/google/distributed_point_functions/archive/45da5f54836c38b73a1392e846c9db999c548711.tar.gz"], ) + +# Dependencies for Flex/Bison build rules +http_archive( + name = "rules_m4", + sha256 = "10ce41f150ccfbfddc9d2394ee680eb984dc8a3dfea613afd013cfb22ea7445c", + urls = ["https://github.com/jmillikin/rules_m4/releases/download/v0.2.3/rules_m4-v0.2.3.tar.xz"], +) + +load("@rules_m4//m4:m4.bzl", "m4_register_toolchains") + +m4_register_toolchains(version = "1.4.18") + +http_archive( + name = "rules_bison", + sha256 = "2279183430e438b2dc77cacd7b1dbb63438971b2411406570f1ddd920b7c9145", + urls = ["https://github.com/jmillikin/rules_bison/releases/download/v0.2.2/rules_bison-v0.2.2.tar.xz"], +) + +load("@rules_bison//bison:bison.bzl", "bison_register_toolchains") + +bison_register_toolchains(version = "3.3.2") + +http_archive( + name = "rules_flex", + sha256 = "8929fedc40909d19a4b42548d0785f796c7677dcef8b5d1600b415e5a4a7749f", + urls = ["https://github.com/jmillikin/rules_flex/releases/download/v0.2.1/rules_flex-v0.2.1.tar.xz"], +) + +load("@rules_flex//flex:flex.bzl", "flex_register_toolchains") + +flex_register_toolchains(version = "2.6.4") diff --git a/buf.lock b/buf.lock new file mode 100644 index 00000000..65d8f1f7 --- /dev/null +++ b/buf.lock @@ -0,0 +1,8 @@ +# Generated by buf. DO NOT EDIT. +version: v1 +deps: + - remote: buf.build + owner: googleapis + repository: googleapis + commit: cc916c31859748a68fd229a3c8d7a2e8 + digest: shake256:469b049d0eb04203d5272062636c078decefc96fec69739159c25d85349c50c34c7706918a8b216c5c27f76939df48452148cff8c5c3ae77fa6ba5c25c1b8bf8 diff --git a/public/buf.yaml b/buf.yaml similarity index 89% rename from public/buf.yaml rename to buf.yaml index 0a4e9448..ede43b1c 100644 --- a/public/buf.yaml +++ b/buf.yaml @@ -13,6 +13,8 @@ # limitations under the License. version: v1 +deps: +- buf.build/googleapis/googleapis lint: use: - DEFAULT @@ -24,10 +26,9 @@ lint: - ENUM_VALUE_PREFIX ignore_only: RPC_RESPONSE_STANDARD_NAME: - - public/query/get_values.proto - - public/query/v2/get_values_v2.proto + - public/query/v2/get_values_v2.proto RPC_REQUEST_RESPONSE_UNIQUE: - - public/query/v2/get_values_v2.proto + - public/query/v2/get_values_v2.proto enum_zero_value_suffix: _UNSPECIFIED rpc_allow_same_request_response: false rpc_allow_google_protobuf_empty_requests: false diff --git a/builders/.pre-commit-config.yaml b/builders/.pre-commit-config.yaml index f0949c71..7f683b56 100644 --- a/builders/.pre-commit-config.yaml +++ b/builders/.pre-commit-config.yaml @@ -45,23 +45,27 @@ repos: - id: script-must-have-extension - id: require-ascii - id: shellcheck - exclude: '^(tools|google_internal|images)/.*$' - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v14.0.6 + rev: v16.0.4 hooks: - id: clang-format types_or: - c++ - c +- repo: https://github.com/bufbuild/buf + rev: v1.19.0 + hooks: + - id: buf-format + - repo: local hooks: - id: addlicense name: addlicense language: golang additional_dependencies: - - github.com/google/addlicense@v1.1.0 + - github.com/google/addlicense@v1.1.1 always_run: false pass_filenames: true entry: addlicense -v @@ -72,7 +76,7 @@ repos: name: addlicense check language: golang additional_dependencies: - - github.com/google/addlicense@v1.1.0 + - github.com/google/addlicense@v1.1.1 always_run: false pass_filenames: true entry: addlicense -check @@ -105,7 +109,7 @@ repos: - markdown - repo: https://github.com/DavidAnson/markdownlint-cli2 - rev: v0.6.0 + rev: v0.7.1 hooks: - id: markdownlint-cli2 name: lint markdown @@ -117,7 +121,7 @@ repos: description: Format bazel WORKSPACE, BUILD and .bzl files with a standard convention. language: golang additional_dependencies: - - github.com/bazelbuild/buildtools/buildifier@5.1.0 + - github.com/bazelbuild/buildtools/buildifier@6.1.1 always_run: true pass_filenames: true types_or: @@ -140,7 +144,7 @@ repos: - --quiet - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 23.3.0 hooks: - id: black name: black python formatter diff --git a/builders/CHANGELOG.md b/builders/CHANGELOG.md index 15f363c0..07d5709c 100644 --- a/builders/CHANGELOG.md +++ b/builders/CHANGELOG.md @@ -2,52 +2,157 @@ All notable changes to this project will be documented in this file. See [commit-and-tag-version](https://github.com/absolute-version/commit-and-tag-version) for commit guidelines. -## [0.23.0](https://team/kiwi-air-force-eng-team/build-system/compare/release-0.22.0...release-0.23.0) (2023-04-13) +## 0.30.1 (2023-06-27) + + +### Bug Fixes + +* Use = for --env flag +* Use = for --env flag for all tools + +## 0.30.0 (2023-06-26) + + +### Features + +* Install numpy for python3.9 +* Set PYTHON_BIN_PATH/PYTHON_LIB_PATH in build-debian +* Upgrade AmazonLinux2 to 20230530 +* Upgrade packer to v1.9.1 + + +### Bug Fixes + +* Add links for llvm-{cov,profdata} + +## 0.29.0 (2023-06-05) + + +### Features + +* Update pre-commit hook versions + + +### Bug Fixes + +* Catch error when shifting multiple args +* Remove golang from test-tools image +* Resolve WORKSPACE using realpath +* Use correct exit code in --fast mode + +## 0.28.0 (2023-05-24) + + +### Features + +* Update ca-certificates + + +### Bug Fixes + +* Downgrade to clang v15 +* Use builders version.txt for tarfile tag + +## 0.27.0 (2023-05-23) + + +### Features + +* Add buf to presubmit image + + +### Documentation + +* Add CONTRIBUTING.md + +## 0.26.0 (2023-05-16) + + +### Features + +* Remove zlib-dev package +* Upgrade clang to v16 +* Upgrade go to v1.20.4 + +## 0.25.0 (2023-05-11) + + +### Features + +* Update default bazel version to 5.4.1 +* Upgrade rules_python to 0.21.0 + + +### Bug Fixes + +* Add file utility to build-debian image + +## 0.24.0 (2023-05-05) + + +### Features + +* Add --build-images flag to tests/run-tests +* Reuse tar image if available + + +### Bug Fixes + +* Address linter warnings +* Address linter warnings for tools +* Correct mangled usage text +* Pin pre-commit to 3.x +* Remove .gz suffix from tar file +* Remove function keyword for busybox sh script +* Remove Release in changelog title +* Upgrade pre-commit hooks + +## 0.23.0 (2023-04-13) ### Features -* Add wrapper for commit-and-tag-version ([9a996c0]( )) -* Upgrade curl to version 8 ([9f30521]( )) -* Upgrade to amazonlinux 2.0.20230320.0 ([e5f9c1d]( )) +* Add wrapper for commit-and-tag-version +* Upgrade curl to version 8 +* Upgrade to amazonlinux 2.0.20230320.0 ### Bug Fixes -* Use commit-and-tag-version wrapper ([a5125a4]( )) +* Use commit-and-tag-version wrapper -## [0.22.0](https://team/kiwi-air-force-eng-team/build-system/compare/release-0.21.1...release-0.22.0) (2023-04-03) +## 0.22.0 (2023-04-03) ### Features -* Add awscurl wrapper script ([58b2ce1]( )) -* Add tests for misc CLI wrappers ([105c7ee]( )) -* Correctly quote bash args ([d57a0de]( )) -* Extend test-tool to support the release image ([61129e8]( )) -* Use login shell for interactive container ([67fe3e0]( )) +* Add awscurl wrapper script +* Add tests for misc CLI wrappers +* Correctly quote bash args +* Extend test-tool to support the release image +* Use login shell for interactive container ### Documentation -* Add section on tools to README ([d572517]( )) -* Remove section on building images directly ([19d47ec]( )) +* Add section on tools to README +* Remove section on building images directly -### [0.21.1] (2023-03-07) +## 0.21.1 (2023-03-07) ### Bug Fixes * Relax pinned version for apache2-utils -## [0.21.0] (2023-03-06) +## 0.21.0 (2023-03-06) ### Features * Add wrapper scripts for utils -## [0.20.0] (2023-03-01) +## 0.20.0 (2023-03-01) ### Features @@ -57,7 +162,7 @@ All notable changes to this project will be documented in this file. See [commit * Permit testing of a single image * Relax pinned versions in build-debian, presubmit and test-tools -## [0.19.0] (2023-03-01) +## 0.19.0 (2023-03-01) ### Features @@ -69,14 +174,14 @@ All notable changes to this project will be documented in this file. See [commit * Relax pinned version of openjdk to 11.0.* -## [0.18.0] (2023-02-23) +## 0.18.0 (2023-02-23) ### Features * Relax pinned versions for apk and yum packages to semver -## [0.17.0] (2023-02-21) +## 0.17.0 (2023-02-21) ### Features @@ -91,21 +196,21 @@ All notable changes to this project will be documented in this file. See [commit * Minor code cleanup in images/presubmit/install_apps * Upgrade ghz to 0.114.0 -## [0.16.0] (2023-02-05) +## 0.16.0 (2023-02-05) ### Features * Run test tools in docker interactive mode to admit std streams -### [0.15.1] (2023-02-04) +## 0.15.1 (2023-02-04) ### Bug Fixes * Return value from get_docker_workspace_mount() -## [0.15.0] (2023-02-03) +## 0.15.0 (2023-02-03) ### Features @@ -117,21 +222,21 @@ All notable changes to this project will be documented in this file. See [commit * Pin commit-and-tag-version to v10.1.0 -## [0.14.0] (2023-01-27) +## 0.14.0 (2023-01-27) ### Features * Improve verbose output for get-builder-image-tagged -### [0.13.1] (2023-01-26) +## 0.13.1 (2023-01-26) ### Bug Fixes * Upgrade software-properties-common -## [0.13.0] (2023-01-23) +## 0.13.0 (2023-01-23) ### Features @@ -151,7 +256,7 @@ All notable changes to this project will be documented in this file. See [commit * Upgrade amazonlinux2 base image * Upgrade git on amazonlinux2 -## [0.12.0] (2023-01-10) +## 0.12.0 (2023-01-10) ### Features @@ -159,7 +264,7 @@ All notable changes to this project will be documented in this file. See [commit * Modify ghz wrapper for generic use. Add curl * Use test-tools image for grpcurl -## [0.11.0] (2023-01-09) +## 0.11.0 (2023-01-09) ### Features @@ -171,14 +276,14 @@ All notable changes to this project will be documented in this file. See [commit * Clean up tmpdir via RETURN trap -## [0.10.0] (2023-01-06) +## 0.10.0 (2023-01-06) ### Features * Drop ubuntu package version minor for curl -## [0.9.0] (2023-01-04) +## 0.9.0 (2023-01-04) ### Features @@ -193,7 +298,7 @@ All notable changes to this project will be documented in this file. See [commit * Elide warnings from bazel info * Revert from clang-format v15 to v14 -## [0.8.0] (2022-12-29) +## 0.8.0 (2022-12-29) ### Features @@ -203,7 +308,7 @@ All notable changes to this project will be documented in this file. See [commit * Ensure run-tests includes all images * Skip symlinks that resolve in normalize-bazel-symlink -## [0.7.0] (2022-12-27) +## 0.7.0 (2022-12-27) ### Features @@ -211,7 +316,7 @@ All notable changes to this project will be documented in this file. See [commit * Add ghz wrapper script * Add test-tools image -## [0.6.0] (2022-12-12) +## 0.6.0 (2022-12-12) ### Features @@ -224,7 +329,7 @@ All notable changes to this project will be documented in this file. See [commit * Emit docker build output only on non-zero exit * Remove tempfile before exiting -## [0.5.0] (2022-12-06) +## 0.5.0 (2022-12-06) ### Features @@ -238,14 +343,14 @@ All notable changes to this project will be documented in this file. See [commit * Update version pin for ca-certificates * Use images subdirs for image list -### [0.4.4] (2022-11-18) +## 0.4.4 (2022-11-18) ### Bug Fixes * Retain execute permissions when normalizing dist -### [0.4.3] (2022-11-17) +## 0.4.3 (2022-11-17) ### Bug Fixes @@ -255,14 +360,14 @@ All notable changes to this project will be documented in this file. See [commit * Improve verbose output for get-builder-image-tagged * Pin apt and yum package versions -### [0.4.2] (2022-11-17) +## 0.4.2 (2022-11-17) ### Bug Fixes * Generate SHA within docker container -### [0.4.1] (2022-11-15) +## 0.4.1 (2022-11-15) ### Bug Fixes @@ -270,7 +375,7 @@ All notable changes to this project will be documented in this file. See [commit * Reduce noise creating presubmit image * Remove docker run --interactive flag -## [0.4.0] (2022-11-14) +## 0.4.0 (2022-11-14) ### Features @@ -288,14 +393,14 @@ All notable changes to this project will be documented in this file. See [commit * Add xz to build-debian * Explicitly add machine type and OS release to toolchains hash -### [0.3.1] (2022-11-01) +## 0.3.1 (2022-11-01) ### Bug Fixes * Add OpenJDK 11 in build-amazonlinux2 -## [0.3.0] (2022-11-01) +## 0.3.0 (2022-11-01) ### Features @@ -310,7 +415,7 @@ All notable changes to this project will be documented in this file. See [commit * Ensure builder::set_workspace does not overwrite WORKSPACE -## [0.2.0] (2022-10-26) +## 0.2.0 (2022-10-26) ### Features diff --git a/builders/CONTRIBUTING.md b/builders/CONTRIBUTING.md new file mode 100644 index 00000000..0d0e3e85 --- /dev/null +++ b/builders/CONTRIBUTING.md @@ -0,0 +1,3 @@ +# How to Contribute + +Presently this project is not accepting contributions. diff --git a/builders/bazel/deps.bzl b/builders/bazel/deps.bzl index ab98fb3b..55320399 100644 --- a/builders/bazel/deps.bzl +++ b/builders/bazel/deps.bzl @@ -26,8 +26,10 @@ def python_deps(bazel_package): """ http_archive( name = "rules_python", - sha256 = "8c8fe44ef0a9afc256d1e75ad5f448bb59b81aba149b8958f02f7b3a98f5d9b4", - strip_prefix = "rules_python-0.13.0", - url = "https://github.com/bazelbuild/rules_python/archive/refs/tags/0.13.0.tar.gz", + sha256 = "94750828b18044533e98a129003b6a68001204038dc4749f40b195b24c38f49f", + strip_prefix = "rules_python-0.21.0", + urls = [ + "https://github.com/bazelbuild/rules_python/releases/download/0.21.0/rules_python-0.21.0.tar.gz", + ], ) native.register_toolchains("{}:py_toolchain".format(bazel_package)) diff --git a/builders/etc/.bazelversion b/builders/etc/.bazelversion index 84197c89..ade65226 100644 --- a/builders/etc/.bazelversion +++ b/builders/etc/.bazelversion @@ -1 +1 @@ -5.3.2 +5.4.1 diff --git a/builders/images/build-amazonlinux2/Dockerfile b/builders/images/build-amazonlinux2/Dockerfile index 4be3919d..ace0169e 100644 --- a/builders/images/build-amazonlinux2/Dockerfile +++ b/builders/images/build-amazonlinux2/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM amazonlinux:2.0.20230320.0 +FROM amazonlinux:2.0.20230530.0 COPY /install_apps install_golang_apps install_go.sh generate_system_bazelrc .bazelversion /scripts/ COPY get_workspace_mount /usr/local/bin diff --git a/builders/images/build-amazonlinux2/install_apps b/builders/images/build-amazonlinux2/install_apps index c0f884b8..1e6fdc75 100755 --- a/builders/images/build-amazonlinux2/install_apps +++ b/builders/images/build-amazonlinux2/install_apps @@ -73,7 +73,7 @@ function install_misc() { function install_packer() { yum install -y "yum-utils-1.1.31*" yum-config-manager --add-repo https://rpm.releases.hashicorp.com/AmazonLinux/hashicorp.repo - yum -y install "packer-1.8.5*" + yum -y install "packer-1.9.1*" update-alternatives --install /usr/local/bin/packer packer /usr/bin/packer 100 /usr/local/bin/packer version diff --git a/builders/images/build-debian/Dockerfile b/builders/images/build-debian/Dockerfile index 2b748984..5236eed9 100644 --- a/builders/images/build-debian/Dockerfile +++ b/builders/images/build-debian/Dockerfile @@ -38,4 +38,6 @@ RUN \ /scripts/install_golang_apps && \ rm -rf /scripts -ENV PATH="${PATH}:/usr/local/go/bin:/opt/bin" +ENV PATH="${PATH}:/usr/local/go/bin:/opt/bin" \ + PYTHON_BIN_PATH="/opt/bin/python3" \ + PYTHON_LIB_PATH="/usr/lib/python3.9" diff --git a/builders/images/build-debian/install_apps b/builders/images/build-debian/install_apps index c2d69168..144e3e4d 100755 --- a/builders/images/build-debian/install_apps +++ b/builders/images/build-debian/install_apps @@ -3,8 +3,8 @@ set -o pipefail set -o errexit -VERBOSE=0 -INSTALL_LOCALE=en_US.UTF-8 +declare -i VERBOSE=0 +declare INSTALL_LOCALE=en_US.UTF-8 usage() { local exitval=${1-1} @@ -22,17 +22,13 @@ while [[ $# -gt 0 ]]; do case "$1" in --locale) INSTALL_LOCALE="$2" - shift - shift + shift 2 || usage ;; --verbose) VERBOSE=1 shift ;; - -h | --help) - usage 0 - break - ;; + -h | --help) usage 0 ;; *) usage 0 ;; esac done @@ -46,18 +42,31 @@ function apt_update() { apt-get --quiet -o 'Acquire::https::No-Cache=True' -o 'Acquire::http::No-Cache=True' update } +function install_python() { + mkdir -p /opt/bin + update-alternatives --install /opt/bin/python3 python3 /usr/bin/python3.9 100 + update-alternatives --install /opt/bin/python python /usr/bin/python3.9 100 + curl https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py + /usr/bin/python3.9 /tmp/get-pip.py + rm -f /tmp/get-pip.py + /usr/bin/python3.9 -m pip --version + /usr/bin/python3.9 -m pip install \ + "numpy~=1.25" +} + function install_misc() { DEBIAN_FRONTEND=noninteractive apt-get --quiet install -y --no-install-recommends \ apt-transport-https="2.0.*" \ - ca-certificates="20211016ubuntu0.20.04.1" \ + ca-certificates \ chrpath="0.16-*" \ libcurl4="7.68.*" \ curl="7.68.*" \ + file="1:5*" \ gettext="0.19.*" \ git="1:2.25.*" \ gnupg="2.2.*" \ locales="2.31-*" \ - lsb-release="11.1.0*" \ + lsb-release="11.1.*" \ openjdk-11-jdk="11.0.*" \ python3.9-venv="3.9.*" \ rename="1.10-*" \ @@ -65,11 +74,8 @@ function install_misc() { unzip="6.0-*" \ wget="1.20.*" \ xz-utils="5.2.*" \ - zip="3.0-*" \ - zlib1g-dev="1:1.2.*" - mkdir -p /opt/bin - update-alternatives --install /opt/bin/python3 python3 /usr/bin/python3.9 100 - update-alternatives --install /opt/bin/python python /usr/bin/python3.9 100 + zip="3.0-*" + install_python if [[ -n ${INSTALL_LOCALE} ]]; then printf "\nSetting locale to: %s\n" "${INSTALL_LOCALE}" locale-gen "${INSTALL_LOCALE}" @@ -78,15 +84,19 @@ function install_misc() { } function install_clang() { - declare -r VERSION="14" + declare -r -i clang_ver=15 curl --silent --fail --show-error --location --remote-name https://apt.llvm.org/llvm.sh chmod +x llvm.sh - ./llvm.sh ${VERSION} - apt-get --quiet install -y --no-install-recommends libc++-${VERSION}-dev - update-alternatives --install /usr/bin/clang clang /usr/bin/clang-${VERSION} 100 + ./llvm.sh ${clang_ver} + apt-get --quiet install -y --no-install-recommends libc++-${clang_ver}-dev + update-alternatives --install /usr/bin/clang clang /usr/bin/clang-${clang_ver} 100 + update-alternatives --install /usr/bin/llvm-cov llvm-cov /usr/bin/llvm-cov-${clang_ver} 100 + update-alternatives --install /usr/bin/llvm-profdata llvm-profdata /usr/bin/llvm-profdata-${clang_ver} 100 rm -f llvm.sh clang --version + llvm-cov --version + llvm-profdata show --version } # Install Docker (https://docs.docker.com/engine/install/debian/) @@ -94,9 +104,11 @@ function install_docker() { declare -r arch="$1" apt-get --quiet remove docker docker.io containerd runc mkdir -p /etc/apt/keyrings - declare -r DIST=ubuntu - curl --silent --fail --show-error --location https://download.docker.com/linux/${DIST}/gpg | gpg --dearmor -o /etc/apt/keyrings/docker.gpg - echo "deb [arch=${arch} signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/${DIST} $(lsb_release -cs) stable" | tee /etc/apt/sources.list.d/docker.list + declare -r dist=ubuntu + curl --silent --fail --show-error --location https://download.docker.com/linux/${dist}/gpg \ + | gpg --dearmor -o /etc/apt/keyrings/docker.gpg + echo "deb [arch=${arch} signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/${dist} $(lsb_release -cs) stable" \ + | tee /etc/apt/sources.list.d/docker.list apt_update apt-get --quiet install -y --no-install-recommends docker-ce docker-ce-cli containerd.io } diff --git a/builders/images/install_go.sh b/builders/images/install_go.sh index c284b132..a436d20a 100644 --- a/builders/images/install_go.sh +++ b/builders/images/install_go.sh @@ -22,12 +22,12 @@ function _golang_install_dir() { function install_golang() { declare -r _ARCH="$1" declare -r FNAME=gobin.tar.gz - declare -r VERSION=1.19 + declare -r VERSION=1.20.4 # shellcheck disable=SC2155 declare -r GO_INSTALL_DIR="$(_golang_install_dir)" declare -r -A GO_HASHES=( - [amd64]="464b6b66591f6cf055bc5df90a9750bf5fbc9d038722bb84a9d56a2bea974be6" - [arm64]="efa97fac9574fc6ef6c9ff3e3758fb85f1439b046573bf434cccb5e012bd00c8" + [amd64]="698ef3243972a51ddb4028e4a1ac63dc6d60821bf18e59a807e051fee0a385bd" + [arm64]="105889992ee4b1d40c7c108555222ca70ae43fccb42e20fbf1eebb822f5e72c6" ) declare -r GO_HASH=${GO_HASHES[${_ARCH}]} if [[ -z ${GO_HASH} ]]; then diff --git a/builders/images/presubmit/install_apps b/builders/images/presubmit/install_apps index 81557cca..64808cc1 100755 --- a/builders/images/presubmit/install_apps +++ b/builders/images/presubmit/install_apps @@ -55,7 +55,7 @@ function apt_update() { function install_packages() { DEBIAN_FRONTEND=noninteractive apt-get --quiet install -y --no-install-recommends \ apt-transport-https="2.0.*" \ - ca-certificates="20211016ubuntu0.20.04.1" \ + ca-certificates \ libcurl4="7.68.*" \ curl="7.68.*" \ gnupg="2.2.*" \ @@ -85,7 +85,7 @@ function install_docker() { function install_precommit() { /usr/bin/python3.9 -m venv "${PRE_COMMIT_VENV_DIR}" - "${PRE_COMMIT_VENV_DIR}"/bin/pip install pre-commit + "${PRE_COMMIT_VENV_DIR}"/bin/pip install pre-commit~=3.1 "${PRE_COMMIT_TOOL}" --version # initialize pre-commit cache, which needs a git repo (a temporary will suffice) diff --git a/builders/images/test-tools/Dockerfile b/builders/images/test-tools/Dockerfile index 5362fca6..21bb383a 100644 --- a/builders/images/test-tools/Dockerfile +++ b/builders/images/test-tools/Dockerfile @@ -19,12 +19,17 @@ WORKDIR /build ADD https://github.com/shekyan/slowhttptest/archive/refs/tags/v1.9.0.tar.gz /build RUN tar xz --strip-components 1 -f v1.9.0.tar.gz && ls -l && ./configure && make -FROM golang:1.19.4-alpine3.17 AS golang-1.19 +FROM golang:1.19.4-alpine3.17 AS golang +ENV BUILD_ARCH="${TARGETARCH}" \ + GOBIN=/usr/local/go/bin +COPY build_golang_apps /scripts/ +RUN /scripts/build_golang_apps + FROM fullstorydev/grpcurl:v1.8.7 AS grpcurl FROM alpine:3.17.2 -COPY --from=golang-1.19 /usr/local/go/ /usr/local/go/ -COPY --from=grpcurl /bin/grpcurl /usr/bin +COPY --from=golang /usr/local/go/bin/* /usr/local/bin/ +COPY --from=grpcurl /bin/grpcurl /usr/local/bin/ ARG TARGETARCH ENV BUILD_ARCH="${TARGETARCH}" \ diff --git a/builders/images/test-tools/build_golang_apps b/builders/images/test-tools/build_golang_apps new file mode 100755 index 00000000..b79e20c4 --- /dev/null +++ b/builders/images/test-tools/build_golang_apps @@ -0,0 +1,16 @@ +#!/bin/busybox sh + +set -o errexit + +install_ghz() { + go install github.com/bojand/ghz/cmd/ghz@v0.114.0 + ghz --help +} + +install_cassowary() { + go install github.com/rogerwelin/cassowary/cmd/cassowary@v0.16.0 + cassowary --help +} + +install_ghz +install_cassowary diff --git a/builders/images/test-tools/install_apps b/builders/images/test-tools/install_apps index fd250d33..597f937b 100755 --- a/builders/images/test-tools/install_apps +++ b/builders/images/test-tools/install_apps @@ -1,8 +1,8 @@ -#!/bin/sh +#!/bin/busybox sh set -o errexit -function install_packages() { +install_packages() { apk --no-cache add \ bash~=5 \ curl~=8 \ @@ -10,31 +10,19 @@ function install_packages() { libstdc++ } -function install_ghz() { - go install github.com/bojand/ghz/cmd/ghz@v0.114.0 - ghz --help -} - -function install_nghttp2() { +install_nghttp2() { apk --no-cache add \ nghttp2~=1 h2load --version } -function install_apache2_utils() { +install_apache2_utils() { apk --no-cache add \ apache2-utils~=2.4 ab -V } -function install_cassowary() { - go install github.com/rogerwelin/cassowary/cmd/cassowary@v0.16.0 - cassowary --help -} - apk --no-cache update install_packages -install_ghz install_nghttp2 install_apache2_utils -install_cassowary diff --git a/builders/tests/data/hashes/build-amazonlinux2 b/builders/tests/data/hashes/build-amazonlinux2 index 73ceb8dd..88a6d6b0 100644 --- a/builders/tests/data/hashes/build-amazonlinux2 +++ b/builders/tests/data/hashes/build-amazonlinux2 @@ -1 +1 @@ -2ff66f7605176dee84fc17664fadd93a6ed870f31fdd3a60f10c64b80aec16a9 +f6a897c7a391b8fb064954d2654870541dee54ee6d46abf0b439c410ab65ded9 diff --git a/builders/tests/data/hashes/build-debian b/builders/tests/data/hashes/build-debian index 828fbaae..3f66372f 100644 --- a/builders/tests/data/hashes/build-debian +++ b/builders/tests/data/hashes/build-debian @@ -1 +1 @@ -a90e9bd7155bd76482f943c8cf961d9b59939d538c82cf174e48dccdd7ab4592 +826258b167a0563a96e3b7d25299f85fffc2d4d0e566b776d8916c1b6d9af3cf diff --git a/builders/tests/data/hashes/presubmit b/builders/tests/data/hashes/presubmit index 95ceb384..89f84244 100644 --- a/builders/tests/data/hashes/presubmit +++ b/builders/tests/data/hashes/presubmit @@ -1 +1 @@ -4d71054737c693f2b214adf6b3f1c1e4d503ebea0505c3946cf37fd51d6578df +3442bf68f187ad8f98cbea94f447a3c08ab6cb97d05ce999cecf1f33fdeafd08 diff --git a/builders/tests/data/hashes/test-tools b/builders/tests/data/hashes/test-tools index 1aba999e..72a3fa38 100644 --- a/builders/tests/data/hashes/test-tools +++ b/builders/tests/data/hashes/test-tools @@ -1 +1 @@ -2a0bc96dbbe843a18554f0a958495a09d179c03f0bc89eb43489e8e7835a4273 +e19aaa4e08668be8056a6064e27be159ee7e158b3a90d75b7c4e5483369ebe41 diff --git a/builders/tests/run-tests b/builders/tests/run-tests index 19f90693..f98389a4 100755 --- a/builders/tests/run-tests +++ b/builders/tests/run-tests @@ -21,7 +21,7 @@ set -o errexit trap _cleanup EXIT function _cleanup() { - declare -r -i STATUS=$? + local -r -i STATUS=$? if [[ -d ${TMP_HASHES_DIR1} ]]; then rm -rf "${TMP_HASHES_DIR1}" "${TMP_HASHES_DIR2}" fi @@ -33,8 +33,9 @@ function _cleanup() { } function get_image_list() { - declare -r _images_dir="$1" - find "${_images_dir}" -maxdepth 1 -mindepth 1 -type d -printf "%P\n" | sort + local -r _images_dir="$1" + find "${_images_dir}" -maxdepth 1 -mindepth 1 -type d -printf "%P\n" \ + | sort } function usage() { @@ -44,6 +45,7 @@ usage: $0 --image Run tests only for specified image --fast Only generate hashes directly rather than also using cbuild + --build-images Build the images --verbose Produce verbose output USAGE # shellcheck disable=SC2086 @@ -52,6 +54,7 @@ USAGE declare -i FAST=0 declare -i VERBOSE=0 +declare -i BUILD_IMAGES=0 declare IMAGE while [[ $# -gt 0 ]]; do @@ -61,12 +64,16 @@ while [[ $# -gt 0 ]]; do if [[ -z ${IMAGE} ]]; then usage fi - shift 2 + shift 2 || usage ;; --fast) FAST=1 shift ;; + --build-images) + BUILD_IMAGES=1 + shift + ;; --verbose) VERBOSE=1 shift @@ -116,7 +123,7 @@ function cli_tests_test-tools() { function cli_tests_misc() { declare -a -r TOOLS=( - "${TOOLS_DIR}/aws-cli --help" + "${TOOLS_DIR}/aws-cli help" "${TOOLS_DIR}/awscurl --help" ) printf "Testing utils CLI tool wrappers\n" @@ -158,71 +165,83 @@ function cli_tests_release() { } function create_temp_hash_dir() { - declare -r DIR=$(mktemp --directory) + local -r DIR="$(mktemp --directory)" cp "${HASHES_DIR}"/* "${DIR}" printf "%s" "${DIR}" } -TMP_HASHES_DIR1=$(create_temp_hash_dir) - # warning when running inside a docker container if [[ -f /.dockerenv ]]; then printf "warning: Executing within docker container, which obviates testing in a non-docker environment\n" &>/dev/stderr fi -printf "Generating image hashes (direct mode)\n" -declare -a -r GET_BUILDER_IMAGE_ARGS=( - --no-build + +declare -a GET_BUILDER_IMAGE_ARGS=( --sha-only ) -for img in ${IMAGE_LIST}; do - # shellcheck disable=SC2086 - if ! "${TOOLS_DIR}"/get-builder-image-tagged "${GET_BUILDER_IMAGE_ARGS[@]}" --image ${img} >"${TMP_HASHES_DIR1}/${img}"; then - printf "Error generating image hash: %s\n" "${img}" &>/dev/stderr - RETCODE+=1 - fi -done - -BASELINE_UPDATES="$(diff --brief "${HASHES_DIR}" "${TMP_HASHES_DIR1}" || true)" -readonly BASELINE_UPDATES -if [[ -n $BASELINE_UPDATES ]]; then - # shellcheck disable=SC2086 - printf "detected shift in baseline files:\n%s\n" "${BASELINE_UPDATES}" - if [[ ${VERBOSE} -eq 1 ]]; then - diff "${HASHES_DIR}" "${TMP_HASHES_DIR1}" || true - fi - cp --force "${TMP_HASHES_DIR1}"/* "${HASHES_DIR}" - RETCODE+=10 -else - printf "hashes unchanged\n" +if [[ ${BUILD_IMAGES} -eq 0 ]]; then + GET_BUILDER_IMAGE_ARGS+=(--no-build) fi -if [[ ${FAST} -eq 1 ]]; then - exit ${RETCODE} -fi +function generate_hashes_direct() { + TMP_HASHES_DIR1=$(create_temp_hash_dir) + printf "Generating image hashes (direct mode)\n" -TMP_HASHES_DIR2=$(create_temp_hash_dir) + for img in ${IMAGE_LIST}; do + # shellcheck disable=SC2086 + if ! "${TOOLS_DIR}"/get-builder-image-tagged "${GET_BUILDER_IMAGE_ARGS[@]}" --image ${img} >"${TMP_HASHES_DIR1}/${img}"; then + printf "Error generating image hash: %s\n" "${img}" &>/dev/stderr + RETCODE+=1 + fi + done -printf "Generating image hashes (cbuild mode)\n" -for img in ${IMAGE_LIST}; do - if ! "${TOOLS_DIR}"/cbuild --image build-debian --cmd "tools/get-builder-image-tagged ${GET_BUILDER_IMAGE_ARGS[*]} --image ${img}" >"${TMP_HASHES_DIR2}/${img}"; then - printf "Error generating image hash: %s\n" "${img}" &>/dev/stderr - RETCODE+=1 + BASELINE_UPDATES="$(diff --brief "${HASHES_DIR}" "${TMP_HASHES_DIR1}" || true)" + readonly BASELINE_UPDATES + if [[ -n $BASELINE_UPDATES ]]; then + # shellcheck disable=SC2086 + printf "detected shift in baseline files:\n%s\n" "${BASELINE_UPDATES}" + if [[ ${VERBOSE} -eq 1 ]]; then + diff "${HASHES_DIR}" "${TMP_HASHES_DIR1}" || true + fi + cp --force "${TMP_HASHES_DIR1}"/* "${HASHES_DIR}" + RETCODE+=10 + else + printf "hashes unchanged\n" fi -done +} -MODE_MISMATCH="$(diff --brief "${TMP_HASHES_DIR1}" "${TMP_HASHES_DIR2}" || true)" -readonly MODE_MISMATCH -if [[ -n $MODE_MISMATCH ]]; then - # shellcheck disable=SC2086 - printf "Error: mismatch between direct and cbuild modes\n%s" "${MODE_MISMATCH}" &>/dev/stderr - if [[ ${VERBOSE} -eq 1 ]]; then - diff "${TMP_HASHES_DIR1}" "${TMP_HASHES_DIR2}" || true +function generate_hashes_cbuild() { + TMP_HASHES_DIR2=$(create_temp_hash_dir) + + printf "Generating image hashes (cbuild mode)\n" + for img in ${IMAGE_LIST}; do + if ! "${TOOLS_DIR}"/cbuild --image build-debian --cmd "tools/get-builder-image-tagged ${GET_BUILDER_IMAGE_ARGS[*]} --image ${img}" >"${TMP_HASHES_DIR2}/${img}"; then + printf "Error generating image hash: %s\n" "${img}" &>/dev/stderr + RETCODE+=1 + fi + done + + MODE_MISMATCH="$(diff --brief "${TMP_HASHES_DIR1}" "${TMP_HASHES_DIR2}" || true)" + readonly MODE_MISMATCH + if [[ -n $MODE_MISMATCH ]]; then + # shellcheck disable=SC2086 + printf "Error: mismatch between direct and cbuild modes\n%s" "${MODE_MISMATCH}" &>/dev/stderr + if [[ ${VERBOSE} -eq 1 ]]; then + diff "${TMP_HASHES_DIR1}" "${TMP_HASHES_DIR2}" || true + fi + RETCODE+=100 + else + printf "hashes unchanged\n" fi - RETCODE+=100 -else - printf "hashes unchanged\n" +} + +generate_hashes_direct + +if [[ ${FAST} -eq 1 ]]; then + exit ${RETCODE} fi +generate_hashes_cbuild + # CLI tests cli_tests_misc diff --git a/builders/tools/aws-cli b/builders/tools/aws-cli index d8839ecb..3301d277 100755 --- a/builders/tools/aws-cli +++ b/builders/tools/aws-cli @@ -29,7 +29,7 @@ set -o errexit -# shellcheck disable=SC1091 +# shellcheck disable=SC1090 source "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")"/builder.sh declare -a ENV_VARS @@ -46,7 +46,7 @@ DOCKER_RUN_ARGS+=( for evar in "${ENV_VARS[@]}" do DOCKER_RUN_ARGS+=( - "--env ${evar}" + "--env=${evar}" ) done if [[ -t 0 ]] && [[ -t 1 ]]; then diff --git a/builders/tools/awscurl b/builders/tools/awscurl index 0a95942a..1f783942 100755 --- a/builders/tools/awscurl +++ b/builders/tools/awscurl @@ -29,14 +29,14 @@ set -o errexit -# shellcheck disable=SC1091 +# shellcheck disable=SC1090 source "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")"/builder.sh declare -a ENV_VARS builder::add_aws_env_vars ENV_VARS ENV_VARS+=( - HOME=/home - PYTHONPATH=/ + "HOME=/home" + "PYTHONPATH=/" ) declare -a DOCKER_RUN_ARGS @@ -48,7 +48,7 @@ DOCKER_RUN_ARGS+=( for evar in "${ENV_VARS[@]}" do DOCKER_RUN_ARGS+=( - --env "${evar}" + "--env=${evar}" ) done if [[ -t 0 ]] && [[ -t 1 ]]; then diff --git a/builders/tools/builder.sh b/builders/tools/builder.sh index 22e16356..93cf5176 100644 --- a/builders/tools/builder.sh +++ b/builders/tools/builder.sh @@ -32,12 +32,14 @@ function builder::set_workspace() { if [[ -v WORKSPACE ]]; then return fi - declare -r GIT_TOPLEVEL="$(git rev-parse --show-superproject-working-tree)" + local -r GIT_TOPLEVEL="$(git rev-parse --show-superproject-working-tree)" if [[ -n ${GIT_TOPLEVEL} ]]; then WORKSPACE="${GIT_TOPLEVEL}" else WORKSPACE="$(git rev-parse --show-toplevel)" fi + local -r ws_path="$(realpath "${WORKSPACE}"/WORKSPACE)" + WORKSPACE="$(dirname "${ws_path}")" } ####################################### @@ -59,7 +61,7 @@ function builder::get_docker_workspace_mount() { # determined by git or bazel or inspecting the filesystem itself. Instead, we # need to use docker to expose its mount info for the /src/workspace path. # determine the current container's ID - declare -r CONTAINER_ID="$(uname --nodename)" + local -r CONTAINER_ID="$(uname --nodename)" # use docker inspect to extract the current mount path for /src/workspace # this format string is a golang template (https://pkg.go.dev/text/template) processed # by docker's --format flag, per https://docs.docker.com/config/formatting/ @@ -72,7 +74,7 @@ function builder::get_docker_workspace_mount() { {{end -}} {{end -}} ' - declare -r MOUNT_PATH="$(docker inspect --format "${FORMAT_STR}" "${CONTAINER_ID}")" + local -r MOUNT_PATH="$(docker inspect --format "${FORMAT_STR}" "${CONTAINER_ID}")" if [[ -z ${MOUNT_PATH} ]]; then printf "Error: Unable to determine mount point for /src/workspace. Exiting\n" &>/dev/stderr exit 1 @@ -88,7 +90,7 @@ function builder::get_tools_dir() { # Invoke cbuild tool in a build-debian container ####################################### function builder::cbuild_debian() { - declare -r CBUILD="$(builder::get_tools_dir)"/cbuild + local -r CBUILD="$(builder::get_tools_dir)"/cbuild printf "=== cbuild debian action envs ===\n" # shellcheck disable=SC2086 "${CBUILD}" ${CBUILD_ARGS} --image build-debian --cmd "grep -o 'action_env.*' /etc/bazel.bazelrc 1>/dev/stderr 2>/dev/null" @@ -117,13 +119,12 @@ function builder::add_aws_env_vars() { # Invoke cbuild tool in a build-amazonlinux2 container ####################################### function builder::cbuild_al2() { - declare -r CBUILD="$(builder::get_tools_dir)"/cbuild + local -r CBUILD="$(builder::get_tools_dir)"/cbuild declare -a env_vars builder::add_aws_env_vars env_vars declare env_args - for evar in "${env_vars[@]}" - do - env_args+=("--env" "${evar}") + for evar in "${env_vars[@]}"; do + env_args+=(--env "${evar}") done printf "=== cbuild amazonlinux2 action envs ===\n" # shellcheck disable=SC2086 diff --git a/builders/tools/cbuild b/builders/tools/cbuild index ba8bc6d1..5fabefa2 100755 --- a/builders/tools/cbuild +++ b/builders/tools/cbuild @@ -119,7 +119,7 @@ fi TOOLS_DIR="$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" readonly TOOLS_DIR -# shellcheck disable=SC1091 +# shellcheck disable=SC1090 source "${TOOLS_DIR}"/builder.sh WORKSPACE_MOUNT="$(builder::get_docker_workspace_mount)" @@ -138,10 +138,10 @@ readonly IMAGE_TAGGED declare -a DOCKER_RUN_ARGS DOCKER_RUN_ARGS+=( - "--rm" - "--entrypoint=/bin/bash" - "--volume ${WORKSPACE_MOUNT}:/src/workspace" - "--workdir /src/workspace" + --rm + --entrypoint /bin/bash + --volume "${WORKSPACE_MOUNT}":/src/workspace + --workdir /src/workspace "$(echo "${EXTRA_DOCKER_RUN_ARGS}" | envsubst)" ) @@ -166,14 +166,14 @@ fi for evar in "${ENV_VARS[@]}" do DOCKER_RUN_ARGS+=( - "--env ${evar}" + "--env=${evar}" ) done if [[ -t 0 ]] && [[ -t 1 ]]; then # stdin and stdout are open, assume it's an interactive tty session DOCKER_RUN_ARGS+=( - "--interactive" - "--tty" + --interactive + --tty ) fi @@ -184,12 +184,12 @@ if [[ -z ${CMD} ]]; then # shellcheck disable=SC2068 docker run \ ${DOCKER_RUN_ARGS[@]} \ - ${IMAGE_TAGGED} \ + "${IMAGE_TAGGED}" \ --login else # shellcheck disable=SC2068 docker run \ ${DOCKER_RUN_ARGS[@]} \ - ${IMAGE_TAGGED} \ + "${IMAGE_TAGGED}" \ --login -c "$CMD" fi diff --git a/builders/tools/get-builder-image-tagged b/builders/tools/get-builder-image-tagged index 5c7236ac..a725faad 100755 --- a/builders/tools/get-builder-image-tagged +++ b/builders/tools/get-builder-image-tagged @@ -23,9 +23,6 @@ set -o errexit trap _cleanup EXIT function _cleanup() { declare -r -i STATUS=$? - if [[ -n ${TAR_IMAGE} ]]; then - docker image rm --force "${TAR_IMAGE}" &>/dev/null - fi if [[ -n ${TEMPTAR} ]]; then rm -f "${TEMPTAR}" "${SHAFILE}" fi @@ -66,23 +63,20 @@ function usage() { cat &>/dev/stderr < - --no-build Do not build image if it doesn't exist - --image Image name for the build runtime. Valid names: - -environment variables (all optional): - IMAGE_BUILD_VERBOSE Capture docker build output if set + --no-build Do not build image if it doesn't exist + --image Image name for the build runtime. Valid names: USAGE for elem in $(get_image_list "${IMAGES_DIR}"); do - if [[ ${IMAGE} == "${elem}" ]]; then - local EXTRA=" (default)" - fi - printf " * %s%s\n" "${elem}" "${EXTRA}" &>/dev/stderr + printf " * %s\n" "${elem}" &>/dev/stderr done cat &>/dev/stderr <"${BUILD_OUTPUT}" -rm -f "${BUILD_OUTPUT}" -if ! docker image inspect "${TAR_IMAGE}" &>/dev/null; then - printf "error creating docker image [%s]\n" "${TAR_IMAGE}" &>/dev/stderr - exit 1 + } | docker buildx build "${DOCKER_BUILD_ARGS[@]}" --no-cache --output=type=docker --tag "${TAR_IMAGE}" - &>"${BUILD_OUTPUT}" + rm -f "${BUILD_OUTPUT}" + if ! docker image inspect "${TAR_IMAGE}" &>/dev/null; then + printf "error creating docker image [%s]\n" "${TAR_IMAGE}" &>/dev/stderr + exit 1 + fi +} + +BUILD_OUTPUT="$(make_temp .log)" +readonly BUILD_OUTPUT +BUILDSYS_VERSION="$(<"${BUILDERS_DIR}"/version.txt)" +readonly BUILDSYS_VERSION +readonly TAR_IMAGE="builders/tar-get-builder-image-tagged:v${BUILDSYS_VERSION}" +TAR_IMAGE_HASH="$(docker image ls --filter "reference=${TAR_IMAGE}" --quiet)" +readonly TAR_IMAGE_HASH +if [[ -z ${TAR_IMAGE_HASH} ]]; then + generate_image fi # Create a deterministic tar file for the specified file path, returning @@ -182,19 +187,20 @@ function _tar_for_dir() { # expand TMP_IMAGE_DIR immediately trap "rm -rf '${TMP_IMAGE_DIR}'" RETURN - WS_FILE_TAR="$(realpath "${FILE_TAR}" --relative-to="${WORKSPACE}")" - WS_FILE_SHA="$(realpath "${FILE_SHA}" --relative-to="${WORKSPACE}")" - WS_FILEPATH="$(realpath "${FILEPATH}" --relative-to="${WORKSPACE}")" - WS_TMP_IMAGE_DIR="$(realpath "${TMP_IMAGE_DIR}" --relative-to="${WORKSPACE}")" + local -r WS_FILE_TAR="$(realpath "${FILE_TAR}" --relative-to="${WORKSPACE}")" + local -r WS_FILE_SHA="$(realpath "${FILE_SHA}" --relative-to="${WORKSPACE}")" + local -r WS_FILEPATH="$(realpath "${FILEPATH}" --relative-to="${WORKSPACE}")" + local -r WS_TMP_IMAGE_DIR="$(realpath "${TMP_IMAGE_DIR}" --relative-to="${WORKSPACE}")" # find workspace etc files that are also in the image dir and the builders etc dir - WORKSPACE_ETC_FILES="$({ + local -r WORKSPACE_ETC_FILES="$({ # shellcheck disable=SC2012 ls -A -1 "${FILEPATH}" "${ETC_DIR}" | sort | uniq -d ls -A -1 "${WORKSPACE}" } | sort | uniq -d)" # create a deterministic tarball of the collected files - docker run --rm \ - --entrypoint=/bin/sh \ + docker run \ + --rm \ + --entrypoint /bin/sh \ --volume "${WORKSPACE_MOUNT}":/workspace \ --workdir /workspace \ "${TAR_IMAGE}" -c " @@ -218,13 +224,13 @@ tar --create --dereference --sort=name --owner=0 --group=0 --numeric-owner --for " } -TEMPTAR="$(make_temp .tar.gz)" +TEMPTAR="$(make_temp .tar)" readonly TEMPTAR SHAFILE="$(make_temp .sha)" readonly SHAFILE # use the tarfile size and file content to generate a sha256 hash _tar_for_dir "${TEMPTAR}" "${SHAFILE}" "${IMAGE_PATH_FULL}" -SHA="$(cat "${SHAFILE}")" +SHA="$(<"${SHAFILE}")" readonly SHA ARCH="$("${TOOLS_DIR}"/get-architecture)" readonly ARCH diff --git a/builders/tools/hadolint b/builders/tools/hadolint index c0395219..4051dbde 100755 --- a/builders/tools/hadolint +++ b/builders/tools/hadolint @@ -17,7 +17,7 @@ set -o errexit TOOLS_DIR="$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" readonly TOOLS_DIR -# shellcheck disable=SC1091 +# shellcheck disable=SC1090 source "${TOOLS_DIR}"/builder.sh declare -a DOCKER_RUN_ARGS diff --git a/builders/tools/normalize-bazel-symlinks b/builders/tools/normalize-bazel-symlinks index a1ebb722..a66f5f34 100755 --- a/builders/tools/normalize-bazel-symlinks +++ b/builders/tools/normalize-bazel-symlinks @@ -21,16 +21,22 @@ declare -r BAZEL_CACHE_DIR="${HOME}/.cache/bazel" function normalize_symlink() { declare -r link_name="$1" - if readlink --canonicalize-existing ${link_name} &>/dev/null ; then - printf "symlink %s resolves fully, skipping\n" ${link_name} + if readlink --canonicalize-existing "${link_name}" &>/dev/null ; then + printf "symlink %s resolves fully, skipping\n" "${link_name}" return fi - local -r link_path=$(readlink ${link_name}) - local -r output_user_root=${link_path///bazel_root\/} - rm -f ${link_name} - ln -s "${BAZEL_CACHE_DIR}/${output_user_root}" ${link_name} + local -r link_path="$(readlink "${link_name}")" + local -r output_user_root="${link_path///bazel_root\/}" + rm -f "${link_name}" + ln -s "${BAZEL_CACHE_DIR}/${output_user_root}" "${link_name}" } -for link in bazel-{bin,out,testlogs,workspace}; do - normalize_symlink ${link} +declare -a -r LINK_DIRS=( + bazel-bin + bazel-out + bazel-testlogs + bazel-workspace +) +for link in "${LINK_DIRS[@]}"; do + normalize_symlink "${link}" done diff --git a/builders/tools/normalize-dist b/builders/tools/normalize-dist index b861d409..740d81d8 100755 --- a/builders/tools/normalize-dist +++ b/builders/tools/normalize-dist @@ -32,7 +32,7 @@ function _cleanup() { TOOLS_DIR="$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" readonly TOOLS_DIR -# shellcheck disable=SC1091 +# shellcheck disable=SC1090 source "${TOOLS_DIR}"/builder.sh readonly IMAGE=build-debian GROUP="$(builder::id g)" diff --git a/builders/tools/pre-commit b/builders/tools/pre-commit index 80b0e738..be8c5c7c 100755 --- a/builders/tools/pre-commit +++ b/builders/tools/pre-commit @@ -18,7 +18,7 @@ set -o errexit TOOLS_DIR="$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" readonly TOOLS_DIR -# shellcheck disable=SC1091 +# shellcheck disable=SC1090 source "${TOOLS_DIR}"/builder.sh trap __cleanup EXIT @@ -34,7 +34,7 @@ function __cleanup() { exit ${STATUS} fi # shellcheck disable=SC2086 - "${TOOLS_DIR}"/cbuild ${CBUILD_COMMON_ARGS} --cmd $" + "${TOOLS_DIR}"/cbuild "${CBUILD_COMMON_ARGS[@]}" --cmd $" # change file ownership back to user { git ls-files . --modified @@ -111,9 +111,13 @@ function __init() { SKIP_ENV="SKIP=${SKIP_HOOKS}" } -CLEANUP=0 - -CBUILD_COMMON_ARGS="--without-shared-cache --image presubmit" +declare -i CLEANUP=0 +declare -a -r CBUILD_COMMON_ARGS=( + "--without-shared-cache" + "--image" + presubmit +) +declare -r PRECOMMIT=/usr/pre-commit-venv/bin/pre-commit # TODO: run bazel //:precommit-hooks rather than just the pre-commit tool if [[ $# -gt 0 ]]; then @@ -127,7 +131,7 @@ if [[ $# -gt 0 ]]; then shift __init # shellcheck disable=SC2086 - "${TOOLS_DIR}"/cbuild ${CBUILD_COMMON_ARGS} --env "${SKIP_ENV}" --cmd "/usr/pre-commit-venv/bin/pre-commit ${PRECOMMIT_CMD} --config ./.pre-commit-config.yaml $*" + "${TOOLS_DIR}"/cbuild "${CBUILD_COMMON_ARGS[@]}" --env "${SKIP_ENV}" --cmd "${PRECOMMIT} ${PRECOMMIT_CMD} --config ./.pre-commit-config.yaml $*" ;; hook-impl) @@ -139,7 +143,7 @@ if [[ $# -gt 0 ]]; then --entrypoint=/usr/pre-commit-venv/bin/pre-commit \ --volume "${WORKSPACE}":/src/workspace \ -v /var/run/docker.sock:/var/run/docker.sock \ - --env "${SKIP_ENV}" \ + --env="${SKIP_ENV}" \ --workdir /src/workspace \ "${IMAGE_TAGGED}" \ "${PRECOMMIT_CMD}" --config ./.pre-commit-config.yaml "$@" @@ -150,12 +154,12 @@ if [[ $# -gt 0 ]]; then CLEANUP=1 for HOOK in "$@"; do # shellcheck disable=SC2086 - "${TOOLS_DIR}"/cbuild ${CBUILD_COMMON_ARGS} --env "${SKIP_ENV}" --cmd "/usr/pre-commit-venv/bin/pre-commit run --config ./.pre-commit-config.yaml --all-files ${HOOK}" + "${TOOLS_DIR}"/cbuild "${CBUILD_COMMON_ARGS[@]}" --env "${SKIP_ENV}" --cmd "${PRECOMMIT} run --config ./.pre-commit-config.yaml --all-files ${HOOK}" done esac else __init CLEANUP=1 # shellcheck disable=SC2086 - "${TOOLS_DIR}"/cbuild ${CBUILD_COMMON_ARGS} --env "${SKIP_ENV}" --cmd "/usr/pre-commit-venv/bin/pre-commit run --config ./.pre-commit-config.yaml --all-files" + "${TOOLS_DIR}"/cbuild "${CBUILD_COMMON_ARGS[@]}" --env "${SKIP_ENV}" --cmd "${PRECOMMIT} run --config ./.pre-commit-config.yaml --all-files" fi diff --git a/builders/tools/terraform b/builders/tools/terraform index 6a860aee..8c896412 100755 --- a/builders/tools/terraform +++ b/builders/tools/terraform @@ -29,7 +29,7 @@ set -o errexit -# shellcheck disable=SC1091 +# shellcheck disable=SC1090 source "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")"/builder.sh declare -a ENV_VARS @@ -44,7 +44,7 @@ declare -a DOCKER_RUN_ARGS=( for evar in "${ENV_VARS[@]}" do DOCKER_RUN_ARGS+=( - "--env ${evar}" + "--env=${evar}" ) done if [[ -t 0 ]] && [[ -t 1 ]]; then diff --git a/builders/tools/test-tool b/builders/tools/test-tool index 41f567fe..4ce075da 100755 --- a/builders/tools/test-tool +++ b/builders/tools/test-tool @@ -43,7 +43,7 @@ case "${APP_LINK_TARGET}" in esac readonly IMAGE -# shellcheck disable=SC1091 +# shellcheck disable=SC1090 source "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")"/builder.sh TOOLS_DIR="$(builder::get_tools_dir)" readonly TOOLS_DIR @@ -63,6 +63,7 @@ declare -a DOCKER_RUN_ARGS=( --workdir /src/workspace/"${REL_PWD}" ) if [[ -n ${EXTRA_DOCKER_RUN_ARGS} ]]; then + # shellcheck disable=SC2207 DOCKER_RUN_ARGS+=( $(echo "${EXTRA_DOCKER_RUN_ARGS}" | envsubst) ) diff --git a/builders/version.txt b/builders/version.txt index 2ba6141d..6a932058 100644 --- a/builders/version.txt +++ b/builders/version.txt @@ -1 +1 @@ -0.23.0 \ No newline at end of file +0.30.1 \ No newline at end of file diff --git a/components/cloud_config/BUILD b/components/cloud_config/BUILD index f15a9e16..6161eb71 100644 --- a/components/cloud_config/BUILD +++ b/components/cloud_config/BUILD @@ -38,15 +38,30 @@ cc_library( ) cc_test( - name = "parameter_client_local_test", + name = "parameter_client_test", size = "small", - srcs = ["parameter_client_local_test.cc"], - deps = [ - ":parameter_client_local", - "//components/data/common:mocks", - "//components/util:sleepfor_mock", - "@com_github_google_glog//:glog", - "@com_github_grpc_grpc//:grpc++", + srcs = select({ + "//:aws_platform": [ + "parameter_client_aws_test.cc", + ], + "//:local_platform": [ + ":parameter_client_local_test.cc", + ], + }), + deps = select({ + "//:aws_platform": [ + "//components/util:platform_initializer", + "@aws_sdk_cpp//:core", + "@aws_sdk_cpp//:ssm", + ], + "//:local_platform": [ + "//components/data/common:mocks", + "//components/util:sleepfor_mock", + "@com_github_google_glog//:glog", + "@com_github_grpc_grpc//:grpc++", + ], + }) + [ + ":parameter_client", "@com_google_googletest//:gtest_main", ], ) @@ -62,15 +77,15 @@ cc_library( ], deps = select({ "//:aws_platform": [ + "//components/errors:aws_error_util", "@aws_sdk_cpp//:core", "@aws_sdk_cpp//:ssm", - "//components/errors:aws_error_util", ], "//conditions:default": [], }) + [ - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/status", "@com_github_google_glog//:glog", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) @@ -85,18 +100,20 @@ cc_library( ], deps = select({ "//:aws_instance": [ + "//components/errors:aws_error_util", "@aws_sdk_cpp//:autoscaling", "@aws_sdk_cpp//:core", "@aws_sdk_cpp//:ec2", - "//components/errors:aws_error_util", ], "//:local_instance": [ "@com_google_absl//absl/flags:flag", ], "//conditions:default": [], }) + [ - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/status", "@com_github_google_glog//:glog", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@google_privacysandbox_servers_common//src/cpp/telemetry:metrics_recorder", ], ) diff --git a/components/cloud_config/instance_client.h b/components/cloud_config/instance_client.h index 7f0084de..05d2e8ea 100644 --- a/components/cloud_config/instance_client.h +++ b/components/cloud_config/instance_client.h @@ -16,16 +16,35 @@ #include #include +#include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "src/cpp/telemetry/metrics_recorder.h" // TODO: Replace config cpio client once ready namespace kv_server { +// State to indicate whether an instance is servicing requests or not. +enum class InstanceServiceStatus : int8_t { + kUnknown = 0, + kPreService, + kInService, + kPostService, +}; + +struct InstanceInfo { + std::string id; + InstanceServiceStatus service_status; + std::string instance_group; + std::string private_ip_address; +}; + // Client to perform instance-specific operations. class InstanceClient { public: - static std::unique_ptr Create(); + static std::unique_ptr Create( + privacy_sandbox::server_common::MetricsRecorder& metrics_recorder); virtual ~InstanceClient() = default; // Retrieves all tags for the current instance and returns the tag with the @@ -46,6 +65,17 @@ class InstanceClient { // Retrieves all tags for the current instance and returns the tag with the // key "shard_num". virtual absl::StatusOr GetShardNumTag() = 0; + + // Retrieves descriptive information about all instances managed by the + // specified by instance_groups, e.g., on AWS, instance groups would map to + // auto scaling groups. + virtual absl::StatusOr> + DescribeInstanceGroupInstances( + const absl::flat_hash_set& instance_group_names) = 0; + + // Retrieves descriptive information about the given instances. + virtual absl::StatusOr> DescribeInstances( + const absl::flat_hash_set& instance_ids) = 0; }; } // namespace kv_server diff --git a/components/cloud_config/instance_client_aws.cc b/components/cloud_config/instance_client_aws.cc index 471e6fba..f7da7c60 100644 --- a/components/cloud_config/instance_client_aws.cc +++ b/components/cloud_config/instance_client_aws.cc @@ -21,6 +21,8 @@ #include "absl/strings/str_cat.h" #include "aws/autoscaling/AutoScalingClient.h" #include "aws/autoscaling/model/CompleteLifecycleActionRequest.h" +#include "aws/autoscaling/model/DescribeAutoScalingGroupsRequest.h" +#include "aws/autoscaling/model/DescribeAutoScalingGroupsResult.h" #include "aws/autoscaling/model/DescribeAutoScalingInstancesRequest.h" #include "aws/autoscaling/model/RecordLifecycleActionHeartbeatRequest.h" #include "aws/core/Aws.h" @@ -30,6 +32,7 @@ #include "aws/core/internal/AWSHttpResourceClient.h" #include "aws/core/utils/Outcome.h" #include "aws/ec2/EC2Client.h" +#include "aws/ec2/model/DescribeInstancesRequest.h" #include "aws/ec2/model/DescribeTagsRequest.h" #include "aws/ec2/model/DescribeTagsResponse.h" #include "aws/ec2/model/Filter.h" @@ -40,6 +43,13 @@ namespace kv_server { namespace { +using Aws::AutoScaling::Model::DescribeAutoScalingGroupsRequest; +using Aws::AutoScaling::Model::Instance; +using Aws::AutoScaling::Model::LifecycleState; +using Aws::EC2::Model::DescribeInstancesRequest; +using privacy_sandbox::server_common::MetricsRecorder; +using privacy_sandbox::server_common::ScopeLatencyRecorder; + constexpr char kEnvironmentTag[] = "environment"; constexpr char kShardNumTag[] = "shard-num"; constexpr char kResourceIdFilter[] = "resource-id"; @@ -50,8 +60,35 @@ constexpr char kImdsTokenTtlHeader[] = "x-aws-ec2-metadata-token-ttl-seconds"; constexpr char kImdsTokenResourcePath[] = "/latest/api/token"; constexpr char kImdsEndpoint[] = "http://169.254.169.254"; constexpr char kInstanceIdResourcePath[] = "/latest/meta-data/instance-id"; - constexpr char kContinueAction[] = "CONTINUE"; +constexpr char kDescribeInstanceGroupInstancesEvent[] = + "DescribeInstanceGroupInstances"; +constexpr char kDescribeInstancesEvent[] = "DescribeInstances"; + +const absl::flat_hash_set kInstancePreServiceStatuses = { + LifecycleState::Pending, + LifecycleState::Pending_Wait, + LifecycleState::Pending_Proceed, + LifecycleState::Warmed_Pending, + LifecycleState::Warmed_Pending_Wait, + LifecycleState::Warmed_Pending_Proceed, + LifecycleState::Warmed_Running, + LifecycleState::EnteringStandby, + LifecycleState::Standby, +}; +const absl::flat_hash_set kInstancePostServiceStatuses = { + LifecycleState::Terminated, + LifecycleState::Terminating, + LifecycleState::Terminating_Proceed, + LifecycleState::Terminating_Wait, + LifecycleState::Warmed_Terminated, + LifecycleState::Warmed_Terminating, + LifecycleState::Warmed_Terminating_Proceed, + LifecycleState::Warmed_Terminating_Wait, + LifecycleState::Quarantined, + LifecycleState::Detached, + LifecycleState::Detaching, +}; absl::StatusOr GetAwsHttpResource( const Aws::Internal::AWSHttpResourceClient& http_client, @@ -117,6 +154,19 @@ absl::StatusOr GetAutoScalingGroupName( .GetAutoScalingGroupName(); } +InstanceServiceStatus GetInstanceServiceStatus(const Instance& instance) { + if (instance.GetLifecycleState() == LifecycleState::InService) { + return InstanceServiceStatus::kInService; + } + if (kInstancePreServiceStatuses.contains(instance.GetLifecycleState())) { + return InstanceServiceStatus::kPreService; + } + if (kInstancePostServiceStatuses.contains(instance.GetLifecycleState())) { + return InstanceServiceStatus::kPostService; + } + return InstanceServiceStatus::kUnknown; +} + class AwsInstanceClient : public InstanceClient { public: absl::StatusOr GetEnvironmentTag() override { @@ -206,8 +256,73 @@ class AwsInstanceClient : public InstanceClient { return machine_id_; } - AwsInstanceClient() - : ec2_client_(std::make_unique()), + absl::StatusOr> DescribeInstanceGroupInstances( + const absl::flat_hash_set& instance_groups) override { + std::vector instances; + DescribeAutoScalingGroupsRequest request; + request.SetAutoScalingGroupNames( + {instance_groups.begin(), instance_groups.end()}); + std::string next_token; + while (true) { + if (!next_token.empty()) { + request.SetNextToken(next_token); + } + auto outcome = auto_scaling_client_->DescribeAutoScalingGroups(request); + if (!outcome.IsSuccess()) { + return AwsErrorToStatus(outcome.GetError()); + } + const auto& result = outcome.GetResultWithOwnership(); + for (const auto& auto_scaling_group : result.GetAutoScalingGroups()) { + for (const auto& instance : auto_scaling_group.GetInstances()) { + InstanceInfo instance_info; + instance_info.instance_group = + auto_scaling_group.GetAutoScalingGroupName(); + instance_info.id = instance.GetInstanceId(); + instance_info.service_status = GetInstanceServiceStatus(instance); + instances.push_back(instance_info); + } + } + if (next_token = result.GetNextToken(); next_token.empty()) { + break; + } + } + return instances; + } + + absl::StatusOr> DescribeInstances( + const absl::flat_hash_set& instance_ids) override { + std::vector instances; + DescribeInstancesRequest request; + request.SetInstanceIds({instance_ids.begin(), instance_ids.end()}); + std::string next_token; + while (true) { + if (!next_token.empty()) { + request.SetNextToken(next_token); + } + auto outcome = ec2_client_->DescribeInstances(request); + if (!outcome.IsSuccess()) { + return AwsErrorToStatus(outcome.GetError()); + } + const auto& result = outcome.GetResultWithOwnership(); + for (const auto& reservation : result.GetReservations()) { + for (const auto& instance : reservation.GetInstances()) { + InstanceInfo instance_info; + instance_info.id = instance.GetInstanceId(); + instance_info.private_ip_address = instance.GetPrivateIpAddress(); + instance_info.service_status = InstanceServiceStatus::kUnknown; + instances.push_back(instance_info); + } + } + if (next_token = result.GetNextToken(); next_token.empty()) { + break; + } + } + return instances; + } + + explicit AwsInstanceClient(MetricsRecorder& metrics_recorder) + : metrics_recorder_(metrics_recorder), + ec2_client_(std::make_unique()), // EC2MetadataClient does not fall back to the default client // configuration, needs to specify it to // fall back default configuration such as connectTimeoutMs (1000ms) @@ -218,11 +333,11 @@ class AwsInstanceClient : public InstanceClient { std::make_unique()) {} private: + MetricsRecorder& metrics_recorder_; std::unique_ptr ec2_client_; std::unique_ptr ec2_metadata_client_; std::unique_ptr auto_scaling_client_; std::string machine_id_; - int32_t shard_num_; absl::StatusOr GetTag(std::string tag) { absl::StatusOr instance_id = GetInstanceId(); @@ -258,8 +373,9 @@ class AwsInstanceClient : public InstanceClient { } // namespace -std::unique_ptr InstanceClient::Create() { - return std::make_unique(); +std::unique_ptr InstanceClient::Create( + MetricsRecorder& metrics_recorder) { + return std::make_unique(metrics_recorder); } } // namespace kv_server diff --git a/components/cloud_config/local_instance_client.cc b/components/cloud_config/local_instance_client.cc index 80064d4a..a8da8c92 100644 --- a/components/cloud_config/local_instance_client.cc +++ b/components/cloud_config/local_instance_client.cc @@ -28,6 +28,8 @@ ABSL_FLAG(std::string, shard_num, "0", "Shard number."); namespace kv_server { namespace { +using privacy_sandbox::server_common::MetricsRecorder; + class LocalInstanceClient : public InstanceClient { public: absl::StatusOr GetEnvironmentTag() override { @@ -60,11 +62,27 @@ class LocalInstanceClient : public InstanceClient { hostname.resize(strlen(hostname.c_str())); return hostname; } + + absl::StatusOr> DescribeInstanceGroupInstances( + const absl::flat_hash_set& instance_groups) override { + auto id = GetInstanceId(); + return DescribeInstances({}); + } + + absl::StatusOr> DescribeInstances( + const absl::flat_hash_set& instance_ids) { + auto id = GetInstanceId(); + if (!id.ok()) { + return id.status(); + } + return std::vector{InstanceInfo{.id = *id}}; + } }; } // namespace -std::unique_ptr InstanceClient::Create() { +std::unique_ptr InstanceClient::Create( + MetricsRecorder& metrics_recorder) { return std::make_unique(); } diff --git a/components/cloud_config/parameter_client.h b/components/cloud_config/parameter_client.h index 26d7de8f..b0764355 100644 --- a/components/cloud_config/parameter_client.h +++ b/components/cloud_config/parameter_client.h @@ -20,13 +20,27 @@ #include "absl/status/statusor.h" +namespace Aws { +namespace SSM { +class SSMClient; +} // namespace SSM +} // namespace Aws + // TODO: Replace config cpio client once ready namespace kv_server { // Client to interact with Parameter storage. class ParameterClient { public: - static std::unique_ptr Create(); + struct ClientOptions { + ClientOptions() {} + // ParameterClient takes ownership of this if it's set: + ::Aws::SSM::SSMClient* ssm_client_for_unit_testing_ = nullptr; + }; + + static std::unique_ptr Create( + ClientOptions client_options = ClientOptions()); + virtual ~ParameterClient() = default; virtual absl::StatusOr GetParameter( @@ -34,6 +48,9 @@ class ParameterClient { virtual absl::StatusOr GetInt32Parameter( std::string_view parameter_name) const = 0; + + virtual absl::StatusOr GetBoolParameter( + std::string_view parameter_name) const = 0; }; } // namespace kv_server diff --git a/components/cloud_config/parameter_client_aws.cc b/components/cloud_config/parameter_client_aws.cc index de11fe3f..2f6e5d8a 100644 --- a/components/cloud_config/parameter_client_aws.cc +++ b/components/cloud_config/parameter_client_aws.cc @@ -57,26 +57,58 @@ class AwsParameterClient : public ParameterClient { int32_t parameter_int32; if (!absl::SimpleAtoi(*parameter, ¶meter_int32)) { - std::string error = + const std::string error = absl::StrFormat("Failed converting %s parameter: %s to int32.", parameter_name, *parameter); - LOG(INFO) << error; + LOG(ERROR) << error; return absl::InvalidArgumentError(error); } return parameter_int32; }; - AwsParameterClient() : ssm_client_(std::make_unique()) {} + absl::StatusOr GetBoolParameter( + std::string_view parameter_name) const override { + // https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_GetParameter.html + // AWS SDK only returns "string" value, so we need to do the conversion + // ourselves. + absl::StatusOr parameter = GetParameter(parameter_name); + + if (!parameter.ok()) { + return parameter.status(); + } + + bool parameter_bool; + if (!absl::SimpleAtob(*parameter, ¶meter_bool)) { + const std::string error = + absl::StrFormat("Failed converting %s parameter: %s to bool.", + parameter_name, *parameter); + LOG(ERROR) << error; + return absl::InvalidArgumentError(error); + } + + return parameter_bool; + }; + + explicit AwsParameterClient(ParameterClient::ClientOptions client_options) + : client_options_(std::move(client_options)) { + if (client_options.ssm_client_for_unit_testing_ != nullptr) { + ssm_client_.reset(client_options.ssm_client_for_unit_testing_); + } else { + ssm_client_ = std::make_unique(); + } + } private: + ClientOptions client_options_; std::unique_ptr ssm_client_; }; } // namespace -std::unique_ptr ParameterClient::Create() { - return std::make_unique(); +std::unique_ptr ParameterClient::Create( + ParameterClient::ClientOptions client_options) { + return std::make_unique(std::move(client_options)); } } // namespace kv_server diff --git a/components/cloud_config/parameter_client_aws_test.cc b/components/cloud_config/parameter_client_aws_test.cc new file mode 100644 index 00000000..a3b996fb --- /dev/null +++ b/components/cloud_config/parameter_client_aws_test.cc @@ -0,0 +1,140 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include + +#include "aws/ssm/SSMClient.h" +#include "aws/ssm/SSMErrors.h" +#include "aws/ssm/model/GetParameterRequest.h" +#include "components/cloud_config/parameter_client.h" +#include "components/util/platform_initializer.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +using testing::_; +using testing::Return; + +class MockSsmClient : public ::Aws::SSM::SSMClient { + public: + MOCK_METHOD(Aws::SSM::Model::GetParameterOutcome, GetParameter, + (const Aws::SSM::Model::GetParameterRequest& request), + (const, override)); +}; + +class ParameterClientAwsTest : public ::testing::Test { + protected: + PlatformInitializer initializer_; +}; + +Aws::SSM::Model::GetParameterResult BuildParameterResult(std::string value) { + Aws::SSM::Model::Parameter parameter; + parameter.SetValue(std::move(value)); + Aws::SSM::Model::GetParameterResult result; + result.SetParameter(parameter); + return result; +} + +TEST_F(ParameterClientAwsTest, GetInt32ParameterSuccess) { + auto ssm_client = std::make_unique(); + auto parameter_result = BuildParameterResult("1"); + Aws::SSM::Model::GetParameterOutcome outcome(parameter_result); + EXPECT_CALL(*ssm_client, GetParameter(_)).WillOnce(Return(outcome)); + + ParameterClient::ClientOptions options; + options.ssm_client_for_unit_testing_ = ssm_client.release(); + + auto parameter_client = ParameterClient::Create(options); + auto result_or_status = parameter_client->GetInt32Parameter("my_param"); + EXPECT_TRUE(result_or_status.ok()); + EXPECT_EQ(result_or_status.value(), 1); +} + +TEST_F(ParameterClientAwsTest, GetInt32ParameterSSMClientFailsReturnsError) { + auto ssm_client = std::make_unique(); + Aws::SSM::SSMError ssm_error; + Aws::SSM::Model::GetParameterOutcome outcome(ssm_error); + EXPECT_CALL(*ssm_client, GetParameter(_)).WillOnce(Return(outcome)); + + ParameterClient::ClientOptions options; + options.ssm_client_for_unit_testing_ = ssm_client.release(); + + auto parameter_client = ParameterClient::Create(options); + auto result_or_status = parameter_client->GetInt32Parameter("my_param"); + EXPECT_FALSE(result_or_status.ok()); +} + +TEST_F(ParameterClientAwsTest, GetInt32ParameterNotIntReturnsError) { + auto ssm_client = std::make_unique(); + auto parameter_result = BuildParameterResult("not an int"); + Aws::SSM::Model::GetParameterOutcome outcome(parameter_result); + EXPECT_CALL(*ssm_client, GetParameter(_)).WillOnce(Return(outcome)); + + ParameterClient::ClientOptions options; + options.ssm_client_for_unit_testing_ = ssm_client.release(); + + auto parameter_client = ParameterClient::Create(options); + auto result_or_status = parameter_client->GetInt32Parameter("my_param"); + EXPECT_FALSE(result_or_status.ok()); +} + +TEST_F(ParameterClientAwsTest, GetBoolParameterSuccess) { + auto ssm_client = std::make_unique(); + auto parameter_result = BuildParameterResult("true"); + Aws::SSM::Model::GetParameterOutcome outcome(parameter_result); + EXPECT_CALL(*ssm_client, GetParameter(_)).WillOnce(Return(outcome)); + + ParameterClient::ClientOptions options; + options.ssm_client_for_unit_testing_ = ssm_client.release(); + + auto parameter_client = ParameterClient::Create(options); + auto result_or_status = parameter_client->GetBoolParameter("my_param"); + EXPECT_TRUE(result_or_status.ok()); + EXPECT_TRUE(result_or_status.value()); +} + +TEST_F(ParameterClientAwsTest, GetBoolParameterSSMClientFailsReturnsError) { + auto ssm_client = std::make_unique(); + Aws::SSM::SSMError ssm_error; + Aws::SSM::Model::GetParameterOutcome outcome(ssm_error); + EXPECT_CALL(*ssm_client, GetParameter(_)).WillOnce(Return(outcome)); + + ParameterClient::ClientOptions options; + options.ssm_client_for_unit_testing_ = ssm_client.release(); + + auto parameter_client = ParameterClient::Create(options); + auto result_or_status = parameter_client->GetBoolParameter("my_param"); + EXPECT_FALSE(result_or_status.ok()); +} + +TEST_F(ParameterClientAwsTest, GetBoolParameterNotBoolReturnsError) { + auto ssm_client = std::make_unique(); + auto parameter_result = BuildParameterResult("not a bool"); + Aws::SSM::Model::GetParameterOutcome outcome(parameter_result); + EXPECT_CALL(*ssm_client, GetParameter(_)).WillOnce(Return(outcome)); + + ParameterClient::ClientOptions options; + options.ssm_client_for_unit_testing_ = ssm_client.release(); + + auto parameter_client = ParameterClient::Create(options); + auto result_or_status = parameter_client->GetBoolParameter("my_param"); + EXPECT_FALSE(result_or_status.ok()); +} + +} // namespace + +} // namespace kv_server diff --git a/components/cloud_config/parameter_client_local.cc b/components/cloud_config/parameter_client_local.cc index 6e2e1cfa..2ce045ab 100644 --- a/components/cloud_config/parameter_client_local.cc +++ b/components/cloud_config/parameter_client_local.cc @@ -76,6 +76,9 @@ ABSL_FLAG(int32_t, s3client_max_connections, 1, ABSL_FLAG(int32_t, s3client_max_range_bytes, 1, "S3Client max range bytes for reading data files."); ABSL_FLAG(int32_t, num_shards, 1, "Total number of shards."); +ABSL_FLAG(int32_t, udf_num_workers, 2, "Number of workers for UDF execution."); +ABSL_FLAG(bool, route_v1_to_v2, false, + "Whether to route V1 requests through V2."); namespace kv_server { namespace { @@ -124,7 +127,13 @@ class LocalParameterClient : public ParameterClient { absl::GetFlag(FLAGS_s3client_max_range_bytes)}); int32_t_flag_values_.insert( {"kv-server-local-num-shards", absl::GetFlag(FLAGS_num_shards)}); + int32_t_flag_values_.insert({"kv-server-local-udf-num-workers", + absl::GetFlag(FLAGS_udf_num_workers)}); // Insert more int32 flag values here. + + bool_flag_values_.insert({"kv-server-local-route-v1-to-v2", + absl::GetFlag(FLAGS_route_v1_to_v2)}); + // Insert more bool flag values here. } absl::StatusOr GetParameter( @@ -149,14 +158,27 @@ class LocalParameterClient : public ParameterClient { } } + absl::StatusOr GetBoolParameter( + std::string_view parameter_name) const override { + const auto& it = bool_flag_values_.find(parameter_name); + if (it != bool_flag_values_.end()) { + return it->second; + } else { + return absl::InvalidArgumentError( + absl::StrCat("Unknown local bool parameter: ", parameter_name)); + } + } + private: absl::flat_hash_map int32_t_flag_values_; absl::flat_hash_map string_flag_values_; + absl::flat_hash_map bool_flag_values_; }; } // namespace -std::unique_ptr ParameterClient::Create() { +std::unique_ptr ParameterClient::Create( + ParameterClient::ClientOptions client_options) { return std::make_unique(); } diff --git a/components/cloud_config/parameter_client_local_test.cc b/components/cloud_config/parameter_client_local_test.cc index 92c36b2b..91b91eea 100644 --- a/components/cloud_config/parameter_client_local_test.cc +++ b/components/cloud_config/parameter_client_local_test.cc @@ -97,6 +97,18 @@ TEST(ParameterClientLocal, ExpectedFlagDefaultsArePresent) { ASSERT_TRUE(statusor.ok()); EXPECT_EQ(1, *statusor); } + { + const auto statusor = + client->GetInt32Parameter("kv-server-local-udf-num-workers"); + ASSERT_TRUE(statusor.ok()); + EXPECT_EQ(2, *statusor); + } + { + const auto statusor = + client->GetBoolParameter("kv-server-local-route-v1-to-v2"); + ASSERT_TRUE(statusor.ok()); + EXPECT_EQ(false, *statusor); + } } } // namespace diff --git a/components/data/blob_storage/BUILD b/components/data/blob_storage/BUILD index 3b77140d..3b8d19e5 100644 --- a/components/data/blob_storage/BUILD +++ b/components/data/blob_storage/BUILD @@ -52,7 +52,7 @@ cc_library( name = "blob_storage_client", srcs = select({ "//:aws_platform": ["blob_storage_client_s3.cc"], - "//:local_platform": [":blob_storage_client_local"], + "//:local_platform": ["blob_storage_client_local.cc"], }), hdrs = [ "blob_storage_client.h", @@ -73,29 +73,23 @@ cc_library( ], ) -cc_library( - name = "blob_storage_client_local", - srcs = ["blob_storage_client_local.cc"], - hdrs = [ - "blob_storage_client.h", - ], - deps = [ - ":seeking_input_streambuf", - "@com_github_google_glog//:glog", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - cc_test( - name = "local_blob_storage_client_test", + name = "blob_storage_client_test", size = "small", - srcs = [ - "blob_storage_client_local_test.cc", - ], - deps = [ - ":blob_storage_client_local", + srcs = select({ + "//:aws_platform": ["blob_storage_client_s3_test.cc"], + "//:local_platform": ["blob_storage_client_local_test.cc"], + }), + deps = select({ + "//:aws_platform": [ + "//components/errors:aws_error_util", + "//components/util:platform_initializer", + "@aws_sdk_cpp//:s3", + "@aws_sdk_cpp//:transfer", + ], + "//conditions:default": [], + }) + [ + ":blob_storage_client", "//components/data/common:mocks", "//components/util:sleepfor_mock", "//public/data_loading:filename_utils", @@ -134,7 +128,15 @@ cc_test( ], "//:local_platform": ["blob_storage_change_notifier_local_test.cc"], }), - deps = [ + deps = select({ + "//:aws_platform": [ + "//components/errors:aws_error_util", + "//components/util:platform_initializer", + "@aws_sdk_cpp//:sqs", + ], + "//:local_platform": [ + ], + }) + [ ":blob_storage_change_notifier", "@com_github_google_glog//:glog", "@com_github_grpc_grpc//:grpc++", @@ -154,7 +156,7 @@ cc_library( deps = [ ":blob_storage_change_notifier", ":blob_storage_client", - "//components/data/common:thread_notifier", + "//components/data/common:thread_manager", "//components/errors:retry", "//components/util:sleepfor", "//public:constants", diff --git a/components/data/blob_storage/blob_storage_change_notifier_s3.cc b/components/data/blob_storage/blob_storage_change_notifier_s3.cc index f1c540d7..a7188512 100644 --- a/components/data/blob_storage/blob_storage_change_notifier_s3.cc +++ b/components/data/blob_storage/blob_storage_change_notifier_s3.cc @@ -25,10 +25,14 @@ namespace { using privacy_sandbox::server_common::MetricsRecorder; +constexpr char* kAwsJsonParseError = "AwsJsonParseError"; + class S3BlobStorageChangeNotifier : public BlobStorageChangeNotifier { public: - explicit S3BlobStorageChangeNotifier(std::unique_ptr notifier) - : change_notifier_(std::move(notifier)) {} + explicit S3BlobStorageChangeNotifier(std::unique_ptr notifier, + MetricsRecorder& metrics_recorder) + : change_notifier_(std::move(notifier)), + metrics_recorder_(metrics_recorder) {} absl::StatusOr> GetNotifications( absl::Duration max_wait, @@ -37,6 +41,7 @@ class S3BlobStorageChangeNotifier : public BlobStorageChangeNotifier { change_notifier_->GetNotifications(max_wait, should_stop_callback); if (!notifications.ok()) { + // No need to increment metrics here, that happens in ChangeNotifier. return notifications.status(); } @@ -45,7 +50,9 @@ class S3BlobStorageChangeNotifier : public BlobStorageChangeNotifier { const absl::StatusOr parsedMessage = ParseObjectKeyFromJson(message); if (!parsedMessage.ok()) { - LOG(ERROR) << "Failed to parse JSON: " << message; + LOG(ERROR) << "Failed to parse JSON. Error: " << parsedMessage.status() + << " Message:" << message; + metrics_recorder_.IncrementEventCounter(kAwsJsonParseError); continue; } parsed_notifications.push_back(std::move(*parsedMessage)); @@ -95,6 +102,7 @@ class S3BlobStorageChangeNotifier : public BlobStorageChangeNotifier { } std::unique_ptr change_notifier_; + MetricsRecorder& metrics_recorder_; }; } // namespace @@ -113,7 +121,8 @@ BlobStorageChangeNotifier::Create(NotifierMetadata notifier_metadata, return status_or.status(); } - return std::make_unique(std::move(*status_or)); + return std::make_unique(std::move(*status_or), + metrics_recorder); } } // namespace kv_server diff --git a/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc b/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc index 2d76aadd..1ee78905 100644 --- a/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc +++ b/components/data/blob_storage/blob_storage_change_notifier_s3_test.cc @@ -14,15 +14,348 @@ * limitations under the License. */ +#include "absl/status/statusor.h" +#include "aws/sqs/SQSClient.h" +#include "aws/sqs/model/ReceiveMessageRequest.h" #include "components/data/blob_storage/blob_storage_change_notifier.h" +#include "components/data/common/msg_svc.h" +#include "components/util/platform_initializer.h" +#include "glog/logging.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "src/cpp/telemetry/mocks.h" namespace kv_server { namespace { -TEST(BlobStorageChangeNotifierS3Test, NotImplemented) { - // TODO(b/237669491): Add unit tests for the S3 BlobStorageChangeNotifier. - EXPECT_TRUE(true); +class MockMessageService : public MessageService { + public: + MOCK_METHOD(bool, IsSetupComplete, (), (const)); + + MOCK_METHOD(const std::string&, GetSqsUrl, (), (const)); + + MOCK_METHOD(absl::Status, SetupQueue, (), ()); + + MOCK_METHOD(void, Reset, (), ()); +}; + +class MockSqsClient : public ::Aws::SQS::SQSClient { + public: + MOCK_METHOD(Aws::SQS::Model::ReceiveMessageOutcome, ReceiveMessage, + (const Aws::SQS::Model::ReceiveMessageRequest& request), + (const, override)); + + MOCK_METHOD(Aws::SQS::Model::DeleteMessageBatchOutcome, DeleteMessageBatch, + (const Aws::SQS::Model::DeleteMessageBatchRequest& request), + (const, override)); +}; + +// See this link for the JSON format of AWS S3 notifications to SQS that're +// parsing: +// https://docs.aws.amazon.com/AmazonS3/latest/userguide/notification-content-structure.html +class BlobStorageChangeNotifierS3Test : public ::testing::Test { + protected: + void CreateRequiredSqsCallExpectations() { + static const std::string mock_sqs_url("mock sqs url"); + EXPECT_CALL(mock_message_service_, IsSetupComplete) + .WillOnce(::testing::Return(true)); + EXPECT_CALL(mock_message_service_, GetSqsUrl()) + .WillRepeatedly(::testing::ReturnRef(mock_sqs_url)); + } + + void SetMockMessage(const std::string& mock_message, MockSqsClient& client) { + Aws::SQS::Model::ReceiveMessageResult result; + Aws::SQS::Model::Message message; + message.SetBody(mock_message); + result.AddMessages(message); + // Because we populate the Outcome with a Result that means that IsSuccess() + // will return true. + Aws::SQS::Model::ReceiveMessageOutcome outcome(result); + EXPECT_CALL(client, ReceiveMessage(::testing::_)) + .WillOnce(::testing::Return(outcome)); + } + + PlatformInitializer initializer_; + privacy_sandbox::server_common::MockMetricsRecorder metrics_recorder_; + MockMessageService mock_message_service_; +}; + +TEST_F(BlobStorageChangeNotifierS3Test, AwsSqsUnavailable) { + CreateRequiredSqsCallExpectations(); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + auto mock_sqs_client = std::make_unique(); + // A default ReceiveMessageOutcome will be returned for calls to + // mock_sqs_client.ReceiveMessage(_). + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + EXPECT_EQ(::absl::StatusCode::kUnavailable, notifications.status().code()); +} + +TEST_F(BlobStorageChangeNotifierS3Test, InvalidJsonMessage) { + CreateRequiredSqsCallExpectations(); + + auto mock_sqs_client = std::make_unique(); + SetMockMessage("this is not valid json", *mock_sqs_client); + + // Make sure that the metric for this error is incremented but ignore any + // other metrics. + EXPECT_CALL(metrics_recorder_, IncrementEventCounter(::testing::_)) + .Times(::testing::AnyNumber()); + EXPECT_CALL(metrics_recorder_, IncrementEventCounter("AwsJsonParseError")) + .Times(1); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + ASSERT_TRUE(notifications.ok()); + // The invalid json message is dropped. + EXPECT_EQ(0, notifications->size()); +} + +TEST_F(BlobStorageChangeNotifierS3Test, JsonHasNoMessageObject) { + CreateRequiredSqsCallExpectations(); + + auto mock_sqs_client = std::make_unique(); + SetMockMessage("{}", *mock_sqs_client); + + // Make sure that the metric for this error is incremented but ignore any + // other metrics. + EXPECT_CALL(metrics_recorder_, IncrementEventCounter(::testing::_)) + .Times(::testing::AnyNumber()); + EXPECT_CALL(metrics_recorder_, IncrementEventCounter("AwsJsonParseError")) + .Times(1); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + ASSERT_TRUE(notifications.ok()); + // The invalid json message is dropped. + EXPECT_EQ(0, notifications->size()); +} + +TEST_F(BlobStorageChangeNotifierS3Test, MessageObjectIsNotAString) { + CreateRequiredSqsCallExpectations(); + + auto mock_sqs_client = std::make_unique(); + SetMockMessage(R"({ + "Message": {} + })", + *mock_sqs_client); + + // Make sure that the metric for this error is incremented but ignore any + // other metrics. + EXPECT_CALL(metrics_recorder_, IncrementEventCounter(::testing::_)) + .Times(::testing::AnyNumber()); + EXPECT_CALL(metrics_recorder_, IncrementEventCounter("AwsJsonParseError")) + .Times(1); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + ASSERT_TRUE(notifications.ok()); + // The invalid json message is dropped. + EXPECT_EQ(0, notifications->size()); +} + +TEST_F(BlobStorageChangeNotifierS3Test, RecordsIsNotAList) { + CreateRequiredSqsCallExpectations(); + + auto mock_sqs_client = std::make_unique(); + SetMockMessage(R"({ + "Message": "{\"Records\": {} }" + })", + *mock_sqs_client); + + // Make sure that the metric for this error is incremented but ignore any + // other metrics. + EXPECT_CALL(metrics_recorder_, IncrementEventCounter(::testing::_)) + .Times(::testing::AnyNumber()); + EXPECT_CALL(metrics_recorder_, IncrementEventCounter("AwsJsonParseError")) + .Times(1); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + ASSERT_TRUE(notifications.ok()); + // The invalid json message is dropped. + EXPECT_EQ(0, notifications->size()); +} + +TEST_F(BlobStorageChangeNotifierS3Test, NoS3RecordPresent) { + CreateRequiredSqsCallExpectations(); + + auto mock_sqs_client = std::make_unique(); + SetMockMessage(R"({ + "Message": "{\"Records\":[]}" + })", + *mock_sqs_client); + + // Make sure that the metric for this error is incremented but ignore any + // other metrics. + EXPECT_CALL(metrics_recorder_, IncrementEventCounter(::testing::_)) + .Times(::testing::AnyNumber()); + EXPECT_CALL(metrics_recorder_, IncrementEventCounter("AwsJsonParseError")) + .Times(1); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + ASSERT_TRUE(notifications.ok()); + // The invalid json message is dropped. + EXPECT_EQ(0, notifications->size()); +} + +TEST_F(BlobStorageChangeNotifierS3Test, S3RecordIsNull) { + CreateRequiredSqsCallExpectations(); + + auto mock_sqs_client = std::make_unique(); + SetMockMessage(R"({ + "Message": "{\"Records\":[ {\"s3\":null}]}" + })", + *mock_sqs_client); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + ASSERT_TRUE(notifications.ok()); + // The invalid json message is dropped. + EXPECT_EQ(0, notifications->size()); +} + +TEST_F(BlobStorageChangeNotifierS3Test, S3ObjectIsNull) { + CreateRequiredSqsCallExpectations(); + + auto mock_sqs_client = std::make_unique(); + SetMockMessage(R"({ + "Message": "{\"Records\":[ {\"s3\":{\"object\":null}}]}" + })", + *mock_sqs_client); + + // Make sure that the metric for this error is incremented but ignore any + // other metrics. + EXPECT_CALL(metrics_recorder_, IncrementEventCounter(::testing::_)) + .Times(::testing::AnyNumber()); + EXPECT_CALL(metrics_recorder_, IncrementEventCounter("AwsJsonParseError")) + .Times(1); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + ASSERT_TRUE(notifications.ok()); + // The invalid json message is dropped. + EXPECT_EQ(0, notifications->size()); +} + +TEST_F(BlobStorageChangeNotifierS3Test, S3KeyIsNotAString) { + CreateRequiredSqsCallExpectations(); + + auto mock_sqs_client = std::make_unique(); + SetMockMessage(R"({ + "Message": "{\"Records\":[ {\"s3\":{\"object\":{\"key\":{}}}}}]}" + })", + *mock_sqs_client); + + // Make sure that the metric for this error is incremented but ignore any + // other metrics. + EXPECT_CALL(metrics_recorder_, IncrementEventCounter(::testing::_)) + .Times(::testing::AnyNumber()); + EXPECT_CALL(metrics_recorder_, IncrementEventCounter("AwsJsonParseError")) + .Times(1); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + ASSERT_TRUE(notifications.ok()); + // The invalid json message is dropped. + EXPECT_EQ(0, notifications->size()); +} + +TEST_F(BlobStorageChangeNotifierS3Test, ValidJson) { + CreateRequiredSqsCallExpectations(); + + auto mock_sqs_client = std::make_unique(); + SetMockMessage(R"({ + "Message": "{\"Records\":[ {\"s3\":{\"object\":{\"key\":\"HappyFace.jpg\"}}}]}" + })", + *mock_sqs_client); + + CloudNotifierMetadata notifier_metadata; + notifier_metadata.queue_manager = &mock_message_service_; + notifier_metadata.only_for_testing_sqs_client_ = mock_sqs_client.release(); + + absl::StatusOr> notifier = + BlobStorageChangeNotifier::Create(notifier_metadata, metrics_recorder_); + ASSERT_TRUE(notifier.status().ok()); + + const absl::StatusOr> notifications = + (*notifier)->GetNotifications(absl::Seconds(1), [] { return false; }); + ASSERT_TRUE(notifications.ok()); + EXPECT_THAT(*notifications, + testing::UnorderedElementsAreArray({"HappyFace.jpg"})); } } // namespace diff --git a/components/data/blob_storage/blob_storage_client.h b/components/data/blob_storage/blob_storage_client.h index 8b2e87cf..a15cad0b 100644 --- a/components/data/blob_storage/blob_storage_client.h +++ b/components/data/blob_storage/blob_storage_client.h @@ -28,6 +28,12 @@ #include "absl/status/statusor.h" #include "src/cpp/telemetry/metrics_recorder.h" +namespace Aws { +namespace S3 { +class S3Client; +} // namespace S3 +} // namespace Aws + namespace kv_server { // Contains a stream of content data read from a cloud object. @@ -60,6 +66,9 @@ class BlobStorageClient { ClientOptions() {} int64_t max_connections = std::thread::hardware_concurrency(); int64_t max_range_bytes = 8 * 1024 * 1024; // 8MB + + // BlobStorageClient takes ownership of this if it's set: + ::Aws::S3::S3Client* s3_client_for_unit_testing_ = nullptr; }; // TODO(b/237669491): Replace these factory methods with one based off the diff --git a/components/data/blob_storage/blob_storage_client_s3.cc b/components/data/blob_storage/blob_storage_client_s3.cc index c6808d75..45f938a3 100644 --- a/components/data/blob_storage/blob_storage_client_s3.cc +++ b/components/data/blob_storage/blob_storage_client_s3.cc @@ -203,9 +203,14 @@ class S3BlobStorageClient : public BlobStorageClient { BlobStorageClient::ClientOptions client_options) : metrics_recorder_(metrics_recorder), client_options_(std::move(client_options)) { - Aws::Client::ClientConfiguration config; - config.maxConnections = client_options_.max_connections; - client_ = std::make_shared(config); + if (client_options.s3_client_for_unit_testing_ != nullptr) { + client_.reset(client_options.s3_client_for_unit_testing_); + } else { + Aws::Client::ClientConfiguration config; + config.maxConnections = client_options_.max_connections; + client_ = std::make_shared(config); + } + executor_ = std::make_unique( std::thread::hardware_concurrency()); Aws::Transfer::TransferManagerConfiguration transfer_config( diff --git a/components/data/blob_storage/blob_storage_client_s3_test.cc b/components/data/blob_storage/blob_storage_client_s3_test.cc new file mode 100644 index 00000000..7d1fd349 --- /dev/null +++ b/components/data/blob_storage/blob_storage_client_s3_test.cc @@ -0,0 +1,171 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "aws/core/Aws.h" +#include "aws/s3/S3Client.h" +#include "aws/s3/model/DeleteObjectRequest.h" +#include "aws/s3/model/ListObjectsV2Request.h" +#include "aws/s3/model/Object.h" +#include "components/data/blob_storage/blob_storage_client.h" +#include "components/util/platform_initializer.h" +#include "gtest/gtest.h" +#include "src/cpp/telemetry/mocks.h" + +namespace kv_server { +namespace { + +using privacy_sandbox::server_common::MockMetricsRecorder; + +class MockS3Client : public ::Aws::S3::S3Client { + public: + MOCK_METHOD(Aws::S3::Model::DeleteObjectOutcome, DeleteObject, + (const Aws::S3::Model::DeleteObjectRequest& request), + (const, override)); + + MOCK_METHOD(Aws::S3::Model::ListObjectsV2Outcome, ListObjectsV2, + (const Aws::S3::Model::ListObjectsV2Request& request), + (const, override)); +}; + +class BlobStorageClientS3Test : public ::testing::Test { + protected: + PlatformInitializer initializer_; + privacy_sandbox::server_common::MockMetricsRecorder metrics_recorder_; +}; + +TEST_F(BlobStorageClientS3Test, DeleteBlobSucceeds) { + auto mock_s3_client = std::make_unique(); + Aws::S3::Model::DeleteObjectResult result; // An empty result means success. + EXPECT_CALL(*mock_s3_client, DeleteObject(::testing::_)) + .WillOnce(::testing::Return(result)); + + BlobStorageClient::ClientOptions options; + options.s3_client_for_unit_testing_ = mock_s3_client.release(); + std::unique_ptr client = + BlobStorageClient::Create(metrics_recorder_, options); + ASSERT_TRUE(client != nullptr); + + BlobStorageClient::DataLocation location; + EXPECT_TRUE(client->DeleteBlob(location).ok()); +} + +TEST_F(BlobStorageClientS3Test, DeleteBlobFails) { + // By default an error is returned for calls to DeleteBlob(). + auto mock_s3_client = std::make_unique(); + + BlobStorageClient::ClientOptions options; + options.s3_client_for_unit_testing_ = mock_s3_client.release(); + std::unique_ptr client = + BlobStorageClient::Create(metrics_recorder_, options); + ASSERT_TRUE(client != nullptr); + + BlobStorageClient::DataLocation location; + EXPECT_EQ(absl::StatusCode::kUnknown, client->DeleteBlob(location).code()); +} + +TEST_F(BlobStorageClientS3Test, ListBlobsSucceeds) { + auto mock_s3_client = std::make_unique(); + { + Aws::S3::Model::ListObjectsV2Result + result; // An empty result means success. + Aws::S3::Model::Object object_to_return; + object_to_return.SetKey("HappyFace.jpg"); + Aws::Vector objects_to_return = {object_to_return}; + result.SetContents(objects_to_return); + EXPECT_CALL(*mock_s3_client, ListObjectsV2(::testing::_)) + .WillOnce(::testing::Return(result)); + } + + BlobStorageClient::ClientOptions options; + options.s3_client_for_unit_testing_ = mock_s3_client.release(); + std::unique_ptr client = + BlobStorageClient::Create(metrics_recorder_, options); + ASSERT_TRUE(client != nullptr); + + BlobStorageClient::DataLocation location; + BlobStorageClient::ListOptions list_options; + absl::StatusOr> response = + client->ListBlobs(location, list_options); + ASSERT_TRUE(response.ok()); + EXPECT_THAT(*response, testing::UnorderedElementsAreArray({"HappyFace.jpg"})); +} + +TEST_F(BlobStorageClientS3Test, ListBlobsSucceedsWithContinuedRequests) { + auto mock_s3_client = std::make_unique(); + // Set up two expected requests. The first one is marked as truncated so that + // the second one will happen. + ::testing::InSequence in_sequence; + { + Aws::S3::Model::ListObjectsV2Result + result; // An empty result means success. + result.SetIsTruncated(true); + Aws::S3::Model::Object object_to_return; + object_to_return.SetKey("SomewhatHappyFace.jpg"); + Aws::Vector objects_to_return = {object_to_return}; + result.SetContents(objects_to_return); + EXPECT_CALL(*mock_s3_client, ListObjectsV2(::testing::_)) + .WillOnce(::testing::Return(result)); + } + { + Aws::S3::Model::ListObjectsV2Result + result; // An empty result means success. + Aws::S3::Model::Object object_to_return; + object_to_return.SetKey("VeryHappyFace.jpg"); + Aws::Vector objects_to_return = {object_to_return}; + result.SetContents(objects_to_return); + EXPECT_CALL(*mock_s3_client, ListObjectsV2(::testing::_)) + .WillOnce(::testing::Return(result)); + } + + BlobStorageClient::ClientOptions options; + options.s3_client_for_unit_testing_ = mock_s3_client.release(); + std::unique_ptr client = + BlobStorageClient::Create(metrics_recorder_, options); + ASSERT_TRUE(client != nullptr); + + BlobStorageClient::DataLocation location; + BlobStorageClient::ListOptions list_options; + absl::StatusOr> response = + client->ListBlobs(location, list_options); + ASSERT_TRUE(response.ok()); + EXPECT_THAT(*response, testing::UnorderedElementsAreArray( + {"SomewhatHappyFace.jpg", "VeryHappyFace.jpg"})); +} + +TEST_F(BlobStorageClientS3Test, ListBlobsFails) { + // By default an error is returned for calls to ListObjectsV2(). + auto mock_s3_client = std::make_unique(); + + BlobStorageClient::ClientOptions options; + options.s3_client_for_unit_testing_ = mock_s3_client.release(); + std::unique_ptr client = + BlobStorageClient::Create(metrics_recorder_, options); + ASSERT_TRUE(client != nullptr); + + BlobStorageClient::DataLocation location; + BlobStorageClient::ListOptions list_options; + EXPECT_EQ(absl::StatusCode::kUnknown, + client->ListBlobs(location, list_options).status().code()); +} + +} // namespace +} // namespace kv_server diff --git a/components/data/blob_storage/delta_file_notifier.cc b/components/data/blob_storage/delta_file_notifier.cc index 5addf315..ac67cdb4 100644 --- a/components/data/blob_storage/delta_file_notifier.cc +++ b/components/data/blob_storage/delta_file_notifier.cc @@ -21,7 +21,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" -#include "components/data/common/thread_notifier.h" +#include "components/data/common/thread_manager.h" #include "components/errors/retry.h" #include "glog/logging.h" #include "public/constants.h" @@ -40,7 +40,7 @@ class DeltaFileNotifierImpl : public DeltaFileNotifier { const absl::Duration poll_frequency, std::unique_ptr sleep_for, SteadyClock& clock) - : thread_notifier_(ThreadNotifier::Create("Delta file notifier")), + : thread_manager_(TheadManager::Create("Delta file notifier")), client_(client), poll_frequency_(poll_frequency), sleep_for_(std::move(sleep_for)), @@ -50,10 +50,10 @@ class DeltaFileNotifierImpl : public DeltaFileNotifier { BlobStorageChangeNotifier& change_notifier, BlobStorageClient::DataLocation location, std::string start_after, std::function callback) override { - return thread_notifier_->Start([this, location = std::move(location), - start_after = std::move(start_after), - callback = std::move(callback), - &change_notifier]() mutable { + return thread_manager_->Start([this, location = std::move(location), + start_after = std::move(start_after), + callback = std::move(callback), + &change_notifier]() mutable { Watch(change_notifier, std::move(location), std::move(start_after), std::move(callback)); }); @@ -61,11 +61,11 @@ class DeltaFileNotifierImpl : public DeltaFileNotifier { absl::Status Stop() override { absl::Status status = sleep_for_->Stop(); - status.Update(thread_notifier_->Stop()); + status.Update(thread_manager_->Stop()); return status; } - bool IsRunning() const override { return thread_notifier_->IsRunning(); } + bool IsRunning() const override { return thread_manager_->IsRunning(); } private: // Returns max DeltaFile in alphabetical order from notification @@ -75,7 +75,7 @@ class DeltaFileNotifierImpl : public DeltaFileNotifier { absl::Duration wait_duration) { absl::StatusOr> changes = change_notifier.GetNotifications( - wait_duration, [this]() { return thread_notifier_->ShouldStop(); }); + wait_duration, [this]() { return thread_manager_->ShouldStop(); }); if (!changes.ok()) { return changes.status(); } @@ -125,7 +125,7 @@ class DeltaFileNotifierImpl : public DeltaFileNotifier { // Flag starts expired, and forces an initial poll. ExpiringFlag expiring_flag(clock_); uint32_t sequential_failures = 0; - while (!thread_notifier_->ShouldStop()) { + while (!thread_manager_->ShouldStop()) { const absl::StatusOr should_list_blobs = ShouldListBlobs(change_notifier, expiring_flag, last_key); if (!should_list_blobs.ok()) { @@ -171,7 +171,7 @@ class DeltaFileNotifierImpl : public DeltaFileNotifier { } } - std::unique_ptr thread_notifier_; + std::unique_ptr thread_manager_; BlobStorageClient& client_; const absl::Duration poll_frequency_; std::unique_ptr sleep_for_; diff --git a/components/data/blob_storage/delta_file_notifier.h b/components/data/blob_storage/delta_file_notifier.h index 37638659..c18a647f 100644 --- a/components/data/blob_storage/delta_file_notifier.h +++ b/components/data/blob_storage/delta_file_notifier.h @@ -22,7 +22,7 @@ #include "components/data/blob_storage/blob_storage_change_notifier.h" #include "components/data/blob_storage/blob_storage_client.h" -#include "components/data/common/thread_notifier.h" +#include "components/data/common/thread_manager.h" #include "components/errors/retry.h" #include "components/util/sleepfor.h" #include "src/cpp/util/duration.h" diff --git a/components/data/common/BUILD b/components/data/common/BUILD index 5c007f91..645eea52 100644 --- a/components/data/common/BUILD +++ b/components/data/common/BUILD @@ -110,12 +110,12 @@ cc_test( ) cc_library( - name = "thread_notifier", + name = "thread_manager", srcs = [ - "thread_notifier.cc", + "thread_manager.cc", ], hdrs = [ - "thread_notifier.h", + "thread_manager.h", ], deps = [ "//components/errors:retry", diff --git a/components/data/common/msg_svc.h b/components/data/common/msg_svc.h index 9b5653c8..ec07ccb3 100644 --- a/components/data/common/msg_svc.h +++ b/components/data/common/msg_svc.h @@ -18,6 +18,7 @@ #define COMPONENTS_DATA_COMMON_MSG_SVC_H_ #include +#include #include #include @@ -46,7 +47,8 @@ class MessageService { virtual void Reset() = 0; static absl::StatusOr> Create( - NotifierMetadata notifier_metadata); + NotifierMetadata notifier_metadata, + std::optional shard_num = std::nullopt); }; } // namespace kv_server diff --git a/components/data/common/msg_svc_aws.cc b/components/data/common/msg_svc_aws.cc index bce7e0a2..ae3b3e54 100644 --- a/components/data/common/msg_svc_aws.cc +++ b/components/data/common/msg_svc_aws.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "absl/random/random.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -54,6 +56,10 @@ constexpr char kPolicyTemplate[] = R"({ ] })"; +constexpr char kFilterPolicyTemplate[] = R"({ + "shard_num": ["%d"] +})"; + constexpr std::string_view alphanum = "_-0123456789" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -63,8 +69,11 @@ class AwsMessageService : public MessageService { public: // `prefix` is the prefix of randomly generated SQS Queue name. // The queue is subscribed to the topic at `sns_arn`. - AwsMessageService(std::string prefix, std::string sns_arn) - : prefix_(std::move(prefix)), sns_arn_(std::move(sns_arn)) {} + AwsMessageService(std::string prefix, std::string sns_arn, + std::optional shard_num) + : prefix_(std::move(prefix)), + sns_arn_(std::move(sns_arn)), + shard_num_(shard_num) {} bool IsSetupComplete() const { absl::ReaderMutexLock lock(&mutex_); @@ -174,6 +183,10 @@ class AwsMessageService : public MessageService { req.SetTopicArn(sns_arn); req.SetProtocol("sqs"); req.SetEndpoint(queue_url); + if (prefix_ == "QueueNotifier_" && shard_num_.has_value()) { + req.AddAttributes("FilterPolicy", absl::StrFormat(kFilterPolicyTemplate, + shard_num_.value())); + } const auto outcome = sns.Subscribe(req); return outcome.IsSuccess() ? absl::OkStatus() : AwsErrorToStatus(outcome.GetError()); @@ -188,15 +201,16 @@ class AwsMessageService : public MessageService { std::string sqs_url_; std::string sqs_arn_; bool are_attributes_set_ = false; + std::optional shard_num_; }; } // namespace absl::StatusOr> MessageService::Create( - NotifierMetadata notifier_metadata) { + NotifierMetadata notifier_metadata, std::optional shard_num) { auto metadata = std::get(notifier_metadata); - return std::make_unique(std::move(metadata.queue_prefix), - std::move(metadata.sns_arn)); + return std::make_unique( + std::move(metadata.queue_prefix), std::move(metadata.sns_arn), shard_num); } } // namespace kv_server diff --git a/components/data/common/msg_svc_local.cc b/components/data/common/msg_svc_local.cc index 414dc2b0..a8a1fead 100644 --- a/components/data/common/msg_svc_local.cc +++ b/components/data/common/msg_svc_local.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "components/data/common/msg_svc.h" namespace kv_server { @@ -28,7 +30,7 @@ class LocalMessageService : public MessageService { } // namespace absl::StatusOr> MessageService::Create( - NotifierMetadata notifier_metadata) { + NotifierMetadata notifier_metadata, std::optional shard_num) { auto metadata = std::get(notifier_metadata); return std::make_unique( std::move(metadata.local_directory)); diff --git a/components/data/common/thread_notifier.cc b/components/data/common/thread_manager.cc similarity index 70% rename from components/data/common/thread_notifier.cc rename to components/data/common/thread_manager.cc index e083646d..7d00ecff 100644 --- a/components/data/common/thread_notifier.cc +++ b/components/data/common/thread_manager.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "components/data/common/thread_notifier.h" +#include "components/data/common/thread_manager.h" #include #include @@ -30,30 +30,30 @@ namespace kv_server { namespace { -class ThreadNotifierImpl : public ThreadNotifier { +class TheadManagerImpl : public TheadManager { public: - explicit ThreadNotifierImpl(std::string notifier_name) - : notifier_name_(std::move(notifier_name)) {} + explicit TheadManagerImpl(std::string thread_name) + : thread_name_(std::move(thread_name)) {} - ~ThreadNotifierImpl() { + ~TheadManagerImpl() { if (!IsRunning()) return; if (const auto s = Stop(); !s.ok()) { - LOG(ERROR) << notifier_name_ << " failed to stop notifier: " << s; + LOG(ERROR) << thread_name_ << " failed to stop: " << s; } } absl::Status Start(std::function watch) override { if (IsRunning()) { - return absl::FailedPreconditionError("Already notifying"); + return absl::FailedPreconditionError("Already running"); } - LOG(INFO) << notifier_name_ << "Creating thread for watching files"; + LOG(INFO) << thread_name_ << " Creating thread for processing"; thread_ = std::make_unique(watch); return absl::OkStatus(); } absl::Status Stop() override { if (!IsRunning()) { - return absl::FailedPreconditionError("Not currently notifying"); + return absl::FailedPreconditionError("Not currently running"); } should_stop_ = true; thread_->join(); @@ -69,14 +69,13 @@ class ThreadNotifierImpl : public ThreadNotifier { private: std::unique_ptr thread_; std::atomic should_stop_ = false; - std::string notifier_name_; + std::string thread_name_; }; } // namespace -std::unique_ptr ThreadNotifier::Create( - std::string notifier_name) { - return std::make_unique(std::move(notifier_name)); +std::unique_ptr TheadManager::Create(std::string thread_name) { + return std::make_unique(std::move(thread_name)); } } // namespace kv_server diff --git a/components/data/common/thread_notifier.h b/components/data/common/thread_manager.h similarity index 77% rename from components/data/common/thread_notifier.h rename to components/data/common/thread_manager.h index 32de4189..abd9f24b 100644 --- a/components/data/common/thread_notifier.h +++ b/components/data/common/thread_manager.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef COMPONENTS_DATA_COMMON_THREAD_NOTIFIER_H_ -#define COMPONENTS_DATA_COMMON_THREAD_NOTIFIER_H_ +#ifndef COMPONENTS_DATA_COMMON_THREAD_MANAGER_H_ +#define COMPONENTS_DATA_COMMON_THREAD_MANAGER_H_ #include #include @@ -24,11 +24,11 @@ namespace kv_server { -class ThreadNotifier { +class TheadManager { public: - virtual ~ThreadNotifier() = default; + virtual ~TheadManager() = default; - // Checks if the ThreadNotifier is already running. + // Checks if the TheadManager is already running. // If not, starts a thread on which `watch` is executed. // Start and Stop should be called on the same thread as // the constructor. @@ -43,8 +43,8 @@ class ThreadNotifier { virtual bool ShouldStop() = 0; - static std::unique_ptr Create(std::string notifier_name); + static std::unique_ptr Create(std::string thread_name); }; } // namespace kv_server -#endif // COMPONENTS_DATA_COMMON_THREAD_NOTIFIER_H_ +#endif // COMPONENTS_DATA_COMMON_THREAD_MANAGER_H_ diff --git a/components/data/realtime/BUILD b/components/data/realtime/BUILD index 47195425..8ae94f02 100644 --- a/components/data/realtime/BUILD +++ b/components/data/realtime/BUILD @@ -73,7 +73,7 @@ cc_library( ], deps = [ ":delta_file_record_change_notifier", - "//components/data/common:thread_notifier", + "//components/data/common:thread_manager", "//components/errors:retry", "//components/util:sleepfor", "//public:constants", diff --git a/components/data/realtime/realtime_notifier.cc b/components/data/realtime/realtime_notifier.cc index 01e802a1..ee0e6203 100644 --- a/components/data/realtime/realtime_notifier.cc +++ b/components/data/realtime/realtime_notifier.cc @@ -21,7 +21,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" -#include "components/data/common/thread_notifier.h" +#include "components/data/common/thread_manager.h" #include "components/data/realtime/delta_file_record_change_notifier.h" #include "components/errors/retry.h" #include "glog/logging.h" @@ -56,7 +56,7 @@ class RealtimeNotifierImpl : public RealtimeNotifier { public: explicit RealtimeNotifierImpl(MetricsRecorder& metrics_recorder, std::unique_ptr sleep_for) - : thread_notifier_(ThreadNotifier::Create("Realtime notifier")), + : thread_manager_(TheadManager::Create("Realtime notifier")), metrics_recorder_(metrics_recorder), sleep_for_(std::move(sleep_for)) { metrics_recorder.RegisterHistogram(kReceivedLowLatencyNotificationsE2E, @@ -72,7 +72,7 @@ class RealtimeNotifierImpl : public RealtimeNotifier { DeltaFileRecordChangeNotifier& change_notifier, std::function(const std::string& key)> callback) override { - return thread_notifier_->Start( + return thread_manager_->Start( [this, callback = std::move(callback), &change_notifier]() mutable { Watch(change_notifier, std::move(callback)); }); @@ -80,11 +80,11 @@ class RealtimeNotifierImpl : public RealtimeNotifier { absl::Status Stop() override { absl::Status status = sleep_for_->Stop(); - status.Update(thread_notifier_->Stop()); + status.Update(thread_manager_->Stop()); return status; } - bool IsRunning() const override { return thread_notifier_->IsRunning(); } + bool IsRunning() const override { return thread_manager_->IsRunning(); } private: void Watch( @@ -95,9 +95,9 @@ class RealtimeNotifierImpl : public RealtimeNotifier { // Later polls are long polls. auto max_wait = absl::ZeroDuration(); uint32_t sequential_failures = 0; - while (!thread_notifier_->ShouldStop()) { + while (!thread_manager_->ShouldStop()) { auto updates = change_notifier.GetNotifications( - max_wait, [this]() { return thread_notifier_->ShouldStop(); }); + max_wait, [this]() { return thread_manager_->ShouldStop(); }); if (absl::IsDeadlineExceeded(updates.status())) { sequential_failures = 0; @@ -154,7 +154,7 @@ class RealtimeNotifierImpl : public RealtimeNotifier { } } - std::unique_ptr thread_notifier_; + std::unique_ptr thread_manager_; MetricsRecorder& metrics_recorder_; std::unique_ptr sleep_for_; }; diff --git a/components/data/realtime/realtime_notifier.h b/components/data/realtime/realtime_notifier.h index 04b8b4f4..e06af62b 100644 --- a/components/data/realtime/realtime_notifier.h +++ b/components/data/realtime/realtime_notifier.h @@ -20,7 +20,7 @@ #include #include -#include "components/data/common/thread_notifier.h" +#include "components/data/common/thread_manager.h" #include "components/data/realtime/delta_file_record_change_notifier.h" #include "components/errors/retry.h" #include "components/util/sleepfor.h" diff --git a/components/data_server/cache/BUILD b/components/data_server/cache/BUILD index af421543..16477789 100644 --- a/components/data_server/cache/BUILD +++ b/components/data_server/cache/BUILD @@ -19,13 +19,29 @@ package(default_visibility = [ "//tools:__subpackages__", ]) +cc_library( + name = "get_key_value_set_result_impl", + srcs = [ + "get_key_value_set_result_impl.cc", + ], + hdrs = [ + "get_key_value_set_result.h", + ], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + cc_library( name = "cache", hdrs = [ "cache.h", ], deps = [ + ":get_key_value_set_result_impl", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -39,6 +55,7 @@ cc_library( ], deps = [ ":cache", + ":get_key_value_set_result_impl", "//public:base_types_cc_proto", "@com_github_google_glog//:glog", "@com_google_absl//absl/base", diff --git a/components/data_server/cache/cache.h b/components/data_server/cache/cache.h index 54549901..eaa4b34a 100644 --- a/components/data_server/cache/cache.h +++ b/components/data_server/cache/cache.h @@ -25,6 +25,8 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "components/data_server/cache/get_key_value_set_result.h" namespace kv_server { @@ -38,14 +40,31 @@ class Cache { virtual absl::flat_hash_map GetKeyValuePairs( const std::vector& key_list) const = 0; + // Looks up and returns key-value set result for the given key set. + virtual std::unique_ptr GetKeyValueSet( + const absl::flat_hash_set& key_set) const = 0; + // Inserts or updates the key with the new value. virtual void UpdateKeyValue(std::string_view key, std::string_view value, int64_t logical_commit_time) = 0; + // Inserts or updates values in the set for a given key, if a value exists, + // updates its timestamp to the latest logical commit time. + virtual void UpdateKeyValueSet(std::string_view key, + absl::Span value_set, + int64_t logical_commit_time) = 0; + // Deletes a particular (key, value) pair. virtual void DeleteKey(std::string_view key, int64_t logical_commit_time) = 0; - // Remove the values that were deleted before the specified + // Deletes values in the set for a given key. The deletion, this object + // still exist and is marked "deleted", in case there are + // late-arriving updates to this value. + virtual void DeleteValuesInSet(std::string_view key, + absl::Span value_set, + int64_t logical_commit_time) = 0; + + // Removes the values that were deleted before the specified // logical_commit_time. virtual void RemoveDeletedKeys(int64_t logical_commit_time) = 0; diff --git a/components/data_server/cache/get_key_value_set_result.h b/components/data_server/cache/get_key_value_set_result.h new file mode 100644 index 00000000..b9a36861 --- /dev/null +++ b/components/data_server/cache/get_key_value_set_result.h @@ -0,0 +1,48 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_SERVER_CACHE_GET_KEY_VALUE_SET_RESULT_H_ +#define COMPONENTS_DATA_SERVER_CACHE_GET_KEY_VALUE_SET_RESULT_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" + +namespace kv_server { +// Class that holds the data retrieved from cache lookup and read locks for +// the lookup keys +class GetKeyValueSetResult { + public: + virtual ~GetKeyValueSetResult() = default; + // Looks up and returns key-value set result for the given key set. + virtual absl::flat_hash_set GetValueSet( + std::string_view key) const = 0; + + private: + // Adds key, value_set to the result data map, creates a read lock for + // the key mutex + virtual void AddKeyValueSet( + absl::Mutex& key_mutex, std::string_view key, + const absl::flat_hash_set& value_set) = 0; + static std::unique_ptr Create(); + friend class KeyValueCache; +}; + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_CACHE_GET_KEY_VALUE_SET_RESULT_H_ diff --git a/components/data_server/cache/get_key_value_set_result_impl.cc b/components/data_server/cache/get_key_value_set_result_impl.cc new file mode 100644 index 00000000..ee894b63 --- /dev/null +++ b/components/data_server/cache/get_key_value_set_result_impl.cc @@ -0,0 +1,68 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "components/data_server/cache/get_key_value_set_result.h" + +namespace kv_server { +namespace { + +// Class that holds the data retrieved from cache lookup and read locks for +// the lookup keys +class GetKeyValueSetResultImpl : public GetKeyValueSetResult { + public: + GetKeyValueSetResultImpl() {} + // Looks up the key in the data map and returns value set. If the value_set + // for the key is missing, returns empty set. + absl::flat_hash_set GetValueSet( + std::string_view key) const override { + static const absl::flat_hash_set* kEmptySet = + new absl::flat_hash_set(); + auto key_itr = data_map_.find(key); + return key_itr == data_map_.end() ? *kEmptySet : key_itr->second; + } + GetKeyValueSetResultImpl(const GetKeyValueSetResultImpl&) = delete; + GetKeyValueSetResultImpl& operator=(const GetKeyValueSetResultImpl&) = delete; + GetKeyValueSetResultImpl(GetKeyValueSetResultImpl&& other) = default; + GetKeyValueSetResultImpl& operator=(GetKeyValueSetResultImpl&& other) = + default; + + private: + std::vector> read_locks_; + absl::flat_hash_map> + data_map_; + + // Adds key, value_set to the result data map, creates a read lock for + // the key mutex + void AddKeyValueSet( + absl::Mutex& key_mutex, std::string_view key, + const absl::flat_hash_set& value_set) override { + read_locks_.emplace_back(std::move(new absl::ReaderMutexLock(&key_mutex))); + data_map_.emplace(key, value_set); + } +}; +} // namespace + +std::unique_ptr GetKeyValueSetResult::Create() { + return std::make_unique(); +} + +} // namespace kv_server diff --git a/components/data_server/cache/key_value_cache.cc b/components/data_server/cache/key_value_cache.cc index 9d99d618..4a8cb2d6 100644 --- a/components/data_server/cache/key_value_cache.cc +++ b/components/data_server/cache/key_value_cache.cc @@ -15,21 +15,17 @@ #include #include -#include #include -#include #include -#include "absl/base/optimization.h" -#include "absl/base/thread_annotations.h" -#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "components/data_server/cache/cache.h" +#include "components/data_server/cache/get_key_value_set_result.h" #include "glog/logging.h" -#include "public/base_types.pb.h" namespace kv_server { + absl::flat_hash_map KeyValueCache::GetKeyValuePairs( const std::vector& key_list) const { absl::flat_hash_map kv_pairs; @@ -39,18 +35,48 @@ absl::flat_hash_map KeyValueCache::GetKeyValuePairs( if (key_iter == map_.end() || key_iter->second.value == nullptr) { continue; } else { + VLOG(9) << "Get called for " << key + << ". returning value: " << *(key_iter->second.value); kv_pairs.insert_or_assign(key, *(key_iter->second.value)); } } return kv_pairs; } +std::unique_ptr KeyValueCache::GetKeyValueSet( + const absl::flat_hash_set& key_set) const { + // lock the cache map + absl::ReaderMutexLock lock(&set_map_mutex_); + auto result = GetKeyValueSetResult::Create(); + for (const auto& key : key_set) { + VLOG(8) << "Getting key: " << key; + const auto key_itr = key_to_value_set_map_.find(key); + if (key_itr != key_to_value_set_map_.end()) { + absl::flat_hash_set value_set; + for (const auto& v : key_itr->second->second) { + if (!v.second.is_deleted) { + value_set.emplace(v.first); + } + } + // Add key value set to the result + result->AddKeyValueSet(key_itr->second->first, key, value_set); + } + } + return result; +} + // Replaces the current key-value entry with the new key-value entry. void KeyValueCache::UpdateKeyValue(std::string_view key, std::string_view value, int64_t logical_commit_time) { + VLOG(9) << "Received update for [" << key << "] at " << logical_commit_time + << ". value will be set to: " << value; absl::MutexLock lock(&mutex_); if (logical_commit_time <= max_cleanup_logical_commit_time_) { + VLOG(1) << "Skipping the update as its logical_commit_time: " + << logical_commit_time << " is older than the current cutoff time:" + << max_cleanup_logical_commit_time_; + return; } @@ -58,6 +84,9 @@ void KeyValueCache::UpdateKeyValue(std::string_view key, std::string_view value, if (key_iter != map_.end() && key_iter->second.last_logical_commit_time >= logical_commit_time) { + VLOG(1) << "Skipping the update as its logical_commit_time: " + << logical_commit_time << " is older than the current value's time:" + << key_iter->second.last_logical_commit_time; return; } @@ -65,7 +94,8 @@ void KeyValueCache::UpdateKeyValue(std::string_view key, std::string_view value, key_iter->second.last_logical_commit_time < logical_commit_time && key_iter->second.value == nullptr) { // should always have this, but checking just in case - auto dl_key_iter = deleted_nodes_.find(logical_commit_time); + auto dl_key_iter = + deleted_nodes_.find(key_iter->second.last_logical_commit_time); if (dl_key_iter != deleted_nodes_.end() && dl_key_iter->second == key) { deleted_nodes_.erase(dl_key_iter); } @@ -75,12 +105,79 @@ void KeyValueCache::UpdateKeyValue(std::string_view key, std::string_view value, .last_logical_commit_time = logical_commit_time}); } +void KeyValueCache::UpdateKeyValueSet( + std::string_view key, absl::Span input_value_set, + int64_t logical_commit_time) { + VLOG(9) << "Received update for [" << key << "] at " << logical_commit_time; + std::unique_ptr key_lock; + absl::flat_hash_map* existing_value_set; + // The max cleanup time needs to be locked before doing this comparison + { + absl::MutexLock lock_map(&set_map_mutex_); + + if (logical_commit_time <= max_cleanup_logical_commit_time_for_set_cache_) { + VLOG(1) << "Skipping the update as its logical_commit_time: " + << logical_commit_time + << " is older than the current cutoff time:" + << max_cleanup_logical_commit_time_for_set_cache_; + return; + } else if (input_value_set.empty()) { + VLOG(1) << "Skipping the update as it has no value in the set."; + return; + } + auto key_itr = key_to_value_set_map_.find(key); + if (key_itr == key_to_value_set_map_.end()) { + VLOG(9) << key << " is a new key. Adding it"; + // There is no existing value set for the given key, + // simply insert the key value set to the map, no need to update deleted + // set nodes + auto mutex_value_map_pair = std::make_unique>>(); + + for (const auto& value : input_value_set) { + mutex_value_map_pair->second.emplace( + value, SetValueMeta{logical_commit_time, /*is_deleted=*/false}); + } + key_to_value_set_map_.emplace(key, std::move(mutex_value_map_pair)); + return; + } + // The given key has an existing value set, then + // update the existing value if update is suggested by the comparison result + // on the logical commit times. + // Lock the key + key_lock = std::make_unique(&key_itr->second->first); + existing_value_set = &key_itr->second->second; + } // end locking map; + + for (const auto& value : input_value_set) { + auto& current_value_state = (*existing_value_set)[value]; + if (current_value_state.last_logical_commit_time >= logical_commit_time) { + // no need to update + continue; + } + // Insert new value or update existing value with + // the recent logical commit time. If the existing value was marked + // deleted, update is_deleted boolean to false + current_value_state.is_deleted = false; + current_value_state.last_logical_commit_time = logical_commit_time; + } + // end locking key +} + void KeyValueCache::DeleteKey(std::string_view key, int64_t logical_commit_time) { absl::MutexLock lock(&mutex_); + + if (logical_commit_time <= max_cleanup_logical_commit_time_) { + return; + } const auto key_iter = map_.find(key); - if (key_iter != map_.end() && - key_iter->second.last_logical_commit_time < logical_commit_time) { + if ((key_iter != map_.end() && + key_iter->second.last_logical_commit_time < logical_commit_time) || + key_iter == map_.end()) { + // If key is missing, we still need to add a null value to the map to + // avoid the late coming update with smaller logical commit time + // inserting value to the map for the given key map_.insert_or_assign( key, {.value = nullptr, .last_logical_commit_time = logical_commit_time}); @@ -95,7 +192,75 @@ void KeyValueCache::DeleteKey(std::string_view key, } } +void KeyValueCache::DeleteValuesInSet(std::string_view key, + absl::Span value_set, + int64_t logical_commit_time) { + std::unique_ptr key_lock; + absl::flat_hash_map* existing_value_set; + // The max cleanup time needs to be locked before doing this comparison + { + absl::MutexLock lock_map(&set_map_mutex_); + + if (logical_commit_time <= max_cleanup_logical_commit_time_for_set_cache_ || + value_set.empty()) { + return; + } + auto key_itr = key_to_value_set_map_.find(key); + if (key_itr == key_to_value_set_map_.end()) { + // If the key is missing, still need to add all the deleted values to the + // map to avoid late arriving update with smaller logical commit time + // inserting values same as the deleted ones for the key + auto mutex_value_map_pair = std::make_unique>>(); + + for (const auto& value : value_set) { + mutex_value_map_pair->second.emplace( + value, SetValueMeta{logical_commit_time, /*is_deleted=*/true}); + } + key_to_value_set_map_.emplace(key, std::move(mutex_value_map_pair)); + // Add to deleted set nodes + for (const std::string_view value : value_set) { + deleted_set_nodes_[logical_commit_time][key].emplace(value); + } + return; + } + // Lock the key + key_lock = std::make_unique(&key_itr->second->first); + existing_value_set = &key_itr->second->second; + } // end locking map + // Keep track of the values to be added to the deleted set nodes + std::vector values_to_delete; + for (const auto& value : value_set) { + auto& current_value_state = (*existing_value_set)[value]; + if (current_value_state.last_logical_commit_time >= logical_commit_time) { + // No need to delete + continue; + } + // Add a value that represents a deleted value, or mark the existing value + // deleted. We need to add the value in deleted state to the map to avoid + // late arriving update with smaller logical commit time + // inserting the same value + current_value_state.last_logical_commit_time = logical_commit_time; + current_value_state.is_deleted = true; + values_to_delete.push_back(value); + } + if (!values_to_delete.empty()) { + // Release key lock before locking the map to avoid potential deadlock + // caused by cycle in the ordering of lock acquisitions + key_lock.reset(); + absl::MutexLock lock_map(&set_map_mutex_); + for (const std::string_view value : values_to_delete) { + deleted_set_nodes_[logical_commit_time][key].emplace(value); + } + } +} + void KeyValueCache::RemoveDeletedKeys(int64_t logical_commit_time) { + CleanUpKeyValueMap(logical_commit_time); + CleanUpKeyValueSetMap(logical_commit_time); +} + +void KeyValueCache::CleanUpKeyValueMap(int64_t logical_commit_time) { absl::MutexLock lock(&mutex_); auto it = deleted_nodes_.begin(); @@ -113,11 +278,43 @@ void KeyValueCache::RemoveDeletedKeys(int64_t logical_commit_time) { ++it; } - + deleted_nodes_.erase(deleted_nodes_.begin(), it); max_cleanup_logical_commit_time_ = std::max(max_cleanup_logical_commit_time_, logical_commit_time); +} - deleted_nodes_.erase(deleted_nodes_.begin(), it); +void KeyValueCache::CleanUpKeyValueSetMap(int64_t logical_commit_time) { + absl::MutexLock lock_set_map(&set_map_mutex_); + auto delete_itr = deleted_set_nodes_.begin(); + while (delete_itr != deleted_set_nodes_.end()) { + if (delete_itr->first > logical_commit_time) { + break; + } + for (const auto& [key, values] : delete_itr->second) { + if (auto key_itr = key_to_value_set_map_.find(key); + key_itr != key_to_value_set_map_.end()) { + absl::MutexLock(&key_itr->second->first); + for (const auto& v_to_delete : values) { + auto existing_value_itr = key_itr->second->second.find(v_to_delete); + if (existing_value_itr != key_itr->second->second.end() && + existing_value_itr->second.is_deleted && + existing_value_itr->second.last_logical_commit_time <= + logical_commit_time) { + // Delete the existing value that is marked deleted from set + key_itr->second->second.erase(existing_value_itr); + } + } + if (key_itr->second->second.empty()) { + // If the value set is empty, erase the key-value_set from cache map + key_to_value_set_map_.erase(key); + } + } + } + ++delete_itr; + } + deleted_set_nodes_.erase(deleted_set_nodes_.begin(), delete_itr); + max_cleanup_logical_commit_time_for_set_cache_ = std::max( + max_cleanup_logical_commit_time_for_set_cache_, logical_commit_time); } std::unique_ptr KeyValueCache::Create() { diff --git a/components/data_server/cache/key_value_cache.h b/components/data_server/cache/key_value_cache.h index 96565d8a..341ef7c7 100644 --- a/components/data_server/cache/key_value_cache.h +++ b/components/data_server/cache/key_value_cache.h @@ -18,30 +18,22 @@ #define COMPONENTS_DATA_SERVER_CACHE_KEY_VALUE_CACHE_H_ #include +#include #include #include #include +#include #include #include #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "components/data_server/cache/cache.h" +#include "components/data_server/cache/get_key_value_set_result.h" #include "public/base_types.pb.h" namespace kv_server { -struct CacheValue { - // We need to be able to set the value to null. For deletion we're keeping - // the timestamp of the key (to prevent a specific type of out of order - // delete-update messages issue) until it is later cleaned up. - // We've also considered using optional, but it takes more space. - // sizeof(string) + sizeof(bool) -- for optional - // sizeof(string*) when null, sizeof(string*) + sizeof(string) otherwise - // -- for the unique pointer - std::unique_ptr value; - int64_t last_logical_commit_time; -}; - // In-memory datastore. // One cache object is only for keys in one namespace. class KeyValueCache : public Cache { @@ -50,14 +42,31 @@ class KeyValueCache : public Cache { absl::flat_hash_map GetKeyValuePairs( const std::vector& key_list) const override; + // Looks up and returns key-value set result for the given key set. + std::unique_ptr GetKeyValueSet( + const absl::flat_hash_set& key_set) const override; + // Inserts or updates the key with the new value. void UpdateKeyValue(std::string_view key, std::string_view value, int64_t logical_commit_time) override; + // Inserts or updates values in the set for a given key, if a value exists, + // updates its timestamp to the latest logical commit time. + void UpdateKeyValueSet(std::string_view key, + absl::Span input_value_set, + int64_t logical_commit_time) override; + // Deletes a particular (key, value) pair. void DeleteKey(std::string_view key, int64_t logical_commit_time) override; - // Remove the values that were deleted before the specified + // Deletes values in the set for a given key. The deletion, this object + // still exist and is marked "deleted", in case there are + // late-arriving updates to this value. + void DeleteValuesInSet(std::string_view key, + absl::Span value_set, + int64_t logical_commit_time) override; + + // Removes the values that were deleted before the specified // logical_commit_time. // TODO: b/267182790 -- Cache cleanup should be done periodically from a // background thread @@ -66,7 +75,33 @@ class KeyValueCache : public Cache { static std::unique_ptr Create(); private: + struct CacheValue { + // We need to be able to set the value to null. For deletion we're keeping + // the timestamp of the key (to prevent a specific type of out of order + // delete-update messages issue) until it is later cleaned up. + // We've also considered using optional, but it takes more space. + // sizeof(string) + sizeof(bool) -- for optional + // sizeof(string*) when null, sizeof(string*) + sizeof(string) otherwise + // -- for the unique pointer + std::unique_ptr value; + int64_t last_logical_commit_time; + }; + struct SetValueMeta { + // Last logical commit time for a value + int64_t last_logical_commit_time; + // Boolean to mark if the value should be deleted or not. + // We need this to represent its deleted state, + // because after deletion, this value should still exist in case + // there are late-arriving updates to this. + bool is_deleted; + SetValueMeta() : last_logical_commit_time(0), is_deleted(false) {} + SetValueMeta(int64_t logical_commit_time, bool deleted) + : last_logical_commit_time(logical_commit_time), is_deleted(deleted) {} + }; + // mutex for key value map; mutable absl::Mutex mutex_; + // mutex for key value set map; + mutable absl::Mutex set_map_mutex_; // Mapping from a key to its value absl::flat_hash_map map_ ABSL_GUARDED_BY(mutex_); @@ -77,6 +112,38 @@ class KeyValueCache : public Cache { // The maximum value that was passed to RemoveDeletedKeys. int64_t max_cleanup_logical_commit_time_ ABSL_GUARDED_BY(mutex_) = 0; + // The maximum value of logical commit time that is used to do update/delete + // for key-value set map. + // TODO(b/284474892) Need to evaluate if we really need to make this variable + // guarded b mutex, if not, we may want to remove it and use one + // max_cleanup_logical_commit_time in update/deletion for both maps + int64_t max_cleanup_logical_commit_time_for_set_cache_ + ABSL_GUARDED_BY(set_map_mutex_) = 0; + + // Mapping from a key to its value map. The key in the inner map is the + // value string, and value is the ValueMeta. The inner map allows value + // look up to check the meta data to determine to state of the value + // in the cache, like logical commit time and whether the value + // is deleted or not. + absl::flat_hash_map< + std::string, + std::unique_ptr>>> + key_to_value_set_map_ ABSL_GUARDED_BY(set_map_mutex_); + // Sorted mapping from logical timestamp to key-value_set map to keep track of + // deleted key-values to handle out of order update case. In the inner map, + // the key string is the key for the values, and the string + // in the flat_hash_set is the value + absl::btree_map>> + deleted_set_nodes_ ABSL_GUARDED_BY(set_map_mutex_); + + // Removes deleted keys from key-value map + void CleanUpKeyValueMap(int64_t logical_commit_time); + + // Removes deleted key-values from key-value_set map + void CleanUpKeyValueSetMap(int64_t logical_commit_time); + friend class KeyValueCacheTestPeer; }; } // namespace kv_server diff --git a/components/data_server/cache/key_value_cache_test.cc b/components/data_server/cache/key_value_cache_test.cc index 15564080..607744b3 100644 --- a/components/data_server/cache/key_value_cache_test.cc +++ b/components/data_server/cache/key_value_cache_test.cc @@ -14,14 +14,18 @@ #include "components/data_server/cache/key_value_cache.h" +#include #include #include #include +#include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/synchronization/notification.h" #include "components/data_server/cache/cache.h" +#include "components/data_server/cache/get_key_value_set_result.h" #include "components/data_server/cache/mocks.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -37,11 +41,47 @@ class KeyValueCacheTestPeer { absl::MutexLock lock(&c.mutex_); return c.deleted_nodes_; } - static absl::flat_hash_map& ReadNodes( - KeyValueCache& c) { + static absl::flat_hash_map& + ReadNodes(KeyValueCache& c) { absl::MutexLock lock(&c.mutex_); return c.map_; } + + static int GetDeletedSetNodesMapSize(const KeyValueCache& c) { + absl::MutexLock lock(&c.set_map_mutex_); + return c.deleted_set_nodes_.size(); + } + + static absl::flat_hash_set ReadDeletedSetNodesForTimestamp( + const KeyValueCache& c, int64_t logical_commit_time, + std::string_view key) { + absl::MutexLock lock(&c.set_map_mutex_); + return c.deleted_set_nodes_.find(logical_commit_time) + ->second.find(key) + ->second; + } + + static int GetCacheKeyValueSetMapSize(KeyValueCache& c) { + absl::MutexLock lock(&c.set_map_mutex_); + return c.key_to_value_set_map_.size(); + } + + static KeyValueCache::SetValueMeta GetSetValueMeta(const KeyValueCache& c, + std::string_view key, + std::string_view value) { + absl::MutexLock lock(&c.set_map_mutex_); + auto iter = c.key_to_value_set_map_.find(key); + return iter->second->second.find(value)->second; + } + static int GetSetValueSize(const KeyValueCache& c, std::string_view key) { + absl::MutexLock lock(&c.set_map_mutex_); + auto iter = c.key_to_value_set_map_.find(key); + return iter->second->second.size(); + } + + static void CallCacheCleanup(KeyValueCache& c, int64_t logical_commit_time) { + c.CleanUpKeyValueMap(logical_commit_time); + } }; namespace { @@ -113,6 +153,26 @@ TEST(CacheTest, GetForEmptyCacheReturnsEmptyList) { EXPECT_EQ(kv_pairs.size(), 0); } +TEST(CacheTest, GetForCacheReturnsValueSet) { + std::unique_ptr cache = KeyValueCache::Create(); + std::vector values = {"v1", "v2"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + absl::flat_hash_set value_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_THAT(value_set, UnorderedElementsAre("v1", "v2")); +} + +TEST(CacheTest, GetForCacheMissingKeyReturnsEmptySet) { + std::unique_ptr cache = KeyValueCache::Create(); + std::vector values = {"v1", "v2"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + auto get_key_value_set_result = + cache->GetKeyValueSet({"missing_key", "my_key"}); + EXPECT_EQ(get_key_value_set_result->GetValueSet("missing_key").size(), 0); + EXPECT_THAT(get_key_value_set_result->GetValueSet("my_key"), + UnorderedElementsAre("v1", "v2")); +} + TEST(DeleteKeyTest, RemovesKeyEntry) { std::unique_ptr cache = KeyValueCache::Create(); cache->UpdateKeyValue("my_key", "my_value", 1); @@ -123,7 +183,7 @@ TEST(DeleteKeyTest, RemovesKeyEntry) { EXPECT_EQ(kv_pairs.size(), 0); } -TEST(DeleteKeyTest, WrongkeyDoesNotRemoveEntry) { +TEST(DeleteKeyValueSetTest, WrongkeyDoesNotRemoveEntry) { std::unique_ptr cache = KeyValueCache::Create(); cache->UpdateKeyValue("my_key", "my_value", 1); cache->DeleteKey("wrong_key", 1); @@ -133,6 +193,84 @@ TEST(DeleteKeyTest, WrongkeyDoesNotRemoveEntry) { EXPECT_THAT(kv_pairs, UnorderedElementsAre(KVPairEq("my_key", "my_value"))); } +TEST(DeleteKeyValueSetTest, RemovesValueEntry) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"v1", "v2", "v3"}; + std::vector values_to_delete = {"v1", "v2"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->DeleteValuesInSet("my_key", + absl::Span(values_to_delete), 2); + absl::flat_hash_set value_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_THAT(value_set, UnorderedElementsAre("v3")); + auto value_meta_v3 = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v3"); + EXPECT_EQ(value_meta_v3.last_logical_commit_time, 1); + EXPECT_EQ(value_meta_v3.is_deleted, false); + + auto value_meta_v1_deleted = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v1"); + EXPECT_EQ(value_meta_v1_deleted.last_logical_commit_time, 2); + EXPECT_EQ(value_meta_v1_deleted.is_deleted, true); + + auto value_meta_v2_deleted = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v2"); + EXPECT_EQ(value_meta_v2_deleted.last_logical_commit_time, 2); + EXPECT_EQ(value_meta_v2_deleted.is_deleted, true); +} + +TEST(DeleteKeyValueSetTest, WrongKeyDoesNotRemoveKeyValueEntry) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"v1", "v2", "v3"}; + std::vector values_to_delete = {"v1"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->DeleteValuesInSet("wrong_key", + absl::Span(values_to_delete), 2); + + EXPECT_THAT( + cache->GetKeyValueSet({"my_key", "wrong_key"})->GetValueSet("my_key"), + UnorderedElementsAre("v1", "v2", "v3")); + EXPECT_EQ(cache->GetKeyValueSet({"my_key", "wrong_key"}) + ->GetValueSet("wrong_key") + .size(), + 0); + + auto value_meta_v1 = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v1"); + EXPECT_EQ(value_meta_v1.last_logical_commit_time, 1); + EXPECT_EQ(value_meta_v1.is_deleted, false); + + auto value_meta_v1_deleted_for_wrong_key = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "wrong_key", "v1"); + EXPECT_EQ(value_meta_v1_deleted_for_wrong_key.last_logical_commit_time, 2); + EXPECT_EQ(value_meta_v1_deleted_for_wrong_key.is_deleted, true); +} + +TEST(DeleteKeyValueSetTest, WrongValueDoesNotRemoveEntry) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"v1", "v2", "v3"}; + std::vector values_to_delete = {"v4"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->DeleteValuesInSet("my_key", + absl::Span(values_to_delete), 2); + absl::flat_hash_set value_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_THAT(value_set, UnorderedElementsAre("v1", "v2", "v3")); + auto value_meta_v1 = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v1"); + EXPECT_EQ(value_meta_v1.last_logical_commit_time, 1); + EXPECT_EQ(value_meta_v1.is_deleted, false); + + auto value_meta_v4_deleted = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v4"); + EXPECT_EQ(value_meta_v4_deleted.last_logical_commit_time, 2); + EXPECT_EQ(value_meta_v4_deleted.is_deleted, true); + + int value_set_in_cache_size = + KeyValueCacheTestPeer::GetSetValueSize(*cache, "my_key"); + EXPECT_EQ(value_set_in_cache_size, 4); +} + TEST(CacheTest, OutOfOrderUpdateAfterUpdateWorks) { std::unique_ptr cache = KeyValueCache::Create(); cache->UpdateKeyValue("my_key", "my_value", 2); @@ -157,8 +295,7 @@ TEST(DeleteKeyTest, OutOfOrderDeleteAfterUpdateWorks) { std::vector full_keys = {"my_key"}; absl::flat_hash_map kv_pairs = cache->GetKeyValuePairs(full_keys); - EXPECT_EQ(kv_pairs.size(), 1); - EXPECT_THAT(kv_pairs, UnorderedElementsAre(KVPairEq("my_key", "my_value"))); + EXPECT_EQ(kv_pairs.size(), 0); } TEST(DeleteKeyTest, OutOfOrderUpdateAfterDeleteWorks) { @@ -193,6 +330,97 @@ TEST(DeleteKeyTest, InOrderDeleteAfterUpdateWorks) { EXPECT_EQ(kv_pairs.size(), 0); } +TEST(UpateKeyValueSetTest, UpdateAfterUpdateWithSameValue) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"v1"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->UpdateKeyValueSet("my_key", absl::Span(values), 2); + absl::flat_hash_set value_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_THAT(value_set, UnorderedElementsAre("v1")); + auto value_meta = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v1"); + EXPECT_EQ(value_meta.last_logical_commit_time, 2); + EXPECT_EQ(value_meta.is_deleted, false); +} + +TEST(UpateKeyValueSetTest, UpdateAfterUpdateWithDifferentValue) { + std::unique_ptr cache = std::make_unique(); + std::vector first_value = {"v1"}; + std::vector second_value = {"v2"}; + cache->UpdateKeyValueSet("my_key", absl::Span(first_value), + 1); + cache->UpdateKeyValueSet("my_key", absl::Span(second_value), + 2); + absl::flat_hash_set value_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_THAT(value_set, UnorderedElementsAre("v1", "v2")); + auto value_meta_v1 = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v1"); + EXPECT_EQ(value_meta_v1.last_logical_commit_time, 1); + EXPECT_EQ(value_meta_v1.is_deleted, false); + auto value_meta_v2 = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v2"); + EXPECT_EQ(value_meta_v2.last_logical_commit_time, 2); + EXPECT_EQ(value_meta_v2.is_deleted, false); +} + +TEST(InOrderUpateKeyValueSetTest, InsertAfterDeleteExpectInsert) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"v1"}; + cache->DeleteValuesInSet("my_key", absl::Span(values), 1); + cache->UpdateKeyValueSet("my_key", absl::Span(values), 2); + absl::flat_hash_set value_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_THAT(value_set, UnorderedElementsAre("v1")); + auto value_meta = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v1"); + EXPECT_EQ(value_meta.last_logical_commit_time, 2); + EXPECT_EQ(value_meta.is_deleted, false); +} + +TEST(InOrderUpateKeyValueSetTest, DeleteAfterInsert) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"v1"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->DeleteValuesInSet("my_key", absl::Span(values), 2); + absl::flat_hash_set value_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_EQ(value_set.size(), 0); + auto value_meta_v1 = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v1"); + EXPECT_EQ(value_meta_v1.last_logical_commit_time, 2); + EXPECT_EQ(value_meta_v1.is_deleted, true); +} + +TEST(OutOfOrderUpateKeyValueSetTest, InsertAfterDeleteExpectNoInsert) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"v1"}; + cache->DeleteValuesInSet("my_key", absl::Span(values), 2); + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + absl::flat_hash_set value_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_EQ(value_set.size(), 0); + auto value_meta = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v1"); + EXPECT_EQ(value_meta.last_logical_commit_time, 2); + EXPECT_EQ(value_meta.is_deleted, true); +} + +TEST(OutOfOrderUpateKeyValueSetTest, DeleteAfterInsertExpectNoDelete) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"v1"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 2); + cache->DeleteValuesInSet("my_key", absl::Span(values), 1); + absl::flat_hash_set value_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_THAT(value_set, UnorderedElementsAre("v1")); + auto value_meta_v1 = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "v1"); + EXPECT_EQ(value_meta_v1.last_logical_commit_time, 2); + EXPECT_EQ(value_meta_v1.is_deleted, false); +} + TEST(CleanUpTimestamps, InsertAKeyDoesntUpdateDeletedNodes) { std::unique_ptr cache = std::make_unique(); cache->UpdateKeyValue("my_key", "my_value", 1); @@ -274,5 +502,487 @@ TEST(CleanUpTimestamps, CantInsertOldRecordsAfterCleanup) { EXPECT_EQ(kv_pairs.size(), 0); } +TEST(CleanUpTimestampsForSetCache, InsertKeyValueSetDoesntUpdateDeletedNodes) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"my_value"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + int deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 0); +} + +TEST(CleanUpTimestampsForSetCache, DeleteKeyValueSetExpectUpdateDeletedNodes) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"my_value"}; + cache->DeleteValuesInSet("my_key", absl::Span(values), 1); + int deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 1); + EXPECT_EQ(KeyValueCacheTestPeer::ReadDeletedSetNodesForTimestamp(*cache, 1, + "my_key") + .size(), + 1); +} + +TEST(CleanUpTimestampsForSetCache, RemoveDeletedKeyValuesRemovesOldRecords) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"my_value"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->DeleteValuesInSet("my_key", absl::Span(values), 2); + int deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 1); + + cache->RemoveDeletedKeys(3); + deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 0); + EXPECT_EQ(KeyValueCacheTestPeer::GetCacheKeyValueSetMapSize(*cache), 0); +} + +TEST(CleanUpTimestampsForSetCache, + RemoveDeletedKeyValuesDoesntAffectNewRecords) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"my_value"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 5); + cache->DeleteValuesInSet("my_key", absl::Span(values), 6); + + cache->RemoveDeletedKeys(2); + + int deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 1); + EXPECT_EQ(KeyValueCacheTestPeer::ReadDeletedSetNodesForTimestamp(*cache, 6, + "my_key") + .size(), + 1); +} + +TEST(CleanUpTimestampsForSetCache, + RemoveDeletedKeysRemovesOldRecordsDoesntAffectNewRecords) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"v1", "v2"}; + std::vector values_to_delete = {"v1"}; + cache->UpdateKeyValueSet("my_key1", absl::Span(values), 1); + cache->UpdateKeyValueSet("my_key2", absl::Span(values), 2); + cache->UpdateKeyValueSet("my_key3", absl::Span(values), 3); + cache->UpdateKeyValueSet("my_key4", absl::Span(values), 4); + + cache->DeleteValuesInSet("my_key3", + absl::Span(values_to_delete), 4); + cache->DeleteValuesInSet("my_key1", + absl::Span(values_to_delete), 5); + cache->DeleteValuesInSet("my_key2", + absl::Span(values_to_delete), 6); + + cache->RemoveDeletedKeys(5); + + int deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 1); + EXPECT_EQ(KeyValueCacheTestPeer::ReadDeletedSetNodesForTimestamp(*cache, 6, + "my_key2") + .size(), + 1); + auto get_value_set_result = + cache->GetKeyValueSet({"my_key1", "my_key4", "my_key3"}); + EXPECT_THAT(get_value_set_result->GetValueSet("my_key4"), + UnorderedElementsAre("v1", "v2")); + EXPECT_THAT(get_value_set_result->GetValueSet("my_key3"), + UnorderedElementsAre("v2")); + EXPECT_THAT(get_value_set_result->GetValueSet("my_key1"), + UnorderedElementsAre("v2")); +} + +TEST(CleanUpTimestampsForSetCache, CantInsertOldRecordsAfterCleanup) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"my_value"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->DeleteValuesInSet("my_key", absl::Span(values), 2); + cache->RemoveDeletedKeys(3); + + int deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 0); + EXPECT_EQ(KeyValueCacheTestPeer::GetCacheKeyValueSetMapSize(*cache), 0); + + cache->UpdateKeyValueSet("my_key", absl::Span(values), 2); + + absl::flat_hash_set kv_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_EQ(kv_set.size(), 0); +} + +TEST(CleanUpTimestampsForSetCache, CantAddOldDeletedRecordsAfterCleanup) { + std::unique_ptr cache = std::make_unique(); + std::vector values = {"my_value"}; + cache->UpdateKeyValueSet("my_key", absl::Span(values), 1); + cache->DeleteValuesInSet("my_key", absl::Span(values), 2); + cache->RemoveDeletedKeys(3); + + int deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 0); + EXPECT_EQ(KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache), 0); + + // Old delete + cache->DeleteValuesInSet("my_key", absl::Span(values), 2); + deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 0); + EXPECT_EQ(KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache), 0); + + // New delete + cache->DeleteValuesInSet("my_key", absl::Span(values), 4); + deleted_nodes_map_size = + KeyValueCacheTestPeer::GetDeletedSetNodesMapSize(*cache); + EXPECT_EQ(deleted_nodes_map_size, 1); + EXPECT_EQ(KeyValueCacheTestPeer::GetCacheKeyValueSetMapSize(*cache), 1); + auto value_meta = + KeyValueCacheTestPeer::GetSetValueMeta(*cache, "my_key", "my_value"); + EXPECT_EQ(value_meta.is_deleted, true); + EXPECT_EQ(value_meta.last_logical_commit_time, 4); + + absl::flat_hash_set kv_set = + cache->GetKeyValueSet({"my_key"})->GetValueSet("my_key"); + EXPECT_EQ(kv_set.size(), 0); +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentGetAndGet) { + auto cache = std::make_unique(); + absl::flat_hash_set keys_lookup_request = {"key1", "key2"}; + std::vector values_for_key1 = {"v1"}; + std::vector values_for_key2 = {"v2"}; + cache->UpdateKeyValueSet("key1", + absl::Span(values_for_key1), 1); + cache->UpdateKeyValueSet("key2", + absl::Span(values_for_key2), 1); + absl::Notification start; + auto lookup_fn = [&cache, &keys_lookup_request, &start]() { + start.WaitForNotification(); + auto result = cache->GetKeyValueSet(keys_lookup_request); + EXPECT_THAT(result->GetValueSet("key1"), UnorderedElementsAre("v1")); + EXPECT_THAT(result->GetValueSet("key2"), UnorderedElementsAre("v2")); + }; + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + ++i) { + threads.emplace_back(lookup_fn); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentGetAndUpdateExpectNoUpdate) { + auto cache = std::make_unique(); + absl::flat_hash_set keys = {"key1"}; + std::vector existing_values = {"v1"}; + cache->UpdateKeyValueSet("key1", + absl::Span(existing_values), 3); + absl::Notification start; + auto lookup_fn = [&cache, &keys, &start]() { + start.WaitForNotification(); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key1"), + UnorderedElementsAre("v1")); + }; + std::vector new_values = {"v1"}; + auto update_fn = [&cache, &new_values, &start]() { + start.WaitForNotification(); + cache->UpdateKeyValueSet("key1", absl::Span(new_values), + 1); + }; + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + i++) { + threads.emplace_back(lookup_fn); + threads.emplace_back(update_fn); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentGetAndUpdateExpectUpdate) { + auto cache = std::make_unique(); + absl::flat_hash_set keys = {"key1", "key2"}; + std::vector existing_values = {"v1"}; + cache->UpdateKeyValueSet("key1", + absl::Span(existing_values), 1); + absl::Notification start; + auto lookup_fn = [&cache, &keys, &start]() { + start.WaitForNotification(); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key1"), + UnorderedElementsAre("v1")); + }; + std::vector new_values_for_key2 = {"v2"}; + auto update_fn = [&cache, &new_values_for_key2, &start]() { + // expect new value is inserted for key2 + start.WaitForNotification(); + cache->UpdateKeyValueSet( + "key2", absl::Span(new_values_for_key2), 2); + }; + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + i++) { + threads.emplace_back(lookup_fn); + threads.emplace_back(update_fn); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentGetAndDeleteExpectNoDelete) { + auto cache = std::make_unique(); + absl::flat_hash_set keys = {"key1"}; + std::vector existing_values = {"v1"}; + cache->UpdateKeyValueSet("key1", + absl::Span(existing_values), 3); + absl::Notification start; + auto lookup_fn = [&cache, &keys, &start]() { + start.WaitForNotification(); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key1"), + UnorderedElementsAre("v1")); + }; + std::vector delete_values = {"v1"}; + auto delete_fn = [&cache, &delete_values, &start]() { + // expect no delete + start.WaitForNotification(); + cache->DeleteValuesInSet("key1", + absl::Span(delete_values), 1); + }; + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + i++) { + threads.emplace_back(lookup_fn); + threads.emplace_back(delete_fn); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentGetAndCleanUp) { + auto cache = std::make_unique(); + absl::flat_hash_set keys = {"key1", "key2"}; + std::vector existing_values = {"v1"}; + cache->UpdateKeyValueSet("key1", + absl::Span(existing_values), 3); + cache->UpdateKeyValueSet("key2", + absl::Span(existing_values), 1); + cache->DeleteValuesInSet("key2", + absl::Span(existing_values), 2); + absl::Notification start; + auto lookup_fn = [&cache, &keys, &start]() { + start.WaitForNotification(); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key1"), + UnorderedElementsAre("v1")); + EXPECT_EQ(cache->GetKeyValueSet(keys)->GetValueSet("key2").size(), 0); + }; + auto cleanup_fn = [&cache, &start]() { + // clean up old records + start.WaitForNotification(); + KeyValueCacheTestPeer::CallCacheCleanup(*cache, 3); + }; + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + i++) { + threads.emplace_back(lookup_fn); + threads.emplace_back(cleanup_fn); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentUpdateAndUpdateExpectUpdateBoth) { + auto cache = std::make_unique(); + absl::flat_hash_set keys = {"key1", "key2"}; + std::vector values_for_key1 = {"v1"}; + absl::Notification start; + auto update_key1 = [&cache, &keys, &values_for_key1, &start]() { + start.WaitForNotification(); + // expect new value is inserted for key1 + cache->UpdateKeyValueSet("key1", + absl::Span(values_for_key1), 1); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key1"), + UnorderedElementsAre("v1")); + }; + std::vector values_for_key2 = {"v2"}; + auto update_key2 = [&cache, &keys, &values_for_key2, &start]() { + // expect new value is inserted for key2 + start.WaitForNotification(); + cache->UpdateKeyValueSet("key2", + absl::Span(values_for_key2), 2); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key2"), + UnorderedElementsAre("v2")); + }; + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + i++) { + threads.emplace_back(update_key1); + threads.emplace_back(update_key2); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentUpdateAndDelete) { + auto cache = std::make_unique(); + absl::flat_hash_set keys = {"key1", "key2"}; + std::vector values_for_key1 = {"v1"}; + absl::Notification start; + auto update_key1 = [&cache, &keys, &values_for_key1, &start]() { + start.WaitForNotification(); + // expect new value is inserted for key1 + cache->UpdateKeyValueSet("key1", + absl::Span(values_for_key1), 1); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key1"), + UnorderedElementsAre("v1")); + }; + // Update existing value for key2 + std::vector existing_values_for_key2 = {"v1", "v2"}; + cache->UpdateKeyValueSet( + "key2", absl::Span(existing_values_for_key2), 1); + std::vector values_to_delete_for_key2 = {"v1"}; + + auto delete_key2 = [&cache, &keys, &values_to_delete_for_key2, &start]() { + start.WaitForNotification(); + // expect value is deleted for key2 + cache->DeleteValuesInSet( + "key2", absl::Span(values_to_delete_for_key2), 2); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key2"), + UnorderedElementsAre("v2")); + }; + + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + i++) { + threads.emplace_back(update_key1); + threads.emplace_back(delete_key2); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentUpdateAndCleanUp) { + auto cache = std::make_unique(); + absl::flat_hash_set keys = {"key1"}; + std::vector values_for_key1 = {"v1"}; + absl::Notification start; + auto update_fn = [&cache, &keys, &values_for_key1, &start]() { + start.WaitForNotification(); + cache->UpdateKeyValueSet("key1", + absl::Span(values_for_key1), 1); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key1"), + UnorderedElementsAre("v1")); + }; + auto cleanup_fn = [&cache, &start]() { + start.WaitForNotification(); + KeyValueCacheTestPeer::CallCacheCleanup(*cache, 2); + }; + + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + i++) { + threads.emplace_back(update_fn); + threads.emplace_back(cleanup_fn); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentDeleteAndCleanUp) { + auto cache = std::make_unique(); + absl::flat_hash_set keys = {"key1"}; + std::vector values_for_key1 = {"v1"}; + cache->UpdateKeyValueSet("key1", + absl::Span(values_for_key1), 1); + absl::Notification start; + auto delete_fn = [&cache, &keys, &values_for_key1, &start]() { + start.WaitForNotification(); + // expect new value is deleted for key1 + cache->DeleteValuesInSet("key1", + absl::Span(values_for_key1), 2); + EXPECT_EQ(cache->GetKeyValueSet(keys)->GetValueSet("key1").size(), 0); + }; + auto cleanup_fn = [&cache, &start]() { + start.WaitForNotification(); + KeyValueCacheTestPeer::CallCacheCleanup(*cache, 2); + }; + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + i++) { + threads.emplace_back(delete_fn); + threads.emplace_back(cleanup_fn); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(ConcurrentSetMemoryAccessTest, ConcurrentGetUpdateDeleteCleanUp) { + auto cache = std::make_unique(); + absl::flat_hash_set keys = {"key1", "key2"}; + std::vector existing_values_for_key1 = {"v1"}; + std::vector existing_values_for_key2 = {"v1"}; + cache->UpdateKeyValueSet( + "key1", absl::Span(existing_values_for_key1), 1); + cache->UpdateKeyValueSet( + "key2", absl::Span(existing_values_for_key2), 1); + + std::vector values_to_insert_for_key2 = {"v2"}; + std::vector values_to_delete_for_key2 = {"v1"}; + absl::Notification start; + auto insert_for_key2 = [&cache, &values_to_insert_for_key2, &start]() { + start.WaitForNotification(); + cache->UpdateKeyValueSet( + "key2", absl::Span(values_to_insert_for_key2), 2); + }; + auto delete_for_key2 = [&cache, &values_to_delete_for_key2, &start]() { + start.WaitForNotification(); + cache->DeleteValuesInSet( + "key2", absl::Span(values_to_delete_for_key2), 2); + }; + auto cleanup = [&cache, &start]() { + start.WaitForNotification(); + KeyValueCacheTestPeer::CallCacheCleanup(*cache, 2); + }; + + auto lookup_for_key1 = [&cache, &keys, &start]() { + start.WaitForNotification(); + EXPECT_THAT(cache->GetKeyValueSet(keys)->GetValueSet("key1"), + UnorderedElementsAre("v1")); + }; + + std::vector threads; + for (int i = 0; i < std::min(20, (int)std::thread::hardware_concurrency()); + i++) { + threads.emplace_back(insert_for_key2); + threads.emplace_back(cleanup); + threads.emplace_back(delete_for_key2); + threads.emplace_back(lookup_for_key1); + } + start.Notify(); + for (auto& thread : threads) { + thread.join(); + } + auto look_up_result_for_key2 = + cache->GetKeyValueSet(keys)->GetValueSet("key2"); + EXPECT_THAT(look_up_result_for_key2, UnorderedElementsAre("v2")); +} } // namespace } // namespace kv_server diff --git a/components/data_server/cache/mocks.h b/components/data_server/cache/mocks.h index 8250cb0d..f1e07f76 100644 --- a/components/data_server/cache/mocks.h +++ b/components/data_server/cache/mocks.h @@ -15,6 +15,7 @@ #ifndef COMPONENTS_DATA_SERVER_CACHE_MOCKS_H_ #define COMPONENTS_DATA_SERVER_CACHE_MOCKS_H_ +#include #include #include #include @@ -34,13 +35,34 @@ class MockCache : public Cache { MOCK_METHOD((absl::flat_hash_map), GetKeyValuePairs, (const std::vector& key_list), (const, override)); + MOCK_METHOD((std::unique_ptr), GetKeyValueSet, + (const absl::flat_hash_set&), + (const, override)); MOCK_METHOD(void, UpdateKeyValue, (std::string_view key, std::string_view value, int64_t ts), (override)); + MOCK_METHOD(void, UpdateKeyValueSet, + (std::string_view key, absl::Span value_set, + int64_t logical_commit_time), + (override)); + MOCK_METHOD(void, DeleteValuesInSet, + (std::string_view key, absl::Span value_set, + int64_t logical_commit_time), + (override)); MOCK_METHOD(void, DeleteKey, (std::string_view key, int64_t ts), (override)); MOCK_METHOD(void, RemoveDeletedKeys, (int64_t ts), (override)); }; +class MockGetKeyValueSetResult : public GetKeyValueSetResult { + public: + MOCK_METHOD((absl::flat_hash_set), GetValueSet, + (std::string_view), (const, override)); + MOCK_METHOD(void, AddKeyValueSet, + (absl::Mutex & key_mutex, std::string_view key, + const absl::flat_hash_set& value_set), + (override)); +}; + } // namespace kv_server #endif // COMPONENTS_DATA_SERVER_CACHE_MOCKS_H_ diff --git a/components/data_server/data_loading/BUILD b/components/data_server/data_loading/BUILD index bd081e55..be6989b0 100644 --- a/components/data_server/data_loading/BUILD +++ b/components/data_server/data_loading/BUILD @@ -34,17 +34,19 @@ cc_library( "//components/data/realtime:realtime_notifier", "//components/data_server/cache", "//components/errors:retry", + "//components/udf:udf_client", "//public:constants", "//public/data_loading:data_loading_fbs", "//public/data_loading:filename_utils", + "//public/data_loading:records_utils", "//public/data_loading/readers:riegeli_stream_io", + "//public/sharding:sharding_function", "@com_github_google_glog//:glog", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@distributed_point_functions//pir/hashing:sha256_hash_family", "@google_privacysandbox_servers_common//src/cpp/telemetry:metrics_recorder", "@google_privacysandbox_servers_common//src/cpp/telemetry:tracing", ], @@ -60,12 +62,13 @@ cc_test( ":data_orchestrator", "//components/data/common:mocks", "//components/data_server/cache:mocks", + "//components/udf:code_config", + "//components/udf:mocks", "//public/data_loading:filename_utils", "//public/data_loading:records_utils", "//public/test_util:mocks", "//public/test_util:proto_matcher", "@com_github_google_glog//:glog", - "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@google_privacysandbox_servers_common//src/cpp/telemetry:mocks", diff --git a/components/data_server/data_loading/data_orchestrator.cc b/components/data_server/data_loading/data_orchestrator.cc index c356e5d7..ba68aa52 100644 --- a/components/data_server/data_loading/data_orchestrator.cc +++ b/components/data_server/data_loading/data_orchestrator.cc @@ -23,10 +23,11 @@ #include "absl/strings/str_cat.h" #include "components/errors/retry.h" #include "glog/logging.h" -#include "pir/hashing/sha256_hash_family.h" #include "public/constants.h" #include "public/data_loading/data_loading_generated.h" #include "public/data_loading/filename_utils.h" +#include "public/data_loading/records_utils.h" +#include "public/sharding/sharding_function.h" #include "src/cpp/telemetry/tracing.h" namespace kv_server { @@ -35,7 +36,7 @@ namespace { using privacy_sandbox::server_common::MetricsRecorder; using privacy_sandbox::server_common::TraceWithStatusOr; -constexpr char* kTotalRowsDroppedIncorrectShardNumber = +constexpr char kTotalRowsDroppedIncorrectShardNumber[] = "kTotalRowsDroppedIncorrectShardNumber"; // Holds an input stream pointing to a blob of Riegeli records. @@ -49,76 +50,131 @@ class BlobRecordStream : public RecordStream { std::unique_ptr blob_reader_; }; +absl::Status ApplyUpdateMutation(const KeyValueMutationRecord& record, + Cache& cache) { + if (record.value_type() == Value::String) { + cache.UpdateKeyValue(record.key()->string_view(), + GetRecordValue(record), + record.logical_commit_time()); + return absl::OkStatus(); + } + if (record.value_type() == Value::StringSet) { + auto values = GetRecordValue>(record); + cache.UpdateKeyValueSet(record.key()->string_view(), absl::MakeSpan(values), + record.logical_commit_time()); + return absl::OkStatus(); + } + return absl::InvalidArgumentError( + absl::StrCat("Record with key: ", record.key()->string_view(), + " has unsupported value type: ", record.value_type())); +} + +absl::Status ApplyDeleteMutation(const KeyValueMutationRecord& record, + Cache& cache) { + if (record.value_type() == Value::String) { + cache.DeleteKey(record.key()->string_view(), record.logical_commit_time()); + return absl::OkStatus(); + } + if (record.value_type() == Value::StringSet) { + auto values = GetRecordValue>(record); + cache.DeleteValuesInSet(record.key()->string_view(), absl::MakeSpan(values), + record.logical_commit_time()); + return absl::OkStatus(); + } + return absl::InvalidArgumentError( + absl::StrCat("Record with key: ", record.key()->string_view(), + " has unsupported value type: ", record.value_type())); +} + +bool ShouldProcessRecord(const KeyValueMutationRecord& record, + int64_t num_shards, int64_t server_shard_num, + MetricsRecorder& metrics_recorder) { + if (num_shards <= 1) { + return true; + } + auto shard_num = + ShardingFunction(/*seed=*/"") + .GetShardNumForKey(record.key()->string_view(), num_shards); + if (shard_num == server_shard_num) { + return true; + } + metrics_recorder.IncrementEventCounter(kTotalRowsDroppedIncorrectShardNumber); + auto error_message = absl::StrFormat( + "Data does not belong to this shard replica. Key: %s, Actual " + "shard id: %d, Server's shard id: %d.", + record.key()->string_view(), shard_num, server_shard_num); + LOG(ERROR) << error_message; + return false; +} + +absl::Status ApplyKeyValueMutationToCache( + const KeyValueMutationRecord& record, Cache& cache, int64_t& max_timestamp, + DataLoadingStats& data_loading_stats) { + switch (record.mutation_type()) { + case KeyValueMutationType::Update: { + if (auto status = ApplyUpdateMutation(record, cache); !status.ok()) { + return status; + } + max_timestamp = std::max(max_timestamp, record.logical_commit_time()); + data_loading_stats.total_updated_records++; + break; + } + case KeyValueMutationType::Delete: { + if (auto status = ApplyDeleteMutation(record, cache); !status.ok()) { + return status; + } + max_timestamp = std::max(max_timestamp, record.logical_commit_time()); + data_loading_stats.total_deleted_records++; + break; + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Invalid mutation type: ", + EnumNameKeyValueMutationType(record.mutation_type()))); + } + return absl::OkStatus(); +} + absl::StatusOr LoadCacheWithData( StreamRecordReader& record_reader, Cache& cache, int64_t& max_timestamp, const int32_t server_shard_num, - const int32_t num_shards, MetricsRecorder& metrics_recorder) { + const int32_t num_shards, MetricsRecorder& metrics_recorder, + UdfClient& udf_client) { DataLoadingStats data_loading_stats; - // TODO: propagate this from terraform parameters - std::string hashing_seed = ""; - auto hash_function = - distributed_point_functions::SHA256HashFunction(hashing_seed); - - auto status = record_reader.ReadStreamRecords( + const auto process_data_record_fn = [&cache, &max_timestamp, &data_loading_stats, server_shard_num, - num_shards, &hash_function, &metrics_recorder](std::string_view raw) { - auto record = flatbuffers::GetRoot(raw.data()); - auto recordVerifier = flatbuffers::Verifier( - reinterpret_cast(raw.data()), raw.size()); - if (!record->Verify(recordVerifier)) { - // TODO(b/239061954): Publish metrics for alerting - return absl::InvalidArgumentError("Invalid flatbuffer format"); - } - - if (num_shards > 1) { - int32_t shard_num = - hash_function(record->key()->string_view(), num_shards); - - if (shard_num != server_shard_num) { - metrics_recorder.IncrementEventCounter( - kTotalRowsDroppedIncorrectShardNumber); - - auto error_message = absl::StrFormat( - "Data does not belong to this shard replica. Key: %s, Actual " - "shard id: %d, Server's shard id: %d.", - record->key()->string_view(), shard_num, server_shard_num); - LOG(ERROR) << error_message; + num_shards, &metrics_recorder, + &udf_client](const DataRecord& data_record) { + if (data_record.record_type() == Record::KeyValueMutationRecord) { + const auto* record = data_record.record_as_KeyValueMutationRecord(); + if (!ShouldProcessRecord(*record, num_shards, server_shard_num, + metrics_recorder)) { // NOTE: currently upstream logic retries on non-ok status // this will get us in a loop return absl::OkStatus(); } + return ApplyKeyValueMutationToCache(*record, cache, max_timestamp, + data_loading_stats); + } else if (data_record.record_type() == + Record::UserDefinedFunctionsConfig) { + const auto* udf_config = + data_record.record_as_UserDefinedFunctionsConfig(); + return udf_client.SetCodeObject(CodeConfig{ + .js = udf_config->code_snippet()->str(), + .udf_handler_name = udf_config->handler_name()->str(), + .logical_commit_time = udf_config->logical_commit_time()}); } + LOG(ERROR) << "Received unsupported record "; + return absl::InvalidArgumentError("Record type not supported."); + }; - switch (record->mutation_type()) { - case DeltaMutationType::Update: { - cache.UpdateKeyValue(record->key()->string_view(), - record->value()->string_view(), - record->logical_commit_time()); - max_timestamp = - std::max(max_timestamp, record->logical_commit_time()); - data_loading_stats.total_updated_records++; - break; - } - case DeltaMutationType::Delete: { - cache.DeleteKey(record->key()->string_view(), - record->logical_commit_time()); - max_timestamp = - std::max(max_timestamp, record->logical_commit_time()); - data_loading_stats.total_deleted_records++; - break; - } - default: - return absl::InvalidArgumentError(absl::StrCat( - "Invalid mutation type: ", - EnumNameDeltaMutationType(record->mutation_type()))); - } - return absl::OkStatus(); + auto status = record_reader.ReadStreamRecords( + [&process_data_record_fn](std::string_view raw) { + return DeserializeDataRecord(raw, process_data_record_fn); }); - if (!status.ok()) { return status; } - return data_loading_stats; } @@ -137,9 +193,9 @@ absl::StatusOr LoadCacheWithDataFromFile( return std::make_unique( options.blob_client.GetBlobReader(location)); }); - auto status = - LoadCacheWithData(*record_reader, cache, max_timestamp, options.shard_num, - options.num_shards, metrics_recorder); + auto status = LoadCacheWithData(*record_reader, cache, max_timestamp, + options.shard_num, options.num_shards, + metrics_recorder, options.udf_client); if (status.ok()) { cache.RemoveDeletedKeys(max_timestamp); } @@ -160,8 +216,8 @@ absl::StatusOr TraceLoadCacheWithDataFromFile( class DataOrchestratorImpl : public DataOrchestrator { public: - // `last_basename` is the last file seen during init. The cache is up to date - // until this file. + // `last_basename` is the last file seen during init. The cache is up to + // date until this file. DataOrchestratorImpl(Options options, std::string last_basename, MetricsRecorder& metrics_recorder) : options_(std::move(options)), @@ -279,8 +335,8 @@ class DataOrchestratorImpl : public DataOrchestrator { // Reads new files, if any, from the `unprocessed_basenames_` queue and // processes them one by one. // - // On failure, puts the file back to the end of the queue and retry at a later - // point. + // On failure, puts the file back to the end of the queue and retry at a + // later point. void ProcessNewFiles() { LOG(INFO) << "Thread for new file processing started"; absl::Condition has_new_event(this, @@ -381,14 +437,14 @@ class DataOrchestratorImpl : public DataOrchestrator { auto record_reader = delta_stream_reader_factory.CreateReader(is); return LoadCacheWithData(*record_reader, cache, max_timestamp, options_.shard_num, options_.num_shards, - metrics_recorder_); + metrics_recorder_, options_.udf_client); } const Options options_; absl::Mutex mu_; - std::deque unprocessed_basenames_ GUARDED_BY(mu_); + std::deque unprocessed_basenames_ ABSL_GUARDED_BY(mu_); std::unique_ptr data_loader_thread_; - bool stop_ GUARDED_BY(mu_) = false; + bool stop_ ABSL_GUARDED_BY(mu_) = false; // last basename of file in initialization. const std::string last_basename_of_init_; MetricsRecorder& metrics_recorder_; diff --git a/components/data_server/data_loading/data_orchestrator.h b/components/data_server/data_loading/data_orchestrator.h index e15103e2..b904343b 100644 --- a/components/data_server/data_loading/data_orchestrator.h +++ b/components/data_server/data_loading/data_orchestrator.h @@ -28,6 +28,7 @@ #include "components/data/realtime/delta_file_record_change_notifier.h" #include "components/data/realtime/realtime_notifier.h" #include "components/data_server/cache/cache.h" +#include "components/udf/udf_client.h" #include "public/data_loading/readers/riegeli_stream_io.h" #include "src/cpp/telemetry/metrics_recorder.h" @@ -53,9 +54,9 @@ class DataOrchestrator { BlobStorageClient& blob_client; DeltaFileNotifier& delta_notifier; BlobStorageChangeNotifier& change_notifier; + UdfClient& udf_client; StreamRecordReaderFactory& delta_stream_reader_factory; std::vector& realtime_options; - const absl::AnyInvocable& udf_update_callback = []() {}; const int32_t shard_num = 0; const int32_t num_shards = 1; }; diff --git a/components/data_server/data_loading/data_orchestrator_test.cc b/components/data_server/data_loading/data_orchestrator_test.cc index ae646e91..31973d34 100644 --- a/components/data_server/data_loading/data_orchestrator_test.cc +++ b/components/data_server/data_loading/data_orchestrator_test.cc @@ -22,6 +22,8 @@ #include "components/data/common/mocks.h" #include "components/data_server/cache/cache.h" #include "components/data_server/cache/mocks.h" +#include "components/udf/code_config.h" +#include "components/udf/mocks.h" #include "glog/logging.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" @@ -35,11 +37,13 @@ using kv_server::BlobStorageChangeNotifier; using kv_server::BlobStorageClient; +using kv_server::CodeConfig; using kv_server::DataOrchestrator; -using kv_server::DeltaFileRecordStruct; -using kv_server::DeltaMutationType; +using kv_server::DataRecordStruct; using kv_server::FilePrefix; using kv_server::FileType; +using kv_server::KeyValueMutationRecordStruct; +using kv_server::KeyValueMutationType; using kv_server::KVFileMetadata; using kv_server::MockBlobReader; using kv_server::MockBlobStorageChangeNotifier; @@ -50,9 +54,13 @@ using kv_server::MockDeltaFileRecordChangeNotifier; using kv_server::MockRealtimeNotifier; using kv_server::MockStreamRecordReader; using kv_server::MockStreamRecordReaderFactory; +using kv_server::MockUdfClient; using kv_server::ToDeltaFileName; +using kv_server::ToFlatBufferBuilder; using kv_server::ToSnapshotFileName; using kv_server::ToStringView; +using kv_server::UserDefinedFunctionsConfigStruct; +using kv_server::UserDefinedFunctionsLanguage; using privacy_sandbox::server_common::MockMetricsRecorder; using testing::_; using testing::AllOf; @@ -80,12 +88,14 @@ class DataOrchestratorTest : public ::testing::Test { .blob_client = blob_client_, .delta_notifier = notifier_, .change_notifier = change_notifier_, + .udf_client = udf_client_, .delta_stream_reader_factory = delta_stream_reader_factory_, .realtime_options = realtime_options_}) {} MockBlobStorageClient blob_client_; MockDeltaFileNotifier notifier_; MockBlobStorageChangeNotifier change_notifier_; + MockUdfClient udf_client_; MockStreamRecordReaderFactory delta_stream_reader_factory_; MockCache cache_; std::vector realtime_options_; @@ -233,10 +243,12 @@ TEST_F(DataOrchestratorTest, InitCacheSuccess) { .Times(1) .WillOnce( [](const std::function& callback) { - const auto fb = DeltaFileRecordStruct{DeltaMutationType::Update, 3, - "bar", "bar value"} - .ToFlatBuffer(); - callback(ToStringView(fb)).IgnoreError(); + callback(ToStringView(ToFlatBufferBuilder( + DataRecordStruct{.record = + KeyValueMutationRecordStruct{ + KeyValueMutationType::Update, + 3, "bar", "bar value"}}))) + .IgnoreError(); return absl::OkStatus(); }); auto delete_reader = std::make_unique(); @@ -244,10 +256,12 @@ TEST_F(DataOrchestratorTest, InitCacheSuccess) { .Times(1) .WillOnce( [](const std::function& callback) { - const auto fb = DeltaFileRecordStruct{DeltaMutationType::Delete, 3, - "bar", "bar value"} - .ToFlatBuffer(); - callback(ToStringView(fb)).IgnoreError(); + callback(ToStringView(ToFlatBufferBuilder( + DataRecordStruct{.record = + KeyValueMutationRecordStruct{ + KeyValueMutationType::Delete, + 3, "bar", "bar value"}}))) + .IgnoreError(); return absl::OkStatus(); }); EXPECT_CALL(delta_stream_reader_factory_, CreateConcurrentReader) @@ -269,6 +283,108 @@ TEST_F(DataOrchestratorTest, InitCacheSuccess) { EXPECT_FALSE((*maybe_orchestrator)->Start().ok()); } +TEST_F(DataOrchestratorTest, UpdateUdfCodeSuccess) { + const std::vector fnames({ToDeltaFileName(1).value()}); + EXPECT_CALL( + blob_client_, + ListBlobs(GetTestLocation(), + AllOf(Field(&BlobStorageClient::ListOptions::start_after, ""), + Field(&BlobStorageClient::ListOptions::prefix, + FilePrefix())))) + .WillOnce(Return(std::vector())); + EXPECT_CALL( + blob_client_, + ListBlobs(GetTestLocation(), + AllOf(Field(&BlobStorageClient::ListOptions::start_after, ""), + Field(&BlobStorageClient::ListOptions::prefix, + FilePrefix())))) + .WillOnce(Return(fnames)); + + KVFileMetadata metadata; + auto reader = std::make_unique(); + EXPECT_CALL(*reader, ReadStreamRecords) + .WillOnce( + [](const std::function& callback) { + callback(ToStringView(ToFlatBufferBuilder(DataRecordStruct{ + .record = + UserDefinedFunctionsConfigStruct{ + .code_snippet = "function hello(){}", + .handler_name = "hello", + .language = + UserDefinedFunctionsLanguage::Javascript, + .logical_commit_time = 1}}))) + .IgnoreError(); + return absl::OkStatus(); + }); + auto delete_reader = std::make_unique(); + EXPECT_CALL(delta_stream_reader_factory_, CreateConcurrentReader) + .WillOnce(Return(ByMove(std::move(reader)))); + + EXPECT_CALL(udf_client_, SetCodeObject(CodeConfig{.js = "function hello(){}", + .udf_handler_name = "hello", + .logical_commit_time = 1})) + .WillOnce(Return(absl::OkStatus())); + auto maybe_orchestrator = + DataOrchestrator::TryCreate(options_, metrics_recorder_); + ASSERT_TRUE(maybe_orchestrator.ok()); + + const std::string last_basename = ToDeltaFileName(1).value(); + EXPECT_CALL(notifier_, Start(_, GetTestLocation(), last_basename, _)) + .WillOnce(Return(absl::UnknownError(""))); + EXPECT_FALSE((*maybe_orchestrator)->Start().ok()); +} + +TEST_F(DataOrchestratorTest, UpdateUdfCodeFails_OrchestratorContinues) { + const std::vector fnames({ToDeltaFileName(1).value()}); + EXPECT_CALL( + blob_client_, + ListBlobs(GetTestLocation(), + AllOf(Field(&BlobStorageClient::ListOptions::start_after, ""), + Field(&BlobStorageClient::ListOptions::prefix, + FilePrefix())))) + .WillOnce(Return(std::vector())); + EXPECT_CALL( + blob_client_, + ListBlobs(GetTestLocation(), + AllOf(Field(&BlobStorageClient::ListOptions::start_after, ""), + Field(&BlobStorageClient::ListOptions::prefix, + FilePrefix())))) + .WillOnce(Return(fnames)); + + KVFileMetadata metadata; + auto reader = std::make_unique(); + EXPECT_CALL(*reader, ReadStreamRecords) + .WillOnce( + [](const std::function& callback) { + callback(ToStringView(ToFlatBufferBuilder(DataRecordStruct{ + .record = + UserDefinedFunctionsConfigStruct{ + .code_snippet = "function hello(){}", + .handler_name = "hello", + .language = + UserDefinedFunctionsLanguage::Javascript, + .logical_commit_time = 1}}))) + .IgnoreError(); + return absl::OkStatus(); + }); + auto delete_reader = std::make_unique(); + EXPECT_CALL(delta_stream_reader_factory_, CreateConcurrentReader) + .WillOnce(Return(ByMove(std::move(reader)))); + + EXPECT_CALL(udf_client_, SetCodeObject(CodeConfig{.js = "function hello(){}", + .udf_handler_name = "hello", + .logical_commit_time = 1})) + .WillOnce(Return(absl::UnknownError("Some error."))); + auto maybe_orchestrator = + DataOrchestrator::TryCreate(options_, metrics_recorder_); + ASSERT_TRUE(maybe_orchestrator.ok()); + + const std::string last_basename = ToDeltaFileName(1).value(); + EXPECT_CALL(notifier_, Start(_, GetTestLocation(), last_basename, _)) + .WillOnce(Return(absl::UnknownError(""))); + EXPECT_FALSE((*maybe_orchestrator)->Start().ok()); +} + TEST_F(DataOrchestratorTest, StartLoading) { ON_CALL(blob_client_, ListBlobs) .WillByDefault(Return(std::vector({}))); @@ -299,10 +415,12 @@ TEST_F(DataOrchestratorTest, StartLoading) { .Times(1) .WillOnce( [](const std::function& callback) { - const auto fb = DeltaFileRecordStruct{DeltaMutationType::Update, 3, - "bar", "bar value"} - .ToFlatBuffer(); - callback(ToStringView(fb)).IgnoreError(); + callback(ToStringView(ToFlatBufferBuilder( + DataRecordStruct{.record = + KeyValueMutationRecordStruct{ + KeyValueMutationType::Update, + 3, "bar", "bar value"}}))) + .IgnoreError(); return absl::OkStatus(); }); auto delete_reader = std::make_unique(); @@ -311,10 +429,12 @@ TEST_F(DataOrchestratorTest, StartLoading) { .WillOnce( [&all_records_loaded]( const std::function& callback) { - const auto fb = DeltaFileRecordStruct{DeltaMutationType::Delete, 3, - "bar", "bar value"} - .ToFlatBuffer(); - callback(ToStringView(fb)).IgnoreError(); + callback(ToStringView(ToFlatBufferBuilder( + DataRecordStruct{.record = + KeyValueMutationRecordStruct{ + KeyValueMutationType::Delete, + 3, "bar", "bar value"}}))) + .IgnoreError(); all_records_loaded.Notify(); return absl::OkStatus(); }); @@ -367,11 +487,13 @@ TEST_F(DataOrchestratorTest, InitCacheShardedSuccessSkipRecord) { .Times(1) .WillOnce( [](const std::function& callback) { - // key: "shard2" -> shard num: 0 - const auto fb = DeltaFileRecordStruct{DeltaMutationType::Update, 3, - "shard1", "bar value"} - .ToFlatBuffer(); - callback(ToStringView(fb)).IgnoreError(); + // key: "shard1" -> shard num: 0 + callback(ToStringView(ToFlatBufferBuilder( + DataRecordStruct{.record = + KeyValueMutationRecordStruct{ + KeyValueMutationType::Update, + 3, "shard1", "bar value"}}))) + .IgnoreError(); return absl::OkStatus(); }); auto delete_reader = std::make_unique(); @@ -380,10 +502,12 @@ TEST_F(DataOrchestratorTest, InitCacheShardedSuccessSkipRecord) { .WillOnce( [](const std::function& callback) { // key: "shard2" -> shard num: 1 - const auto fb = DeltaFileRecordStruct{DeltaMutationType::Delete, 3, - "shard2", "bar value"} - .ToFlatBuffer(); - callback(ToStringView(fb)).IgnoreError(); + callback(ToStringView(ToFlatBufferBuilder( + DataRecordStruct{.record = + KeyValueMutationRecordStruct{ + KeyValueMutationType::Delete, + 3, "shard2", "bar value"}}))) + .IgnoreError(); return absl::OkStatus(); }); EXPECT_CALL(delta_stream_reader_factory_, CreateConcurrentReader) @@ -391,6 +515,7 @@ TEST_F(DataOrchestratorTest, InitCacheShardedSuccessSkipRecord) { .WillOnce(Return(ByMove(std::move(update_reader)))) .WillOnce(Return(ByMove(std::move(delete_reader)))); + EXPECT_CALL(metrics_recorder_, IncrementEventCounter).Times(1); EXPECT_CALL(strict_cache, RemoveDeletedKeys(0)).Times(1); EXPECT_CALL(strict_cache, DeleteKey("shard2", 3)).Times(1); EXPECT_CALL(strict_cache, RemoveDeletedKeys(3)).Times(1); @@ -401,10 +526,12 @@ TEST_F(DataOrchestratorTest, InitCacheShardedSuccessSkipRecord) { .blob_client = blob_client_, .delta_notifier = notifier_, .change_notifier = change_notifier_, + .udf_client = udf_client_, .delta_stream_reader_factory = delta_stream_reader_factory_, .realtime_options = realtime_options_, + .shard_num = 1, .num_shards = 2, - .shard_num = 1}; + }; auto maybe_orchestrator = DataOrchestrator::TryCreate(sharded_options, metrics_recorder_); diff --git a/components/data_server/request_handler/BUILD b/components/data_server/request_handler/BUILD index 68b424d4..50b0f840 100644 --- a/components/data_server/request_handler/BUILD +++ b/components/data_server/request_handler/BUILD @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") +load("@rules_proto//proto:defs.bzl", "proto_library") package(default_visibility = [ "//components/data_server:__subpackages__", + "//components/internal_server:__subpackages__", ]) cc_library( @@ -27,6 +29,7 @@ cc_library( "get_values_handler.h", ], deps = [ + ":get_values_adapter", "//components/data_server/cache", "//public:base_types_cc_proto", "//public:constants", @@ -48,6 +51,7 @@ cc_test( ], deps = [ ":get_values_handler", + ":mocks", "//components/data_server/cache", "//components/data_server/cache:key_value_cache", "//components/data_server/cache:mocks", @@ -70,14 +74,13 @@ cc_library( ], deps = [ ":compression", + ":ohttp_server_encryptor", "//components/data_server/cache", "//components/udf:udf_client", "//public:base_types_cc_proto", - "//public:constants", "//public/query/v2:get_values_v2_cc_grpc", "@com_github_google_glog//:glog", "@com_github_google_quiche//quiche:binary_http_unstable_api", - "@com_github_google_quiche//quiche:oblivious_http_unstable_api", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", @@ -163,10 +166,12 @@ cc_library( ], deps = [ ":get_values_v2_handler", + ":v2_response_data_cc_proto", "//public/query:get_values_cc_grpc", "//public/query/v2:get_values_v2_cc_grpc", "@com_github_google_glog//:glog", "@com_github_grpc_grpc//:grpc++", + "@com_google_protobuf//:protobuf", ], ) @@ -178,6 +183,7 @@ cc_test( ], deps = [ ":get_values_adapter", + ":mocks", "//components/udf:mocks", "//public/query:get_values_cc_grpc", "//public/test_util:proto_matcher", @@ -187,3 +193,86 @@ cc_test( "@google_privacysandbox_servers_common//src/cpp/telemetry:mocks", ], ) + +cc_test( + name = "v2_response_data_proto_test", + size = "small", + srcs = [ + "v2_response_data_proto_test.cc", + ], + deps = [ + ":v2_response_data_cc_proto", + "//public/test_util:proto_matcher", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) + +proto_library( + name = "v2_response_data_proto", + srcs = ["v2_response_data.proto"], + deps = [ + "@com_google_protobuf//:struct_proto", + ], +) + +cc_proto_library( + name = "v2_response_data_cc_proto", + deps = [":v2_response_data_proto"], +) + +cc_library( + name = "mocks", + testonly = 1, + hdrs = ["mocks.h"], + deps = [ + ":get_values_adapter", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "ohttp_client_encryptor", + srcs = [ + "ohttp_client_encryptor.cc", + ], + hdrs = [ + "ohttp_client_encryptor.h", + ], + deps = [ + "//public:constants", + "@com_github_google_quiche//quiche:oblivious_http_unstable_api", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "ohttp_server_encryptor", + srcs = [ + "ohttp_server_encryptor.cc", + ], + hdrs = [ + "ohttp_server_encryptor.h", + ], + deps = [ + "//public:constants", + "@com_github_google_quiche//quiche:oblivious_http_unstable_api", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "ohttp_encryptor_test", + size = "small", + srcs = [ + "ohttp_encryptor_test.cc", + ], + deps = [ + ":ohttp_client_encryptor", + ":ohttp_server_encryptor", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/components/data_server/request_handler/get_values_adapter.cc b/components/data_server/request_handler/get_values_adapter.cc index cb290901..5a229215 100644 --- a/components/data_server/request_handler/get_values_adapter.cc +++ b/components/data_server/request_handler/get_values_adapter.cc @@ -19,17 +19,26 @@ #include #include #include +#include +#include "components/data_server/request_handler/v2_response_data.pb.h" #include "glog/logging.h" +#include "google/protobuf/util/json_util.h" namespace kv_server { namespace { using google::protobuf::RepeatedPtrField; +using google::protobuf::Struct; +using google::protobuf::Value; +using google::protobuf::util::JsonStringToMessage; constexpr char kKeysTag[] = "keys"; constexpr char kRenderUrlsTag[] = "renderUrls"; constexpr char kAdComponentRenderUrlsTag[] = "adComponentRenderUrls"; constexpr char kKvInternalTag[] = "kvInternal"; +constexpr char kCustomTag[] = "custom"; + +constexpr int kUdfInputApiVersion = 1; nlohmann::json BuildKeyGroup(const RepeatedPtrField& keys, std::string namespace_tag) { @@ -73,70 +82,149 @@ v2::GetValuesRequest BuildV2Request(const v1::GetValuesRequest& v1_request) { partition["keyGroups"] = keyGroups; get_values_v2["partitions"] = nlohmann::json::array({partition}); + get_values_v2["udfInputApiVersion"] = kUdfInputApiVersion; v2::GetValuesRequest v2_request; v2_request.mutable_raw_body()->set_data(get_values_v2.dump()); return v2_request; } -absl::Status ProcessV2ResponseJson(const nlohmann::json& v2_response_json, - v1::GetValuesResponse& v1_response) { - // Process the partitions in the response - nlohmann::json partitions; - if (const auto iter = v2_response_json.find("partitions"); - iter == v2_response_json.end()) { - // V2 does not require partitions, so ignore missing partitions. - return absl::OkStatus(); - } else { - partitions = std::move(iter.value()); +// Add key value pairs to the result struct +void ProcessKeyValues(KeyGroupOutput key_group_output, Struct& result_struct) { + for (auto&& [k, v] : std::move(key_group_output.key_values())) { + if (v.value().has_string_value()) { + Value value_proto; + google::protobuf::util::Status status = + google::protobuf::util::JsonStringToMessage(v.value().string_value(), + &value_proto); + if (status.ok()) { + (*result_struct.mutable_fields())[std::move(k)] = value_proto; + } + } + (*result_struct.mutable_fields())[std::move(k)] = v.value(); + } +} + +// Find the namespace tag that is paired with the "custom" tag. +absl::StatusOr FindNamespace(RepeatedPtrField tags) { + if (tags.size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("Expected 2 tags, found ", tags.size())); } - // TODO(b/278764114): Implement + bool has_custom_tag = false; + std::string maybe_namespace_tag; + for (auto&& tag : std::move(tags)) { + if (tag == kCustomTag) { + has_custom_tag = true; + } else { + maybe_namespace_tag = std::move(tag); + } + } + + if (has_custom_tag) { + return maybe_namespace_tag; + } + return absl::InvalidArgumentError("No namespace tags found"); +} + +absl::Status ProcessKeyGroupOutput(KeyGroupOutput key_group_output, + v1::GetValuesResponse& v1_response) { + // Ignore if no valid namespace tag that is paired with a 'custom' tag + auto tag_namespace_status_or = + FindNamespace(std::move(key_group_output.tags())); + if (!tag_namespace_status_or.ok()) { + return tag_namespace_status_or.status(); + } + if (tag_namespace_status_or.value() == kKeysTag) { + ProcessKeyValues(std::move(key_group_output), *v1_response.mutable_keys()); + } + if (tag_namespace_status_or.value() == kRenderUrlsTag) { + ProcessKeyValues(std::move(key_group_output), + *v1_response.mutable_render_urls()); + } + if (tag_namespace_status_or.value() == kAdComponentRenderUrlsTag) { + ProcessKeyValues(std::move(key_group_output), + *v1_response.mutable_ad_component_render_urls()); + } + if (tag_namespace_status_or.value() == kKvInternalTag) { + ProcessKeyValues(std::move(key_group_output), + *v1_response.mutable_kv_internal()); + } return absl::OkStatus(); } -absl::Status BuildV1Response(const google::api::HttpBody& v2_response, +// Process a V2 response object. The response JSON consists of an array of +// compression groups, each of which is a group of partition outputs. +absl::Status BuildV1Response(const nlohmann::json& v2_response_json, v1::GetValuesResponse& v1_response) { - nlohmann::json v2_response_json = - nlohmann::json::parse(v2_response.data(), nullptr, - /*allow_exceptions=*/false, - /*ignore_comments=*/true); - if (v2_response_json.is_discarded()) { + if (v2_response_json.is_null()) { + return absl::InternalError("v2 GetValues response is null"); + } + if (!v2_response_json.is_array()) { return absl::InvalidArgumentError( - "Error while parsing v2 GetValues response body."); + "Response should be an array of compression groups."); } - return ProcessV2ResponseJson(v2_response_json, v1_response); + for (const auto& compression_group_json : v2_response_json) { + V2CompressionGroup compression_group_proto; + auto status = JsonStringToMessage(compression_group_json.dump(), + &compression_group_proto); + if (!status.ok()) { + return absl::InternalError( + absl::StrCat("Could not convert compression group json to proto: ", + status.message().as_string())); + } + for (auto&& partition_proto : compression_group_proto.partitions()) { + for (auto&& key_group_output_proto : + partition_proto.key_group_outputs()) { + const auto key_group_output_status = ProcessKeyGroupOutput( + std::move(key_group_output_proto), v1_response); + if (!key_group_output_status.ok()) { + // Skip and log failed key group outputs + LOG(ERROR) << "Error processing key group output: " + << key_group_output_status; + } + } + } + } + return absl::OkStatus(); } } // namespace class GetValuesAdapterImpl : public GetValuesAdapter { public: - explicit GetValuesAdapterImpl(const GetValuesV2Handler& v2_handler) - : v2_handler_(v2_handler) {} + explicit GetValuesAdapterImpl(std::unique_ptr v2_handler) + : v2_handler_(std::move(v2_handler)) {} grpc::Status CallV2Handler(const v1::GetValuesRequest& v1_request, v1::GetValuesResponse& v1_response) const { v2::GetValuesRequest v2_request = BuildV2Request(v1_request); - google::api::HttpBody v2_response; - auto v2_response_status = v2_handler_.GetValues(v2_request, &v2_response); - if (!v2_response_status.ok()) { - return v2_response_status; + auto maybe_v2_response_json = + v2_handler_->GetValuesJsonResponse(v2_request); + if (!maybe_v2_response_json.ok()) { + return grpc::Status( + grpc::StatusCode::INTERNAL, + std::string(maybe_v2_response_json.status().message())); } - BuildV1Response(v2_response, v1_response); - // TODO(b/278764114): process response status + auto build_response_status = + BuildV1Response(maybe_v2_response_json.value(), v1_response); + if (!build_response_status.ok()) { + return grpc::Status(grpc::StatusCode::INTERNAL, + std::string(build_response_status.message())); + } return grpc::Status::OK; } private: - const GetValuesV2Handler& v2_handler_; + std::unique_ptr v2_handler_; }; std::unique_ptr GetValuesAdapter::Create( - const GetValuesV2Handler& v2_handler) { - return std::make_unique(v2_handler); + std::unique_ptr v2_handler) { + return std::make_unique(std::move(v2_handler)); } } // namespace kv_server diff --git a/components/data_server/request_handler/get_values_adapter.h b/components/data_server/request_handler/get_values_adapter.h index 5608c17f..aa4343ec 100644 --- a/components/data_server/request_handler/get_values_adapter.h +++ b/components/data_server/request_handler/get_values_adapter.h @@ -37,7 +37,7 @@ class GetValuesAdapter { v1::GetValuesResponse& v1_response) const = 0; static std::unique_ptr Create( - const GetValuesV2Handler& v2_handler); + std::unique_ptr v2_handler); }; } // namespace kv_server diff --git a/components/data_server/request_handler/get_values_adapter_test.cc b/components/data_server/request_handler/get_values_adapter_test.cc index d0523a37..3c9f7e9e 100644 --- a/components/data_server/request_handler/get_values_adapter_test.cc +++ b/components/data_server/request_handler/get_values_adapter_test.cc @@ -15,6 +15,7 @@ #include "components/data_server/request_handler/get_values_adapter.h" #include +#include #include #include "components/udf/mocks.h" @@ -37,17 +38,196 @@ using testing::Return; class GetValuesAdapterTest : public ::testing::Test { protected: void SetUp() override { - GetValuesV2Handler v2_handler(mock_udf_client_, mock_metrics_recorder_); - get_values_adapter_ = GetValuesAdapter::Create(v2_handler); + v2_handler_ = std::make_unique(mock_udf_client_, + mock_metrics_recorder_); + get_values_adapter_ = GetValuesAdapter::Create(std::move(v2_handler_)); } std::unique_ptr get_values_adapter_; + std::unique_ptr v2_handler_; MockUdfClient mock_udf_client_; MockMetricsRecorder mock_metrics_recorder_; }; TEST_F(GetValuesAdapterTest, EmptyRequestReturnsEmptyResponse) { + nlohmann::json udf_input = + R"({"context":{"subkey":""},"keyGroups":[],"udfInputApiVersion":1})"_json; + nlohmann::json udf_output = + R"({"keyGroupOutputs": [], "udfOutputApiVersion": 1})"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString(R"pb()pb", &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, V1RequestWithTwoKeysReturnsOk) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","keys"],"keyList": ["key1", "key2"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutputs": [{ + "keyValues": { + "key1": { "value": "value1" }, + "key2": { "value": "value2" } + }, + "tags": ["custom","keys"] + }], + "udfOutputApiVersion": 1 + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1"); + v1_request.add_keys("key2"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString(R"pb( + keys { + fields { + key: "key1" + value { string_value: "value1" } + } + fields { + key: "key2" + value { string_value: "value2" } + } + + })pb", + &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, V1RequestWithTwoKeyGroupsReturnsOk) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","renderUrls"],"keyList": ["key1"]},{"keyList":["key2"],"tags":["custom","adComponentRenderUrls"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutputs": [{ + "keyValues": { "key1": { "value": "value1" } }, + "tags": ["custom","renderUrls"] + },{ + "keyValues": { "key2": { "value": "value2" } }, + "tags": ["custom","adComponentRenderUrls"] + }], + "udfOutputApiVersion": 1 + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1_request.add_render_urls("key1"); + v1_request.add_ad_component_render_urls("key2"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString(R"pb( + render_urls { + fields { + key: "key1" + value { string_value: "value1" } + } + } + ad_component_render_urls { + fields { + key: "key2" + value { string_value: "value2" } + } + })pb", + &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, V2ResponseIsNullReturnsError) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","keys"],"keyList": ["key1"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutpus": [] + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_FALSE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString(R"pb()pb", &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, KeyGroupOutputWithEmptyKVsReturnsOk) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","keys"],"keyList": ["key1"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutputs": [{ + "keyValues": {}, + "tags": ["custom","keys"] + }], + "udfOutputApiVersion": 1 + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString(R"pb(keys {})pb", &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, KeyGroupOutputWithInvalidNamespaceTagIsIgnored) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","keys"],"keyList": ["key1"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutputs": [{ + "keyValues": { "key1": { "value": "value1" } }, + "tags": ["custom","invalidTag"] + }], + "udfOutputApiVersion": 1 + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1"); v1::GetValuesResponse v1_response; auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); EXPECT_TRUE(status.ok()); @@ -56,5 +236,246 @@ TEST_F(GetValuesAdapterTest, EmptyRequestReturnsEmptyResponse) { EXPECT_THAT(v1_response, EqualsProto(v1_expected)); } +TEST_F(GetValuesAdapterTest, KeyGroupOutputWithNoCustomTagIsIgnored) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","keys"],"keyList": ["key1"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutputs": [{ + "keyValues": { "key1": { "value": "value1" } }, + "tags": ["keys", "somethingelse"] + }], + "udfOutputApiVersion": 1 + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString(R"pb()pb", &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, KeyGroupOutputWithNoNamespaceTagIsIgnored) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","keys"],"keyList": ["key1"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutputs": [{ + "keyValues": { "key1": { "value": "value1" } }, + "tags": ["custom"] + }], + "udfOutputApiVersion": 1 + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString(R"pb()pb", &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, + KeyGroupOutputHasDuplicateNamespaceTagReturnsAllKeys) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","keys"],"keyList": ["key1"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutputs": [{ + "keyValues": { "key1": { "value": "value1" } }, + "tags": ["custom", "keys"] + }, + { + "keyValues": { "key2": { "value": "value2" } }, + "tags": ["custom", "keys"] + }], + "udfOutputApiVersion": 1 + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString(R"pb( + keys { + fields { + key: "key1" + value { string_value: "value1" } + } + fields { + key: "key2" + value { string_value: "value2" } + } + })pb", + &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, KeyGroupOutputHasDifferentValueTypesReturnsOk) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","keys"],"keyList": ["key1"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutputs": [{ + "keyValues": { + "key1": { "value": [[[1,2,3,4]],null,["123456789","123456789"],["v1"]] }, + "key2": { "value": {"k2":"v","k1":123} }, + "key3": { "value": "3"} + }, + "tags": ["custom", "keys"] + }], + "udfOutputApiVersion": 1 + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString( + R"pb(keys { + fields { + key: "key1" + value { + list_value { + values { + list_value { + values { + list_value { + values { number_value: 1 } + values { number_value: 2 } + values { number_value: 3 } + values { number_value: 4 } + } + } + } + } + values { null_value: NULL_VALUE } + values { + list_value { + values { string_value: "123456789" } + values { string_value: "123456789" } + } + } + values { list_value { values { string_value: "v1" } } } + } + } + } + fields { + key: "key2" + value { + struct_value { + fields { + key: "k1" + value { number_value: 123 } + } + fields { + key: "k2" + value { string_value: "v" } + } + } + } + } + fields { + key: "key3" + value { string_value: "3" } + } + })pb", + &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + +TEST_F(GetValuesAdapterTest, ValueWithStatusSuccess) { + nlohmann::json udf_input = R"({ + "context": {"subkey": ""}, + "keyGroups": [{"tags": ["custom","keys"],"keyList": ["key1"]}], + "udfInputApiVersion": 1 + })"_json; + + nlohmann::json udf_output = R"({ + "keyGroupOutputs": [{ + "keyValues": { "key1": { + "value": { + "status": { + "code": 1, + "message": "some error message" + } + } + } }, + "tags": ["custom", "keys"] + }], + "udfOutputApiVersion": 1 + })"_json; + EXPECT_CALL(mock_udf_client_, + ExecuteCode(std::vector({udf_input.dump()}))) + .WillOnce(Return(udf_output.dump())); + + v1::GetValuesRequest v1_request; + v1_request.add_keys("key1"); + v1::GetValuesResponse v1_response; + auto status = get_values_adapter_->CallV2Handler(v1_request, v1_response); + EXPECT_TRUE(status.ok()); + v1::GetValuesResponse v1_expected; + TextFormat::ParseFromString( + R"pb(keys { + fields { + key: "key1" + value { + struct_value { + fields { + key: "status" + value { + struct_value { + fields { + key: "code" + value { number_value: 1 } + } + fields { + key: "message" + value { string_value: "some error message" } + } + } + } + } + } + } + } + })pb", + &v1_expected); + EXPECT_THAT(v1_response, EqualsProto(v1_expected)); +} + } // namespace } // namespace kv_server diff --git a/components/data_server/request_handler/get_values_handler.cc b/components/data_server/request_handler/get_values_handler.cc index 076de734..116261b1 100644 --- a/components/data_server/request_handler/get_values_handler.cc +++ b/components/data_server/request_handler/get_values_handler.cc @@ -20,7 +20,7 @@ #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" -#include "components/data_server/cache/cache.h" +#include "components/data_server/request_handler/get_values_adapter.h" #include "glog/logging.h" #include "grpcpp/grpcpp.h" #include "public/constants.h" @@ -123,6 +123,11 @@ grpc::Status GetValuesHandler::GetValues(const GetValuesRequest& request, return status; } + if (use_v2_) { + VLOG(5) << "Using V2 adapter for " << request.DebugString(); + return adapter_.CallV2Handler(request, *response); + } + VLOG(5) << "Processing kv_internal for " << request.DebugString(); if (!request.kv_internal().empty()) { VLOG(5) << "Processing keys for " << request.DebugString(); diff --git a/components/data_server/request_handler/get_values_handler.h b/components/data_server/request_handler/get_values_handler.h index 39899eea..3a0e40fd 100644 --- a/components/data_server/request_handler/get_values_handler.h +++ b/components/data_server/request_handler/get_values_handler.h @@ -19,8 +19,9 @@ #include #include +#include -#include "components/data_server/cache/cache.h" +#include "components/data_server/request_handler/get_values_adapter.h" #include "grpcpp/grpcpp.h" #include "public/query/get_values.grpc.pb.h" #include "src/cpp/telemetry/metrics_recorder.h" @@ -33,12 +34,14 @@ namespace kv_server { class GetValuesHandler { public: explicit GetValuesHandler( - const Cache& cache, + const Cache& cache, const GetValuesAdapter& adapter, privacy_sandbox::server_common::MetricsRecorder& metrics_recorder, - bool dsp_mode) - : cache_(cache), + bool dsp_mode, bool use_v2) + : cache_(std::move(cache)), + adapter_(std::move(adapter)), metrics_recorder_(metrics_recorder), - dsp_mode_(dsp_mode) {} + dsp_mode_(dsp_mode), + use_v2_(use_v2) {} // TODO: Implement subkey, ad/render url lookups. grpc::Status GetValues(const v1::GetValuesRequest& request, @@ -48,11 +51,15 @@ class GetValuesHandler { grpc::Status ValidateRequest(const v1::GetValuesRequest& request) const; const Cache& cache_; + const GetValuesAdapter& adapter_; privacy_sandbox::server_common::MetricsRecorder& metrics_recorder_; // Use DSP mode for request validation. If false, then automatically assumes // SSP mode. const bool dsp_mode_; + + // If true, routes requests through V2 (UDF). Otherwise, calls cache. + const bool use_v2_; }; } // namespace kv_server diff --git a/components/data_server/request_handler/get_values_handler_test.cc b/components/data_server/request_handler/get_values_handler_test.cc index 272a5750..80f311c0 100644 --- a/components/data_server/request_handler/get_values_handler_test.cc +++ b/components/data_server/request_handler/get_values_handler_test.cc @@ -22,6 +22,7 @@ #include "components/data_server/cache/cache.h" #include "components/data_server/cache/key_value_cache.h" #include "components/data_server/cache/mocks.h" +#include "components/data_server/request_handler/mocks.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "grpcpp/grpcpp.h" @@ -36,8 +37,11 @@ namespace { using google::protobuf::TextFormat; using grpc::StatusCode; using privacy_sandbox::server_common::MockMetricsRecorder; +using testing::_; +using testing::DoAll; using testing::Return; using testing::ReturnRef; +using testing::SetArgReferee; using testing::UnorderedElementsAre; using v1::GetValuesRequest; using v1::GetValuesResponse; @@ -46,6 +50,7 @@ class GetValuesHandlerTest : public ::testing::Test { protected: MockCache mock_cache_; MockMetricsRecorder mock_metrics_recorder_; + MockGetValuesAdapter mock_get_values_adapter_; }; TEST_F(GetValuesHandlerTest, ReturnsExistingKeyTwice) { @@ -56,8 +61,9 @@ TEST_F(GetValuesHandlerTest, ReturnsExistingKeyTwice) { GetValuesRequest request; request.add_keys("my_key"); GetValuesResponse response; - GetValuesHandler handler(mock_cache_, mock_metrics_recorder_, - /*dsp_mode=*/true); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, + /*dsp_mode=*/true, /*use_v2=*/false); const auto result = handler.GetValues(request, &response); ASSERT_TRUE(result.ok()) << "code: " << result.error_code() << ", msg: " << result.error_message(); @@ -77,7 +83,6 @@ TEST_F(GetValuesHandlerTest, ReturnsExistingKeyTwice) { } TEST_F(GetValuesHandlerTest, RepeatedKeys) { - MockCache mock_cache_; EXPECT_CALL(mock_cache_, GetKeyValuePairs(UnorderedElementsAre("key1", "key2", "key3"))) .Times(1) @@ -86,8 +91,9 @@ TEST_F(GetValuesHandlerTest, RepeatedKeys) { GetValuesRequest request; request.add_keys("key1,key2,key3"); GetValuesResponse response; - GetValuesHandler handler(mock_cache_, mock_metrics_recorder_, - /*dsp_mode=*/true); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, + /*dsp_mode=*/true, /*use_v2=*/false); ASSERT_TRUE(handler.GetValues(request, &response).ok()); GetValuesResponse expected; @@ -102,7 +108,6 @@ TEST_F(GetValuesHandlerTest, RepeatedKeys) { } TEST_F(GetValuesHandlerTest, ReturnsMultipleExistingKeysSameNamespace) { - MockCache mock_cache; EXPECT_CALL(mock_cache_, GetKeyValuePairs(UnorderedElementsAre("key1", "key2"))) .Times(1) @@ -112,8 +117,9 @@ TEST_F(GetValuesHandlerTest, ReturnsMultipleExistingKeysSameNamespace) { request.add_keys("key1"); request.add_keys("key2"); GetValuesResponse response; - GetValuesHandler handler(mock_cache_, mock_metrics_recorder_, - /*dsp_mode=*/true); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, + /*dsp_mode=*/true, /*use_v2=*/false); ASSERT_TRUE(handler.GetValues(request, &response).ok()); GetValuesResponse expected; @@ -132,12 +138,11 @@ TEST_F(GetValuesHandlerTest, ReturnsMultipleExistingKeysSameNamespace) { } TEST_F(GetValuesHandlerTest, ReturnsMultipleExistingKeysDifferentNamespace) { - MockCache mock_cache; - EXPECT_CALL(mock_cache, GetKeyValuePairs(UnorderedElementsAre("key1"))) + EXPECT_CALL(mock_cache_, GetKeyValuePairs(UnorderedElementsAre("key1"))) .Times(1) .WillOnce(Return( absl::flat_hash_map{{"key1", "value1"}})); - EXPECT_CALL(mock_cache, GetKeyValuePairs(UnorderedElementsAre("key2"))) + EXPECT_CALL(mock_cache_, GetKeyValuePairs(UnorderedElementsAre("key2"))) .Times(1) .WillOnce(Return( absl::flat_hash_map{{"key2", "value2"}})); @@ -145,8 +150,9 @@ TEST_F(GetValuesHandlerTest, ReturnsMultipleExistingKeysDifferentNamespace) { request.add_render_urls("key1"); request.add_ad_component_render_urls("key2"); GetValuesResponse response; - GetValuesHandler handler(mock_cache, mock_metrics_recorder_, - /*dsp_mode=*/false); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, + /*dsp_mode=*/false, /*use_v2=*/false); ASSERT_TRUE(handler.GetValues(request, &response).ok()); GetValuesResponse expected; @@ -167,33 +173,36 @@ TEST_F(GetValuesHandlerTest, ReturnsMultipleExistingKeysDifferentNamespace) { } TEST_F(GetValuesHandlerTest, DspModeErrorOnMissingKeysNamespace) { - std::unique_ptr cache = KeyValueCache::Create(); GetValuesRequest request; request.set_subkey("my_subkey"); GetValuesResponse response; - GetValuesHandler handler(*cache, mock_metrics_recorder_, /*dsp_mode=*/true); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, /*dsp_mode=*/true, + /*use_v2=*/false); grpc::Status status = handler.GetValues(request, &response); EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); EXPECT_EQ(status.error_details(), "Missing field 'keys'"); } TEST_F(GetValuesHandlerTest, ErrorOnMissingKeysInDspMode) { - std::unique_ptr cache = KeyValueCache::Create(); GetValuesRequest request; GetValuesResponse response; - GetValuesHandler handler(*cache, mock_metrics_recorder_, /*dsp_mode=*/true); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, /*dsp_mode=*/true, + /*use_v2=*/false); grpc::Status status = handler.GetValues(request, &response); EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); EXPECT_EQ(status.error_details(), "Missing field 'keys'"); } TEST_F(GetValuesHandlerTest, ErrorOnRenderUrlInDspMode) { - std::unique_ptr cache = KeyValueCache::Create(); GetValuesRequest request; request.add_keys("my_key"); request.add_render_urls("my_render_url"); GetValuesResponse response; - GetValuesHandler handler(*cache, mock_metrics_recorder_, /*dsp_mode=*/true); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, /*dsp_mode=*/true, + /*use_v2=*/false); grpc::Status status = handler.GetValues(request, &response); EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); @@ -201,12 +210,13 @@ TEST_F(GetValuesHandlerTest, ErrorOnRenderUrlInDspMode) { } TEST_F(GetValuesHandlerTest, ErrorOnAdComponentRenderUrlInDspMode) { - std::unique_ptr cache = KeyValueCache::Create(); GetValuesRequest request; request.add_keys("my_key"); request.add_ad_component_render_urls("my_ad_component_render_url"); GetValuesResponse response; - GetValuesHandler handler(*cache, mock_metrics_recorder_, /*dsp_mode=*/true); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, /*dsp_mode=*/true, + /*use_v2=*/false); grpc::Status status = handler.GetValues(request, &response); EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); @@ -214,11 +224,12 @@ TEST_F(GetValuesHandlerTest, ErrorOnAdComponentRenderUrlInDspMode) { } TEST_F(GetValuesHandlerTest, ErrorOnMissingRenderUrlInSspMode) { - std::unique_ptr cache = KeyValueCache::Create(); GetValuesRequest request; request.add_ad_component_render_urls("my_ad_component_render_url"); GetValuesResponse response; - GetValuesHandler handler(*cache, mock_metrics_recorder_, /*dsp_mode=*/false); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, /*dsp_mode=*/false, + /*use_v2=*/false); grpc::Status status = handler.GetValues(request, &response); EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); @@ -226,12 +237,13 @@ TEST_F(GetValuesHandlerTest, ErrorOnMissingRenderUrlInSspMode) { } TEST_F(GetValuesHandlerTest, ErrorOnKeysInSspMode) { - std::unique_ptr cache = KeyValueCache::Create(); GetValuesRequest request; request.add_render_urls("my_render_url"); request.add_keys("my_key"); GetValuesResponse response; - GetValuesHandler handler(*cache, mock_metrics_recorder_, /*dsp_mode=*/false); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, /*dsp_mode=*/false, + /*use_v2=*/false); grpc::Status status = handler.GetValues(request, &response); EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); @@ -239,12 +251,13 @@ TEST_F(GetValuesHandlerTest, ErrorOnKeysInSspMode) { } TEST_F(GetValuesHandlerTest, ErrorOnSubkeysInSspMode) { - std::unique_ptr cache = KeyValueCache::Create(); GetValuesRequest request; request.add_render_urls("my_render_url"); request.set_subkey("my_subkey"); GetValuesResponse response; - GetValuesHandler handler(*cache, mock_metrics_recorder_, /*dsp_mode=*/false); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, /*dsp_mode=*/false, + /*use_v2=*/false); grpc::Status status = handler.GetValues(request, &response); EXPECT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); @@ -252,7 +265,6 @@ TEST_F(GetValuesHandlerTest, ErrorOnSubkeysInSspMode) { } TEST_F(GetValuesHandlerTest, TestResponseOnDifferentValueFormats) { - MockCache mock_cache; std::string value1 = R"json([ [[1, 2, 3, 4]], null, @@ -337,8 +349,9 @@ TEST_F(GetValuesHandlerTest, TestResponseOnDifferentValueFormats) { request.add_keys("key2"); request.add_keys("key3"); GetValuesResponse response; - GetValuesHandler handler(mock_cache_, mock_metrics_recorder_, - /*dsp_mode=*/true); + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, + /*dsp_mode=*/true, /*use_v2=*/false); ASSERT_TRUE(handler.GetValues(request, &response).ok()); GetValuesResponse expected_from_pb; TextFormat::ParseFromString(response_pb_string, &expected_from_pb); @@ -349,5 +362,28 @@ TEST_F(GetValuesHandlerTest, TestResponseOnDifferentValueFormats) { EXPECT_THAT(response, EqualsProto(expected_from_json)); } +TEST_F(GetValuesHandlerTest, CallsV2Adapter) { + GetValuesResponse adapter_response; + TextFormat::ParseFromString(R"pb(keys { + fields { + key: "key1" + value { string_value: "value1" } + } + })pb", + &adapter_response); + EXPECT_CALL(mock_get_values_adapter_, CallV2Handler(_, _)) + .WillOnce( + DoAll(SetArgReferee<1>(adapter_response), Return(grpc::Status::OK))); + + GetValuesRequest request; + request.add_keys("key1"); + GetValuesResponse response; + GetValuesHandler handler(mock_cache_, mock_get_values_adapter_, + mock_metrics_recorder_, + /*dsp_mode=*/true, /*use_v2=*/true); + ASSERT_TRUE(handler.GetValues(request, &response).ok()); + EXPECT_THAT(response, EqualsProto(adapter_response)); +} + } // namespace } // namespace kv_server diff --git a/components/data_server/request_handler/get_values_v2_handler.cc b/components/data_server/request_handler/get_values_v2_handler.cc index 9938111f..dd81c5ab 100644 --- a/components/data_server/request_handler/get_values_v2_handler.cc +++ b/components/data_server/request_handler/get_values_v2_handler.cc @@ -23,6 +23,7 @@ #include "absl/algorithm/container.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "components/data_server/request_handler/ohttp_server_encryptor.h" #include "glog/logging.h" #include "grpcpp/grpcpp.h" #include "public/base_types.pb.h" @@ -33,7 +34,6 @@ #include "quiche/oblivious_http/oblivious_http_gateway.h" #include "src/cpp/telemetry/telemetry.h" -constexpr char* kGetValuesV2HandlerSpan = "GetValuesV2Handler"; constexpr char* kCacheKeyV2Hit = "CacheKeyHit"; constexpr char* kCacheKeyV2Miss = "CacheKeyMiss"; @@ -63,6 +63,7 @@ absl::StatusOr ExecuteUdfForKeyGroups( if (!maybe_udf_output_string.ok()) { return maybe_udf_output_string.status(); } + VLOG(5) << "UDF output: " << maybe_udf_output_string.value(); nlohmann::json key_group_outputs = nlohmann::json::parse(std::move(maybe_udf_output_string.value()), nullptr, /*allow_exceptions=*/false, @@ -212,38 +213,34 @@ nlohmann::json GetValuesV2Handler::BuildCompressionGroupsForDebugging( return output; } -grpc::Status GetValuesV2Handler::GetValues( - const GetValuesRequest& request, google::api::HttpBody* response) const { - auto span = GetTracer()->StartSpan(kGetValuesV2HandlerSpan); - auto scope = opentelemetry::trace::Scope(span); - +absl::StatusOr GetValuesV2Handler::GetValuesJsonResponse( + const v2::GetValuesRequest& request) const { absl::StatusOr maybe_core_request_json = Parse(request.raw_body().data()); if (!maybe_core_request_json.ok()) { - return grpc::Status( - StatusCode::INTERNAL, - std::string(maybe_core_request_json.status().message())); + return maybe_core_request_json.status(); } - if (auto maybe_compression_groups = ProcessGetValuesCoreRequest( - udf_client_, maybe_core_request_json.value()); - maybe_compression_groups.ok()) { - nlohmann::json response_json = BuildCompressionGroupsForDebugging( - std::move(maybe_compression_groups).value()); - - if (response_json.size() > 0) - metrics_recorder_.IncrementEventCounter(kCacheKeyV2Hit); - else - metrics_recorder_.IncrementEventCounter(kCacheKeyV2Miss); + auto maybe_compression_groups = + ProcessGetValuesCoreRequest(udf_client_, maybe_core_request_json.value()); + if (!maybe_compression_groups.ok()) { + return maybe_compression_groups.status(); + } + nlohmann::json response_json = BuildCompressionGroupsForDebugging( + std::move(maybe_compression_groups).value()); + VLOG(5) << "Uncompressed response: " << response_json.dump(1); + return response_json; +} - VLOG(5) << "Uncompressed response: " << response_json.dump(1); - response->set_data(response_json.dump()); +grpc::Status GetValuesV2Handler::GetValues( + const GetValuesRequest& request, google::api::HttpBody* response) const { + const auto maybe_response_json = GetValuesJsonResponse(request); + if (maybe_response_json.ok()) { + response->set_data(maybe_response_json.value().dump()); return grpc::Status::OK; - } else { - return grpc::Status( - StatusCode::INTERNAL, - std::string(maybe_compression_groups.status().message())); } + return grpc::Status(StatusCode::INTERNAL, + std::string(maybe_response_json.status().message())); } grpc::Status GetValuesV2Handler::BinaryHttpGetValues( @@ -332,54 +329,28 @@ grpc::Status GetValuesV2Handler::ObliviousGetValues( const ObliviousGetValuesRequest& oblivious_request, google::api::HttpBody* oblivious_response) const { VLOG(9) << "Received ObliviousGetValues request. "; - - const absl::StatusOr maybe_req_key_id = quiche:: - ObliviousHttpHeaderKeyConfig::ParseKeyIdFromObliviousHttpRequestPayload( - oblivious_request.raw_body().data()); - if (!maybe_req_key_id.ok()) { + OhttpServerEncryptor encryptor; + auto maybe_plain_text = + encryptor.DecryptRequest(oblivious_request.raw_body().data()); + if (!maybe_plain_text.ok()) { return grpc::Status(StatusCode::INTERNAL, - absl::StrCat("Unable to get OHTTP key id: ", - maybe_req_key_id.status().message())); + absl::StrCat(maybe_plain_text.status().code(), " : ", + maybe_plain_text.status().message())); } - const auto maybe_config = quiche::ObliviousHttpHeaderKeyConfig::Create( - *maybe_req_key_id, kKEMParameter, kKDFParameter, kAEADParameter); - if (!maybe_config.ok()) { - return grpc::Status(StatusCode::INTERNAL, - absl::StrCat("Unable to build OHTTP config: ", - maybe_config.status().message())); - } - - const auto ohttp_instance = - quiche::ObliviousHttpGateway::Create(test_private_key_, *maybe_config); - - auto decrypted_req = ohttp_instance->DecryptObliviousHttpRequest( - oblivious_request.raw_body().data()); - - if (!decrypted_req.ok()) { - return grpc::Status(StatusCode::INTERNAL, - std::string(decrypted_req.status().message())); - } - absl::string_view request_text = decrypted_req->GetPlaintextData(); - // Now process the binary http request std::string response; - if (const auto s = BinaryHttpGetValues(request_text, response); !s.ok()) { + if (const auto s = BinaryHttpGetValues(*maybe_plain_text, response); + !s.ok()) { return s; } - - // encrypt/encapsulate the response - google::api::HttpBody bhttp_response; - auto server_request_context = - std::move(decrypted_req).value().ReleaseContext(); - const auto encapsulate_resp = ohttp_instance->CreateObliviousHttpResponse( - response, server_request_context); - if (!encapsulate_resp.ok()) { - return grpc::Status(StatusCode::INTERNAL, - std::string(encapsulate_resp.status().message())); + auto encrypted_response = encryptor.EncryptResponse(std::move(response)); + if (!encrypted_response.ok()) { + return grpc::Status(grpc::StatusCode::INTERNAL, + absl::StrCat(encrypted_response.status().code(), " : ", + encrypted_response.status().message())); } oblivious_response->set_content_type(std::string(kOHTTPResponseContentType)); - oblivious_response->set_data(encapsulate_resp->EncapsulateAndSerialize()); - + oblivious_response->set_data(*encrypted_response); return grpc::Status::OK; } diff --git a/components/data_server/request_handler/get_values_v2_handler.h b/components/data_server/request_handler/get_values_v2_handler.h index bd7fe3ea..3f6af2f4 100644 --- a/components/data_server/request_handler/get_values_v2_handler.h +++ b/components/data_server/request_handler/get_values_v2_handler.h @@ -22,6 +22,7 @@ #include #include +#include "absl/status/statusor.h" #include "absl/strings/escaping.h" #include "components/data_server/cache/cache.h" #include "components/data_server/request_handler/compression.h" @@ -50,6 +51,9 @@ class GetValuesV2Handler { create_compression_group_concatenator_( std::move(create_compression_group_concatenator)) {} + absl::StatusOr GetValuesJsonResponse( + const v2::GetValuesRequest& request) const; + grpc::Status GetValues(const v2::GetValuesRequest& request, google::api::HttpBody* response) const; @@ -93,11 +97,6 @@ class GetValuesV2Handler { grpc::Status BinaryHttpGetValues(std::string_view bhttp_request_body, std::string& response) const; - // X25519 Secret key (private key). - // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#appendix-A-2 - const std::string test_private_key_ = absl::HexStringToBytes( - "3c168975674b2fa8e465970b79c8dcf09f1c741626480bd4c6162fc5b6a98e1a"); - const UdfClient& udf_client_; std::function create_compression_group_concatenator_; diff --git a/components/data_server/request_handler/mocks.h b/components/data_server/request_handler/mocks.h new file mode 100644 index 00000000..1b23035f --- /dev/null +++ b/components/data_server/request_handler/mocks.h @@ -0,0 +1,37 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_MOCKS_H +#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_MOCKS_H + +#include "components/data_server/request_handler/get_values_adapter.h" +#include "gmock/gmock.h" +#include "grpcpp/grpcpp.h" +#include "public/query/get_values.grpc.pb.h" + +namespace kv_server { + +class MockGetValuesAdapter : public GetValuesAdapter { + public: + MOCK_METHOD((grpc::Status), CallV2Handler, + (const v1::GetValuesRequest& v1_request, + v1::GetValuesResponse& v1_response), + (const, override)); +}; + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_MOCKS_H diff --git a/components/data_server/request_handler/ohttp_client_encryptor.cc b/components/data_server/request_handler/ohttp_client_encryptor.cc new file mode 100644 index 00000000..b34c1efd --- /dev/null +++ b/components/data_server/request_handler/ohttp_client_encryptor.cc @@ -0,0 +1,61 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/ohttp_client_encryptor.h" + +#include + +#include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" + +namespace kv_server { +absl::StatusOr OhttpClientEncryptor::EncryptRequest( + std::string payload) { + auto maybe_config = quiche::ObliviousHttpHeaderKeyConfig::Create( + test_key_id, kKEMParameter, kKDFParameter, kAEADParameter); + if (!maybe_config.ok()) { + return absl::InternalError(std::string(maybe_config.status().message())); + } + auto http_client_maybe = + quiche::ObliviousHttpClient::Create(test_public_key_, *maybe_config); + if (!http_client_maybe.ok()) { + return absl::InternalError( + std::string(http_client_maybe.status().message())); + } + http_client_ = std::move(*http_client_maybe); + auto encrypted_req = + http_client_->CreateObliviousHttpRequest(std::move(payload)); + if (!encrypted_req.ok()) { + return absl::InternalError(std::string(encrypted_req.status().message())); + } + std::string serialized_encrypted_req = + encrypted_req->EncapsulateAndSerialize(); + http_request_context_ = std::move(encrypted_req.value()).ReleaseContext(); + return serialized_encrypted_req; +} + +absl::StatusOr +OhttpClientEncryptor::DecryptResponse(std::string encrypted_payload) { + if (!http_client_.has_value() || !http_request_context_.has_value()) { + return absl::InternalError( + "Emtpy `http_client_` or `http_request_context_`. You should call " + "`ClientEncryptRequest` first"); + } + auto decrypted_response = http_client_->DecryptObliviousHttpResponse( + std::move(encrypted_payload), *http_request_context_); + if (!decrypted_response.ok()) { + return decrypted_response.status(); + } + return *decrypted_response; +} +} // namespace kv_server diff --git a/components/data_server/request_handler/ohttp_client_encryptor.h b/components/data_server/request_handler/ohttp_client_encryptor.h new file mode 100644 index 00000000..64c768bc --- /dev/null +++ b/components/data_server/request_handler/ohttp_client_encryptor.h @@ -0,0 +1,53 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_CLIENT_ENCRYPTOR_H_ +#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_CLIENT_ENCRYPTOR_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "public/constants.h" +#include "quiche/oblivious_http/oblivious_http_client.h" + +namespace kv_server { + +// Handles client side encyption of requests and decryptions of responses. +// Not thread safe. Supports serial encryption/decryption per request. +class OhttpClientEncryptor { + public: + // Encrypts ougoing request. + absl::StatusOr EncryptRequest(std::string payload); + // Decrypts incoming reponse. Since OHTTP is stateful, this method should be + // called after EncryptRequest. + // In order to avoid an extra copy, leaking the `ObliviousHttpResponse`. + // Note that we have a CL for the underlying library that might allow us to + // not do leak this object and not do the copy. If/when that's merged, we + // should refactor this back to returning a string. + absl::StatusOr DecryptResponse( + std::string encrypted_payload); + + private: + std::optional http_client_; + std::optional http_request_context_; + + const std::string test_public_key_ = absl::HexStringToBytes(kTestPublicKey); + const uint8_t test_key_id = 1; +}; + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_CLIENT_ENCRYPTOR_H_ diff --git a/components/data_server/request_handler/ohttp_encryptor_test.cc b/components/data_server/request_handler/ohttp_encryptor_test.cc new file mode 100644 index 00000000..7565abaa --- /dev/null +++ b/components/data_server/request_handler/ohttp_encryptor_test.cc @@ -0,0 +1,85 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "components/data_server/request_handler/ohttp_client_encryptor.h" +#include "components/data_server/request_handler/ohttp_server_encryptor.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +TEST(OhttpEncryptorTest, FullCircleSuccess) { + const std::string kTestRequest = "request to encrypt"; + OhttpClientEncryptor client_encryptor; + OhttpServerEncryptor server_encryptor; + auto request_encrypted_status = client_encryptor.EncryptRequest(kTestRequest); + ASSERT_TRUE(request_encrypted_status.ok()); + auto request_decrypted_status = + server_encryptor.DecryptRequest(*request_encrypted_status); + ASSERT_TRUE(request_decrypted_status.ok()); + EXPECT_EQ(kTestRequest, *request_decrypted_status); + + const std::string kTestResponse = "response to encrypt"; + auto response_encrypted_status = + server_encryptor.EncryptResponse(kTestResponse); + ASSERT_TRUE(response_encrypted_status.ok()); + auto response_decrypted_status = + client_encryptor.DecryptResponse(*response_encrypted_status); + ASSERT_TRUE(response_decrypted_status.ok()); + EXPECT_EQ(kTestResponse, response_decrypted_status->GetPlaintextData()); +} + +TEST(OhttpEncryptorTest, ServerDecryptRequestFails) { + OhttpServerEncryptor server_encryptor; + auto request_decrypted_status = server_encryptor.DecryptRequest("garbage"); + ASSERT_FALSE(request_decrypted_status.ok()); +} + +TEST(OhttpEncryptorTest, ClientDecryptFails) { + const std::string kTestRequest = "request to encrypt"; + OhttpClientEncryptor client_encryptor; + auto request_encrypted_status = client_encryptor.EncryptRequest(kTestRequest); + ASSERT_TRUE(request_encrypted_status.ok()); + auto response_decrypted_status = client_encryptor.DecryptResponse("garbage"); + ASSERT_FALSE(response_decrypted_status.ok()); +} + +TEST(OhttpEncryptorTest, ServerEncryptResponseFails) { + const std::string kTestRequest = "request to encrypt"; + OhttpServerEncryptor server_encryptor; + auto request_encrypted_status = + server_encryptor.EncryptResponse(kTestRequest); + ASSERT_FALSE(request_encrypted_status.ok()); + EXPECT_EQ( + "Emtpy `ohttp_gateway_` or `decrypted_request_`. You should call " + "`ServerDecryptRequest` first", + request_encrypted_status.status().message()); +} + +TEST(OhttpEncryptorTest, ClientDecryptResponseFails) { + const std::string kTestRequest = "request to decrypt"; + OhttpClientEncryptor client_encryptor; + auto request_encrypted_status = + client_encryptor.DecryptResponse(kTestRequest); + ASSERT_FALSE(request_encrypted_status.ok()); + EXPECT_EQ( + "Emtpy `http_client_` or `http_request_context_`. You should call " + "`ClientEncryptRequest` first", + request_encrypted_status.status().message()); +} + +} // namespace +} // namespace kv_server diff --git a/components/data_server/request_handler/ohttp_server_encryptor.cc b/components/data_server/request_handler/ohttp_server_encryptor.cc new file mode 100644 index 00000000..58cd1800 --- /dev/null +++ b/components/data_server/request_handler/ohttp_server_encryptor.cc @@ -0,0 +1,68 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/request_handler/ohttp_server_encryptor.h" + +#include + +#include "quiche/oblivious_http/common/oblivious_http_header_key_config.h" + +namespace kv_server { +absl::StatusOr OhttpServerEncryptor::DecryptRequest( + absl::string_view encrypted_payload) { + const absl::StatusOr maybe_req_key_id = + quiche::ObliviousHttpHeaderKeyConfig:: + ParseKeyIdFromObliviousHttpRequestPayload(encrypted_payload); + if (!maybe_req_key_id.ok()) { + return absl::InternalError(absl::StrCat( + "Unable to get OHTTP key id: ", maybe_req_key_id.status().message())); + } + const auto maybe_config = quiche::ObliviousHttpHeaderKeyConfig::Create( + *maybe_req_key_id, kKEMParameter, kKDFParameter, kAEADParameter); + if (!maybe_config.ok()) { + return absl::InternalError(absl::StrCat( + "Unable to build OHTTP config: ", maybe_req_key_id.status().message())); + } + auto maybe_ohttp_gateway = + quiche::ObliviousHttpGateway::Create(test_private_key_, *maybe_config); + if (!maybe_ohttp_gateway.ok()) { + return maybe_ohttp_gateway.status(); + } + ohttp_gateway_ = std::move(*maybe_ohttp_gateway); + auto decrypted_request_maybe = + ohttp_gateway_->DecryptObliviousHttpRequest(encrypted_payload); + if (!decrypted_request_maybe.ok()) { + return decrypted_request_maybe.status(); + } + decrypted_request_ = std::move(*decrypted_request_maybe); + return decrypted_request_->GetPlaintextData(); +} + +absl::StatusOr OhttpServerEncryptor::EncryptResponse( + std::string payload) { + if (!ohttp_gateway_.has_value() || !decrypted_request_.has_value()) { + return absl::InternalError( + "Emtpy `ohttp_gateway_` or `decrypted_request_`. You should call " + "`ServerDecryptRequest` first"); + } + auto server_request_context = std::move(*decrypted_request_).ReleaseContext(); + const auto encapsulate_resp = ohttp_gateway_->CreateObliviousHttpResponse( + std::move(payload), server_request_context); + if (!encapsulate_resp.ok()) { + return absl::InternalError( + std::string(encapsulate_resp.status().message())); + } + return encapsulate_resp->EncapsulateAndSerialize(); +} +} // namespace kv_server diff --git a/components/data_server/request_handler/ohttp_server_encryptor.h b/components/data_server/request_handler/ohttp_server_encryptor.h new file mode 100644 index 00000000..d1d321ea --- /dev/null +++ b/components/data_server/request_handler/ohttp_server_encryptor.h @@ -0,0 +1,54 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_SERVER_ENCRYPTOR_H_ +#define COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_SERVER_ENCRYPTOR_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/escaping.h" +#include "public/constants.h" +#include "quiche/oblivious_http/oblivious_http_gateway.h" + +namespace kv_server { + +// Handles server side decryption of requests and encryption of responses. +// Not thread safe. Supports serial decryption/encryption per request. +class OhttpServerEncryptor { + public: + // Decrypts incoming request. + // The return value points to a string stored in decrypted_request_, so its + // lifetime is tied to that object, which lifetime is in turn tied to the + // instance of OhttpEncryptor. + absl::StatusOr DecryptRequest( + absl::string_view encrypted_payload); + // Encrypts outgoing response. Since OHTTP is stateful, this method should be + // called after DecryptRequest. + absl::StatusOr EncryptResponse(std::string payload); + + private: + std::optional ohttp_gateway_; + std::optional decrypted_request_; + + // X25519 Secret key (private key). + // https://www.ietf.org/archive/id/draft-ietf-ohai-ohttp-03.html#appendix-A-2 + const std::string test_private_key_ = absl::HexStringToBytes( + "3c168975674b2fa8e465970b79c8dcf09f1c741626480bd4c6162fc5b6a98e1a"); +}; + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_REQUEST_HANDLER_OHTTP_SERVER_ENCRYPTOR_H_ diff --git a/components/data_server/request_handler/v2_response_data.proto b/components/data_server/request_handler/v2_response_data.proto new file mode 100644 index 00000000..309248c7 --- /dev/null +++ b/components/data_server/request_handler/v2_response_data.proto @@ -0,0 +1,41 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package kv_server; + +import "google/protobuf/struct.proto"; + +// Proto equivalent of a compression group in the KV V2 API: +// https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md#schema-of-the-request +message V2CompressionGroup { + repeated Partition partitions = 1; +} + +message Partition { + int64 id = 1; + repeated KeyGroupOutput key_group_outputs = 2; +} + +message KeyGroupOutput { + repeated string tags = 1; + map key_values = 2; +} + +message ValueObject { + google.protobuf.Value value = 1; + int64 global_ttl_sec = 2; + int64 dedicated_ttl_sec = 3; +} diff --git a/components/data_server/request_handler/v2_response_data_proto_test.cc b/components/data_server/request_handler/v2_response_data_proto_test.cc new file mode 100644 index 00000000..28371643 --- /dev/null +++ b/components/data_server/request_handler/v2_response_data_proto_test.cc @@ -0,0 +1,130 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "components/data_server/request_handler/v2_response_data.pb.h" +#include "gmock/gmock.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/json_util.h" +#include "gtest/gtest.h" +#include "public/test_util/proto_matcher.h" + +using google::protobuf::TextFormat; +using testing::_; +using testing::Return; + +using google::protobuf::util::JsonStringToMessage; +using google::protobuf::util::MessageToJsonString; + +namespace kv_server { +namespace { + +TEST(V2CompressionGroupProtoTest, + SuccessfullyParsesV2ResponseCompressionGroup) { + V2CompressionGroup v2_response_data_proto; + std::string v2_response_data_json = R"( + { + "partitions": [ + { + "id": 0, + "keyGroupOutputs": [ + { + "keyValues": { + "hello": { + "value": "world" + } + }, + "tags": [ + "custom", + "keys" + ] + }, + { + "keyValues": { + "hello": { + "value": "world" + } + }, + "tags": [ + "structured", + "groupNames" + ] + } + ] + }, + { + "id": 1, + "keyGroupOutputs": [ + { + "keyValues": { + "hello2": { + "value": "world2" + } + }, + "tags": [ + "custom", + "keys" + ] + } + ] + } + ] + } +)"; + auto json_to_proto_status = + JsonStringToMessage(v2_response_data_json, &v2_response_data_proto); + EXPECT_TRUE(json_to_proto_status.ok()); + EXPECT_EQ(json_to_proto_status.message().as_string(), ""); + V2CompressionGroup expected; + TextFormat::ParseFromString( + R"pb(partitions { + key_group_outputs { + tags: "custom" + tags: "keys" + key_values { + key: "hello" + value { value { string_value: "world" } } + } + } + key_group_outputs { + tags: "structured" + tags: "groupNames" + key_values { + key: "hello" + value { value { string_value: "world" } } + } + } + } + partitions { + id: 1 + key_group_outputs { + tags: "custom" + tags: "keys" + key_values { + key: "hello2" + value { value { string_value: "world2" } } + } + } + })pb", + &expected); + EXPECT_THAT(v2_response_data_proto, EqualsProto(expected)); +} + +} // namespace +} // namespace kv_server diff --git a/components/data_server/server/BUILD b/components/data_server/server/BUILD index 34903035..31898152 100644 --- a/components/data_server/server/BUILD +++ b/components/data_server/server/BUILD @@ -92,6 +92,7 @@ cc_test( ], deps = [ ":lifecycle_heartbeat", + ":mocks", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@google_privacysandbox_servers_common//src/cpp/telemetry:mocks", @@ -131,16 +132,19 @@ cc_library( "//components/data_server/cache", "//components/data_server/cache:key_value_cache", "//components/data_server/data_loading:data_orchestrator", + "//components/data_server/request_handler:get_values_adapter", "//components/data_server/request_handler:get_values_handler", "//components/data_server/request_handler:get_values_v2_handler", "//components/errors:retry", - "//components/internal_lookup:constants", - "//components/internal_lookup:lookup_client_impl", - "//components/internal_lookup:lookup_server_impl", + "//components/internal_server:constants", + "//components/internal_server:lookup_client_impl", + "//components/internal_server:lookup_server_impl", + "//components/internal_server:sharded_lookup_server_impl", + "//components/sharding:cluster_mappings_manager", "//components/telemetry:kv_telemetry", - "//components/udf:code_fetcher", - "//components/udf:get_values_hook_impl", + "//components/udf:get_values_hook", "//components/udf:udf_client", + "//components/udf:udf_config_builder", "//components/util:periodic_closure", "//components/util:platform_initializer", "//components/util:version_linkstamp", @@ -154,6 +158,7 @@ cc_library( "@com_github_grpc_grpc//:grpc++_reflection", # for grpc_cli "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", "@google_privacysandbox_servers_common//src/cpp/telemetry", "@google_privacysandbox_servers_common//src/cpp/telemetry:init", @@ -170,6 +175,7 @@ cc_test( "//:local_platform": ["server_local_test.cc"], }), deps = [ + ":mocks", ":server_lib", "//components/udf:mocks", "@com_google_absl//absl/flags:flag", @@ -186,6 +192,7 @@ cc_binary( visibility = ["//production/packaging:__subpackages__"], deps = [ ":server_lib", + "//components/sharding:shard_manager", "//components/util:version_linkstamp", "@com_github_google_glog//:glog", "@com_google_absl//absl/debugging:failure_signal_handler", @@ -215,3 +222,13 @@ sh_test( "smoke", ], ) + +cc_library( + name = "mocks", + testonly = 1, + hdrs = ["mocks.h"], + visibility = ["//components/sharding:__subpackages__"], + deps = [ + "@com_google_googletest//:gtest", + ], +) diff --git a/components/data_server/server/key_value_service_impl.cc b/components/data_server/server/key_value_service_impl.cc index 02457309..f29b4963 100644 --- a/components/data_server/server/key_value_service_impl.cc +++ b/components/data_server/server/key_value_service_impl.cc @@ -23,8 +23,6 @@ #include "src/cpp/telemetry/metrics_recorder.h" #include "src/cpp/telemetry/telemetry.h" -constexpr char* kGetValuesSpan = "GetValues"; -constexpr char* kBinaryGetValuesSpan = "BinaryHttpGetValues"; constexpr char* kGetValuesSuccess = "GetValuesSuccess"; namespace kv_server { diff --git a/components/data_server/server/key_value_service_v2_impl.cc b/components/data_server/server/key_value_service_v2_impl.cc index fb66f6b8..bac80fb1 100644 --- a/components/data_server/server/key_value_service_v2_impl.cc +++ b/components/data_server/server/key_value_service_v2_impl.cc @@ -20,8 +20,6 @@ #include "src/cpp/telemetry/metrics_recorder.h" #include "src/cpp/telemetry/telemetry.h" -constexpr char* kGetValuesV2Span = "GetValuesv2"; - namespace kv_server { namespace { @@ -39,9 +37,6 @@ grpc::ServerUnaryReactor* HandleRequest( CallbackServerContext* context, const RequestT* request, ResponseT* response, const GetValuesV2Handler& handler, HandlerFunctionT handler_function) { - auto span = GetTracer()->StartSpan(kGetValuesV2Span); - auto scope = opentelemetry::trace::Scope(span); - grpc::Status status = (handler.*handler_function)(*request, response); auto* reactor = context->DefaultReactor(); diff --git a/components/data_server/server/lifecycle_heartbeat_test.cc b/components/data_server/server/lifecycle_heartbeat_test.cc index e8f61fc7..be7d1975 100644 --- a/components/data_server/server/lifecycle_heartbeat_test.cc +++ b/components/data_server/server/lifecycle_heartbeat_test.cc @@ -16,7 +16,9 @@ #include #include +#include +#include "components/data_server/server/mocks.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "src/cpp/telemetry/mocks.h" @@ -53,23 +55,14 @@ class FakePeriodicClosure : public PeriodicClosure { std::function closure_; }; -class MockInstanceClient : public InstanceClient { - public: - MOCK_METHOD(absl::StatusOr, GetEnvironmentTag, (), (override)); - MOCK_METHOD(absl::StatusOr, GetShardNumTag, (), (override)); - MOCK_METHOD(absl::Status, RecordLifecycleHeartbeat, - (std::string_view lifecycle_hook_name), (override)); - MOCK_METHOD(absl::Status, CompleteLifecycle, - (std::string_view lifecycle_hook_name), (override)); - MOCK_METHOD(absl::StatusOr, GetInstanceId, (), (override)); -}; - class MockParameterClient : public ParameterClient { public: MOCK_METHOD(absl::StatusOr, GetParameter, (std::string_view parameter_name), (const, override)); MOCK_METHOD(absl::StatusOr, GetInt32Parameter, (std::string_view parameter_name), (const, override)); + MOCK_METHOD(absl::StatusOr, GetBoolParameter, + (std::string_view parameter_name), (const, override)); }; class MockParameterFetcher : public ParameterFetcher { @@ -80,6 +73,8 @@ class MockParameterFetcher : public ParameterFetcher { (const, override)); MOCK_METHOD(int32_t, GetInt32Parameter, (std::string_view parameter_suffix), (const, override)); + MOCK_METHOD(bool, GetBoolParameter, (std::string_view parameter_suffix), + (const, override)); private: MockParameterClient client; diff --git a/components/data_server/server/mocks.h b/components/data_server/server/mocks.h new file mode 100644 index 00000000..284d2214 --- /dev/null +++ b/components/data_server/server/mocks.h @@ -0,0 +1,45 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef COMPONENTS_DATA_SERVER_SERVER_MOCKS_H_ +#define COMPONENTS_DATA_SERVER_SERVER_MOCKS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "components/cloud_config/instance_client.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +class MockInstanceClient : public InstanceClient { + public: + MOCK_METHOD(absl::StatusOr, GetEnvironmentTag, (), (override)); + MOCK_METHOD(absl::StatusOr, GetShardNumTag, (), (override)); + MOCK_METHOD(absl::Status, RecordLifecycleHeartbeat, + (std::string_view lifecycle_hook_name), (override)); + MOCK_METHOD(absl::Status, CompleteLifecycle, + (std::string_view lifecycle_hook_name), (override)); + MOCK_METHOD(absl::StatusOr, GetInstanceId, (), (override)); + MOCK_METHOD(absl::StatusOr>, + DescribeInstanceGroupInstances, + (const absl::flat_hash_set&), (override)); + MOCK_METHOD(absl::StatusOr>, DescribeInstances, + (const absl::flat_hash_set&), (override)); +}; + +} // namespace kv_server +#endif // COMPONENTS_DATA_SERVER_SERVER_MOCKS_H_ diff --git a/components/data_server/server/parameter_fetcher.cc b/components/data_server/server/parameter_fetcher.cc index 8742b1b4..b5aa0401 100644 --- a/components/data_server/server/parameter_fetcher.cc +++ b/components/data_server/server/parameter_fetcher.cc @@ -53,6 +53,16 @@ int32_t ParameterFetcher::GetInt32Parameter( "GetParameter", metrics_recorder_, {{"param", param_name}}); } +bool ParameterFetcher::GetBoolParameter( + std::string_view parameter_suffix) const { + const std::string param_name = GetParamName(parameter_suffix); + return TraceRetryUntilOk( + [this, ¶m_name] { + return parameter_client_.GetBoolParameter(param_name); + }, + "GetParameter", metrics_recorder_, {{"param", param_name}}); +} + std::string ParameterFetcher::GetParamName( std::string_view parameter_suffix) const { const std::vector v = {kServiceName, environment_, diff --git a/components/data_server/server/parameter_fetcher.h b/components/data_server/server/parameter_fetcher.h index 49b0e504..16cc52e6 100644 --- a/components/data_server/server/parameter_fetcher.h +++ b/components/data_server/server/parameter_fetcher.h @@ -44,6 +44,9 @@ class ParameterFetcher { // This function will retry any necessary requests until it succeeds. virtual int32_t GetInt32Parameter(std::string_view parameter_suffix) const; + // This function will retry any necessary requests until it succeeds. + virtual bool GetBoolParameter(std::string_view parameter_suffix) const; + virtual NotifierMetadata GetBlobStorageNotifierMetadata() const; virtual NotifierMetadata GetRealtimeNotifierMetadata() const; diff --git a/components/data_server/server/parameter_fetcher_local_test.cc b/components/data_server/server/parameter_fetcher_local_test.cc index 0b0182d0..ea695cc6 100644 --- a/components/data_server/server/parameter_fetcher_local_test.cc +++ b/components/data_server/server/parameter_fetcher_local_test.cc @@ -31,6 +31,8 @@ class MockParameterClient : public ParameterClient { (std::string_view parameter_name), (const, override)); MOCK_METHOD(absl::StatusOr, GetInt32Parameter, (std::string_view parameter_name), (const, override)); + MOCK_METHOD(absl::StatusOr, GetBoolParameter, + (std::string_view parameter_name), (const, override)); }; TEST(ParameterFetcherTest, CreateChangeNotifierSmokeTest) { diff --git a/components/data_server/server/server.cc b/components/data_server/server/server.cc index f532b770..d6bd00b5 100644 --- a/components/data_server/server/server.cc +++ b/components/data_server/server/server.cc @@ -14,22 +14,30 @@ #include "components/data_server/server/server.h" +#include + #include "absl/flags/flag.h" #include "absl/flags/parse.h" #include "absl/flags/usage.h" +#include "absl/functional/bind_front.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "components/data_server/request_handler/get_values_adapter.h" #include "components/data_server/request_handler/get_values_handler.h" #include "components/data_server/request_handler/get_values_v2_handler.h" #include "components/data_server/server/key_value_service_impl.h" #include "components/data_server/server/key_value_service_v2_impl.h" -#include "components/data_server/server/lifecycle_heartbeat.h" #include "components/errors/retry.h" -#include "components/internal_lookup/constants.h" -#include "components/internal_lookup/lookup_client.h" -#include "components/internal_lookup/lookup_server_impl.h" +#include "components/internal_server/constants.h" +#include "components/internal_server/lookup_client.h" +#include "components/internal_server/lookup_server_impl.h" +#include "components/internal_server/run_query_client.h" +#include "components/internal_server/sharded_lookup_server_impl.h" +#include "components/sharding/cluster_mappings_manager.h" #include "components/telemetry/kv_telemetry.h" -#include "components/udf/get_values_hook_impl.h" +#include "components/udf/get_values_hook.h" +#include "components/udf/run_query_hook.h" +#include "components/udf/udf_config_builder.h" #include "components/util/build_info.h" #include "glog/logging.h" #include "grpcpp/ext/proto_server_reflection_plugin.h" @@ -44,6 +52,8 @@ ABSL_FLAG(uint16_t, port, 50051, "Port the server is listening on. Defaults to 50051."); +ABSL_FLAG(std::string, internal_server_address, "0.0.0.0:50099", + "Internal server address. Defaults to 0.0.0.0:50099."); namespace kv_server { namespace { @@ -72,6 +82,8 @@ constexpr absl::string_view kS3ClientMaxConnectionsParameterSuffix = constexpr absl::string_view kS3ClientMaxRangeBytesParameterSuffix = "s3client-max-range-bytes"; constexpr absl::string_view kNumShardsParameterSuffix = "num-shards"; +constexpr absl::string_view kUdfNumWorkersParameterSuffix = "udf-num-workers"; +constexpr absl::string_view kRouteV1ToV2Suffix = "route-v1-to-v2"; opentelemetry::sdk::metrics::PeriodicExportingMetricReaderOptions GetMetricsOptions(const ParameterClient& parameter_client, @@ -102,14 +114,17 @@ GetMetricsOptions(const ParameterClient& parameter_client, Server::Server() : metrics_recorder_( TelemetryProvider::GetInstance().CreateMetricsRecorder()), - cache_(KeyValueCache::Create()) { + cache_(KeyValueCache::Create()), + get_values_hook_(GetValuesHook::Create(absl::bind_front( + LookupClient::Create, absl::GetFlag(FLAGS_internal_server_address)))), + run_query_hook_(RunQueryHook::Create( + absl::bind_front(RunQueryClient::Create, + absl::GetFlag(FLAGS_internal_server_address)))) { cache_->UpdateKeyValue( "hi", "Hello, world! If you are seeing this, it means you can " "query me successfully", /*logical_commit_time = */ 1); - cache_->UpdateKeyValue(kUdfCodeSnippetKey, kDefaultUdfCodeSnippet, 1); - cache_->UpdateKeyValue(kUdfHandlerNameKey, kDefaultUdfHandlerName, 1); } void Server::InitializeTelemetry(const ParameterClient& parameter_client, @@ -127,52 +142,52 @@ void Server::InitializeTelemetry(const ParameterClient& parameter_client, metrics_recorder_ = TelemetryProvider::GetInstance().CreateMetricsRecorder(); } -absl::Status Server::CreateDefaultInstancesIfNecessary( +absl::Status Server::CreateDefaultInstancesIfNecessaryAndGetEnvironment( std::unique_ptr parameter_client, std::unique_ptr instance_client, - std::unique_ptr code_fetcher, std::unique_ptr udf_client) { - if (parameter_client == nullptr) { - parameter_client_ = std::move(ParameterClient::Create()); - } else { - parameter_client_ = std::move(parameter_client); - } + parameter_client_ = parameter_client == nullptr ? ParameterClient::Create() + : std::move(parameter_client); + instance_client_ = instance_client == nullptr + ? InstanceClient::Create(*metrics_recorder_) + : std::move(instance_client); + environment_ = TraceRetryUntilOk( + [this]() { return instance_client_->GetEnvironmentTag(); }, + "GetEnvironment", nullptr); + LOG(INFO) << "Retrieved environment: " << environment_; + ParameterFetcher parameter_fetcher(environment_, *parameter_client_, + metrics_recorder_.get()); - if (instance_client == nullptr) { - instance_client_ = std::move(InstanceClient::Create()); - } else { - instance_client_ = std::move(instance_client); - } + int32_t number_of_workers = + parameter_fetcher.GetInt32Parameter(kUdfNumWorkersParameterSuffix); - if (code_fetcher == nullptr) { - code_fetcher_ = std::move(CodeFetcher::Create()); - } else { - code_fetcher_ = std::move(code_fetcher); + if (udf_client != nullptr) { + udf_client_ = std::move(udf_client); + return absl::OkStatus(); } - if (udf_client == nullptr) { - absl::StatusOr> udf_client_or_status = - UdfClient::Create(UdfClient::ConfigWithGetValuesHook( - *NewGetValuesHook(&LookupClient::GetSingleton))); - if (!udf_client_or_status.ok()) { - return udf_client_or_status.status(); - } + UdfConfigBuilder config_builder; + // TODO(b/289244673): Once roma interface is updated, internal lookup client + // can be removed and we can own the unique ptr to the hooks. + absl::StatusOr> udf_client_or_status = + UdfClient::Create(config_builder.RegisterGetValuesHook(*get_values_hook_) + .RegisterRunQueryHook(*run_query_hook_) + .RegisterLoggingHook() + .SetNumberOfWorkers(number_of_workers) + .Config()); + if (udf_client_or_status.ok()) { udf_client_ = std::move(*udf_client_or_status); - } else { - udf_client_ = std::move(udf_client); } - - return absl::OkStatus(); + return udf_client_or_status.status(); } absl::Status Server::Init( std::unique_ptr parameter_client, std::unique_ptr instance_client, - std::unique_ptr code_fetcher, std::unique_ptr udf_client) { { - absl::Status status = CreateDefaultInstancesIfNecessary( + absl::Status status = CreateDefaultInstancesIfNecessaryAndGetEnvironment( std::move(parameter_client), std::move(instance_client), - std::move(code_fetcher), std::move(udf_client)); + std::move(udf_client)); if (!status.ok()) { return status; } @@ -184,15 +199,6 @@ absl::Status Server::Init( } absl::Status Server::InitOnceInstancesAreCreated() { - { - InstanceClient* instance_client_ptr = instance_client_.get(); - environment_ = TraceRetryUntilOk( - [instance_client_ptr]() { - return instance_client_ptr->GetEnvironmentTag(); - }, - "GetEnvironment", nullptr); - } - LOG(INFO) << "Retrieved environment: " << environment_; InitializeTelemetry(*parameter_client_, *instance_client_); auto span = GetTracer()->StartSpan("InitServer"); @@ -207,6 +213,9 @@ absl::Status Server::InitOnceInstancesAreCreated() { status != absl::OkStatus()) { return status; } + + SetDefaultUdfCodeObject(); + const auto shard_num_status = instance_client_->GetShardNumTag(); if (!shard_num_status.ok()) { return shard_num_status.status(); @@ -222,6 +231,7 @@ absl::Status Server::InitOnceInstancesAreCreated() { num_shards_ = parameter_fetcher.GetInt32Parameter(kNumShardsParameterSuffix); LOG(INFO) << "Retrieved " << kNumShardsParameterSuffix << " parameter: " << num_shards_; + blob_client_ = CreateBlobClient(parameter_fetcher); delta_stream_reader_factory_ = CreateStreamRecordReaderFactory(parameter_fetcher); @@ -236,8 +246,7 @@ absl::Status Server::InitOnceInstancesAreCreated() { SetQueueManager(metadata, message_service_blob_.get()); grpc_server_ = CreateAndStartGrpcServer(); - internal_lookup_server_ = CreateAndStartInternalLookupServer(); - + remote_lookup_server_ = CreateAndStartRemoteLookupServer(); { auto status_or_notifier = BlobStorageChangeNotifier::Create( std::move(metadata), *metrics_recorder_); @@ -250,8 +259,9 @@ absl::Status Server::InitOnceInstancesAreCreated() { } auto realtime_notifier_metadata = parameter_fetcher.GetRealtimeNotifierMetadata(); - auto realtime_message_service_status = - MessageService::Create(realtime_notifier_metadata); + auto realtime_message_service_status = MessageService::Create( + realtime_notifier_metadata, + (num_shards_ > 1 ? std::optional(shard_num_) : std::nullopt)); if (!realtime_message_service_status.ok()) { return realtime_message_service_status.status(); } @@ -275,11 +285,22 @@ absl::Status Server::InitOnceInstancesAreCreated() { RealtimeNotifier::Create(*metrics_recorder_); realtime_options_.push_back(std::move(realtime_options)); } - data_orchestrator_ = CreateDataOrchestrator(parameter_fetcher, *udf_client_); + data_orchestrator_ = CreateDataOrchestrator(parameter_fetcher); TraceRetryUntilOk([this] { return data_orchestrator_->Start(); }, "StartDataOrchestrator", metrics_recorder_.get()); - SetUdfCodeObject(*code_fetcher_); - + if (num_shards_ > 1) { + // At this point the server is healthy and the initialization is over. + // The only missing piece is having a shard map, which is dependent on + // other instances being `healthy`. Mark this instance as healthy so that + // other instances can pull it in for their mapping. + lifecycle_heartbeat->Finish(); + } + absl::StatusOr> lookup_server_or = + CreateAndStartInternalLookupServer(); + if (!lookup_server_or.ok()) { + return lookup_server_or.status(); + } + internal_lookup_server_ = std::move(*lookup_server_or); return absl::OkStatus(); } @@ -307,7 +328,9 @@ void Server::GracefulShutdown(absl::Duration timeout) { if (internal_lookup_server_) { internal_lookup_server_->Shutdown(); } - + if (remote_lookup_server_) { + remote_lookup_server_->Shutdown(); + } if (grpc_server_) { grpc_server_->Shutdown(absl::ToChronoTime(absl::Now() + timeout)); } else { @@ -319,7 +342,12 @@ void Server::GracefulShutdown(absl::Duration timeout) { LOG(ERROR) << "Failed to stop UDF client: " << status; } } - + if (cluster_mappings_manager_ && cluster_mappings_manager_->IsRunning()) { + const absl::Status status = cluster_mappings_manager_->Stop(); + if (!status.ok()) { + LOG(ERROR) << "Failed to stop cluster mappings manager: " << status; + } + } const absl::Status status = MaybeShutdownNotifiers(); if (!status.ok()) { LOG(ERROR) << "Failed to shutdown notifiers. Got status " << status; @@ -331,7 +359,9 @@ void Server::ForceShutdown() { if (internal_lookup_server_) { internal_lookup_server_->Shutdown(); } - + if (remote_lookup_server_) { + remote_lookup_server_->Shutdown(); + } if (grpc_server_) { grpc_server_->Shutdown(); } else { @@ -347,6 +377,12 @@ void Server::ForceShutdown() { LOG(ERROR) << "Failed to stop UDF client: " << status; } } + if (cluster_mappings_manager_ && cluster_mappings_manager_->IsRunning()) { + const absl::Status status = cluster_mappings_manager_->Stop(); + if (!status.ok()) { + LOG(ERROR) << "Failed to stop cluster mappings manager: " << status; + } + } } std::unique_ptr Server::CreateBlobClient( @@ -378,18 +414,11 @@ Server::CreateStreamRecordReaderFactory( } std::unique_ptr Server::CreateDataOrchestrator( - const ParameterFetcher& parameter_fetcher, UdfClient& udf_client) { + const ParameterFetcher& parameter_fetcher) { const std::string data_bucket = parameter_fetcher.GetParameter(kDataBucketParameterSuffix); LOG(INFO) << "Retrieved " << kDataBucketParameterSuffix << " parameter: " << data_bucket; - auto udf_update_callback = [&udf_client]() { - const absl::Status status = udf_client.SetCodeObject({}); - if (!status.ok()) { - LOG(ERROR) << "Error setting code object: " << status; - } - }; - return TraceRetryUntilOk( [&] { return DataOrchestrator::TryCreate( @@ -401,7 +430,7 @@ std::unique_ptr Server::CreateDataOrchestrator( .change_notifier = *change_notifier_, .delta_stream_reader_factory = *delta_stream_reader_factory_, .realtime_options = realtime_options_, - .udf_update_callback = udf_update_callback, + .udf_client = *udf_client_, .shard_num = shard_num_, .num_shards = num_shards_, }, @@ -413,7 +442,12 @@ std::unique_ptr Server::CreateDataOrchestrator( void Server::CreateGrpcServices(const ParameterFetcher& parameter_fetcher) { const std::string mode = parameter_fetcher.GetParameter(kModeParameterSuffix); LOG(INFO) << "Retrieved " << kModeParameterSuffix << " parameter: " << mode; - GetValuesHandler handler(*cache_, *metrics_recorder_, mode == "DSP"); + const bool use_v2 = parameter_fetcher.GetBoolParameter(kRouteV1ToV2Suffix); + LOG(INFO) << "Retrieved " << kRouteV1ToV2Suffix << " parameter: " << use_v2; + get_values_adapter_ = GetValuesAdapter::Create( + std::make_unique(*udf_client_, *metrics_recorder_)); + GetValuesHandler handler(*cache_, *get_values_adapter_, *metrics_recorder_, + mode == "DSP", use_v2); grpc_services_.push_back(std::make_unique( std::move(handler), *metrics_recorder_)); GetValuesV2Handler v2handler(*udf_client_, *metrics_recorder_); @@ -439,31 +473,74 @@ std::unique_ptr Server::CreateAndStartGrpcServer() { return builder.BuildAndStart(); } -std::unique_ptr Server::CreateAndStartInternalLookupServer() { - internal_lookup_service_ = std::make_unique(*cache_); +absl::Status Server::CreateShardManager() { + cluster_mappings_manager_ = std::make_unique( + environment_, num_shards_, *metrics_recorder_, *instance_client_); + auto& num_shards = num_shards_; + auto& cluster_mappings_manager = *cluster_mappings_manager_; + shard_manager_ = TraceRetryUntilOk( + [&cluster_mappings_manager, &num_shards] { + // It might be that the cluster mappings that are passed don't pass + // validation. E.g. a particular cluster might not have any replicas + // specified. In that case, we need to retry the creation. After an + // exponential backoff, that will trigger`GetClusterMappings` which + // at that point in time might have new replicas spun up. + return ShardManager::Create( + num_shards, cluster_mappings_manager.GetClusterMappings()); + }, + "GetShardManager", metrics_recorder_.get()); + return cluster_mappings_manager_->Start(*shard_manager_); +} + +absl::StatusOr> +Server::CreateAndStartInternalLookupServer() { + if (num_shards_ <= 1) { + internal_lookup_service_ = std::make_unique(*cache_); + } else { + if (const absl::Status status = CreateShardManager(); !status.ok()) { + return status; + } + internal_lookup_service_ = std::make_unique( + *metrics_recorder_, *cache_, num_shards_, shard_num_, *shard_manager_); + } grpc::ServerBuilder internal_lookup_server_builder; + const std::string internal_server_address = + absl::GetFlag(FLAGS_internal_server_address); internal_lookup_server_builder.AddListeningPort( - kInternalLookupServerAddress, grpc::InsecureServerCredentials()); + internal_server_address, grpc::InsecureServerCredentials()); internal_lookup_server_builder.RegisterService( internal_lookup_service_.get()); - LOG(INFO) << "Internal lookup server listening on " - << kInternalLookupServerAddress << std::endl; + LOG(INFO) << "Internal lookup server listening on " << internal_server_address + << std::endl; return internal_lookup_server_builder.BuildAndStart(); } -void Server::SetUdfCodeObject(CodeFetcher& code_fetcher) { - LOG(INFO) << "Fetching UDF Code Snippet"; - auto code_config = TraceRetryUntilOk( - [&code_fetcher] { - return code_fetcher.FetchCodeConfig(LookupClient::GetSingleton()); - }, - "FetchUntrustedCodeConfig", metrics_recorder_.get()); +std::unique_ptr Server::CreateAndStartRemoteLookupServer() { + if (num_shards_ <= 1) { + return nullptr; + } + + remote_lookup_service_ = std::make_unique(*cache_); + grpc::ServerBuilder remote_lookup_server_builder; + auto remoteLookupServerAddress = + absl::StrCat(kLocalIp, ":", kRemoteLookupServerPort); + remote_lookup_server_builder.AddListeningPort( + remoteLookupServerAddress, grpc::InsecureServerCredentials()); + remote_lookup_server_builder.RegisterService(remote_lookup_service_.get()); + LOG(INFO) << "Remote lookup server listening on " << remoteLookupServerAddress + << std::endl; + return remote_lookup_server_builder.BuildAndStart(); +} - LOG(INFO) << "Setting UDF Code Snippet"; - const absl::Status status = - udf_client_->SetCodeObject(std::move(code_config)); +void Server::SetDefaultUdfCodeObject() { + VLOG(8) << "Setting default UDF code config. Snippet: " + << kDefaultUdfCodeSnippet; + const absl::Status status = udf_client_->SetCodeObject( + CodeConfig{.js = kDefaultUdfCodeSnippet, + .udf_handler_name = kDefaultUdfHandlerName, + .logical_commit_time = kDefaultLogicalCommitTime}); if (!status.ok()) { LOG(ERROR) << "Error setting code object: " << status; } diff --git a/components/data_server/server/server.h b/components/data_server/server/server.h index fda37c79..56c491e4 100644 --- a/components/data_server/server/server.h +++ b/components/data_server/server/server.h @@ -30,8 +30,13 @@ #include "components/data_server/cache/cache.h" #include "components/data_server/cache/key_value_cache.h" #include "components/data_server/data_loading/data_orchestrator.h" +#include "components/data_server/request_handler/get_values_adapter.h" +#include "components/data_server/server/lifecycle_heartbeat.h" #include "components/data_server/server/parameter_fetcher.h" -#include "components/udf/code_fetcher.h" +#include "components/sharding/cluster_mappings_manager.h" +#include "components/sharding/shard_manager.h" +#include "components/udf/get_values_hook.h" +#include "components/udf/run_query_hook.h" #include "components/udf/udf_client.h" #include "components/util/platform_initializer.h" #include "grpcpp/grpcpp.h" @@ -51,7 +56,6 @@ class Server { absl::Status Init( std::unique_ptr parameter_client = nullptr, std::unique_ptr instance_client = nullptr, - std::unique_ptr code_fetcher = nullptr, std::unique_ptr udf_client = nullptr); // Wait for the server to shut down. Note that some other thread must be @@ -65,10 +69,9 @@ class Server { private: // If objects were not passed in for unit testing purposes then create them. - absl::Status CreateDefaultInstancesIfNecessary( + absl::Status CreateDefaultInstancesIfNecessaryAndGetEnvironment( std::unique_ptr parameter_client, std::unique_ptr instance_client, - std::unique_ptr code_fetcher, std::unique_ptr udf_client); absl::Status InitOnceInstancesAreCreated(); @@ -78,7 +81,7 @@ class Server { std::unique_ptr> CreateStreamRecordReaderFactory(const ParameterFetcher& parameter_fetcher); std::unique_ptr CreateDataOrchestrator( - const ParameterFetcher& parameter_fetcher, UdfClient& udf_client); + const ParameterFetcher& parameter_fetcher); void CreateGrpcServices(const ParameterFetcher& parameter_fetcher); absl::Status MaybeShutdownNotifiers(); @@ -88,12 +91,15 @@ class Server { std::unique_ptr CreateDeltaFileNotifier( const ParameterFetcher& parameter_fetcher); - std::unique_ptr CreateAndStartInternalLookupServer(); + absl::StatusOr> + CreateAndStartInternalLookupServer(); + std::unique_ptr CreateAndStartRemoteLookupServer(); - void SetUdfCodeObject(CodeFetcher& code_fetcher); + void SetDefaultUdfCodeObject(); void InitializeTelemetry(const ParameterClient& parameter_client, InstanceClient& instance_client); + absl::Status CreateShardManager(); // This must be first, otherwise the AWS SDK will crash when it's called: PlatformInitializer platform_initializer_; @@ -106,7 +112,9 @@ class Server { std::vector> grpc_services_; std::unique_ptr grpc_server_; std::unique_ptr cache_; - std::unique_ptr code_fetcher_; + std::unique_ptr get_values_adapter_; + std::unique_ptr get_values_hook_; + std::unique_ptr run_query_hook_; // BlobStorageClient must outlive DeltaFileNotifier std::unique_ptr blob_client_; @@ -123,11 +131,22 @@ class Server { std::unique_ptr data_orchestrator_; - // Internal Lookup Server + // Internal Lookup Server -- lookup requests to this server originate (from + // UDF sandbox) and terminate on the same machine. std::unique_ptr internal_lookup_service_; std::unique_ptr internal_lookup_server_; + std::unique_ptr shard_manager_; + // Internal Sharded Lookup Server -- + // if `num_shards` > 1, then serves requests originating from servers with + // a different `shard_num`. Only has data for `shard_num` assigned to the + // server at the start up. if `num_shards` == 1, then null, since no remote + // lookups are necessray + std::unique_ptr remote_lookup_service_; + std::unique_ptr remote_lookup_server_; + std::unique_ptr udf_client_; + std::unique_ptr cluster_mappings_manager_; int32_t shard_num_; int32_t num_shards_; diff --git a/components/data_server/server/server_local_test.cc b/components/data_server/server/server_local_test.cc index 1184813d..2a35a0fc 100644 --- a/components/data_server/server/server_local_test.cc +++ b/components/data_server/server/server_local_test.cc @@ -16,6 +16,7 @@ #include +#include "components/data_server/server/mocks.h" #include "components/data_server/server/server.h" #include "components/udf/mocks.h" #include "gmock/gmock.h" @@ -29,23 +30,14 @@ using opentelemetry::sdk::resource::Resource; using privacy_sandbox::server_common::ConfigureMetrics; using testing::_; -class MockInstanceClient : public InstanceClient { - public: - MOCK_METHOD(absl::StatusOr, GetEnvironmentTag, (), (override)); - MOCK_METHOD(absl::StatusOr, GetShardNumTag, (), (override)); - MOCK_METHOD(absl::Status, RecordLifecycleHeartbeat, - (std::string_view lifecycle_hook_name), (override)); - MOCK_METHOD(absl::Status, CompleteLifecycle, - (std::string_view lifecycle_hook_name), (override)); - MOCK_METHOD(absl::StatusOr, GetInstanceId, (), (override)); -}; - class MockParameterClient : public ParameterClient { public: MOCK_METHOD(absl::StatusOr, GetParameter, (std::string_view parameter_name), (const, override)); MOCK_METHOD(absl::StatusOr, GetInt32Parameter, (std::string_view parameter_name), (const, override)); + MOCK_METHOD(absl::StatusOr, GetBoolParameter, + (std::string_view parameter_name), (const, override)); void RegisterRequiredTelemetryExpectations() { EXPECT_CALL(*this, @@ -96,6 +88,7 @@ TEST(ServerLocalTest, InitFailsWithNoDeltaDirectory) { auto instance_client = std::make_unique(); auto parameter_client = std::make_unique(); parameter_client->RegisterRequiredTelemetryExpectations(); + auto mock_udf_client = std::make_unique(); EXPECT_CALL(*instance_client, GetEnvironmentTag()) .WillOnce(::testing::Return("environment")); @@ -122,14 +115,17 @@ TEST(ServerLocalTest, InitFailsWithNoDeltaDirectory) { EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-num-shards")) .WillOnce(::testing::Return(1)); - - auto mock_udf_client = std::make_unique(); - auto mock_code_fetcher = std::make_unique(); + EXPECT_CALL(*parameter_client, + GetInt32Parameter("kv-server-environment-udf-num-workers")) + .WillOnce(::testing::Return(2)); + EXPECT_CALL(*parameter_client, + GetBoolParameter("kv-server-environment-route-v1-to-v2")) + .WillOnce(::testing::Return(false)); kv_server::Server server; absl::Status status = server.Init(std::move(parameter_client), std::move(instance_client), - std::move(mock_code_fetcher), std::move(mock_udf_client)); + std::move(mock_udf_client)); EXPECT_FALSE(status.ok()); } @@ -137,6 +133,7 @@ TEST(ServerLocalTest, InitPassesWithDeltaDirectoryAndRealtimeDirectory) { auto instance_client = std::make_unique(); auto parameter_client = std::make_unique(); parameter_client->RegisterRequiredTelemetryExpectations(); + auto mock_udf_client = std::make_unique(); EXPECT_CALL(*instance_client, GetEnvironmentTag()) .WillOnce(::testing::Return("environment")); @@ -173,19 +170,20 @@ TEST(ServerLocalTest, InitPassesWithDeltaDirectoryAndRealtimeDirectory) { EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-num-shards")) .WillOnce(::testing::Return(1)); + EXPECT_CALL(*parameter_client, + GetInt32Parameter("kv-server-environment-udf-num-workers")) + .WillOnce(::testing::Return(2)); + EXPECT_CALL(*parameter_client, + GetBoolParameter("kv-server-environment-route-v1-to-v2")) + .WillOnce(::testing::Return(false)); - auto mock_udf_client = std::make_unique(); - auto mock_code_fetcher = std::make_unique(); - CodeConfig code_config{.js = "function SomeUDFCode(){}", - .udf_handler_name = "SomeUDFCode"}; - - EXPECT_CALL(*mock_code_fetcher, FetchCodeConfig(_)) - .WillOnce(testing::Return(code_config)); + EXPECT_CALL(*mock_udf_client, SetCodeObject(_)) + .WillOnce(testing::Return(absl::OkStatus())); kv_server::Server server; absl::Status status = server.Init(std::move(parameter_client), std::move(instance_client), - std::move(mock_code_fetcher), std::move(mock_udf_client)); + std::move(mock_udf_client)); EXPECT_TRUE(status.ok()); } @@ -193,6 +191,7 @@ TEST(ServerLocalTest, GracefulServerShutdown) { auto instance_client = std::make_unique(); auto parameter_client = std::make_unique(); parameter_client->RegisterRequiredTelemetryExpectations(); + auto mock_udf_client = std::make_unique(); EXPECT_CALL(*instance_client, GetEnvironmentTag()) .WillOnce(::testing::Return("environment")); @@ -229,19 +228,20 @@ TEST(ServerLocalTest, GracefulServerShutdown) { EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-num-shards")) .WillOnce(::testing::Return(1)); + EXPECT_CALL(*parameter_client, + GetInt32Parameter("kv-server-environment-udf-num-workers")) + .WillOnce(::testing::Return(2)); + EXPECT_CALL(*parameter_client, + GetBoolParameter("kv-server-environment-route-v1-to-v2")) + .WillOnce(::testing::Return(false)); - auto mock_udf_client = std::make_unique(); - auto mock_code_fetcher = std::make_unique(); - CodeConfig code_config{.js = "function SomeUDFCode(){}", - .udf_handler_name = "SomeUDFCode"}; - - EXPECT_CALL(*mock_code_fetcher, FetchCodeConfig(_)) - .WillOnce(testing::Return(code_config)); + EXPECT_CALL(*mock_udf_client, SetCodeObject(_)) + .WillOnce(testing::Return(absl::OkStatus())); kv_server::Server server; absl::Status status = server.Init(std::move(parameter_client), std::move(instance_client), - std::move(mock_code_fetcher), std::move(mock_udf_client)); + std::move(mock_udf_client)); ASSERT_TRUE(status.ok()); std::thread server_thread(&kv_server::Server::Wait, &server); server.GracefulShutdown(absl::Seconds(5)); @@ -252,6 +252,7 @@ TEST(ServerLocalTest, ForceServerShutdown) { auto instance_client = std::make_unique(); auto parameter_client = std::make_unique(); parameter_client->RegisterRequiredTelemetryExpectations(); + auto mock_udf_client = std::make_unique(); EXPECT_CALL(*instance_client, GetEnvironmentTag()) .WillOnce(::testing::Return("environment")); @@ -288,19 +289,20 @@ TEST(ServerLocalTest, ForceServerShutdown) { EXPECT_CALL(*parameter_client, GetInt32Parameter("kv-server-environment-num-shards")) .WillOnce(::testing::Return(1)); + EXPECT_CALL(*parameter_client, + GetInt32Parameter("kv-server-environment-udf-num-workers")) + .WillOnce(::testing::Return(2)); + EXPECT_CALL(*parameter_client, + GetBoolParameter("kv-server-environment-route-v1-to-v2")) + .WillOnce(::testing::Return(false)); - auto mock_udf_client = std::make_unique(); - auto mock_code_fetcher = std::make_unique(); - CodeConfig code_config{.js = "function SomeUDFCode(){}", - .udf_handler_name = "SomeUDFCode"}; - - EXPECT_CALL(*mock_code_fetcher, FetchCodeConfig(_)) - .WillOnce(testing::Return(code_config)); + EXPECT_CALL(*mock_udf_client, SetCodeObject(_)) + .WillOnce(testing::Return(absl::OkStatus())); kv_server::Server server; absl::Status status = server.Init(std::move(parameter_client), std::move(instance_client), - std::move(mock_code_fetcher), std::move(mock_udf_client)); + std::move(mock_udf_client)); ASSERT_TRUE(status.ok()); std::thread server_thread(&kv_server::Server::Wait, &server); server.ForceShutdown(); diff --git a/components/errors/BUILD b/components/errors/BUILD index a29447dd..74a9b902 100644 --- a/components/errors/BUILD +++ b/components/errors/BUILD @@ -20,6 +20,9 @@ package(default_visibility = [ cc_library( name = "aws_error_util", + srcs = [ + "error_util_aws.cc", + ], hdrs = [ "error_util_aws.h", ], @@ -30,6 +33,20 @@ cc_library( ], ) +cc_test( + name = "error_util_aws_test", + size = "small", + srcs = [ + "error_util_aws_test.cc", + ], + deps = [ + ":aws_error_util", + "@aws_sdk_cpp//:core", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "retry", srcs = [ diff --git a/components/errors/error_util_aws.cc b/components/errors/error_util_aws.cc new file mode 100644 index 00000000..e66d7a35 --- /dev/null +++ b/components/errors/error_util_aws.cc @@ -0,0 +1,73 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "components/errors/error_util_aws.h" + +#include + +namespace kv_server { + +absl::StatusCode HttpResponseCodeToStatusCode( + const Aws::Http::HttpResponseCode& response_code) { + // https://sdk.amazonaws.com/cpp/api/0.12.9/d1/d33/_http_response_8h_source.html + // https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto + const int http_code = static_cast(response_code); + switch (http_code) { + case 400: + return absl::StatusCode::kInvalidArgument; + case 401: + return absl::StatusCode::kUnauthenticated; + case 403: + return absl::StatusCode::kPermissionDenied; + case 404: + return absl::StatusCode::kNotFound; + case 408: + case 440: + return absl::StatusCode::kDeadlineExceeded; + case 409: + return absl::StatusCode::kAlreadyExists; + case 412: + case 427: + return absl::StatusCode::kFailedPrecondition; + case 429: + return absl::StatusCode::kResourceExhausted; + case 499: + return absl::StatusCode::kCancelled; + case 500: + return absl::StatusCode::kInternal; + case 501: + return absl::StatusCode::kUnimplemented; + case 503: + return absl::StatusCode::kUnavailable; + case 504: + case 598: + case 599: + return absl::StatusCode::kDeadlineExceeded; + default: + if (http_code >= 200 && http_code < 300) { + return absl::StatusCode::kOk; + } + if (http_code >= 400 && http_code < 500) { + return absl::StatusCode::kFailedPrecondition; + } + if (http_code >= 500 && http_code < 600) { + return absl::StatusCode::kInternal; + } + return absl::StatusCode::kUnknown; + } +} + +} // namespace kv_server diff --git a/components/errors/error_util_aws.h b/components/errors/error_util_aws.h index b42ed022..2af319be 100644 --- a/components/errors/error_util_aws.h +++ b/components/errors/error_util_aws.h @@ -23,57 +23,9 @@ #include "aws/core/client/AWSError.h" namespace kv_server { -namespace { // NOLINT(build/namespaces_headers) + absl::StatusCode HttpResponseCodeToStatusCode( - Aws::Http::HttpResponseCode response_code) { - // https://sdk.amazonaws.com/cpp/api/0.12.9/d1/d33/_http_response_8h_source.html - // https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto - const int http_code = static_cast(response_code); - switch (http_code) { - case 400: - return absl::StatusCode::kInvalidArgument; - case 401: - return absl::StatusCode::kUnauthenticated; - case 403: - return absl::StatusCode::kPermissionDenied; - case 404: - return absl::StatusCode::kNotFound; - case 408: - case 440: - return absl::StatusCode::kDeadlineExceeded; - case 409: - return absl::StatusCode::kAlreadyExists; - case 412: - case 427: - return absl::StatusCode::kFailedPrecondition; - case 429: - return absl::StatusCode::kResourceExhausted; - case 499: - return absl::StatusCode::kCancelled; - case 500: - return absl::StatusCode::kInternal; - case 501: - return absl::StatusCode::kUnimplemented; - case 503: - return absl::StatusCode::kUnavailable; - case 504: - case 598: - case 599: - return absl::StatusCode::kDeadlineExceeded; - default: - if (http_code >= 200 && http_code < 300) { - return absl::StatusCode::kOk; - } - if (http_code >= 400 && http_code < 500) { - return absl::StatusCode::kFailedPrecondition; - } - if (http_code >= 500 && http_code < 600) { - return absl::StatusCode::kInternal; - } - return absl::StatusCode::kUnknown; - } -} -} // namespace + const Aws::Http::HttpResponseCode& response_code); template absl::Status AwsErrorToStatus(const Aws::Client::AWSError& error) { diff --git a/components/errors/error_util_aws_test.cc b/components/errors/error_util_aws_test.cc new file mode 100644 index 00000000..fd04e0e1 --- /dev/null +++ b/components/errors/error_util_aws_test.cc @@ -0,0 +1,103 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/errors/error_util_aws.h" + +#include "absl/status/status.h" +#include "aws/core/client/AWSError.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +TEST(ErrorUtilAwsTest, HttpStatuses) { + EXPECT_EQ(absl::StatusCode::kInvalidArgument, + HttpResponseCodeToStatusCode( + static_cast(400))); + EXPECT_EQ(absl::StatusCode::kUnauthenticated, + HttpResponseCodeToStatusCode( + static_cast(401))); + EXPECT_EQ(absl::StatusCode::kPermissionDenied, + HttpResponseCodeToStatusCode( + static_cast(403))); + EXPECT_EQ(absl::StatusCode::kNotFound, + HttpResponseCodeToStatusCode( + static_cast(404))); + EXPECT_EQ(absl::StatusCode::kDeadlineExceeded, + HttpResponseCodeToStatusCode( + static_cast(408))); + EXPECT_EQ(absl::StatusCode::kDeadlineExceeded, + HttpResponseCodeToStatusCode( + static_cast(440))); + EXPECT_EQ(absl::StatusCode::kAlreadyExists, + HttpResponseCodeToStatusCode( + static_cast(409))); + EXPECT_EQ(absl::StatusCode::kFailedPrecondition, + HttpResponseCodeToStatusCode( + static_cast(412))); + EXPECT_EQ(absl::StatusCode::kFailedPrecondition, + HttpResponseCodeToStatusCode( + static_cast(427))); + EXPECT_EQ(absl::StatusCode::kResourceExhausted, + HttpResponseCodeToStatusCode( + static_cast(429))); + EXPECT_EQ(absl::StatusCode::kCancelled, + HttpResponseCodeToStatusCode( + static_cast(499))); + EXPECT_EQ(absl::StatusCode::kInternal, + HttpResponseCodeToStatusCode( + static_cast(500))); + EXPECT_EQ(absl::StatusCode::kUnimplemented, + HttpResponseCodeToStatusCode( + static_cast(501))); + EXPECT_EQ(absl::StatusCode::kUnavailable, + HttpResponseCodeToStatusCode( + static_cast(503))); + EXPECT_EQ(absl::StatusCode::kDeadlineExceeded, + HttpResponseCodeToStatusCode( + static_cast(504))); + EXPECT_EQ(absl::StatusCode::kDeadlineExceeded, + HttpResponseCodeToStatusCode( + static_cast(598))); + EXPECT_EQ(absl::StatusCode::kDeadlineExceeded, + HttpResponseCodeToStatusCode( + static_cast(599))); + EXPECT_EQ(absl::StatusCode::kOk, + HttpResponseCodeToStatusCode( + static_cast(200))); + EXPECT_EQ(absl::StatusCode::kOk, + HttpResponseCodeToStatusCode( + static_cast(200))); + EXPECT_EQ(absl::StatusCode::kFailedPrecondition, + HttpResponseCodeToStatusCode( + static_cast(413))); + EXPECT_EQ(absl::StatusCode::kFailedPrecondition, + HttpResponseCodeToStatusCode( + static_cast(498))); + EXPECT_EQ(absl::StatusCode::kInternal, + HttpResponseCodeToStatusCode( + static_cast(505))); + EXPECT_EQ(absl::StatusCode::kInternal, + HttpResponseCodeToStatusCode( + static_cast(590))); + EXPECT_EQ(absl::StatusCode::kUnknown, + HttpResponseCodeToStatusCode( + static_cast(600))); + EXPECT_EQ(absl::StatusCode::kUnknown, + HttpResponseCodeToStatusCode( + static_cast(-1))); +} + +} // namespace +} // namespace kv_server diff --git a/components/internal_lookup/BUILD b/components/internal_lookup/BUILD deleted file mode 100644 index 3f815bc3..00000000 --- a/components/internal_lookup/BUILD +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") -load("@rules_buf//buf:defs.bzl", "buf_lint_test") -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") -load("@rules_proto//proto:defs.bzl", "proto_descriptor_set", "proto_library") - -package(default_visibility = [ - "//components:__subpackages__", -]) - -cc_library( - name = "lookup_client_impl", - srcs = [ - "lookup_client_impl.cc", - ], - hdrs = [ - "lookup_client.h", - ], - deps = [ - ":constants", - ":internal_lookup_cc_grpc", - "@com_github_google_glog//:glog", - "@com_github_grpc_grpc//:grpc++", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@google_privacysandbox_servers_common//src/cpp/telemetry", - ], -) - -cc_library( - name = "lookup_server_impl", - srcs = [ - "lookup_server_impl.cc", - ], - hdrs = [ - "lookup_server_impl.h", - ], - deps = [ - ":internal_lookup_cc_grpc", - "//components/data_server/cache", - "@com_github_grpc_grpc//:grpc++", - "@com_google_protobuf//:protobuf", - "@google_privacysandbox_servers_common//src/cpp/telemetry", - ], -) - -cc_test( - name = "lookup_server_impl_test", - size = "small", - srcs = [ - "lookup_server_impl_test.cc", - ], - deps = [ - ":internal_lookup_cc_grpc", - ":lookup_server_impl", - "//components/data_server/cache", - "//components/data_server/cache:mocks", - "//public/test_util:proto_matcher", - "@com_github_grpc_grpc//:grpc++", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "constants", - hdrs = [ - "constants.h", - ], - deps = [ - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "mocks", - testonly = 1, - hdrs = ["mocks.h"], - deps = [ - ":internal_lookup_cc_proto", - ":lookup_client_impl", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - ], -) - -proto_library( - name = "internal_lookup_proto", - srcs = ["lookup.proto"], - deps = [ - "@com_google_googleapis//google/rpc:status_proto", - ], -) - -buf_lint_test( - name = "internal_lookup_lint", - config = "//public:buf.yaml", - targets = [":internal_lookup_proto"], -) - -proto_descriptor_set( - name = "internal_lookup_descriptor_set", - deps = [":internal_lookup_proto"], -) - -cc_proto_library( - name = "internal_lookup_cc_proto", - deps = [":internal_lookup_proto"], -) - -cc_grpc_library( - name = "internal_lookup_cc_grpc", - srcs = [":internal_lookup_proto"], - grpc_only = True, - deps = [":internal_lookup_cc_proto"], -) diff --git a/components/internal_lookup/lookup_server_impl.cc b/components/internal_lookup/lookup_server_impl.cc deleted file mode 100644 index 18058e3b..00000000 --- a/components/internal_lookup/lookup_server_impl.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "components/internal_lookup/lookup_server_impl.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "components/data_server/cache/cache.h" -#include "components/internal_lookup/lookup.grpc.pb.h" -#include "google/protobuf/message.h" -#include "grpcpp/grpcpp.h" -#include "src/cpp/telemetry/telemetry.h" - -namespace kv_server { - -constexpr char kInternalLookupServerSpan[] = "InternalLookupServerHandler"; - -using google::protobuf::RepeatedPtrField; -using privacy_sandbox::server_common::GetTracer; - -void ProcessKeys(RepeatedPtrField keys, const Cache& cache, - InternalLookupResponse& response) { - if (keys.empty()) return; - std::vector key_list; - for (const auto& key : keys) { - key_list.emplace_back(std::move(key)); - } - auto kv_pairs = cache.GetKeyValuePairs(key_list); - - for (const auto& key : key_list) { - SingleLookupResult result; - const auto key_iter = kv_pairs.find(key); - if (key_iter == kv_pairs.end()) { - auto status = result.mutable_status(); - status->set_code(static_cast(absl::StatusCode::kNotFound)); - status->set_message("Key not found"); - } else { - result.set_value(std::move(key_iter->second)); - } - (*response.mutable_kv_pairs())[key] = std::move(result); - } -} - -grpc::ServerUnaryReactor* LookupServiceImpl::InternalLookup( - grpc::CallbackServerContext* context, const InternalLookupRequest* request, - InternalLookupResponse* response) { - auto span = GetTracer()->StartSpan(kInternalLookupServerSpan); - auto scope = opentelemetry::trace::Scope(span); - - ProcessKeys(request->keys(), cache_, *response); - - auto* reactor = context->DefaultReactor(); - reactor->Finish(grpc::Status::OK); - return reactor; -} - -} // namespace kv_server diff --git a/components/internal_server/BUILD b/components/internal_server/BUILD new file mode 100644 index 00000000..ff26772b --- /dev/null +++ b/components/internal_server/BUILD @@ -0,0 +1,264 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load("@rules_buf//buf:defs.bzl", "buf_lint_test") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library", "cc_test") +load("@rules_proto//proto:defs.bzl", "proto_descriptor_set", "proto_library") + +package(default_visibility = [ + "//components:__subpackages__", +]) + +cc_library( + name = "lookup_client_impl", + srcs = [ + "lookup_client_impl.cc", + ], + hdrs = [ + "lookup_client.h", + ], + deps = [ + ":constants", + ":internal_lookup_cc_grpc", + "@com_github_google_glog//:glog", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@google_privacysandbox_servers_common//src/cpp/telemetry", + ], +) + +cc_library( + name = "lookup_server_impl", + srcs = [ + "lookup_server_impl.cc", + ], + hdrs = [ + "lookup_server_impl.h", + ], + deps = [ + ":internal_lookup_cc_grpc", + ":string_padder", + "//components/data_server/cache", + "//components/data_server/request_handler:ohttp_server_encryptor", + "//components/query:driver", + "//components/query:scanner", + "@com_github_grpc_grpc//:grpc++", + "@com_google_protobuf//:protobuf", + "@google_privacysandbox_servers_common//src/cpp/telemetry", + ], +) + +cc_test( + name = "lookup_server_impl_test", + size = "small", + srcs = [ + "lookup_server_impl_test.cc", + ], + deps = [ + ":internal_lookup_cc_grpc", + ":lookup_server_impl", + "//components/data_server/cache", + "//components/data_server/cache:key_value_cache", + "//components/data_server/cache:mocks", + "//public/test_util:proto_matcher", + "@com_github_grpc_grpc//:grpc++", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "constants", + hdrs = [ + "constants.h", + ], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "mocks", + testonly = 1, + hdrs = ["mocks.h"], + deps = [ + ":internal_lookup_cc_proto", + ":lookup_client_impl", + ":run_query_client_impl", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + ], +) + +proto_library( + name = "internal_lookup_proto", + srcs = ["lookup.proto"], + deps = [ + "@com_google_googleapis//google/rpc:status_proto", + ], +) + +buf_lint_test( + name = "internal_lookup_lint", + config = "//:buf.yaml", + targets = [ + ":internal_lookup_proto", + ], +) + +proto_descriptor_set( + name = "internal_lookup_descriptor_set", + deps = [":internal_lookup_proto"], +) + +cc_proto_library( + name = "internal_lookup_cc_proto", + deps = [":internal_lookup_proto"], +) + +cc_grpc_library( + name = "internal_lookup_cc_grpc", + srcs = [":internal_lookup_proto"], + grpc_only = True, + deps = [":internal_lookup_cc_proto"], +) + +cc_library( + name = "remote_lookup_client_impl", + srcs = [ + "remote_lookup_client_impl.cc", + ], + hdrs = [ + "remote_lookup_client.h", + ], + deps = [ + ":constants", + ":internal_lookup_cc_grpc", + ":string_padder", + "//components/data_server/request_handler:ohttp_client_encryptor", + "@com_github_google_glog//:glog", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "sharded_lookup_server_impl", + srcs = [ + "sharded_lookup_server_impl.cc", + ], + hdrs = [ + "sharded_lookup_server_impl.h", + ], + deps = [ + ":internal_lookup_cc_grpc", + ":remote_lookup_client_impl", + "//components/data_server/cache", + "//components/sharding:shard_manager", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log:check", + "@com_google_protobuf//:protobuf", + "@distributed_point_functions//pir/hashing:sha256_hash_family", + "@google_privacysandbox_servers_common//src/cpp/telemetry", + "@google_privacysandbox_servers_common//src/cpp/telemetry:metrics_recorder", + ], +) + +cc_test( + name = "sharded_lookup_server_impl_test", + size = "small", + srcs = [ + "sharded_lookup_server_impl_test.cc", + ], + deps = [ + ":internal_lookup_cc_grpc", + ":sharded_lookup_server_impl", + "//components/data_server/cache:mocks", + "//components/sharding:mocks", + "//public/test_util:proto_matcher", + "@com_google_googletest//:gtest_main", + "@google_privacysandbox_servers_common//src/cpp/telemetry:mocks", + ], +) + +cc_library( + name = "run_query_client_impl", + srcs = [ + "run_query_client_impl.cc", + ], + hdrs = [ + "run_query_client.h", + ], + deps = [ + ":constants", + ":internal_lookup_cc_grpc", + "@com_github_google_glog//:glog", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "string_padder", + srcs = [ + "string_padder.cc", + ], + hdrs = [ + "string_padder.h", + ], + deps = [ + "@com_github_google_glog//:glog", + "@com_github_google_quiche//quiche:quiche_unstable_api", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "string_padder_test", + size = "small", + srcs = [ + "string_padder_test.cc", + ], + deps = [ + ":string_padder", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "remote_lookup_client_impl_test", + size = "small", + srcs = [ + "remote_lookup_client_impl_test.cc", + ], + deps = [ + ":lookup_server_impl", + ":remote_lookup_client_impl", + "//components/data_server/cache", + "//components/data_server/cache:mocks", + "//public/test_util:proto_matcher", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/components/internal_lookup/constants.h b/components/internal_server/constants.h similarity index 67% rename from components/internal_lookup/constants.h rename to components/internal_server/constants.h index f01bf712..bfc04863 100644 --- a/components/internal_lookup/constants.h +++ b/components/internal_server/constants.h @@ -14,14 +14,15 @@ * limitations under the License. */ -#ifndef COMPONENTS_INTERNAL_LOOKUP_CONSTANTS_H_ -#define COMPONENTS_INTERNAL_LOOKUP_CONSTANTS_H_ +#ifndef COMPONENTS_INTERNAL_SERVER_CONSTANTS_H_ +#define COMPONENTS_INTERNAL_SERVER_CONSTANTS_H_ namespace kv_server { -// TODO(b/276750518): Switch to unix socket -constexpr char kInternalLookupServerAddress[] = "0.0.0.0:50099"; +constexpr char kInternalServerAddress[] = "unix:///server/socket/internal.sock"; +constexpr char kRemoteLookupServerPort[] = "50100"; +constexpr char kLocalIp[] = "0.0.0.0"; } // namespace kv_server -#endif // COMPONENTS_INTERNAL_LOOKUP_CONSTANTS_H_ +#endif // COMPONENTS_INTERNAL_SERVER_CONSTANTS_H_ diff --git a/components/internal_lookup/lookup.proto b/components/internal_server/lookup.proto similarity index 53% rename from components/internal_lookup/lookup.proto rename to components/internal_server/lookup.proto index 9b0e7d22..33f38eff 100644 --- a/components/internal_lookup/lookup.proto +++ b/components/internal_server/lookup.proto @@ -14,15 +14,22 @@ syntax = "proto3"; -import "google/rpc/status.proto"; - package kv_server; +import "google/rpc/status.proto"; + // Internal Lookup Service API. service InternalLookupService { // Endpoint for querying the server's internal datastore. Should only be used // within TEEs. rpc InternalLookup(InternalLookupRequest) returns (InternalLookupResponse) {} + + // Endpoint for querying the datastore over the network. + rpc SecureLookup(SecureLookupRequest) returns (SecureLookupResponse) {} + + // Endpoint for running a query on the server's internal datastore. Should + // only be used within TEEs. + rpc InternalRunQuery(InternalRunQueryRequest) returns (InternalRunQueryResponse) {} } // Lookup request for internal datastore. @@ -31,6 +38,28 @@ message InternalLookupRequest { repeated string keys = 1; } +// Encryption key type +enum EncryptionKeyType { + // Taken from the coordinator + PRODUCTION = 0; + // Hardcoded well known test key + TEST = 1; +} + +// Encrypted and padded lookup request for internal datastore. +// We are sending out `num_shards` of these at the same time. Only payload size +// is observable from the outside. So it has to be the same, but the actual +// keys we are looking up can be different. +// If we padded `InternalLookupRequest`s, based of the total length of all keys +// we can end up having different _serialized_ message payload sizes, due to how +// the over wire format for protobuf is constructed. +// If we serialize InternalLookupRequest, and then pad the resulting string, +// then we are guarnteed to serialize to the same length. +message SecureLookupRequest { + bytes ohttp_request = 1; + EncryptionKeyType encryption_key_type = 2; +} + // Lookup response from internal datastore. // // Each key in the request has a corresponding map entry in the response. @@ -45,6 +74,11 @@ message InternalLookupResponse { map kv_pairs = 1; } +// Encrypted InternalLookupResponse +message SecureLookupResponse { + bytes ohttp_response = 1; +} + // Lookup result for a single key that is either a string value or a status. message SingleLookupResult { oneof single_lookup_result { @@ -52,3 +86,15 @@ message SingleLookupResult { google.rpc.Status status = 2; } } + +// Run Query request. +message InternalRunQueryRequest { + // Query to run. + optional string query = 1; +} + +// Run Query response. +message InternalRunQueryResponse { + // Set of elements returned. + repeated string elements = 1; +} diff --git a/components/internal_lookup/lookup_client.h b/components/internal_server/lookup_client.h similarity index 76% rename from components/internal_lookup/lookup_client.h rename to components/internal_server/lookup_client.h index 3edecbf3..6e5b1331 100644 --- a/components/internal_lookup/lookup_client.h +++ b/components/internal_server/lookup_client.h @@ -14,14 +14,16 @@ * limitations under the License. */ -#ifndef COMPONENTS_INTERNAL_LOOKUP_LOOKUP_CLIENT_H_ -#define COMPONENTS_INTERNAL_LOOKUP_LOOKUP_CLIENT_H_ +#ifndef COMPONENTS_INTERNAL_SERVER_LOOKUP_CLIENT_H_ +#define COMPONENTS_INTERNAL_SERVER_LOOKUP_CLIENT_H_ +#include #include +#include #include #include "absl/status/statusor.h" -#include "components/internal_lookup/lookup.grpc.pb.h" +#include "components/internal_server/lookup.grpc.pb.h" namespace kv_server { @@ -35,9 +37,9 @@ class LookupClient { virtual absl::StatusOr GetValues( const std::vector& keys) const = 0; - static const LookupClient& GetSingleton(); + static std::unique_ptr Create(std::string_view server_address); }; } // namespace kv_server -#endif // COMPONENTS_INTERNAL_LOOKUP_LOOKUP_CLIENT_H_ +#endif // COMPONENTS_INTERNAL_SERVER_LOOKUP_CLIENT_H_ diff --git a/components/internal_lookup/lookup_client_impl.cc b/components/internal_server/lookup_client_impl.cc similarity index 63% rename from components/internal_lookup/lookup_client_impl.cc rename to components/internal_server/lookup_client_impl.cc index 16a434a0..1dfd13bf 100644 --- a/components/internal_lookup/lookup_client_impl.cc +++ b/components/internal_server/lookup_client_impl.cc @@ -13,44 +13,52 @@ // limitations under the License. #include #include +#include +#include "absl/flags/flag.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "components/internal_lookup/constants.h" -#include "components/internal_lookup/lookup.grpc.pb.h" -#include "components/internal_lookup/lookup_client.h" +#include "absl/time/time.h" +#include "components/internal_server/constants.h" +#include "components/internal_server/lookup.grpc.pb.h" +#include "components/internal_server/lookup_client.h" #include "glog/logging.h" #include "grpcpp/grpcpp.h" #include "src/cpp/telemetry/telemetry.h" +ABSL_FLAG(absl::Duration, internal_lookup_deadline_duration, + absl::Milliseconds(50), + "Internal lookup RPC deadline. Default value is 50 milliseconds"); + namespace kv_server { namespace { using privacy_sandbox::server_common::GetTracer; -constexpr char kInternalLookupClientSpan[] = "InternalLookupClient"; - class LookupClientImpl : public LookupClient { public: LookupClientImpl(const LookupClientImpl&) = delete; LookupClientImpl& operator=(const LookupClientImpl&) = delete; - LookupClientImpl() - : stub_(InternalLookupService::NewStub( - grpc::CreateChannel(kInternalLookupServerAddress, - grpc::InsecureChannelCredentials()))) {} + explicit LookupClientImpl(std::string_view server_address) + : stub_(InternalLookupService::NewStub(grpc::CreateChannel( + std::string(server_address), grpc::InsecureChannelCredentials()))) { + } absl::StatusOr GetValues( const std::vector& keys) const override { - auto span = GetTracer()->StartSpan(kInternalLookupClientSpan); - auto scope = opentelemetry::trace::Scope(span); - InternalLookupRequest request; (*request.mutable_keys()) = {keys.begin(), keys.end()}; InternalLookupResponse response; grpc::ClientContext context; + absl::Duration deadline = + absl::GetFlag(FLAGS_internal_lookup_deadline_duration); + context.set_deadline( + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(absl::ToInt64Milliseconds(deadline), + GPR_TIMESPAN))); grpc::Status status = stub_->InternalLookup(&context, request, &response); if (status.ok()) { @@ -69,9 +77,9 @@ class LookupClientImpl : public LookupClient { } // namespace -const LookupClient& LookupClient::GetSingleton() { - static const LookupClient* const kInstance = new LookupClientImpl(); - return *kInstance; +std::unique_ptr LookupClient::Create( + std::string_view server_address) { + return std::make_unique(server_address); } } // namespace kv_server diff --git a/components/internal_server/lookup_server_impl.cc b/components/internal_server/lookup_server_impl.cc new file mode 100644 index 00000000..f29b11cc --- /dev/null +++ b/components/internal_server/lookup_server_impl.cc @@ -0,0 +1,154 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/internal_server/lookup_server_impl.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "components/data_server/cache/cache.h" +#include "components/data_server/request_handler/ohttp_server_encryptor.h" +#include "components/internal_server/lookup.grpc.pb.h" +#include "components/internal_server/string_padder.h" +#include "components/query/driver.h" +#include "components/query/scanner.h" +#include "google/protobuf/message.h" +#include "grpcpp/grpcpp.h" +#include "src/cpp/telemetry/telemetry.h" + +namespace kv_server { +namespace { +using google::protobuf::RepeatedPtrField; +using grpc::StatusCode; + +void ProcessKeys(RepeatedPtrField keys, const Cache& cache, + InternalLookupResponse& response) { + if (keys.empty()) return; + std::vector key_list; + for (const auto& key : keys) { + key_list.emplace_back(std::move(key)); + } + auto kv_pairs = cache.GetKeyValuePairs(key_list); + + for (const auto& key : key_list) { + SingleLookupResult result; + const auto key_iter = kv_pairs.find(key); + if (key_iter == kv_pairs.end()) { + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + status->set_message("Key not found"); + } else { + result.set_value(std::move(key_iter->second)); + } + (*response.mutable_kv_pairs())[key] = std::move(result); + } +} + +absl::Status ProcessQuery(std::string query, const Cache& cache, + InternalRunQueryResponse& response) { + if (query.empty()) return absl::OkStatus(); + std::unique_ptr get_key_value_set_result; + kv_server::Driver driver([&get_key_value_set_result](std::string_view key) { + return get_key_value_set_result->GetValueSet(key); + }); + std::istringstream stream(query); + kv_server::Scanner scanner(stream); + kv_server::Parser parse(driver, scanner); + int parse_result = parse(); + if (parse_result) { + return absl::InvalidArgumentError("Parsing failure."); + } + get_key_value_set_result = cache.GetKeyValueSet(driver.GetRootNode()->Keys()); + + auto result = driver.GetResult(); + if (!result.ok()) { + return result.status(); + } + response.mutable_elements()->Assign(result->begin(), result->end()); + return result.status(); +} + +grpc::Status ToInternalGrpcStatus(const absl::Status& status) { + return grpc::Status(StatusCode::INTERNAL, + absl::StrCat(status.code(), " : ", status.message())); +} + +} // namespace + +grpc::Status LookupServiceImpl::InternalLookup( + grpc::ServerContext* context, const InternalLookupRequest* request, + InternalLookupResponse* response) { + if (context->IsCancelled()) { + return grpc::Status(grpc::StatusCode::CANCELLED, + "Deadline exceeded or client cancelled, abandoning."); + } + ProcessKeys(request->keys(), cache_, *response); + return grpc::Status::OK; +} + +grpc::Status LookupServiceImpl::SecureLookup( + grpc::ServerContext* context, + const SecureLookupRequest* secure_lookup_request, + SecureLookupResponse* secure_response) { + if (context->IsCancelled()) { + return grpc::Status(grpc::StatusCode::CANCELLED, + "Deadline exceeded or client cancelled, abandoning."); + } + OhttpServerEncryptor encryptor; + auto padded_serialized_request_maybe = + encryptor.DecryptRequest(secure_lookup_request->ohttp_request()); + if (!padded_serialized_request_maybe.ok()) { + return ToInternalGrpcStatus(padded_serialized_request_maybe.status()); + } + auto serialized_request_maybe = + kv_server::Unpad(*padded_serialized_request_maybe); + if (!serialized_request_maybe.ok()) { + return ToInternalGrpcStatus(serialized_request_maybe.status()); + } + InternalLookupRequest request; + if (!request.ParseFromString(*serialized_request_maybe)) { + return grpc::Status(grpc::StatusCode::INTERNAL, + "Failed parsing incoming request"); + } + InternalLookupResponse response; + ProcessKeys(request.keys(), cache_, response); + auto encrypted_response_payload = + encryptor.EncryptResponse(response.SerializeAsString()); + if (!encrypted_response_payload.ok()) { + return ToInternalGrpcStatus(encrypted_response_payload.status()); + } + secure_response->set_ohttp_response(*encrypted_response_payload); + return grpc::Status::OK; +} + +grpc::Status LookupServiceImpl::InternalRunQuery( + grpc::ServerContext* context, const InternalRunQueryRequest* request, + InternalRunQueryResponse* response) { + if (context->IsCancelled()) { + return grpc::Status(grpc::StatusCode::CANCELLED, + "Deadline exceeded or client cancelled, abandoning."); + } + const auto process_result = ProcessQuery(request->query(), cache_, *response); + if (!process_result.ok()) { + return ToInternalGrpcStatus(process_result); + } + return grpc::Status::OK; +} + +} // namespace kv_server diff --git a/components/internal_lookup/lookup_server_impl.h b/components/internal_server/lookup_server_impl.h similarity index 58% rename from components/internal_lookup/lookup_server_impl.h rename to components/internal_server/lookup_server_impl.h index 05871e9b..364864d9 100644 --- a/components/internal_lookup/lookup_server_impl.h +++ b/components/internal_server/lookup_server_impl.h @@ -14,31 +14,40 @@ * limitations under the License. */ -#ifndef COMPONENTS_INTERNAL_LOOKUP_LOOKUP_SERVER_IMPL_H_ -#define COMPONENTS_INTERNAL_LOOKUP_LOOKUP_SERVER_IMPL_H_ +#ifndef COMPONENTS_INTERNAL_SERVER_LOOKUP_SERVER_IMPL_H_ +#define COMPONENTS_INTERNAL_SERVER_LOOKUP_SERVER_IMPL_H_ #include "components/data_server/cache/cache.h" -#include "components/internal_lookup/lookup.grpc.pb.h" +#include "components/internal_server/lookup.grpc.pb.h" #include "grpcpp/grpcpp.h" namespace kv_server { // Implements the internal lookup service for the data store. class LookupServiceImpl final - : public kv_server::InternalLookupService::CallbackService { + : public kv_server::InternalLookupService::Service { public: LookupServiceImpl(const Cache& cache) : cache_(cache) {} ~LookupServiceImpl() override = default; - grpc::ServerUnaryReactor* InternalLookup( - grpc::CallbackServerContext* context, + grpc::Status InternalLookup( + grpc::ServerContext* context, const kv_server::InternalLookupRequest* request, kv_server::InternalLookupResponse* response) override; + grpc::Status SecureLookup(grpc::ServerContext* context, + const kv_server::SecureLookupRequest* request, + kv_server::SecureLookupResponse* response) override; + + grpc::Status InternalRunQuery( + grpc::ServerContext* context, + const kv_server::InternalRunQueryRequest* request, + kv_server::InternalRunQueryResponse* response) override; + private: const Cache& cache_; }; } // namespace kv_server -#endif // COMPONENTS_INTERNAL_LOOKUP_LOOKUP_SERVER_IMPL_H_ +#endif // COMPONENTS_INTERNAL_SERVER_LOOKUP_SERVER_IMPL_H_ diff --git a/components/internal_lookup/lookup_server_impl_test.cc b/components/internal_server/lookup_server_impl_test.cc similarity index 66% rename from components/internal_lookup/lookup_server_impl_test.cc rename to components/internal_server/lookup_server_impl_test.cc index 6a78fcd3..a74498c9 100644 --- a/components/internal_lookup/lookup_server_impl_test.cc +++ b/components/internal_server/lookup_server_impl_test.cc @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "components/internal_lookup/lookup_server_impl.h" +#include "components/internal_server/lookup_server_impl.h" #include #include @@ -21,7 +21,9 @@ #include #include "components/data_server/cache/cache.h" +#include "components/data_server/cache/key_value_cache.h" #include "components/data_server/cache/mocks.h" +#include "components/internal_server/string_padder.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "grpcpp/grpcpp.h" @@ -113,6 +115,47 @@ TEST_F(LookupServiceImplTest, MissingKeyFromCache) { EXPECT_THAT(response, EqualsProto(expected)); } +TEST_F(LookupServiceImplTest, InternalRunQuerySuccess) { + InternalRunQueryRequest request; + request.set_query("someset"); + + absl::flat_hash_set keys; + keys.emplace("someset"); + auto mock_get_key_value_set_result = + std::make_unique(); + EXPECT_CALL(*mock_get_key_value_set_result, GetValueSet(_)) + .WillOnce( + Return(absl::flat_hash_set{"value1", "value2"})); + EXPECT_CALL(mock_cache_, GetKeyValueSet(_)) + .WillOnce(Return(std::move(mock_get_key_value_set_result))); + InternalRunQueryResponse response; + grpc::ClientContext context; + grpc::Status status = stub_->InternalRunQuery(&context, request, &response); + auto results = response.elements(); + EXPECT_THAT(results, + testing::UnorderedElementsAreArray({"value1", "value2"})); +} + +TEST_F(LookupServiceImplTest, InternalRunQueryParseFailure) { + InternalRunQueryRequest request; + request.set_query("fail|||||now"); + InternalRunQueryResponse response; + grpc::ClientContext context; + grpc::Status status = stub_->InternalRunQuery(&context, request, &response); + auto results = response.elements(); + EXPECT_EQ(status.error_code(), grpc::StatusCode::INTERNAL); +} + +TEST_F(LookupServiceImplTest, SecureLookupFailure) { + SecureLookupRequest secure_lookup_request; + secure_lookup_request.set_ohttp_request("garbage"); + SecureLookupResponse response; + grpc::ClientContext context; + grpc::Status status = + stub_->SecureLookup(&context, secure_lookup_request, &response); + EXPECT_EQ(status.error_code(), grpc::StatusCode::INTERNAL); +} + } // namespace } // namespace kv_server diff --git a/components/internal_lookup/mocks.h b/components/internal_server/mocks.h similarity index 66% rename from components/internal_lookup/mocks.h rename to components/internal_server/mocks.h index dd5691ef..d9d25496 100644 --- a/components/internal_lookup/mocks.h +++ b/components/internal_server/mocks.h @@ -14,16 +14,17 @@ * limitations under the License. */ -#ifndef COMPONENTS_INTERNAL_LOOKUP_MOCKS_H_ -#define COMPONENTS_INTERNAL_LOOKUP_MOCKS_H_ +#ifndef COMPONENTS_INTERNAL_SERVER_MOCKS_H_ +#define COMPONENTS_INTERNAL_SERVER_MOCKS_H_ #include #include #include #include "absl/status/statusor.h" -#include "components/internal_lookup/lookup.grpc.pb.h" -#include "components/internal_lookup/lookup_client.h" +#include "components/internal_server/lookup.grpc.pb.h" +#include "components/internal_server/lookup_client.h" +#include "components/internal_server/run_query_client.h" #include "gmock/gmock.h" namespace kv_server { @@ -34,6 +35,12 @@ class MockLookupClient : public LookupClient { (const std::vector&), (const, override)); }; +class MockRunQueryClient : public RunQueryClient { + public: + MOCK_METHOD((absl::StatusOr), RunQuery, + (std::string), (const, override)); +}; + } // namespace kv_server -#endif // COMPONENTS_INTERNAL_LOOKUP_MOCKS_H_ +#endif // COMPONENTS_INTERNAL_SERVER_MOCKS_H_ diff --git a/components/internal_server/remote_lookup_client.h b/components/internal_server/remote_lookup_client.h new file mode 100644 index 00000000..670facf1 --- /dev/null +++ b/components/internal_server/remote_lookup_client.h @@ -0,0 +1,47 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_INTERNAL_SERVER_REMOTE_LOOKUP_CLIENT_H_ +#define COMPONENTS_INTERNAL_SERVER_REMOTE_LOOKUP_CLIENT_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "components/internal_server/lookup.grpc.pb.h" + +namespace kv_server { + +class RemoteLookupClient { + public: + virtual ~RemoteLookupClient() = default; + // Calls the remote internal lookup server with the given keys. + // Pads the request size with padding_length. + // Note that we need to pass in a `serialized_message` here because we need to + // figure out the correct padding length across multiple requests. That helps + // with preventing double serialization. + virtual absl::StatusOr GetValues( + std::string_view serialized_message, int32_t padding_length) const = 0; + virtual std::string_view GetIpAddress() const = 0; + static std::unique_ptr Create(std::string ip_address); + static std::unique_ptr Create( + std::unique_ptr stub); +}; + +} // namespace kv_server + +#endif // COMPONENTS_INTERNAL_SERVER_REMOTE_LOOKUP_CLIENT_H_ diff --git a/components/internal_server/remote_lookup_client_impl.cc b/components/internal_server/remote_lookup_client_impl.cc new file mode 100644 index 00000000..7daf07fe --- /dev/null +++ b/components/internal_server/remote_lookup_client_impl.cc @@ -0,0 +1,95 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "components/data_server/request_handler/ohttp_client_encryptor.h" +#include "components/internal_server/constants.h" +#include "components/internal_server/lookup.grpc.pb.h" +#include "components/internal_server/remote_lookup_client.h" +#include "components/internal_server/string_padder.h" +#include "glog/logging.h" +#include "grpcpp/grpcpp.h" + +namespace kv_server { +class RemoteLookupClientImpl : public RemoteLookupClient { + public: + RemoteLookupClientImpl(const RemoteLookupClientImpl&) = delete; + RemoteLookupClientImpl& operator=(const RemoteLookupClientImpl&) = delete; + + explicit RemoteLookupClientImpl(std::string ip_address) + : ip_address_( + absl::StrFormat("%s:%s", ip_address, kRemoteLookupServerPort)), + stub_(InternalLookupService::NewStub(grpc::CreateChannel( + ip_address_, grpc::InsecureChannelCredentials()))) {} + + explicit RemoteLookupClientImpl( + std::unique_ptr stub) + : stub_(std::move(stub)) {} + + absl::StatusOr GetValues( + std::string_view serialized_message, + int32_t padding_length) const override { + OhttpClientEncryptor encryptor; + auto encrypted_padded_serialized_request_maybe = + encryptor.EncryptRequest(Pad(serialized_message, padding_length)); + if (!encrypted_padded_serialized_request_maybe.ok()) { + return encrypted_padded_serialized_request_maybe.status(); + } + SecureLookupRequest secure_lookup_request; + secure_lookup_request.set_ohttp_request( + *encrypted_padded_serialized_request_maybe); + SecureLookupResponse secure_response; + grpc::ClientContext context; + grpc::Status status = + stub_->SecureLookup(&context, secure_lookup_request, &secure_response); + if (!status.ok()) { + LOG(ERROR) << status.error_code() << ": " << status.error_message(); + return absl::Status((absl::StatusCode)status.error_code(), + status.error_message()); + } + auto decrypted_response_maybe = + encryptor.DecryptResponse(std::move(secure_response.ohttp_response())); + if (!decrypted_response_maybe.ok()) { + return decrypted_response_maybe.status(); + } + InternalLookupResponse response; + if (!response.ParseFromString( + decrypted_response_maybe->GetPlaintextData())) { + return absl::InvalidArgumentError("Failed parsing the response."); + } + return response; + } + + std::string_view GetIpAddress() const override { return ip_address_; } + + private: + const std::string ip_address_; + std::unique_ptr stub_; +}; + +std::unique_ptr RemoteLookupClient::Create( + std::string ip_address) { + return std::make_unique(std::move(ip_address)); +} + +std::unique_ptr RemoteLookupClient::Create( + std::unique_ptr stub) { + return std::make_unique(std::move(stub)); +} + +} // namespace kv_server diff --git a/components/internal_server/remote_lookup_client_impl_test.cc b/components/internal_server/remote_lookup_client_impl_test.cc new file mode 100644 index 00000000..ccc872e2 --- /dev/null +++ b/components/internal_server/remote_lookup_client_impl_test.cc @@ -0,0 +1,79 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/data_server/cache/cache.h" +#include "components/data_server/cache/mocks.h" +#include "components/internal_server/lookup_server_impl.h" +#include "components/internal_server/remote_lookup_client.h" +#include "gmock/gmock.h" +#include "google/protobuf/text_format.h" +#include "grpcpp/grpcpp.h" +#include "gtest/gtest.h" +#include "public/test_util/proto_matcher.h" + +namespace kv_server { +namespace { + +class RemoteLookupClientImplTest : public ::testing::Test { + protected: + RemoteLookupClientImplTest() { + lookup_service_ = std::make_unique(mock_cache_); + grpc::ServerBuilder builder; + builder.RegisterService(lookup_service_.get()); + server_ = (builder.BuildAndStart()); + remote_lookup_client_ = + RemoteLookupClient::Create(InternalLookupService::NewStub( + server_->InProcessChannel(grpc::ChannelArguments()))); + } + + ~RemoteLookupClientImplTest() { + server_->Shutdown(); + server_->Wait(); + } + MockCache mock_cache_; + std::unique_ptr lookup_service_; + std::unique_ptr server_; + std::unique_ptr remote_lookup_client_; +}; + +TEST_F(RemoteLookupClientImplTest, EncryptedPaddedSuccessfulCall) { + std::vector keys = {"key1", "key2"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(keys.begin(), keys.end()); + std::string serialized_message = request.SerializeAsString(); + InternalLookupRequest request2; + int32_t padding_length = 10; + EXPECT_CALL(mock_cache_, GetKeyValuePairs(testing::_)) + .WillOnce(testing::Return(absl::flat_hash_map{ + {"key1", "value1"}, {"key2", "value2"}})); + auto response_status = + remote_lookup_client_->GetValues(serialized_message, padding_length); + EXPECT_TRUE(response_status.ok()); + InternalLookupResponse response = *response_status; + InternalLookupResponse expected; + google::protobuf::TextFormat::ParseFromString(R"pb(kv_pairs { + key: "key1" + value { value: "value1" } + } + kv_pairs { + key: "key2" + value { value: "value2" } + } + )pb", + &expected); + EXPECT_THAT(response, EqualsProto(expected)); +} + +} // namespace +} // namespace kv_server diff --git a/components/internal_server/run_query_client.h b/components/internal_server/run_query_client.h new file mode 100644 index 00000000..aaba8c0c --- /dev/null +++ b/components/internal_server/run_query_client.h @@ -0,0 +1,48 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_INTERNAL_SERVER_RUN_QUERY_CLIENT_H_ +#define COMPONENTS_INTERNAL_SERVER_RUN_QUERY_CLIENT_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "components/internal_server/lookup.grpc.pb.h" + +namespace kv_server { + +// TOOD(b/261564359): Create a pool of gRPC channels and combine internal +// clients. +// Synchronous client for internal run query service +class RunQueryClient { + public: + virtual ~RunQueryClient() = default; + + // Calls the internal run query server. + virtual absl::StatusOr RunQuery( + std::string query) const = 0; + + // If this client is called as part of the UDF hook, it must be constructed + // after the fork. + static std::unique_ptr Create( + std::string_view server_address); +}; + +} // namespace kv_server + +#endif // COMPONENTS_INTERNAL_SERVER_RUN_QUERY_CLIENT_H_ diff --git a/components/internal_server/run_query_client_impl.cc b/components/internal_server/run_query_client_impl.cc new file mode 100644 index 00000000..aff1e506 --- /dev/null +++ b/components/internal_server/run_query_client_impl.cc @@ -0,0 +1,82 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "components/internal_server/constants.h" +#include "components/internal_server/lookup.grpc.pb.h" +#include "components/internal_server/run_query_client.h" +#include "glog/logging.h" +#include "grpcpp/grpcpp.h" + +ABSL_FLAG(absl::Duration, internal_run_query_deadline_duration, + absl::Milliseconds(50), + "Internal run query RPC deadline. Default value is 50 milliseconds"); + +namespace kv_server { +namespace { +class RunQueryClientImpl : public RunQueryClient { + public: + RunQueryClientImpl(const RunQueryClientImpl&) = delete; + RunQueryClientImpl& operator=(const RunQueryClientImpl&) = delete; + + explicit RunQueryClientImpl(std::string_view server_address) + : stub_(InternalLookupService::NewStub(grpc::CreateChannel( + std::string(server_address), grpc::InsecureChannelCredentials()))) { + } + + absl::StatusOr RunQuery( + std::string query) const override { + VLOG(8) << "Running query: " << query; + InternalRunQueryRequest request; + request.set_query(std::move(query)); + + InternalRunQueryResponse response; + grpc::ClientContext context; + absl::Duration deadline = + absl::GetFlag(FLAGS_internal_run_query_deadline_duration); + context.set_deadline( + gpr_time_add(gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_millis(absl::ToInt64Milliseconds(deadline), + GPR_TIMESPAN))); + grpc::Status status = stub_->InternalRunQuery(&context, request, &response); + + if (status.ok()) { + return response; + } + + LOG(ERROR) << status.error_code() << ": " << status.error_message(); + // Return an absl status from the gRPC status + return absl::Status((absl::StatusCode)status.error_code(), + status.error_message()); + } + + private: + std::unique_ptr stub_; +}; + +} // namespace + +std::unique_ptr RunQueryClient::Create( + std::string_view server_address) { + return std::make_unique(server_address); +} + +} // namespace kv_server diff --git a/components/internal_server/sharded_lookup_server_impl.cc b/components/internal_server/sharded_lookup_server_impl.cc new file mode 100644 index 00000000..e23c4950 --- /dev/null +++ b/components/internal_server/sharded_lookup_server_impl.cc @@ -0,0 +1,198 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/internal_server/sharded_lookup_server_impl.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "components/data_server/cache/cache.h" +#include "components/internal_server/lookup.grpc.pb.h" +#include "glog/logging.h" +#include "google/protobuf/message.h" +#include "grpcpp/grpcpp.h" +#include "src/cpp/telemetry/telemetry.h" + +namespace kv_server { + +constexpr char kShardedLookupServerSpan[] = "ShardedLookupServerHandler"; +constexpr char* kShardedLookupGrpcFailure = "ShardedLookupGrpcFailure"; + +using google::protobuf::RepeatedPtrField; + +namespace { +template +void UpdateResponse(const std::vector& key_list, T& kv_pairs, + std::function&& value_mapper, + InternalLookupResponse& response) { + for (const auto& key : key_list) { + const auto key_iter = kv_pairs.find(key); + if (key_iter == kv_pairs.end()) { + SingleLookupResult result; + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + status->set_message("Key not found"); + (*response.mutable_kv_pairs())[key] = std::move(result); + } else { + (*response.mutable_kv_pairs())[key] = + value_mapper(std::move(key_iter->second)); + } + } +} + +void SetRequestFailed(const std::vector& key_list, + InternalLookupResponse& response) { + SingleLookupResult result; + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kInternal)); + status->set_message("Data lookup failed"); + for (const auto& key : key_list) { + (*response.mutable_kv_pairs())[key] = result; + } +} + +} // namespace + +absl::Status ShardedLookupServiceImpl::ProcessShardedKeys( + const RepeatedPtrField& keys, const Cache& cache, + InternalLookupResponse& response) { + if (keys.empty()) { + return absl::OkStatus(); + } + + const auto shard_lookup_inputs = ShardKeys(keys); + std::vector>> responses; + for (int shard_num = 0; shard_num < num_shards_; shard_num++) { + auto& shard_lookup_input = shard_lookup_inputs[shard_num]; + const auto& key_list = shard_lookup_input.keys; + if (shard_num == current_shard_num_) { + // Eventually this whole branch will go away. Meanwhile, we need the + // following line for proper indexing order when we process responses. + responses.emplace_back(); + if (key_list.empty()) { + continue; + } + auto kv_pairs = cache.GetKeyValuePairs(key_list); + UpdateResponse, + std::string>( + key_list, kv_pairs, + [](std::string result) { + SingleLookupResult actresult; + actresult.set_value(std::move(result)); + return actresult; + }, + response); + } else { + auto client = shard_manager_.Get(shard_num); + if (client == nullptr) { + return absl::InternalError("Internal lookup client is unavailable."); + } + responses.push_back(std::async( + std::launch::async, &ShardedLookupServiceImpl::GetValues, this, + std::ref(*client), shard_lookup_input.serialized_request, + shard_lookup_input.padding)); + } + } + // process responses + for (int shard_num = 0; shard_num < num_shards_; shard_num++) { + auto& shard_lookup_input = shard_lookup_inputs[shard_num]; + if (shard_num == current_shard_num_) { + continue; + } + + auto result = responses[shard_num].get(); + if (!result.ok()) { + // mark all keys as internal failure + SetRequestFailed(shard_lookup_input.keys, response); + continue; + } + auto kv_pairs = result->mutable_kv_pairs(); + UpdateResponse< + ::google::protobuf::Map, + SingleLookupResult>( + shard_lookup_input.keys, *kv_pairs, + [](SingleLookupResult result) { return result; }, response); + } + return absl::OkStatus(); +} + +absl::StatusOr ShardedLookupServiceImpl::GetValues( + RemoteLookupClient& client, std::string_view serialized_message, + int32_t padding_length) { + return client.GetValues(serialized_message, padding_length); +} + +std::vector +ShardedLookupServiceImpl::BucketKeys( + const RepeatedPtrField& keys) const { + ShardLookupInput sli; + std::vector lookup_inputs(num_shards_, sli); + for (const auto& key : keys) { + int32_t shard_num = hash_function_(key, num_shards_); + VLOG(9) << "key: " << key << ", shard number: " << shard_num; + lookup_inputs[shard_num].keys.emplace_back(key); + } + return lookup_inputs; +} + +void ShardedLookupServiceImpl::SerializeShardedRequests( + std::vector& lookup_inputs) + const { + for (auto& lookup_input : lookup_inputs) { + InternalLookupRequest request; + request.mutable_keys()->Assign(lookup_input.keys.begin(), + lookup_input.keys.end()); + lookup_input.serialized_request = request.SerializeAsString(); + } +} + +void ShardedLookupServiceImpl::ComputePadding( + std::vector& lookup_inputs) + const { + int32_t max_length = 0; + for (const auto& lookup_input : lookup_inputs) { + max_length = + std::max(max_length, int32_t(lookup_input.serialized_request.size())); + } + for (auto& lookup_input : lookup_inputs) { + lookup_input.padding = max_length - lookup_input.serialized_request.size(); + } +} + +std::vector +ShardedLookupServiceImpl::ShardKeys( + const RepeatedPtrField& keys) const { + auto lookup_inputs = BucketKeys(keys); + SerializeShardedRequests(lookup_inputs); + ComputePadding(lookup_inputs); + return lookup_inputs; +} + +grpc::Status ShardedLookupServiceImpl::InternalLookup( + grpc::ServerContext* context, const InternalLookupRequest* request, + InternalLookupResponse* response) { + auto current_status = grpc::Status::OK; + auto result = ProcessShardedKeys(request->keys(), cache_, *response); + if (!result.ok()) { + metrics_recorder_.IncrementEventCounter(kShardedLookupGrpcFailure); + current_status = grpc::Status(grpc::StatusCode::INTERNAL, "Internal error"); + } + return current_status; +} + +} // namespace kv_server diff --git a/components/internal_server/sharded_lookup_server_impl.h b/components/internal_server/sharded_lookup_server_impl.h new file mode 100644 index 00000000..ea22d6bd --- /dev/null +++ b/components/internal_server/sharded_lookup_server_impl.h @@ -0,0 +1,114 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_INTERNAL_SERVER_SHARDED_LOOKUP_SERVER_IMPL_H_ +#define COMPONENTS_INTERNAL_SERVER_SHARDED_LOOKUP_SERVER_IMPL_H_ + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "components/data_server/cache/cache.h" +#include "components/internal_server/lookup.grpc.pb.h" +#include "components/internal_server/remote_lookup_client.h" +#include "components/sharding/shard_manager.h" +#include "grpcpp/grpcpp.h" +#include "pir/hashing/sha256_hash_family.h" +#include "src/cpp/telemetry/metrics_recorder.h" +#include "src/cpp/telemetry/telemetry.h" + +namespace kv_server { + +// Implements the internal lookup service for the data store. +class ShardedLookupServiceImpl final + : public kv_server::InternalLookupService::Service { + public: + ShardedLookupServiceImpl( + privacy_sandbox::server_common::MetricsRecorder& metrics_recorder, + const Cache& cache, const int32_t num_shards, + const int32_t current_shard_num, ShardManager& shard_manager, + // We 're currently going with a default empty string and not + // allowing AdTechs to modify it. + const std::string hashing_seed = "") + : metrics_recorder_(metrics_recorder), + cache_(cache), + num_shards_(num_shards), + current_shard_num_(current_shard_num), + hashing_seed_(hashing_seed), + hash_function_( + distributed_point_functions::SHA256HashFunction(hashing_seed_)), + shard_manager_(shard_manager) { + CHECK_GT(num_shards, 1) + << "num_shards for ShardedLookupServiceImpl must be > 1"; + } + + virtual ~ShardedLookupServiceImpl() = default; + + // Iterates over all keys specified in the `request` and assigns them to shard + // buckets. Then for each bucket it queries the underlying data shard. At the + // moment, for the shard number matching the current server shard number, the + // logic will lookup data in its own cache. Eventually, this will change when + // we have two types of servers: UDF and data servers. Then the responses are + // combined and the result is returned. If any underlying request fails -- we + // return an empty response and `Internal` error as the status for the gRPC + // status code. + grpc::Status InternalLookup( + grpc::ServerContext* context, + const kv_server::InternalLookupRequest* request, + kv_server::InternalLookupResponse* response) override; + + private: + // Keeps sharded keys and assosiated metdata. + struct ShardLookupInput { + // Keys that are being looked up. + std::vector keys; + // A serialized `InternalLookupRequest` with the corresponding keys + // from `keys`. + std::string serialized_request; + // Identifies by how many chars `keys` should be padded, so that + // all requests add up to the same length. + int32_t padding; + }; + std::vector ShardKeys( + const google::protobuf::RepeatedPtrField& keys) const; + void ComputePadding(std::vector& sk) const; + void SerializeShardedRequests(std::vector& sk) const; + std::vector BucketKeys( + const google::protobuf::RepeatedPtrField& keys) const; + absl::Status ProcessShardedKeys( + const google::protobuf::RepeatedPtrField& keys, + const Cache& cache, InternalLookupResponse& response); + void LookupKeysExternally(std::vector& key_list, + const Cache& cache, + InternalLookupResponse& response); + absl::StatusOr GetValues( + RemoteLookupClient& client, std::string_view serialized_message, + int32_t padding_length); + + privacy_sandbox::server_common::MetricsRecorder& metrics_recorder_; + const Cache& cache_; + const int32_t num_shards_; + const int32_t current_shard_num_; + const std::string hashing_seed_; + const distributed_point_functions::SHA256HashFunction hash_function_; + const ShardManager& shard_manager_; +}; + +} // namespace kv_server + +#endif // COMPONENTS_INTERNAL_SERVER_SHARDED_LOOKUP_SERVER_IMPL_H_ diff --git a/components/internal_server/sharded_lookup_server_impl_test.cc b/components/internal_server/sharded_lookup_server_impl_test.cc new file mode 100644 index 00000000..d2b2d117 --- /dev/null +++ b/components/internal_server/sharded_lookup_server_impl_test.cc @@ -0,0 +1,463 @@ + +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/internal_server/sharded_lookup_server_impl.h" + +#include +#include +#include +#include +#include + +#include "components/data_server/cache/cache.h" +#include "components/data_server/cache/mocks.h" +#include "components/sharding/mocks.h" +#include "gmock/gmock.h" +#include "google/protobuf/text_format.h" +#include "grpcpp/grpcpp.h" +#include "gtest/gtest.h" +#include "public/test_util/proto_matcher.h" +#include "src/cpp/telemetry/mocks.h" + +namespace kv_server { +namespace { + +using google::protobuf::TextFormat; +using privacy_sandbox::server_common::MockMetricsRecorder; +using testing::_; +using testing::Return; +using testing::ReturnRef; + +class MockRemoteLookupClient : public RemoteLookupClient { + public: + MockRemoteLookupClient() : RemoteLookupClient() {} + MOCK_METHOD(absl::StatusOr, GetValues, + (std::string_view serialized_message, int32_t padding_length), + (const, override)); + MOCK_METHOD(std::string_view, GetIpAddress, (), (const, override)); +}; + +class ShardedLookupServiceImplTest : public ::testing::Test { + protected: + int32_t num_shards_ = 2; + int32_t shard_num_ = 0; + + MockMetricsRecorder mock_metrics_recorder_; + MockCache mock_cache_; + std::unique_ptr lookup_service_; + std::unique_ptr server_; + std::unique_ptr stub_; +}; + +TEST_F(ShardedLookupServiceImplTest, ReturnsKeysFromCache) { + InternalLookupRequest request; + request.add_keys("key1"); + request.add_keys("key4"); + + EXPECT_CALL(mock_cache_, GetKeyValuePairs(_)) + .WillOnce(Return( + absl::flat_hash_map{{"key4", "value4"}})); + + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [](const std::string& ip) { + if (ip != "1") { + return std::make_unique(); + } + + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, + GetValues(serialized_request, 0)) + .WillOnce([&]() { + InternalLookupResponse resp; + SingleLookupResult result; + result.set_value("value1"); + (*resp.mutable_kv_pairs())["key1"] = result; + return resp; + }); + + return std::move(mock_remote_lookup_client_1); + }); + + InternalLookupResponse response; + grpc::ClientContext context; + + lookup_service_ = std::make_unique( + mock_metrics_recorder_, mock_cache_, num_shards_, shard_num_, + *(*shard_manager)); + grpc::ServerBuilder builder; + builder.RegisterService(lookup_service_.get()); + server_ = (builder.BuildAndStart()); + stub_ = InternalLookupService::NewStub( + server_->InProcessChannel(grpc::ChannelArguments())); + + grpc::Status status = stub_->InternalLookup(&context, request, &response); + + InternalLookupResponse expected; + TextFormat::ParseFromString(R"pb(kv_pairs { + key: "key1" + value { value: "value1" } + } + kv_pairs { + key: "key4" + value { value: "value4" } + } + )pb", + &expected); + EXPECT_THAT(response, EqualsProto(expected)); + server_->Shutdown(); + server_->Wait(); +} + +TEST_F(ShardedLookupServiceImplTest, MissingKeyFromCache) { + InternalLookupRequest request; + request.add_keys("key1"); + request.add_keys("key4"); + request.add_keys("key5"); + + EXPECT_CALL(mock_cache_, GetKeyValuePairs(_)) + .WillOnce(Return( + absl::flat_hash_map{{"key4", "value4"}})); + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [](const std::string& ip) { + if (ip != "1") { + return std::make_unique(); + } + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1", "key5"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, + GetValues(serialized_request, 0)) + .WillOnce([&]() { + InternalLookupResponse resp; + SingleLookupResult result; + auto status = result.mutable_status(); + status->set_code(static_cast(absl::StatusCode::kNotFound)); + status->set_message("Key not found"); + + (*resp.mutable_kv_pairs())["key1"] = result; + return resp; + }); + + return std::move(mock_remote_lookup_client_1); + }); + + InternalLookupResponse response; + grpc::ClientContext context; + lookup_service_ = std::make_unique( + mock_metrics_recorder_, mock_cache_, num_shards_, shard_num_, + *(*shard_manager)); + grpc::ServerBuilder builder; + builder.RegisterService(lookup_service_.get()); + server_ = (builder.BuildAndStart()); + stub_ = InternalLookupService::NewStub( + server_->InProcessChannel(grpc::ChannelArguments())); + grpc::Status status = stub_->InternalLookup(&context, request, &response); + + InternalLookupResponse expected; + TextFormat::ParseFromString( + R"pb(kv_pairs { + key: "key1" + value { status: { code: 5, message: "Key not found" } } + } + kv_pairs { + key: "key4" + value { value: "value4" } + }, + kv_pairs { + key: "key5" + value { status: { code: 5, message: "Key not found" } } + } + )pb", + &expected); + EXPECT_THAT(response, EqualsProto(expected)); + server_->Shutdown(); + server_->Wait(); +} + +TEST_F(ShardedLookupServiceImplTest, MissingKeys) { + InternalLookupRequest request; + InternalLookupResponse response; + grpc::ClientContext context; + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = + ShardManager::Create(num_shards_, std::move(cluster_mappings)); + lookup_service_ = std::make_unique( + mock_metrics_recorder_, mock_cache_, num_shards_, shard_num_, + **shard_manager); + grpc::ServerBuilder builder; + builder.RegisterService(lookup_service_.get()); + server_ = (builder.BuildAndStart()); + stub_ = InternalLookupService::NewStub( + server_->InProcessChannel(grpc::ChannelArguments())); + grpc::Status status = stub_->InternalLookup(&context, request, &response); + InternalLookupResponse expected; + TextFormat::ParseFromString(R"pb()pb", &expected); + EXPECT_THAT(response, EqualsProto(expected)); + server_->Shutdown(); + server_->Wait(); +} + +TEST_F(ShardedLookupServiceImplTest, FailedDownstreamRequest) { + InternalLookupRequest request; + request.add_keys("key1"); + request.add_keys("key4"); + EXPECT_CALL(mock_cache_, GetKeyValuePairs(_)) + .WillOnce(Return( + absl::flat_hash_map{{"key4", "value4"}})); + + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards_, std::move(cluster_mappings), + std::make_unique(), [](const std::string& ip) { + if (ip != "1") { + return std::make_unique(); + } + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, + GetValues(serialized_request, 0)) + .WillOnce([]() { return absl::DeadlineExceededError("too long"); }); + + return std::move(mock_remote_lookup_client_1); + }); + + InternalLookupResponse response; + grpc::ClientContext context; + lookup_service_ = std::make_unique( + mock_metrics_recorder_, mock_cache_, num_shards_, shard_num_, + **shard_manager); + grpc::ServerBuilder builder; + builder.RegisterService(lookup_service_.get()); + server_ = (builder.BuildAndStart()); + stub_ = InternalLookupService::NewStub( + server_->InProcessChannel(grpc::ChannelArguments())); + + grpc::Status status = stub_->InternalLookup(&context, request, &response); + InternalLookupResponse expected; + TextFormat::ParseFromString( + R"pb( + kv_pairs { + key: "key1" + value { status { code: 13 message: "Data lookup failed" } } + } + kv_pairs { + key: "key4" + value { value: "value4" } + })pb", + &expected); + EXPECT_THAT(response, EqualsProto(expected)); + EXPECT_TRUE(status.ok()); + server_->Shutdown(); + server_->Wait(); +} + +TEST_F(ShardedLookupServiceImplTest, ReturnsKeysFromCachePadding) { + auto num_shards = 4; + InternalLookupRequest request; + // 0 + request.add_keys("key4"); + request.add_keys("verylongkey2"); + // 1 + request.add_keys("key1"); + request.add_keys("key2"); + request.add_keys("key3"); + // 2 + request.add_keys("randomkey5"); + // 3 + request.add_keys("longkey1"); + request.add_keys("randomkey3"); + + int total_length = 22; + + std::vector key_list = {"key4", "verylongkey2"}; + EXPECT_CALL(mock_cache_, GetKeyValuePairs(key_list)) + .WillOnce(Return(absl::flat_hash_map{ + {"key4", "key4value"}, {"verylongkey2", "verylongkey2value"}})); + + std::vector> cluster_mappings; + for (int i = 0; i < num_shards; i++) { + cluster_mappings.push_back({std::to_string(i)}); + } + auto shard_manager = ShardManager::Create( + num_shards, std::move(cluster_mappings), + std::make_unique(), + [total_length](const std::string& ip) { + if (ip == "1") { + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"key1", "key2", + "key3"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, + GetValues(serialized_request, testing::_)) + .WillOnce( + [total_length](const std::string_view serialized_message, + const int32_t padding_length) { + EXPECT_EQ(total_length, + (serialized_message.size() + padding_length)); + InternalLookupResponse resp; + SingleLookupResult result; + result.set_value("value1"); + (*resp.mutable_kv_pairs())["key1"] = result; + SingleLookupResult result2; + result2.set_value("value2"); + (*resp.mutable_kv_pairs())["key2"] = result2; + SingleLookupResult result3; + result3.set_value("value3"); + (*resp.mutable_kv_pairs())["key3"] = result3; + return resp; + }); + + return std::move(mock_remote_lookup_client_1); + } + if (ip == "2") { + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"randomkey5"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, + GetValues(serialized_request, testing::_)) + .WillOnce([&](const std::string_view serialized_message, + const int32_t padding_length) { + InternalLookupResponse resp; + return resp; + }); + + return std::move(mock_remote_lookup_client_1); + } + if (ip == "3") { + auto mock_remote_lookup_client_1 = + std::make_unique(); + const std::vector key_list_remote = {"longkey1", + "randomkey3"}; + InternalLookupRequest request; + request.mutable_keys()->Assign(key_list_remote.begin(), + key_list_remote.end()); + const std::string serialized_request = request.SerializeAsString(); + EXPECT_CALL(*mock_remote_lookup_client_1, + GetValues(serialized_request, testing::_)) + .WillOnce([&](const std::string_view serialized_message, + const int32_t padding_length) { + EXPECT_EQ(total_length, + (serialized_message.size() + padding_length)); + InternalLookupResponse resp; + SingleLookupResult result; + result.set_value("longkey1value"); + (*resp.mutable_kv_pairs())["longkey1"] = result; + SingleLookupResult result2; + result2.set_value("randomkey3value"); + (*resp.mutable_kv_pairs())["randomkey3"] = result2; + return resp; + }); + + return std::move(mock_remote_lookup_client_1); + } + // ip == "0" + return std::make_unique(); + }); + + InternalLookupResponse response; + grpc::ClientContext context; + + lookup_service_ = std::make_unique( + mock_metrics_recorder_, mock_cache_, num_shards, shard_num_, + *(*shard_manager)); + grpc::ServerBuilder builder; + builder.RegisterService(lookup_service_.get()); + server_ = (builder.BuildAndStart()); + stub_ = InternalLookupService::NewStub( + server_->InProcessChannel(grpc::ChannelArguments())); + + grpc::Status status = stub_->InternalLookup(&context, request, &response); + InternalLookupResponse expected; + TextFormat::ParseFromString( + R"pb( + kv_pairs { + key: "key1" + value { value: "value1" } + } + kv_pairs { + key: "key2" + value { value: "value2" } + } + kv_pairs { + key: "key3" + value { value: "value3" } + } + kv_pairs { + key: "key4" + value { value: "key4value" } + } + kv_pairs { + key: "longkey1" + value { value: "longkey1value" } + } + kv_pairs { + key: "randomkey3" + value { value: "randomkey3value" } + } + kv_pairs { + key: "randomkey5" + value { status { code: 5 message: "Key not found" } } + } + kv_pairs { key: "verylongkey2" + value { value: "verylongkey2value" } + )pb", + &expected); + EXPECT_THAT(response, EqualsProto(expected)); + server_->Shutdown(); + server_->Wait(); +} + +} // namespace +} // namespace kv_server diff --git a/components/internal_server/string_padder.cc b/components/internal_server/string_padder.cc new file mode 100644 index 00000000..663b2a4d --- /dev/null +++ b/components/internal_server/string_padder.cc @@ -0,0 +1,49 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "components/internal_server/string_padder.h" + +#include + +#include "glog/logging.h" +#include "quiche/common/quiche_data_reader.h" +#include "quiche/common/quiche_data_writer.h" + +namespace kv_server { + +std::string Pad(std::string_view string_to_pad, int32_t extra_padding) { + int output_size = sizeof(u_int32_t) + string_to_pad.size() + extra_padding; + std::string output(output_size, '0'); + + quiche::QuicheDataWriter data_writer(output.size(), output.data()); + data_writer.WriteUInt32(string_to_pad.size()); + data_writer.WriteStringPiece(string_to_pad); + return output; +} + +absl::StatusOr Unpad(std::string_view padded_string) { + auto data_reader = quiche::QuicheDataReader(padded_string); + uint32_t string_size = 0; + if (!data_reader.ReadUInt32(&string_size)) { + return absl::InvalidArgumentError("Failed to read string size"); + } + VLOG(9) << "string size: " << string_size; + std::string_view output; + if (!data_reader.ReadStringPiece(&output, string_size)) { + return absl::InvalidArgumentError("Failed to read a string"); + } + VLOG(9) << "string: " << output; + return std::string(output); +} + +} // namespace kv_server diff --git a/components/internal_server/string_padder.h b/components/internal_server/string_padder.h new file mode 100644 index 00000000..994150c8 --- /dev/null +++ b/components/internal_server/string_padder.h @@ -0,0 +1,36 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_INTERNAL_SERVER_STRING_PADDER_H_ +#define COMPONENTS_INTERNAL_SERVER_STRING_PADDER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace kv_server { +// Returns the string of the following format: +// [32 bit unsigned int][string_to_pad][padding] +// length data filler +// filler.size() == extra_padding +std::string Pad(std::string_view string_to_pad, int32_t extra_padding); +// Takes the string padded with the method above OR in the same format +// and returns the string. +absl::StatusOr Unpad(std::string_view padded_string); +} // namespace kv_server + +#endif // COMPONENTS_INTERNAL_SERVER_STRING_PADDER_H_ diff --git a/components/internal_server/string_padder_test.cc b/components/internal_server/string_padder_test.cc new file mode 100644 index 00000000..0f5bb28d --- /dev/null +++ b/components/internal_server/string_padder_test.cc @@ -0,0 +1,66 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/internal_server/string_padder.h" + +#include + +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +TEST(PadUnpad, Success) { + const std::string_view kTestString = "string to pad"; + const int32_t padding_size = 100; + auto padded_string = kv_server::Pad(kTestString, padding_size); + int32_t expected_length = + sizeof(u_int32_t) + kTestString.size() + padding_size; + EXPECT_EQ(expected_length, padded_string.size()); + auto original_string_status = Unpad(padded_string); + ASSERT_TRUE(original_string_status.ok()); + EXPECT_EQ(*original_string_status, kTestString); +} + +TEST(PadUnpadZeroPadding, Success) { + const std::string_view kTestString = "string to pad"; + const int32_t padding_size = 0; + auto padded_string = kv_server::Pad(kTestString, padding_size); + int32_t expected_length = + sizeof(u_int32_t) + kTestString.size() + padding_size; + EXPECT_EQ(expected_length, padded_string.size()); + auto original_string_status = Unpad(padded_string); + ASSERT_TRUE(original_string_status.ok()); + EXPECT_EQ(*original_string_status, kTestString); +} + +TEST(PadUnpadEmtpyString, Success) { + const std::string_view kTestString = ""; + const int32_t padding_size = 100; + auto padded_string = kv_server::Pad(kTestString, padding_size); + int32_t expected_length = + sizeof(u_int32_t) + kTestString.size() + padding_size; + EXPECT_EQ(expected_length, padded_string.size()); + auto original_string_status = Unpad(padded_string); + ASSERT_TRUE(original_string_status.ok()); + EXPECT_EQ(*original_string_status, kTestString); +} + +TEST(UnpadFailure, Success) { + auto original_string_status = Unpad("garbage"); + ASSERT_FALSE(original_string_status.ok()); +} + +} // namespace +} // namespace kv_server diff --git a/components/query/BUILD b/components/query/BUILD new file mode 100644 index 00000000..a5c4e2e4 --- /dev/null +++ b/components/query/BUILD @@ -0,0 +1,125 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_bison//bison:bison.bzl", "bison_cc_library") +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") +load("@rules_flex//flex:flex.bzl", "flex_cc_library") + +package(default_visibility = [ + "//components:__subpackages__", +]) + +cc_library( + name = "sets", + srcs = [ + ], + hdrs = [ + "sets.h", + ], + deps = [ + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +cc_library( + name = "ast", + srcs = [ + "ast.cc", + ], + hdrs = [ + "ast.h", + ], + deps = [ + ":sets", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:bind_front", + ], +) + +cc_test( + name = "ast_test", + size = "small", + srcs = [ + "ast_test.cc", + ], + deps = [ + ":ast", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "driver", + srcs = [ + "driver.cc", + ], + hdrs = [ + "driver.h", + "scanner.h", + ], + deps = [ + ":ast", + ":sets", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@rules_flex//flex:current_flex_toolchain", + ], +) + +cc_test( + name = "driver_test", + size = "small", + srcs = [ + "driver_test.cc", + ], + deps = [ + ":driver", + ":parser", + ":scanner", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + ], +) + +# yy extension required to produce .cc files instead of .c. +bison_cc_library( + name = "parser", + src = "parser.yy", + deps = [ + ":driver", + ], +) + +flex_cc_library( + name = "scanner", + src = "scanner.ll", + deps = [ + ":parser", + ], +) + +cc_test( + name = "scanner_test", + size = "small", + srcs = [ + "scanner_test.cc", + ], + deps = [ + ":scanner", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/components/query/ast.cc b/components/query/ast.cc new file mode 100644 index 00000000..41529b40 --- /dev/null +++ b/components/query/ast.cc @@ -0,0 +1,130 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/query/ast.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "components/query/sets.h" + +namespace kv_server { + +namespace { +// Traverses the binary tree starting at root. +// Returns a vector of `Node`s in post order. +// This is represents the infix input as postfix. +// Postfix can then be more easily evaluated. +std::vector PostOrderTraversal(const Node* root) { + std::vector result; + std::vector stack; + stack.push_back(root); + while (!stack.empty()) { + const Node* top = stack.back(); + stack.pop_back(); + result.push_back(top); + if (top->Left()) { + stack.push_back(top->Left()); + } + if (top->Right()) { + stack.push_back(top->Right()); + } + } + std::reverse(result.begin(), result.end()); + return result; +} + +} // namespace + +void ASTVisitor::Visit(const OpNode& node, std::vector& stack) { + KVSetView right = std::move(stack.back()); + stack.pop_back(); + KVSetView left = std::move(stack.back()); + stack.pop_back(); + stack.emplace_back(node.Op(std::move(left), std::move(right))); +} + +void ASTVisitor::Visit(const ValueNode& node, std::vector& stack) { + stack.emplace_back(node.Lookup()); +} + +KVSetView Compute(const std::vector& postorder) { + std::vector stack; + ASTVisitor visitor; + // Apply the operations on the postorder stack + for (const auto* node : postorder) { + node->Accept(visitor, stack); + } + return stack.back(); +} + +KVSetView Eval(const Node& node) { + std::vector postorder = PostOrderTraversal(&node); + return Compute(postorder); +} + +void OpNode::Accept(ASTVisitor& visitor, std::vector& stack) const { + visitor.Visit(*this, stack); +} + +absl::flat_hash_set OpNode::Keys() const { + std::vector nodes; + absl::flat_hash_set key_set; + nodes.push_back(this); + while (!nodes.empty()) { + const Node* next = nodes.back(); + nodes.pop_back(); + const Node* left = next->Left(); + const Node* right = next->Right(); + if (left == nullptr && right == nullptr) { + // ValueNode + absl::flat_hash_set value_keys = next->Keys(); + assert(value_keys.size() == 1); + key_set.merge(std::move(value_keys)); + } + if (left != nullptr) { + nodes.push_back(left); + } + if (right != nullptr) { + nodes.push_back(right); + } + } + return key_set; +} + +ValueNode::ValueNode( + absl::AnyInvocable lookup_fn, + std::string key) + : lookup_fn_(absl::bind_front(std::move(lookup_fn), key)), + key_(std::move(key)) {} + +void ValueNode::Accept(ASTVisitor& visitor, + std::vector& stack) const { + visitor.Visit(*this, stack); +} + +absl::flat_hash_set ValueNode::Keys() const { + // Return a set containing a view into this instances, `key_`. + // Be sure that the reference is not to any temp string. + return { + {key_}, + }; +} + +KVSetView ValueNode::Lookup() const { return lookup_fn_(); } + +} // namespace kv_server diff --git a/components/query/ast.h b/components/query/ast.h new file mode 100644 index 00000000..c88aee5c --- /dev/null +++ b/components/query/ast.h @@ -0,0 +1,122 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_QUERY_AST_H_ +#define COMPONENTS_QUERY_AST_H_ +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" +#include "absl/functional/bind_front.h" +#include "components/query/sets.h" + +namespace kv_server { +class ASTVisitor; + +// All set operations operate on a reference to the data in the DB +// This means that the data in the DB must be locked throughout the lifetime of +// the result. +using KVSetView = absl::flat_hash_set; + +class Node { + public: + virtual ~Node() = default; + virtual Node* Left() const { return nullptr; } + virtual Node* Right() const { return nullptr; } + // Return all Keys associated with ValueNodes in the tree. + virtual absl::flat_hash_set Keys() const = 0; + // Uses the Visitor pattern for the concrete class + // to mutate the stack accordingly for `Eval` (ValueNode vs. OpNode) + virtual void Accept(ASTVisitor& visitor, + std::vector& stack) const = 0; +}; + +// The value associated with a `ValueNode` is the set with its associated `key`. +class ValueNode : public Node { + public: + ValueNode(absl::AnyInvocable lookup_fn, + std::string key); + absl::flat_hash_set Keys() const override; + KVSetView Lookup() const; + void Accept(ASTVisitor& visitor, + std::vector& stack) const override; + + private: + absl::AnyInvocable lookup_fn_; + std::string key_; +}; + +class OpNode : public Node { + public: + OpNode(std::unique_ptr left, std::unique_ptr right) + : left_(std::move(left)), right_(std::move(right)) {} + absl::flat_hash_set Keys() const override; + inline Node* Left() const override { return left_.get(); } + inline Node* Right() const override { return right_.get(); } + // Computes the operation over the `left` and `right` nodes. + virtual KVSetView Op(KVSetView left, KVSetView right) const = 0; + void Accept(ASTVisitor& visitor, + std::vector& stack) const override; + + private: + std::unique_ptr left_; + std::unique_ptr right_; +}; + +class UnionNode : public OpNode { + public: + using OpNode::OpNode; + inline KVSetView Op(KVSetView left, KVSetView right) const override { + return Union(std::move(left), std::move(right)); + } +}; + +class IntersectionNode : public OpNode { + public: + using OpNode::OpNode; + inline KVSetView Op(KVSetView left, KVSetView right) const override { + return Intersection(std::move(left), std::move(right)); + } +}; + +class DifferenceNode : public OpNode { + public: + using OpNode::OpNode; + inline KVSetView Op(KVSetView left, KVSetView right) const override { + return Difference(std::move(left), std::move(right)); + } +}; + +// Creates execution plan and runs it. +KVSetView Eval(const Node& node); + +// Responsible for mutating the stack with the given `Node`. +// Avoids downcasting for subclass specific behaviors. +class ASTVisitor { + public: + // Applies the operation to the top two values on the stack. + // Replaces the top two values with the result. + void Visit(const OpNode& node, std::vector& stack); + // Pushes the result of `Lookup` to the stack. + void Visit(const ValueNode& node, std::vector& stack); +}; + +} // namespace kv_server +#endif // COMPONENTS_QUERY_AST_H_ diff --git a/components/query/ast_test.cc b/components/query/ast_test.cc new file mode 100644 index 00000000..8be4bbd2 --- /dev/null +++ b/components/query/ast_test.cc @@ -0,0 +1,151 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/query/ast.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +const absl::flat_hash_map> + kDb = { + {"A", {"a", "b", "c"}}, + {"B", {"b", "c", "d"}}, + {"C", {"c", "d", "e"}}, + {"D", {"d", "e", "f"}}, +}; + +absl::flat_hash_set Lookup(std::string_view key) { + const auto& it = kDb.find(key); + if (it != kDb.end()) { + return it->second; + } + return {}; +} + +TEST(AstTest, Value) { + ValueNode value(Lookup, "A"); + EXPECT_EQ(Eval(value), Lookup("A")); + ValueNode value2(Lookup, "B"); + EXPECT_EQ(Eval(value2), Lookup("B")); + ValueNode value3(Lookup, "C"); + EXPECT_EQ(Eval(value3), Lookup("C")); + ValueNode value4(Lookup, "D"); + EXPECT_EQ(Eval(value4), Lookup("D")); + ValueNode value5(Lookup, "E"); + EXPECT_EQ(Eval(value5), Lookup("E")); +} + +TEST(AstTest, Union) { + std::unique_ptr a = std::make_unique(Lookup, "A"); + std::unique_ptr b = std::make_unique(Lookup, "B"); + UnionNode op(std::move(a), std::move(b)); + absl::flat_hash_set expected = {"a", "b", "c", "d"}; + EXPECT_EQ(Eval(op), expected); +} + +TEST(AstTest, UnionSelf) { + std::unique_ptr a = std::make_unique(Lookup, "A"); + std::unique_ptr a2 = std::make_unique(Lookup, "A"); + UnionNode op(std::move(a), std::move(a2)); + absl::flat_hash_set expected = {"a", "b", "c"}; + EXPECT_EQ(Eval(op), expected); +} + +TEST(AstTest, Intersection) { + std::unique_ptr a = std::make_unique(Lookup, "A"); + std::unique_ptr b = std::make_unique(Lookup, "B"); + IntersectionNode op(std::move(a), std::move(b)); + absl::flat_hash_set expected = {"b", "c"}; + EXPECT_EQ(Eval(op), expected); +} + +TEST(AstTest, IntersectionSelf) { + std::unique_ptr a = std::make_unique(Lookup, "A"); + std::unique_ptr a2 = std::make_unique(Lookup, "A"); + IntersectionNode op(std::move(a), std::move(a2)); + absl::flat_hash_set expected = {"a", "b", "c"}; + EXPECT_EQ(Eval(op), expected); +} + +TEST(AstTest, Difference) { + std::unique_ptr a = std::make_unique(Lookup, "A"); + std::unique_ptr b = std::make_unique(Lookup, "B"); + DifferenceNode op(std::move(a), std::move(b)); + absl::flat_hash_set expected = {"a"}; + EXPECT_EQ(Eval(op), expected); + + std::unique_ptr a2 = std::make_unique(Lookup, "A"); + std::unique_ptr b2 = std::make_unique(Lookup, "B"); + DifferenceNode op2(std::move(b2), std::move(a2)); + absl::flat_hash_set expected2 = {"d"}; + EXPECT_EQ(Eval(op2), expected2); +} + +TEST(AstTest, DifferenceSelf) { + std::unique_ptr a = std::make_unique(Lookup, "A"); + std::unique_ptr a2 = std::make_unique(Lookup, "A"); + DifferenceNode op(std::move(a), std::move(a2)); + absl::flat_hash_set expected = {}; + EXPECT_EQ(Eval(op), expected); +} + +TEST(AstTest, All) { + // (A-B) | (C&D) = + // {a} | {d,e} = + // {a, d, e} + std::unique_ptr a = std::make_unique(Lookup, "A"); + std::unique_ptr b = std::make_unique(Lookup, "B"); + std::unique_ptr c = std::make_unique(Lookup, "C"); + std::unique_ptr d = std::make_unique(Lookup, "D"); + std::unique_ptr left = + std::make_unique(std::move(a), std::move(b)); + std::unique_ptr right = + std::make_unique(std::move(c), std::move(d)); + UnionNode center(std::move(left), std::move(right)); + absl::flat_hash_set expected = {"a", "d", "e"}; + EXPECT_EQ(Eval(center), expected); +} + +TEST(AstTest, ValueNodeKeys) { + ValueNode v(Lookup, "A"); + EXPECT_THAT(v.Keys(), testing::UnorderedElementsAre("A")); +} + +TEST(AstTest, OpNodeKeys) { + std::unique_ptr a = std::make_unique(Lookup, "A"); + std::unique_ptr b = std::make_unique(Lookup, "B"); + DifferenceNode op(std::move(b), std::move(a)); + EXPECT_THAT(op.Keys(), testing::UnorderedElementsAre("A", "B")); +} + +TEST(AstTest, DupeNodeKeys) { + std::unique_ptr a = std::make_unique(Lookup, "A"); + std::unique_ptr b = std::make_unique(Lookup, "B"); + std::unique_ptr c = std::make_unique(Lookup, "C"); + std::unique_ptr a2 = std::make_unique(Lookup, "A"); + std::unique_ptr left = + std::make_unique(std::move(a), std::move(b)); + std::unique_ptr right = + std::make_unique(std::move(c), std::move(a2)); + UnionNode center(std::move(left), std::move(right)); + EXPECT_THAT(center.Keys(), testing::UnorderedElementsAre("A", "B", "C")); +} + +} // namespace +} // namespace kv_server diff --git a/components/query/driver.cc b/components/query/driver.cc new file mode 100644 index 00000000..20c53217 --- /dev/null +++ b/components/query/driver.cc @@ -0,0 +1,55 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/query/driver.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/bind_front.h" +#include "components/query/ast.h" + +namespace kv_server { + +Driver::Driver(absl::AnyInvocable( + std::string_view key) const> + lookup_fn) + : lookup_fn_(std::move(lookup_fn)) {} + +absl::flat_hash_set Driver::Lookup( + std::string_view key) const { + return lookup_fn_(key); +} + +void Driver::SetAst(std::unique_ptr ast) { ast_ = std::move(ast); } + +absl::StatusOr> Driver::GetResult() + const { + if (!status_.ok()) { + return status_; + } + if (ast_ == nullptr) { + return absl::flat_hash_set(); + } + return Eval(*ast_); +} + +void Driver::SetError(std::string error) { + status_ = absl::InvalidArgumentError(std::move(error)); +} + +const Node* Driver::GetRootNode() const { return ast_.get(); } + +} // namespace kv_server diff --git a/components/query/driver.h b/components/query/driver.h new file mode 100644 index 00000000..51b0baaa --- /dev/null +++ b/components/query/driver.h @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_QUERY_DRIVER_H_ +#define COMPONENTS_QUERY_DRIVER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "components/query/ast.h" + +namespace kv_server { + +// Driver is responsible for: +// * Gathering the AST from the parser +// * Creating the exeuction plan +// * Executing the query +// * Storing the result +// Typical usage: +// Driver driver(LookupFn); +// std::istringstream stream(query); +// Scanner scanner(stream); +// Parser parse(driver, scanner); +// int parse_result = parse(); +// auto result = driver.GetResult(); +// parse_result is only expected to be non-zero when result is a failure. +class Driver { + public: + // `lookup_fn` returns the set associated with the provided key. + // If no key is present, an empty set should be returned. + explicit Driver(absl::AnyInvocable( + std::string_view key) const> + lookup_fn); + + // The result contains views of the data within the DB. + absl::StatusOr> GetResult() const; + + // Returns the the `Node` associated with `SetAst` + // or nullptr if unset. + const kv_server::Node* GetRootNode() const; + + // Clients should not call these functions, they are called by the parser. + void SetAst(std::unique_ptr); + void SetError(std::string error); + + // Looks up the set which contains a view of the DB data. + absl::flat_hash_set Lookup(std::string_view key) const; + + private: + absl::AnyInvocable(std::string_view key) + const> + lookup_fn_; + std::unique_ptr ast_; + absl::Status status_ = absl::OkStatus(); +}; + +} // namespace kv_server +#endif // COMPONENTS_QUERY_DRIVER_H_ diff --git a/components/query/driver_test.cc b/components/query/driver_test.cc new file mode 100644 index 00000000..dd916ec2 --- /dev/null +++ b/components/query/driver_test.cc @@ -0,0 +1,205 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/query/driver.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/notification.h" +#include "components/query/scanner.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +class DriverTest : public ::testing::Test { + protected: + void SetUp() override { + driver_ = + std::make_unique(absl::bind_front(&DriverTest::Lookup, this)); + for (int i = 1000; i < 1; i++) { + drivers_.emplace_back(absl::bind_front(&DriverTest::Lookup, this)); + } + } + + absl::flat_hash_set Lookup(std::string_view key) { + const auto& it = db_.find(key); + if (it != db_.end()) { + return it->second; + } + return {}; + } + + void Parse(const std::string& query) { + std::istringstream stream(query); + Scanner scanner(stream); + Parser parse(*driver_, scanner); + parse(); + } + + std::unique_ptr driver_; + std::vector drivers_; + const absl::flat_hash_map> + db_ = { + {"A", {"a", "b", "c"}}, + {"B", {"b", "c", "d"}}, + {"C", {"c", "d", "e"}}, + {"D", {"d", "e", "f"}}, + }; +}; + +TEST_F(DriverTest, EmptyQuery) { + Parse(""); + EXPECT_EQ(driver_->GetRootNode(), nullptr); + auto result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + absl::flat_hash_set expected; + EXPECT_EQ(*result, expected); +} + +TEST_F(DriverTest, InvalidTokensQuery) { + Parse("!! hi"); + EXPECT_EQ(driver_->GetRootNode(), nullptr); + auto result = driver_->GetResult(); + EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); +} + +TEST_F(DriverTest, InvalidOp) { + Parse("A UNION "); + EXPECT_EQ(driver_->GetRootNode(), nullptr); + auto result = driver_->GetResult(); + EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument); +} + +TEST_F(DriverTest, KeyOnly) { + Parse("A"); + auto result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c")); + + Parse("B"); + result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c", "d")); +} + +TEST_F(DriverTest, Union) { + Parse("A UNION B"); + auto result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c", "d")); + + Parse("A | B"); + result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "b", "c", "d")); +} + +TEST_F(DriverTest, Difference) { + Parse("A - B"); + auto result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); + + Parse("A DIFFERENCE B"); + result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); + + Parse("B - A"); + result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("d")); + + Parse("B DIFFERENCE A"); + result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("d")); +} + +TEST_F(DriverTest, Intersection) { + Parse("A INTERSECTION B"); + auto result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c")); + + Parse("A & B"); + result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("b", "c")); +} + +TEST_F(DriverTest, OrderOfOperations) { + Parse("A - B - C"); + auto result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("a")); + + Parse("A - (B - C)"); + result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "c")); +} + +TEST_F(DriverTest, MultipleOperations) { + Parse("(A-B) | (C&D)"); + auto result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "d", "e")); +} + +TEST_F(DriverTest, MultipleThreads) { + absl::Notification notification; + auto test_func = [¬ification](Driver* driver) { + notification.WaitForNotification(); + std::string query = "(A-B) | (C&D)"; + std::istringstream stream(query); + Scanner scanner(stream); + Parser parse(*driver, scanner); + parse(); + auto result = driver->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_THAT(*result, testing::UnorderedElementsAre("a", "d", "e")); + }; + + std::vector threads; + for (Driver& driver : drivers_) { + threads.push_back(std::thread(test_func, &driver)); + } + notification.Notify(); + for (auto& th : threads) { + th.join(); + } +} + +TEST_F(DriverTest, EmptyResults) { + // no overlap + Parse("A & D"); + auto result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(result->size(), 0); + + // missing key + Parse("A & E"); + result = driver_->GetResult(); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(result->size(), 0); +} + +} // namespace +} // namespace kv_server diff --git a/components/query/parser.yy b/components/query/parser.yy new file mode 100644 index 00000000..6f588ec9 --- /dev/null +++ b/components/query/parser.yy @@ -0,0 +1,89 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +%skeleton "lalr1.cc" // -*- C++ -*- +%require "3.3.2" +%language "c++" + +%define api.parser.class {Parser} +%define api.namespace {kv_server} +// Make sure we are thread safe and don't use yylval, yylloc +// https://www.gnu.org/software/bison/manual/html_node/Pure-Calling.html +%define api.pure full + +%code requires { + #include + #include + #include "components/query/ast.h" + + namespace kv_server { + class Scanner; + class Driver; + } // namespace kv_server +} +// The parsing context. +%param { Driver& driver } +%parse-param {Scanner& scanner} + + +%code { + #include "components/query/parser.h" + #include "components/query/driver.h" + #include "components/query/scanner.h" + #include "absl/functional/bind_front.h" + + #undef yylex + #define yylex(x) scanner.yylex(x) +} + +/* declare tokens */ +%token UNION INTERSECTION DIFFERENCE LPAREN RPAREN +%token VAR ERROR +%token YYEOF 0 + +// Allows defining the types returned by `term` and `exp below. +%define api.token.constructor +%define api.value.type variant + +%type > term +%nterm > exp + +/* Order of operations is left to right */ +%left UNION INTERSECTION DIFFERENCE + +%% + +query: + %empty + | query exp { driver.SetAst(std::move($2)); } + +exp: term {$$ = std::move($1);} + | exp UNION exp { $$ = std::make_unique(std::move($1), std::move($3)); } + | exp INTERSECTION exp { $$ = std::make_unique(std::move($1), std::move($3)); } + | exp DIFFERENCE exp { $$ = std::make_unique(std::move($1), std::move($3)); } + | LPAREN exp RPAREN { $$ = std::move($2); } + | exp exp { driver.SetError("Missing operator"); YYERROR; } + | ERROR { driver.SetError("Invalid token: " + $1); YYERROR;} + ; + +term: VAR { $$ = std::make_unique(absl::bind_front(&Driver::Lookup, &driver), std::move($1)); } + ; + +%% + +void +kv_server::Parser::error (const std::string& m) +{ + driver.SetError(m); +} diff --git a/components/query/scanner.h b/components/query/scanner.h new file mode 100644 index 00000000..fab7bd74 --- /dev/null +++ b/components/query/scanner.h @@ -0,0 +1,55 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_QUERY_SCANNER_H_ +#define COMPONENTS_QUERY_SCANNER_H_ + +// FlexLexer.h has no include guard so that it can be included multiple times, +// each time providing a different definition for yFlexLexer, to define +// multiple lexer base classes. +#ifndef yyFlexLexerOnce +#define yyFlexLexer KVFlexLexer +#include +#endif + +#include + +#include "components/query/parser.h" + +namespace kv_server { +class Driver; + +// Lexer responsible for converting input stream into tokens. +class Scanner : public yyFlexLexer { + public: + explicit Scanner(std::istream& input) : yyFlexLexer(&input) {} + ~Scanner() override = default; + virtual kv_server::Parser::symbol_type yylex(kv_server::Driver& driver); + + private: + // This function is never called. It is here to avoid + // Compilation warning: 'kv_server::Scanner::yylex' hides + // overloaded virtual function [-Woverloaded-virtual] + // defined in FlexLexer.h + // The above yylex function is called by parser.yy + int yylex() override { + assert(false); + return 1; + } +}; + +} // namespace kv_server +#endif // COMPONENTS_QUERY_SCANNER_H_ diff --git a/components/query/scanner.ll b/components/query/scanner.ll new file mode 100644 index 00000000..de8d4dbb --- /dev/null +++ b/components/query/scanner.ll @@ -0,0 +1,59 @@ +/* + Copyright 2023 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/* recognize tokens to support set operations*/ +%{ +#include "components/query/parser.h" +#include "components/query/scanner.h" + +#undef YY_DECL +#define YY_DECL kv_server::Parser::symbol_type \ + kv_server::Scanner::yylex(kv_server::Driver& driver) + +%} + +%option c++ +%option yyclass="Scanner" +%option prefix="KV" +%option noyywrap nounput noinput debug batch + +/* Valid key name characters, this list can be expanded as needed */ +VAR_CHARS [a-zA-Z0-9_:\.] +/* + Characters used for set operations or those that could be confused. + Allowing for +, =, / makes the quoted key name characters a superset of + base64 encoding. +*/ +OP_CHARS [|&\-+=/] + +%% +[ \t\r\n]+ {} +"(" { return kv_server::Parser::make_LPAREN(); } +")" { return kv_server::Parser::make_RPAREN();} +(?i:UNION) { return kv_server::Parser::make_UNION(); } +"|" { return kv_server::Parser::make_UNION(); } +(?i:INTERSECTION) { return kv_server::Parser::make_INTERSECTION(); } +"&" { return kv_server::Parser::make_INTERSECTION(); } +(?i:DIFFERENCE) { return kv_server::Parser::make_DIFFERENCE(); } +"-" { return kv_server::Parser::make_DIFFERENCE(); } +{VAR_CHARS}+ { return kv_server::Parser::make_VAR(yytext); } +"\""({VAR_CHARS}+|{OP_CHARS}+)+"\"" { + // Exclude the double quotes from the var name. + yytext[strlen(yytext)-1]='\0'; + return kv_server::Parser::make_VAR(yytext+1);} +. { return kv_server::Parser::make_ERROR(yytext); } +<> { return kv_server::Parser::make_YYEOF(); } +%% diff --git a/components/query/scanner_test.cc b/components/query/scanner_test.cc new file mode 100644 index 00000000..7cd75238 --- /dev/null +++ b/components/query/scanner_test.cc @@ -0,0 +1,185 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/query/scanner.h" + +#include +#include +#include +#include + +#include "absl/strings/str_join.h" +#include "components/query/driver.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +absl::flat_hash_set NeverUsedLookup(std::string_view key) { + // Should never be called + assert(0); + return {}; +} + +TEST(ScannerTest, Empty) { + std::istringstream stream(""); + Scanner scanner(stream); + Driver driver(NeverUsedLookup); + auto t = scanner.yylex(driver); + ASSERT_EQ(t.token(), Parser::token::YYEOF); +} + +TEST(ScannerTest, Var) { + std::istringstream stream("FOO foo Foo123"); + Scanner scanner(stream); + Driver driver(NeverUsedLookup); + // first token + auto t1 = scanner.yylex(driver); + ASSERT_EQ(t1.token(), Parser::token::VAR); + ASSERT_EQ(t1.value.as(), "FOO"); + + // second token + auto t2 = scanner.yylex(driver); + ASSERT_EQ(t2.token(), Parser::token::VAR); + ASSERT_EQ(t2.value.as(), "foo"); + + // third token + auto t3 = scanner.yylex(driver); + ASSERT_EQ(t3.token(), Parser::token::VAR); + ASSERT_EQ(t3.value.as(), "Foo123"); + + // done + auto t4 = scanner.yylex(driver); + ASSERT_EQ(t4.token(), Parser::token::YYEOF); +} + +TEST(ScannerTest, Parens) { + std::istringstream stream("()"); + Scanner scanner(stream); + Driver driver(NeverUsedLookup); + + auto t1 = scanner.yylex(driver); + ASSERT_EQ(t1.token(), Parser::token::LPAREN); + auto t2 = scanner.yylex(driver); + ASSERT_EQ(t2.token(), Parser::token::RPAREN); + auto t3 = scanner.yylex(driver); + ASSERT_EQ(t3.token(), Parser::token::YYEOF); +} + +TEST(ScannerTest, WhitespaceVar) { + std::istringstream stream(" FOO "); + Scanner scanner(stream); + Driver driver(NeverUsedLookup); + // first token + auto t1 = scanner.yylex(driver); + ASSERT_EQ(t1.token(), Parser::token::VAR); + ASSERT_EQ(t1.value.as(), "FOO"); + auto t2 = scanner.yylex(driver); + ASSERT_EQ(t2.token(), Parser::token::YYEOF); +} + +TEST(ScannerTest, NotAlphaNumVar) { + std::vector expected_vars = {"_", ":", ".", "A_B", "A:B", "A.B"}; + std::string token_list = absl::StrJoin(expected_vars, " "); + std::istringstream stream(token_list); + Scanner scanner(stream); + Driver driver(NeverUsedLookup); + + for (const auto& expected_var : expected_vars) { + auto token = scanner.yylex(driver); + ASSERT_EQ(token.token(), Parser::token::VAR); + ASSERT_EQ(token.value.as(), expected_var); + } + auto last = scanner.yylex(driver); + ASSERT_EQ(last.token(), Parser::token::YYEOF); +} + +TEST(ScannerTest, QuotedVar) { + std::istringstream stream( + " \"A1:Stuff\" \"A-B:C&D=E|F\" \"A+B\" \"A/B\" \"A\" "); + Scanner scanner(stream); + Driver driver(NeverUsedLookup); + + auto t1 = scanner.yylex(driver); + ASSERT_EQ(t1.token(), Parser::token::VAR); + ASSERT_EQ(t1.value.as(), "A1:Stuff"); + auto t2 = scanner.yylex(driver); + ASSERT_EQ(t2.token(), Parser::token::VAR); + ASSERT_EQ(t2.value.as(), "A-B:C&D=E|F"); + auto t3 = scanner.yylex(driver); + ASSERT_EQ(t3.token(), Parser::token::VAR); + ASSERT_EQ(t3.value.as(), "A+B"); + auto t4 = scanner.yylex(driver); + ASSERT_EQ(t4.token(), Parser::token::VAR); + ASSERT_EQ(t4.value.as(), "A/B"); + auto t5 = scanner.yylex(driver); + ASSERT_EQ(t5.token(), Parser::token::VAR); + ASSERT_EQ(t5.value.as(), "A"); + auto last = scanner.yylex(driver); + ASSERT_EQ(last.token(), Parser::token::YYEOF); +} + +TEST(ScannerTest, EmptyQuotedInvalid) { + std::istringstream stream(" \"\" "); + Scanner scanner(stream); + Driver driver(NeverUsedLookup); + + // Since it there is no valid match, we have 2 errors + // for each of the double quotes. + auto t1 = scanner.yylex(driver); + ASSERT_EQ(t1.token(), Parser::token::ERROR); + auto t2 = scanner.yylex(driver); + ASSERT_EQ(t2.token(), Parser::token::ERROR); + + auto t3 = scanner.yylex(driver); + ASSERT_EQ(t3.token(), Parser::token::YYEOF); +} + +TEST(ScannerTest, Operators) { + std::istringstream stream("| UNION & INTERSECTION - DIFFERENCE"); + Scanner scanner(stream); + Driver driver(NeverUsedLookup); + + auto t1 = scanner.yylex(driver); + ASSERT_EQ(t1.token(), Parser::token::UNION); + auto t2 = scanner.yylex(driver); + ASSERT_EQ(t2.token(), Parser::token::UNION); + + auto t3 = scanner.yylex(driver); + ASSERT_EQ(t3.token(), Parser::token::INTERSECTION); + auto t4 = scanner.yylex(driver); + ASSERT_EQ(t4.token(), Parser::token::INTERSECTION); + + auto t5 = scanner.yylex(driver); + ASSERT_EQ(t5.token(), Parser::token::DIFFERENCE); + auto t6 = scanner.yylex(driver); + ASSERT_EQ(t6.token(), Parser::token::DIFFERENCE); + + auto t7 = scanner.yylex(driver); + ASSERT_EQ(t7.token(), Parser::token::YYEOF); +} + +TEST(ScannerTest, Error) { + std::istringstream stream("!"); + Scanner scanner(stream); + Driver driver(NeverUsedLookup); + + auto t1 = scanner.yylex(driver); + ASSERT_EQ(t1.token(), Parser::token::ERROR); + auto t2 = scanner.yylex(driver); + ASSERT_EQ(t2.token(), Parser::token::YYEOF); +} + +} // namespace +} // namespace kv_server diff --git a/components/query/sets.h b/components/query/sets.h new file mode 100644 index 00000000..84d0dae4 --- /dev/null +++ b/components/query/sets.h @@ -0,0 +1,52 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_QUERY_SETS_H_ +#define COMPONENTS_QUERY_SETS_H_ +#include "absl/container/flat_hash_set.h" + +namespace kv_server { +template +absl::flat_hash_set Union(absl::flat_hash_set&& left, + absl::flat_hash_set&& right) { + auto& small = left.size() <= right.size() ? left : right; + auto& big = left.size() <= right.size() ? right : left; + big.insert(small.begin(), small.end()); + return big; +} + +template +absl::flat_hash_set Intersection(absl::flat_hash_set&& left, + absl::flat_hash_set&& right) { + auto& small = left.size() <= right.size() ? left : right; + const auto& big = left.size() <= right.size() ? right : left; + // Traverse the smaller set removing what is not in both. + absl::erase_if(small, [&big](const T& elem) { return !big.contains(elem); }); + return small; +} + +template +absl::flat_hash_set Difference(absl::flat_hash_set&& left, + absl::flat_hash_set&& right) { + // Remove all elements in right from left. + for (const auto& element : right) { + left.erase(element); + } + return left; +} + +} // namespace kv_server +#endif // COMPONENTS_QUERY_SETS_H_ diff --git a/components/sharding/BUILD b/components/sharding/BUILD new file mode 100644 index 00000000..576db1b6 --- /dev/null +++ b/components/sharding/BUILD @@ -0,0 +1,99 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +package(default_visibility = [ + "//components:__subpackages__", + "//tools:__subpackages__", +]) + +cc_library( + name = "shard_manager", + srcs = + [ + "shard_manager.cc", + ], + hdrs = [ + "shard_manager.h", + ], + deps = [ + "//components/internal_server:remote_lookup_client_impl", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "shard_manager_test", + size = "small", + srcs = + [ + "shard_manager_test.cc", + ], + deps = [ + ":mocks", + ":shard_manager", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "mocks", + testonly = 1, + hdrs = ["mocks.h"], + deps = [ + ":shard_manager", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "cluster_mappings_manager", + srcs = + [ + "cluster_mappings_manager.cc", + ], + hdrs = [ + "cluster_mappings_manager.h", + ], + deps = [ + ":shard_manager", + "//components/cloud_config:instance_client", + "//components/data/common:thread_manager", + "//components/errors:retry", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "cluster_mappings_manager_test", + size = "small", + srcs = + [ + "cluster_mappings_manager_test.cc", + ], + deps = [ + ":cluster_mappings_manager", + ":mocks", + "//components/data_server/server:mocks", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@google_privacysandbox_servers_common//src/cpp/telemetry:mocks", + ], +) diff --git a/components/sharding/cluster_mappings_manager.cc b/components/sharding/cluster_mappings_manager.cc new file mode 100644 index 00000000..e3f3fa82 --- /dev/null +++ b/components/sharding/cluster_mappings_manager.cc @@ -0,0 +1,130 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "components/sharding/cluster_mappings_manager.h" + +#include "components/errors/retry.h" + +namespace kv_server { + +absl::Status ClusterMappingsManager::Start(ShardManager& shard_manager) { + return thread_manager_->Start( + [this, &shard_manager]() { Watch(shard_manager); }); +} + +absl::Status ClusterMappingsManager::Stop() { + absl::Status status = sleep_for_->Stop(); + status.Update(thread_manager_->Stop()); + return status; +} + +bool ClusterMappingsManager::IsRunning() const { + return thread_manager_->IsRunning(); +} + +void ClusterMappingsManager::Watch(ShardManager& shard_manager) { + while (!thread_manager_->ShouldStop()) { + sleep_for_->Duration(absl::Milliseconds(update_interval_millis_)); + shard_manager.InsertBatch(GetClusterMappings()); + } +} + +std::vector> +ClusterMappingsManager::GetClusterMappings() { + absl::flat_hash_set instance_group_names; + for (int i = 0; i < num_shards_; i++) { + instance_group_names.insert( + absl::StrFormat("kv-server-%s-%d-instance-asg", environment_, i)); + } + auto& instance_client = instance_client_; + auto instance_group_instances = TraceRetryUntilOk( + [&instance_client, &instance_group_names] { + return instance_client.DescribeInstanceGroupInstances( + instance_group_names); + }, + "DescribeInstanceGroupInstances", &metrics_recorder_); + + return GroupInstancesToClusterMappings(instance_group_instances); +} + +absl::StatusOr ClusterMappingsManager::GetShardNumberOffAsgName( + std::string asg_name) const { + std::smatch match_result; + if (std::regex_match(asg_name, match_result, asg_regex_)) { + int32_t shard_num; + if (!absl::SimpleAtoi(std::string(match_result[1]), &shard_num)) { + std::string error = absl::StrFormat("Failed converting %s to int32.", + std::string(match_result[1])); + return absl::InvalidArgumentError(error); + } + return shard_num; + } + return absl::InvalidArgumentError(absl::StrCat("Can't parse: ", asg_name)); +} + +absl::flat_hash_map +ClusterMappingsManager::GetInstaceIdToIpMapping( + const std::vector& instance_group_instances) const { + absl::flat_hash_set instance_ids; + for (const auto& instance : instance_group_instances) { + if (instance.service_status != InstanceServiceStatus::kInService) { + continue; + } + instance_ids.insert(instance.id); + } + + auto& instance_client = instance_client_; + std::vector instances_detailed_info = TraceRetryUntilOk( + [&instance_client, &instance_ids] { + return instance_client.DescribeInstances(instance_ids); + }, + "DescribeInstances", &metrics_recorder_); + + absl::flat_hash_map mapping; + for (const auto& instance : instances_detailed_info) { + mapping.emplace(instance.id, instance.private_ip_address); + } + return mapping; +} + +std::vector> +ClusterMappingsManager::GroupInstancesToClusterMappings( + std::vector& instance_group_instances) const { + auto id_to_ip = GetInstaceIdToIpMapping(instance_group_instances); + std::vector> cluster_mappings(num_shards_); + for (const auto& instance : instance_group_instances) { + if (instance.service_status != InstanceServiceStatus::kInService) { + continue; + } + auto shard_num_status = GetShardNumberOffAsgName(instance.instance_group); + if (!shard_num_status.ok()) { + continue; + } + int32_t shard_num = *shard_num_status; + if (shard_num >= num_shards_) { + continue; + } + + const auto key_iter = id_to_ip.find(instance.id); + if (key_iter == id_to_ip.end() || key_iter->second.empty()) { + continue; + } + + cluster_mappings[shard_num].insert(key_iter->second); + } + return cluster_mappings; +} +} // namespace kv_server diff --git a/components/sharding/cluster_mappings_manager.h b/components/sharding/cluster_mappings_manager.h new file mode 100644 index 00000000..574133ac --- /dev/null +++ b/components/sharding/cluster_mappings_manager.h @@ -0,0 +1,91 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_SHARDING_CLUSTER_MAPPINGS_MANAGER_H_ +#define COMPONENTS_SHARDING_CLUSTER_MAPPINGS_MANAGER_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_format.h" +#include "components/cloud_config/instance_client.h" +#include "components/data/common/thread_manager.h" +#include "components/errors/retry.h" +#include "components/sharding/shard_manager.h" + +namespace kv_server { +// Continously updates shard manager's cluster mappings every +// `update_interval_millis`. +// Example: +// cluster_mappings_manager_ = std::make_unique( +// environment_, num_shards_, *metrics_recorder_, *instance_client_); +// cluster_mappings_manager_->Start(*shard_manager_); +class ClusterMappingsManager { + public: + ClusterMappingsManager( + std::string environment, int32_t num_shards, + privacy_sandbox::server_common::MetricsRecorder& metrics_recorder, + InstanceClient& instance_client, + std::unique_ptr sleep_for = std::make_unique(), + int32_t update_interval_millis = 1000) + : environment_{std::move(environment)}, + num_shards_{num_shards}, + metrics_recorder_{metrics_recorder}, + instance_client_{instance_client}, + asg_regex_{std::regex(absl::StrCat("kv-server-", environment_, + R"(-(\d+)-instance-asg)"))}, + thread_manager_(TheadManager::Create("Cluster mappings updater")), + sleep_for_(std::move(sleep_for)), + update_interval_millis_(update_interval_millis) { + CHECK_GT(num_shards, 1) + << "num_shards for ShardedLookupServiceImpl must be > 1"; + } + + // Retreives cluster mappings for the given `environment`, which are + // neceesary for the ShardManager. + // Mappings are: + // {shard_num --> {replica's private ip address 1, ... },...} + // {{0 -> {ip1, ip2}}, ....{num_shards-1}-> {ipN, ipN+1}} + std::vector> GetClusterMappings(); + absl::Status Start(ShardManager& shard_manager); + absl::Status Stop(); + bool IsRunning() const; + + private: + void Watch(ShardManager& shard_manager); + absl::StatusOr GetShardNumberOffAsgName(std::string asg_name) const; + std::vector> GroupInstancesToClusterMappings( + std::vector& instance_group_instances) const; + absl::flat_hash_map GetInstaceIdToIpMapping( + const std::vector& instance_group_instances) const; + + std::string environment_; + int32_t num_shards_; + privacy_sandbox::server_common::MetricsRecorder& metrics_recorder_; + InstanceClient& instance_client_; + std::regex asg_regex_; + std::unique_ptr thread_manager_; + std::unique_ptr sleep_for_; + int32_t update_interval_millis_; +}; + +} // namespace kv_server +#endif // COMPONENTS_SHARDING_CLUSTER_MAPPINGS_MANAGER_H_ diff --git a/components/sharding/cluster_mappings_manager_test.cc b/components/sharding/cluster_mappings_manager_test.cc new file mode 100644 index 00000000..a829ff2d --- /dev/null +++ b/components/sharding/cluster_mappings_manager_test.cc @@ -0,0 +1,253 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/sharding/cluster_mappings_manager.h" + +#include +#include +#include +#include + +#include "components/data_server/server/mocks.h" +#include "components/internal_server/constants.h" +#include "components/sharding/mocks.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "src/cpp/telemetry/mocks.h" + +namespace kv_server { +namespace { + +TEST(ClusterMappingsTest, RetrieveMappingsSuccessfully) { + std::string environment = "testenv"; + int32_t num_shards = 4; + privacy_sandbox::server_common::MockMetricsRecorder mock_metrics_recorder; + auto instance_client = std::make_unique(); + EXPECT_CALL(*instance_client, DescribeInstanceGroupInstances(::testing::_)) + .WillOnce( + [&](const absl::flat_hash_set& instance_group_names) { + absl::flat_hash_set instance_group_names_expected = { + "kv-server-testenv-0-instance-asg", + "kv-server-testenv-1-instance-asg", + "kv-server-testenv-2-instance-asg", + "kv-server-testenv-3-instance-asg"}; + EXPECT_THAT(instance_group_names, + testing::UnorderedElementsAreArray( + instance_group_names_expected)); + + InstanceInfo ii1 = { + .id = "id1", + .instance_group = "kv-server-testenv-0-instance-asg", + .service_status = InstanceServiceStatus::kInService, + }; + InstanceInfo ii2 = { + .id = "id2", + .instance_group = "kv-server-testenv-0-instance-asg", + .service_status = InstanceServiceStatus::kInService, + }; + InstanceInfo ii3 = { + .id = "id3", + .instance_group = "kv-server-testenv-1-instance-asg", + .service_status = InstanceServiceStatus::kInService, + }; + InstanceInfo ii4 = { + .id = "id4", + .instance_group = "kv-server-testenv-2-instance-asg", + .service_status = InstanceServiceStatus::kPreService, + }; + InstanceInfo ii5 = { + .id = "id5", + .instance_group = "garbage", + .service_status = InstanceServiceStatus::kPreService, + }; + std::vector instances{ii1, ii2, ii3, ii4, ii5}; + return instances; + }); + + EXPECT_CALL(*instance_client, DescribeInstances(::testing::_)) + .WillOnce( + [&](const absl::flat_hash_set& instance_group_names) { + absl::flat_hash_set instance_group_names_expected = { + "id1", "id2", "id3"}; + + EXPECT_THAT(instance_group_names, + testing::UnorderedElementsAreArray( + instance_group_names_expected)); + + InstanceInfo ii1 = {.id = "id1", .private_ip_address = "ip1"}; + InstanceInfo ii2 = {.id = "id2", .private_ip_address = "ip2"}; + InstanceInfo ii3 = {.id = "id3", .private_ip_address = "ip3"}; + + std::vector instances{ii1, ii2, ii3}; + return instances; + }); + + auto mgr = ClusterMappingsManager(environment, num_shards, + mock_metrics_recorder, *instance_client); + auto cluster_mappings = mgr.GetClusterMappings(); + EXPECT_EQ(cluster_mappings.size(), 4); + absl::flat_hash_set set0 = {"ip1", "ip2"}; + EXPECT_THAT(cluster_mappings[0], testing::UnorderedElementsAreArray(set0)); + absl::flat_hash_set set1 = {"ip3"}; + EXPECT_THAT(cluster_mappings[1], testing::UnorderedElementsAreArray(set1)); + absl::flat_hash_set set2; + EXPECT_THAT(cluster_mappings[2], testing::UnorderedElementsAreArray(set2)); + EXPECT_THAT(cluster_mappings[3], testing::UnorderedElementsAreArray(set2)); +} + +TEST(ClusterMappingsTest, RetrieveMappingsWithRetrySuccessfully) { + std::string environment = "testenv"; + int32_t num_shards = 2; + privacy_sandbox::server_common::MockMetricsRecorder mock_metrics_recorder; + auto instance_client = std::make_unique(); + EXPECT_CALL(*instance_client, DescribeInstanceGroupInstances(::testing::_)) + .WillOnce(testing::Return(absl::InternalError("Oops."))) + .WillOnce( + [&](const absl::flat_hash_set& instance_group_names) { + absl::flat_hash_set instance_group_names_expected = { + "kv-server-testenv-0-instance-asg", + "kv-server-testenv-1-instance-asg", + }; + EXPECT_THAT(instance_group_names, + testing::UnorderedElementsAreArray( + instance_group_names_expected)); + InstanceInfo ii1 = { + .id = "id1", + .instance_group = "kv-server-testenv-0-instance-asg", + .service_status = InstanceServiceStatus::kInService}; + + std::vector instances{ii1}; + return instances; + }); + + EXPECT_CALL(*instance_client, DescribeInstances(::testing::_)) + .WillOnce( + [&](const absl::flat_hash_set& instance_group_names) { + absl::flat_hash_set instance_group_names_expected = { + "id1"}; + + EXPECT_THAT(instance_group_names, + testing::UnorderedElementsAreArray( + instance_group_names_expected)); + + InstanceInfo ii1 = {.id = "id1", .private_ip_address = "ip1"}; + std::vector instances{ii1}; + return instances; + }); + + auto mgr = ClusterMappingsManager(environment, num_shards, + mock_metrics_recorder, *instance_client); + auto cluster_mappings = mgr.GetClusterMappings(); + EXPECT_EQ(cluster_mappings.size(), 2); + absl::flat_hash_set set0 = {"ip1"}; + EXPECT_THAT(cluster_mappings[0], testing::UnorderedElementsAreArray(set0)); + absl::flat_hash_set set1 = {}; + EXPECT_THAT(cluster_mappings[1], testing::UnorderedElementsAreArray(set1)); +} + +TEST(ClusterMappingsTest, UpdateMappings) { + std::string environment = "testenv"; + int32_t num_shards = 2; + privacy_sandbox::server_common::MockMetricsRecorder mock_metrics_recorder; + auto instance_client = std::make_unique(); + std::vector> cluster_mappings; + for (int i = 0; i < num_shards; i++) { + cluster_mappings.push_back({"some_ip"}); + } + auto shard_manager_status = + ShardManager::Create(num_shards, cluster_mappings); + ASSERT_TRUE(shard_manager_status.ok()); + auto shard_manager = std::move(*shard_manager_status); + absl::Notification finished; + EXPECT_CALL(*instance_client, DescribeInstanceGroupInstances(::testing::_)) + .WillOnce( + [&](const absl::flat_hash_set& instance_group_names) { + absl::flat_hash_set instance_group_names_expected = { + "kv-server-testenv-0-instance-asg", + "kv-server-testenv-1-instance-asg", + }; + EXPECT_THAT(instance_group_names, + testing::UnorderedElementsAreArray( + instance_group_names_expected)); + InstanceInfo ii1 = { + .id = "id10", + .instance_group = "kv-server-testenv-0-instance-asg", + .service_status = InstanceServiceStatus::kInService, + .private_ip_address = "ip10"}; + + std::vector instances{ii1}; + return instances; + }) + .WillOnce( + [&](const absl::flat_hash_set& instance_group_names) { + absl::flat_hash_set instance_group_names_expected = { + "kv-server-testenv-0-instance-asg", + "kv-server-testenv-1-instance-asg", + }; + EXPECT_THAT(instance_group_names, + testing::UnorderedElementsAreArray( + instance_group_names_expected)); + InstanceInfo ii1 = { + .id = "id20", + .instance_group = "kv-server-testenv-0-instance-asg", + .service_status = InstanceServiceStatus::kInService, + .private_ip_address = "ip20"}; + + std::vector instances{ii1}; + + finished.Notify(); + return instances; + }); + + EXPECT_CALL(*instance_client, DescribeInstances(::testing::_)) + .WillOnce( + [&](const absl::flat_hash_set& instance_group_names) { + absl::flat_hash_set instance_group_names_expected = { + "id10"}; + + EXPECT_THAT(instance_group_names, + testing::UnorderedElementsAreArray( + instance_group_names_expected)); + InstanceInfo ii1 = {.id = "id10", .private_ip_address = "ip10"}; + std::vector instances{ii1}; + return instances; + }) + .WillOnce( + [&](const absl::flat_hash_set& instance_group_names) { + absl::flat_hash_set instance_group_names_expected = { + "id20"}; + + EXPECT_THAT(instance_group_names, + testing::UnorderedElementsAreArray( + instance_group_names_expected)); + + InstanceInfo ii1 = {.id = "id20", .private_ip_address = "ip20"}; + std::vector instances{ii1}; + return instances; + }); + + auto mgr = + ClusterMappingsManager(environment, num_shards, mock_metrics_recorder, + *instance_client, std::make_unique(), + /* update_interval_millis */ 10); + mgr.Start(*shard_manager); + finished.WaitForNotification(); + ASSERT_TRUE(mgr.Stop().ok()); + EXPECT_FALSE(mgr.IsRunning()); + auto latest_ip = shard_manager->Get(0)->GetIpAddress(); + EXPECT_EQ(latest_ip, absl::StrCat("ip20:", kRemoteLookupServerPort)); +} + +} // namespace +} // namespace kv_server diff --git a/components/sharding/mocks.h b/components/sharding/mocks.h new file mode 100644 index 00000000..bb3ca5af --- /dev/null +++ b/components/sharding/mocks.h @@ -0,0 +1,34 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef COMPONENTS_DATA_SERVER_SHARDING_MOCKS_H_ +#define COMPONENTS_DATA_SERVER_SHARDING_MOCKS_H_ + +#include +#include +#include + +#include "components/sharding/shard_manager.h" +#include "gmock/gmock.h" + +namespace kv_server { +class MockRandomGenerator : public RandomGenerator { + public: + MockRandomGenerator() : RandomGenerator() {} + MOCK_METHOD(int64_t, Get, (int64_t upper_bound), (override)); +}; + +} // namespace kv_server + +#endif // COMPONENTS_DATA_SERVER_SHARDING_MOCKS_H_ diff --git a/components/sharding/shard_manager.cc b/components/sharding/shard_manager.cc new file mode 100644 index 00000000..dc3cb390 --- /dev/null +++ b/components/sharding/shard_manager.cc @@ -0,0 +1,170 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "components/sharding/shard_manager.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" + +namespace kv_server { +namespace { + +class RandomGeneratorImpl : public RandomGenerator { + public: + RandomGeneratorImpl() : generator_{rand_dev_()} {} + + int64_t Get(int64_t upper_bound) { + std::uniform_int_distribution distr(0, upper_bound - 1); + return distr(generator_); + } + + private: + std::random_device rand_dev_; + std::mt19937 generator_; +}; + +class ShardManagerImpl : public ShardManager { + public: + ShardManagerImpl( + int32_t num_shards, + std::function(const std::string& ip)> + client_factory, + std::unique_ptr random_generator) + : num_shards_{num_shards}, + client_factory_{client_factory}, + random_generator_{std::move(random_generator)} {} + + // taking in a set to exclude duplicates. + // set doesn't have an O(1) lookup --> converting to vector. + void InsertBatch(const std::vector>& + cluster_mappings) override { + if (cluster_mappings.size() != num_shards_) { + return; + } + std::vector> cluster_mappings_vector; + absl::MutexLock lock(&mutex_); + for (const auto& si : cluster_mappings) { + std::vector vc(si.begin(), si.end()); + for (const auto& ip : vc) { + const auto key_iter = remote_lookup_clients_.find(ip); + if (key_iter != remote_lookup_clients_.end()) { + continue; + } + remote_lookup_clients_.insert({ip, client_factory_(ip)}); + } + cluster_mappings_vector.emplace_back(std::move(vc)); + } + cluster_mappings_ = cluster_mappings_vector; + } + + RemoteLookupClient* Get(int64_t shard_num) const override { + absl::ReaderMutexLock lock(&mutex_); + if (shard_num < 0 || shard_num >= num_shards_ || + cluster_mappings_.size() != num_shards_) { + return nullptr; + } + const auto& shard_replicas = cluster_mappings_[shard_num]; + if (shard_replicas.size() == 0) { + return nullptr; + } + const auto replica_idx = random_generator_->Get(shard_replicas.size()); + const auto& ip_address = shard_replicas[replica_idx]; + const auto key_iter = remote_lookup_clients_.find(ip_address); + if (key_iter == remote_lookup_clients_.end()) { + return nullptr; + } else { + return key_iter->second.get(); + } + } + + private: + mutable absl::Mutex mutex_; + // (idx) shard id -> set of ip_addresses + std::vector> cluster_mappings_ + ABSL_GUARDED_BY(mutex_); + absl::flat_hash_map> + remote_lookup_clients_ ABSL_GUARDED_BY(mutex_); + int32_t num_shards_; + std::function(const std::string& ip)> + client_factory_; + std::unique_ptr random_generator_; +}; + +absl::Status ValidateMapping( + int32_t num_shards, + const std::vector>& cluster_mappings) { + if (num_shards < 2) { + return absl::InvalidArgumentError("Should have at least 2 clusters."); + } + + if (num_shards != cluster_mappings.size()) { + return absl::InvalidArgumentError(absl::StrFormat( + "`num_shards`(%d) does not match the size of `cluster_mappings` (%d)", + num_shards, cluster_mappings.size())); + } + + for (auto& set : cluster_mappings) { + if (set.empty()) { + return absl::InvalidArgumentError( + "Should have at least 1 replica per cluster."); + } + } + + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr> ShardManager::Create( + int32_t num_shards, + const std::vector>& cluster_mappings) { + auto validationStatus = ValidateMapping(num_shards, cluster_mappings); + if (!validationStatus.ok()) { + return validationStatus; + } + auto shard_manager = std::make_unique( + cluster_mappings.size(), + [](const std::string& ip) { return RemoteLookupClient::Create(ip); }, + std::make_unique()); + shard_manager->InsertBatch(std::move(cluster_mappings)); + return shard_manager; +} + +absl::StatusOr> ShardManager::Create( + int32_t num_shards, + const std::vector>& cluster_mappings, + std::unique_ptr random_generator, + std::function(const std::string& ip)> + client_factory) { + auto validationStatus = ValidateMapping(num_shards, cluster_mappings); + if (!validationStatus.ok()) { + return validationStatus; + } + auto shard_manager = std::make_unique( + cluster_mappings.size(), client_factory, std::move(random_generator)); + shard_manager->InsertBatch(std::move(cluster_mappings)); + return shard_manager; +} + +} // namespace kv_server diff --git a/components/sharding/shard_manager.h b/components/sharding/shard_manager.h new file mode 100644 index 00000000..e2ec7bc9 --- /dev/null +++ b/components/sharding/shard_manager.h @@ -0,0 +1,66 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_SHARDING_SHARD_MANAGER_H_ +#define COMPONENTS_SHARDING_SHARD_MANAGER_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "components/internal_server/remote_lookup_client.h" + +namespace kv_server { +// This class is useful for testing ShardManager +class RandomGenerator { + public: + virtual ~RandomGenerator() = default; + // Generate a random number in the interval [0,upper_bound) + virtual int64_t Get(int64_t upper_bound) = 0; +}; + +// This class allows communication between a UDF server and data servers. +// A mapping from a shard number to a set of ip addresses should be inserted +// periodically. The class allows to retreive a RemoteLookupClient assigned to a +// random ip address from the provided pool. ShardManager is thread safe. +class ShardManager { + public: + virtual ~ShardManager() = default; + // Insert the mapping of { shard number -> corresponding replicas' ip + // adresseses }. An index of the vector is the shard number. The length of the + // vector must be equal to the `num_shards`. + virtual void InsertBatch(const std::vector>& + cluster_mappings) = 0; + // Given the shard number, get a remote lookup client for one of the replicas + // in the pool. + virtual RemoteLookupClient* Get(int64_t shard_num) const = 0; + static absl::StatusOr> Create( + int32_t num_shards, + const std::vector>& cluster_mappings); + static absl::StatusOr> Create( + int32_t num_shards, + const std::vector>& cluster_mappings, + std::unique_ptr random_generator, + std::function(const std::string& ip)> + client_factory = [](const std::string& ip) { + return RemoteLookupClient::Create(ip); + }); +}; +} // namespace kv_server +#endif // COMPONENTS_SHARDING_SHARD_MANAGER_H_ diff --git a/components/sharding/shard_manager_test.cc b/components/sharding/shard_manager_test.cc new file mode 100644 index 00000000..c6484f01 --- /dev/null +++ b/components/sharding/shard_manager_test.cc @@ -0,0 +1,124 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/sharding/shard_manager.h" + +#include +#include +#include + +#include "components/internal_server/constants.h" +#include "components/sharding/mocks.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +TEST(ShardManagerTest, CreationNotInitialized) { + std::vector> cluster_mappings; + auto shard_manager = ShardManager::Create(4, std::move(cluster_mappings)); + ASSERT_FALSE(shard_manager.ok()); +} + +TEST(ShardManagerTest, CreationInitialized) { + int32_t num_shards = 4; + std::vector> cluster_mappings; + for (int i = 0; i < num_shards; i++) { + cluster_mappings.push_back({"some_ip"}); + } + auto shard_manager = + ShardManager::Create(num_shards, std::move(cluster_mappings)); + ASSERT_TRUE(shard_manager.ok()); +} + +TEST(ShardManagerTest, CreationNotInitializedMissingClusters) { + int32_t num_shards = 4; + std::vector> cluster_mappings; + for (int i = 0; i < 2; i++) { + cluster_mappings.push_back({"some_ip"}); + } + auto shard_manager = + ShardManager::Create(num_shards, std::move(cluster_mappings)); + ASSERT_FALSE(shard_manager.ok()); +} + +TEST(ShardManagerTest, CreationNotInitializedMissingReplicas) { + int32_t num_shards = 4; + std::vector> cluster_mappings; + for (int i = 0; i < 3; i++) { + cluster_mappings.push_back({"some_ip"}); + } + cluster_mappings.push_back({}); + auto shard_manager = + ShardManager::Create(num_shards, std::move(cluster_mappings)); + ASSERT_FALSE(shard_manager.ok()); +} + +TEST(ShardManagerTest, InsertRetrieveSuccess) { + int32_t num_shards = 4; + std::vector> cluster_mappings; + for (int i = 0; i < num_shards; i++) { + cluster_mappings.push_back({"some_ip"}); + } + auto shard_manager = + ShardManager::Create(num_shards, std::move(cluster_mappings)); + ASSERT_TRUE(shard_manager.ok()); + EXPECT_EQ(absl::StrCat("some_ip:", kRemoteLookupServerPort), + (*shard_manager)->Get(0)->GetIpAddress()); +} + +TEST(ShardManagerTest, InsertMissingReplicasRetrieveSuccess) { + int32_t num_shards = 4; + std::vector> cluster_mappings; + for (int i = 0; i < num_shards; i++) { + cluster_mappings.push_back({"some_ip"}); + } + auto shard_manager = + ShardManager::Create(num_shards, std::move(cluster_mappings)); + std::vector> cluster_mappings_2; + for (int i = 0; i < 3; i++) { + cluster_mappings_2.push_back({"some_ip"}); + } + cluster_mappings_2.push_back({}); + (*shard_manager)->InsertBatch(std::move(cluster_mappings_2)); + EXPECT_EQ(absl::StrCat("some_ip:", kRemoteLookupServerPort), + (*shard_manager)->Get(0)->GetIpAddress()); +} + +TEST(ShardManagerTest, InsertRetrieveTwoVersions) { + auto random_generator = std::make_unique(); + EXPECT_CALL(*random_generator, Get(testing::_)) + .WillOnce([]() { return 0; }) + .WillOnce([]() { return 1; }); + std::string instance_id_1 = "some_ip_1"; + std::string instance_id_2 = "some_ip_2"; + std::vector> cluster_mappings; + cluster_mappings.push_back({instance_id_2, instance_id_1}); + for (int i = 0; i < 3; i++) { + cluster_mappings.push_back({"some_ip_3"}); + } + auto shard_manager = ShardManager::Create(4, std::move(cluster_mappings), + std::move(random_generator)); + std::set etalon = { + absl::StrCat(instance_id_1, ":", kRemoteLookupServerPort), + absl::StrCat(instance_id_2, ":", kRemoteLookupServerPort)}; + std::set result; + result.insert(std::string((*shard_manager)->Get(0)->GetIpAddress())); + result.insert(std::string((*shard_manager)->Get(0)->GetIpAddress())); + EXPECT_EQ(etalon, result); +} + +} // namespace +} // namespace kv_server diff --git a/components/tools/BUILD b/components/tools/BUILD index f9de8239..1275c9b7 100644 --- a/components/tools/BUILD +++ b/components/tools/BUILD @@ -68,6 +68,7 @@ cc_binary( "//components/data_server/cache", "//components/data_server/cache:key_value_cache", "//components/data_server/data_loading:data_orchestrator", + "//components/udf:noop_udf_client", "//components/util:platform_initializer", "//public:base_types_cc_proto", "//public/data_loading:data_loading_fbs", @@ -145,7 +146,7 @@ cc_binary( deps = [ "//components/data/blob_storage:blob_storage_client", "//components/data/blob_storage:delta_file_notifier", - "//components/data/common:thread_notifier", + "//components/data/common:thread_manager", "//components/util:platform_initializer", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", @@ -163,10 +164,12 @@ cc_binary( "//components/util:platform_initializer", "//public/data_loading:data_loading_fbs", "//public/data_loading:filename_utils", + "//public/data_loading:records_utils", "//public/data_loading/readers:riegeli_stream_io", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/flags:usage", + "@com_google_absl//absl/strings", ], ) @@ -209,3 +212,18 @@ cc_binary( "@com_google_absl//absl/strings", ], ) + +cc_binary( + name = "query_toy", + srcs = ["query_toy.cc"], + visibility = ["//production/packaging:__subpackages__"], + deps = [ + "//components/query:driver", + "//components/query:scanner", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/strings", + ], +) diff --git a/components/tools/benchmarks/BUILD b/components/tools/benchmarks/BUILD index bd45758f..a2da55cf 100644 --- a/components/tools/benchmarks/BUILD +++ b/components/tools/benchmarks/BUILD @@ -54,6 +54,7 @@ cc_binary( "//components/data_server/cache:key_value_cache", "//components/util:platform_initializer", "//public/data_loading:data_loading_fbs", + "//public/data_loading:records_utils", "//public/data_loading/readers:riegeli_stream_io", "@com_github_google_glog//:glog", "@com_google_absl//absl/container:flat_hash_map", @@ -65,3 +66,21 @@ cc_binary( "@com_google_benchmark//:benchmark", ], ) + +cc_binary( + name = "cache_benchmark", + srcs = ["cache_benchmark.cc"], + deps = [ + ":benchmark_util", + "//components/data_server/cache", + "//components/data_server/cache:key_value_cache", + "@com_github_google_glog//:glog", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_benchmark//:benchmark", + ], +) diff --git a/components/tools/benchmarks/benchmark_util.cc b/components/tools/benchmarks/benchmark_util.cc index b9656b63..854953e9 100644 --- a/components/tools/benchmarks/benchmark_util.cc +++ b/components/tools/benchmarks/benchmark_util.cc @@ -17,6 +17,7 @@ #include "components/tools/benchmarks/benchmark_util.h" #include +#include #include #include "absl/strings/numbers.h" @@ -26,11 +27,10 @@ #include "public/data_loading/writers/delta_record_stream_writer.h" namespace kv_server::benchmark { -namespace { + std::string GenerateRandomString(const int64_t char_count) { return std::string(char_count, 'A' + (std::rand() % 15)); } -} // namespace absl::Status WriteRecords(int64_t num_records, const int64_t record_size, std::iostream& output_stream) { @@ -42,14 +42,15 @@ absl::Status WriteRecords(int64_t num_records, const int64_t record_size, while (num_records > 0) { const std::string key = absl::StrCat("foo", num_records); const std::string value = GenerateRandomString(record_size); - auto status = - (*record_writer) - ->WriteRecord(DeltaFileRecordStruct{ - .mutation_type = DeltaMutationType::Update, - .logical_commit_time = absl::ToUnixSeconds(absl::Now()), - .key = key, - .value = value, - }); + auto kv_mutation_record = KeyValueMutationRecordStruct{ + .mutation_type = KeyValueMutationType::Update, + .logical_commit_time = absl::ToUnixSeconds(absl::Now()), + .key = key, + .value = value, + }; + auto status = (*record_writer) + ->WriteRecord(DataRecordStruct{ + .record = std::move(kv_mutation_record)}); if (!status.ok()) { return status; } diff --git a/components/tools/benchmarks/benchmark_util.h b/components/tools/benchmarks/benchmark_util.h index 7173794f..99f3fae6 100644 --- a/components/tools/benchmarks/benchmark_util.h +++ b/components/tools/benchmarks/benchmark_util.h @@ -26,6 +26,9 @@ namespace kv_server::benchmark { +// Generates a random string with `char_count` characters. +std::string GenerateRandomString(const int64_t char_count); + // Write num_records, each with a size of record_size, to output_stream. absl::Status WriteRecords(int64_t num_records, int64_t record_size, std::iostream& output_stream); diff --git a/components/tools/benchmarks/benchmark_util_test.cc b/components/tools/benchmarks/benchmark_util_test.cc index 20013ac8..0eb2ea75 100644 --- a/components/tools/benchmarks/benchmark_util_test.cc +++ b/components/tools/benchmarks/benchmark_util_test.cc @@ -48,15 +48,21 @@ TEST(BenchmarkUtilTest, VerifyParseInt64ListFailsWithInvalidList) { TEST(BenchmarkUtilTest, VerifyWriteRecords) { std::stringstream data_stream; int64_t num_records = 1000; - int64_t record_size = 100; + int64_t record_size = 2048; auto status = WriteRecords(num_records, record_size, data_stream); EXPECT_TRUE(status.ok()) << status; DeltaRecordStreamReader record_reader(data_stream); - testing::MockFunction record_callback; + testing::MockFunction record_callback; EXPECT_CALL(record_callback, Call) .Times(num_records) - .WillRepeatedly([record_size](DeltaFileRecordStruct record) { - EXPECT_EQ(record.value.size(), record_size); + .WillRepeatedly([record_size](DataRecordStruct data_record) { + if (std::holds_alternative( + data_record.record)) { + auto kv_record = + std::get(data_record.record); + EXPECT_EQ(std::get(kv_record.value).size(), + record_size); + } return absl::OkStatus(); }); status = record_reader.ReadRecords(record_callback.AsStdFunction()); diff --git a/components/tools/benchmarks/cache_benchmark.cc b/components/tools/benchmarks/cache_benchmark.cc new file mode 100644 index 00000000..d487f20b --- /dev/null +++ b/components/tools/benchmarks/cache_benchmark.cc @@ -0,0 +1,237 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/time/time.h" +#include "benchmark/benchmark.h" +#include "components/data_server/cache/cache.h" +#include "components/data_server/cache/key_value_cache.h" +#include "components/tools/benchmarks/benchmark_util.h" +#include "glog/logging.h" + +ABSL_FLAG(std::vector, args_record_size, + std::vector({"100"}), + "Sizes of records that we want to insert into the cache."); +ABSL_FLAG(std::vector, args_dataset_size, + std::vector({"1000"}), + "Number of unique key/value pairs in the cache."); +ABSL_FLAG(std::vector, args_reads_keyset_size, + std::vector({"100"}), + "Sizes of the keyset that we want to read in each iteration for " + "read benchmarks."); +ABSL_FLAG(int64_t, args_benchmark_iterations, -1, + "Number of iterations to run each benchmark."); +ABSL_FLAG(int64_t, args_benchmark_max_concurrent_readers, -1, + "Maximum number of threads for benchmarking."); + +using kv_server::Cache; +using kv_server::GetKeyValueSetResult; +using kv_server::KeyValueCache; +using kv_server::benchmark::GenerateRandomString; +using kv_server::benchmark::ParseInt64List; + +// Format strings used to generate benchmark names. The "dsz", "ksz" and "rsz" +// components in the benchmark name represents the following: +// => dsz - dataset size, i.e., number of key/value pairs in cache when the +// benchmark was run. +// => ksz - keyset size, i.e., number of keys queried for each GetKeyValuePairs +// call. +// => rsz - record size, i.e., approximate byte size of each key/value pair in +// cache. +constexpr std::string_view kNoOpCacheGetKeyValuesFmt = + "BM_NoOpCache_GetKeyValuePairs/dsz:%d/ksz:%d/rsz:%d"; +constexpr std::string_view kKeyValueCacheGetKeyValuesFmt = + "BM_KeyValueCache_GetKeyValuePairs/dsz:%d/ksz:%d/rsz:%d"; + +class NoOpCache : public Cache { + public: + absl::flat_hash_map GetKeyValuePairs( + const std::vector& key_list) const override { + return {}; + }; + std::unique_ptr GetKeyValueSet( + const absl::flat_hash_set& key_set) const override { + return std::make_unique(); + } + void UpdateKeyValue(std::string_view key, std::string_view value, + int64_t logical_commit_time) override {} + void UpdateKeyValueSet(std::string_view key, + absl::Span value_set, + int64_t logical_commit_time) override {} + void DeleteKey(std::string_view key, int64_t logical_commit_time) override {} + void DeleteValuesInSet(std::string_view key, + absl::Span value_set, + int64_t logical_commit_time) override {} + void RemoveDeletedKeys(int64_t logical_commit_time) override {} + static std::unique_ptr Create() { + return std::make_unique(); + } + + private: + class NoOpGetKeyValueSetResult : public GetKeyValueSetResult { + absl::flat_hash_set GetValueSet( + std::string_view key) const override { + return {}; + } + void AddKeyValueSet( + absl::Mutex& key_mutex, std::string_view key, + const absl::flat_hash_set& value_set) override {} + }; +}; + +static Cache* shared_cache = nullptr; +void InitSharedNoOpCache(const benchmark::State&) { + shared_cache = NoOpCache::Create().release(); +} + +void InitSharedKeyValueCache(const benchmark::State&) { + shared_cache = KeyValueCache::Create().release(); +} + +void DeleteSharedCache(const benchmark::State&) { delete shared_cache; } + +struct BenchmarkArgs { + int64_t record_size; + int64_t reads_keyset_size; + int64_t dataset_size; +}; + +class CacheReadsBenchmark { + public: + CacheReadsBenchmark(BenchmarkArgs args, Cache& cache) + : args_(args), cache_(cache) {} + + std::vector PreLoadDataIntoCache() { + for (int i = 0; i < args_.dataset_size; i++) { + auto&& key = absl::StrCat("key", i); + auto&& value = GenerateRandomString(args_.record_size); + cache_.UpdateKeyValue(key, value, 10); + } + std::vector keyset; + for (int i = 0; i < args_.reads_keyset_size; i++) { + keyset.push_back(absl::StrCat("key", std::rand() % args_.dataset_size)); + } + return keyset; + } + + void RunGetKeyValuePair(const std::vector& keyset, + benchmark::State& state) { + for (auto _ : state) { + auto result = cache_.GetKeyValuePairs(keyset); + benchmark::DoNotOptimize(result); + } + } + + private: + BenchmarkArgs args_; + Cache& cache_; +}; + +std::vector ToStringViewList( + const std::vector& keyset) { + std::vector keyset_view; + for (const auto& key : keyset) { + keyset_view.emplace_back(key); + } + return keyset_view; +} + +void BM_Cache_GetKeyValuePairs(benchmark::State& state, BenchmarkArgs args) { + std::vector keyset; + static CacheReadsBenchmark* cache_benchmark = nullptr; + if (state.thread_index() == 0) { + cache_benchmark = new CacheReadsBenchmark(args, *shared_cache); + keyset = cache_benchmark->PreLoadDataIntoCache(); + } + cache_benchmark->RunGetKeyValuePair(ToStringViewList(keyset), state); +} + +// Registers a function to benchmark. +void RegisterBenchmark( + std::string name, BenchmarkArgs args, + std::function benchmark, + void (*benchmark_setup_fn)(const benchmark::State&) = nullptr, + void (*benchmark_teardown_fn)(const benchmark::State&) = nullptr) { + auto b = benchmark::RegisterBenchmark(name.c_str(), benchmark, args); + if (benchmark_setup_fn) { + b->Setup(benchmark_setup_fn); + } + if (benchmark_teardown_fn) { + b->Teardown(benchmark_teardown_fn); + } + if (absl::GetFlag(FLAGS_args_benchmark_iterations) > 0) { + b->Iterations(absl::GetFlag(FLAGS_args_benchmark_iterations)); + } + if (absl ::GetFlag(FLAGS_args_benchmark_max_concurrent_readers) > 0) { + b->ThreadRange(1, + absl::GetFlag(FLAGS_args_benchmark_max_concurrent_readers)); + } +} + +void RegisterCacheReadsBenchmarks() { + auto record_sizes = ParseInt64List(absl::GetFlag(FLAGS_args_record_size)); + auto keyset_sizes = + ParseInt64List(absl::GetFlag(FLAGS_args_reads_keyset_size)); + auto dataset_sizes = ParseInt64List(absl::GetFlag(FLAGS_args_dataset_size)); + for (auto dataset_size : dataset_sizes.value()) { + for (auto keyset_size : keyset_sizes.value()) { + for (auto record_size : record_sizes.value()) { + auto args = BenchmarkArgs{ + .record_size = record_size, + .reads_keyset_size = keyset_size, + .dataset_size = dataset_size, + }; + RegisterBenchmark( + absl::StrFormat(kNoOpCacheGetKeyValuesFmt, dataset_size, + keyset_size, record_size), + args, BM_Cache_GetKeyValuePairs, InitSharedNoOpCache, + DeleteSharedCache); + RegisterBenchmark( + absl::StrFormat(kKeyValueCacheGetKeyValuesFmt, dataset_size, + keyset_size, record_size), + args, BM_Cache_GetKeyValuePairs, InitSharedKeyValueCache, + DeleteSharedCache); + } + } + } +} + +// Microbenchmarks for Cache impelementations. Sample run: +// +// GLOG_logtostderr=1 bazel run -c opt \ +// //components/tools/benchmarks:cache_benchmark \ +// --//:instance=local \ +// --//:platform=local -- \ +// --benchmark_counters_tabular=true +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + ::benchmark::Initialize(&argc, argv); + absl::ParseCommandLine(argc, argv); + RegisterCacheReadsBenchmarks(); + ::benchmark::RunSpecifiedBenchmarks(); + ::benchmark::Shutdown(); + return 0; +} diff --git a/components/tools/benchmarks/data_loading_benchmark.cc b/components/tools/benchmarks/data_loading_benchmark.cc index c31f2904..50a0e48e 100644 --- a/components/tools/benchmarks/data_loading_benchmark.cc +++ b/components/tools/benchmarks/data_loading_benchmark.cc @@ -35,6 +35,8 @@ #include "glog/logging.h" #include "public/data_loading/data_loading_generated.h" #include "public/data_loading/readers/riegeli_stream_io.h" +#include "public/data_loading/records_utils.h" +#include "src/cpp/telemetry/metrics_recorder.h" #include "src/cpp/telemetry/telemetry_provider.h" ABSL_FLAG(std::string, data_directory, "", @@ -68,12 +70,19 @@ using kv_server::BlobReader; using kv_server::BlobStorageClient; using kv_server::Cache; using kv_server::ConcurrentStreamRecordReader; -using kv_server::DeltaFileRecord; -using kv_server::DeltaMutationType; +using kv_server::DataRecord; +using kv_server::DeserializeDataRecord; +using kv_server::GetKeyValueSetResult; +using kv_server::GetRecordValue; using kv_server::KeyValueCache; +using kv_server::KeyValueMutationRecord; +using kv_server::KeyValueMutationType; +using kv_server::Record; using kv_server::RecordStream; +using kv_server::Value; using kv_server::benchmark::ParseInt64List; using kv_server::benchmark::WriteRecords; +using privacy_sandbox::server_common::MetricsRecorder; using privacy_sandbox::server_common::TelemetryProvider; constexpr std::string_view kNoOpCacheNameFormat = @@ -119,13 +128,34 @@ class NoOpCache : public Cache { const std::vector& key_list) const override { return {}; }; + std::unique_ptr GetKeyValueSet( + const absl::flat_hash_set& key_set) const override { + return std::make_unique(); + } void UpdateKeyValue(std::string_view key, std::string_view value, int64_t logical_commit_time) override {} + void UpdateKeyValueSet(std::string_view key, + absl::Span value_set, + int64_t logical_commit_time) override {} void DeleteKey(std::string_view key, int64_t logical_commit_time) override {} + void DeleteValuesInSet(std::string_view key, + absl::Span value_set, + int64_t logical_commit_time) override {} void RemoveDeletedKeys(int64_t logical_commit_time) override {} static std::unique_ptr Create() { return std::make_unique(); } + + private: + class NoOpGetKeyValueSetResult : public kv_server::GetKeyValueSetResult { + absl::flat_hash_set GetValueSet( + std::string_view key) const override { + return {}; + } + void AddKeyValueSet( + absl::Mutex& key_mutex, std::string_view key, + const absl::flat_hash_set& value_set) override {} + }; }; BlobStorageClient::DataLocation GetBlobLocation() { @@ -143,11 +173,14 @@ int64_t GetBlobSize(BlobStorageClient& blob_client, return stream.tellg(); } -void BM_LoadDataIntoCache(benchmark::State& state, BenchmarkArgs args); +void BM_LoadDataIntoCache(benchmark::State& state, BenchmarkArgs args, + MetricsRecorder& metrics_recorder); -void RegisterBenchmark(std::string_view benchmark_name, BenchmarkArgs args) { - auto b = benchmark::RegisterBenchmark(benchmark_name.data(), - BM_LoadDataIntoCache, args); +void RegisterBenchmark(std::string_view benchmark_name, BenchmarkArgs args, + MetricsRecorder& metrics_recorder) { + auto b = + benchmark::RegisterBenchmark(benchmark_name.data(), BM_LoadDataIntoCache, + args, std::ref(metrics_recorder)); b->MeasureProcessCPUTime(); b->UseRealTime(); if (absl::GetFlag(FLAGS_args_benchmark_iterations) > 0) { @@ -156,7 +189,7 @@ void RegisterBenchmark(std::string_view benchmark_name, BenchmarkArgs args) { } // Registers benchmark -void RegisterBenchmarks() { +void RegisterBenchmarks(MetricsRecorder& metrics_recorder) { auto num_worker_threads = ParseInt64List(absl::GetFlag(FLAGS_args_reader_worker_threads)); auto client_max_conns = @@ -174,25 +207,60 @@ void RegisterBenchmarks() { }; RegisterBenchmark(absl::StrFormat(kNoOpCacheNameFormat, num_threads, num_connections, byte_range_mb), - args); + args, metrics_recorder); args.create_cache_fn = []() { return KeyValueCache::Create(); }; RegisterBenchmark(absl::StrFormat(kMutexCacheNameFormat, num_threads, num_connections, byte_range_mb), - args); + args, metrics_recorder); } } } } -void BM_LoadDataIntoCache(benchmark::State& state, BenchmarkArgs args) { +absl::Status ApplyUpdateMutation(const KeyValueMutationRecord& record, + Cache& cache) { + if (record.value_type() == Value::String) { + cache.UpdateKeyValue(record.key()->string_view(), + GetRecordValue(record), + record.logical_commit_time()); + return absl::OkStatus(); + } + if (record.value_type() == Value::StringSet) { + auto values = GetRecordValue>(record); + cache.UpdateKeyValueSet(record.key()->string_view(), absl::MakeSpan(values), + record.logical_commit_time()); + return absl::OkStatus(); + } + return absl::InvalidArgumentError( + absl::StrCat("Record with key: ", record.key()->string_view(), + " has unsupported value type: ", record.value_type())); +} + +absl::Status ApplyDeleteMutation(const KeyValueMutationRecord& record, + Cache& cache) { + if (record.value_type() == Value::String) { + cache.DeleteKey(record.key()->string_view(), record.logical_commit_time()); + return absl::OkStatus(); + } + if (record.value_type() == Value::StringSet) { + auto values = GetRecordValue>(record); + cache.DeleteValuesInSet(record.key()->string_view(), absl::MakeSpan(values), + record.logical_commit_time()); + return absl::OkStatus(); + } + return absl::InvalidArgumentError( + absl::StrCat("Record with key: ", record.key()->string_view(), + " has unsupported value type: ", record.value_type())); +} + +void BM_LoadDataIntoCache(benchmark::State& state, BenchmarkArgs args, + MetricsRecorder& metrics_recorder) { BlobStorageClient::ClientOptions options; options.max_range_bytes = args.client_max_range_mb * 1024 * 1024; options.max_connections = args.client_max_connections; - auto noop_metrics_recorder = - TelemetryProvider::GetInstance().CreateMetricsRecorder(); - auto blob_client = BlobStorageClient::Create(*noop_metrics_recorder, options); + auto blob_client = BlobStorageClient::Create(metrics_recorder, options); ConcurrentStreamRecordReader record_reader( - *noop_metrics_recorder, + metrics_recorder, /*stream_factory=*/ [blob_client = blob_client.get()]() { return std::make_unique( @@ -208,34 +276,37 @@ void BM_LoadDataIntoCache(benchmark::State& state, BenchmarkArgs args) { state.PauseTiming(); auto cache = args.create_cache_fn(); state.ResumeTiming(); - auto status = record_reader.ReadStreamRecords( - [&num_records_read, cache = cache.get()](std::string_view raw) { - num_records_read++; - auto record = flatbuffers::GetRoot(raw.data()); - auto recordVerifier = flatbuffers::Verifier( - reinterpret_cast(raw.data()), raw.size()); - if (!record->Verify(recordVerifier)) { - return absl::InvalidArgumentError("Invalid flatbuffer format"); - } + auto status = record_reader.ReadStreamRecords([&num_records_read, + cache = cache.get()]( + std::string_view raw) { + num_records_read++; + return DeserializeDataRecord(raw, [cache](const DataRecord& data_record) { + if (data_record.record_type() == Record::KeyValueMutationRecord) { + const auto* record = data_record.record_as_KeyValueMutationRecord(); switch (record->mutation_type()) { - case DeltaMutationType::Update: { - cache->UpdateKeyValue(record->key()->string_view(), - record->value()->string_view(), - record->logical_commit_time()); + case KeyValueMutationType::Update: { + if (auto status = ApplyUpdateMutation(*record, *cache); + status.ok()) { + return status; + } break; } - case DeltaMutationType::Delete: { - cache->DeleteKey(record->key()->string_view(), - record->logical_commit_time()); - break; + case KeyValueMutationType::Delete: { + if (auto status = ApplyDeleteMutation(*record, *cache); + status.ok()) { + return status; + } } default: - return absl::InvalidArgumentError(absl::StrCat( - "Invalid mutation type: ", - EnumNameDeltaMutationType(record->mutation_type()))); + return absl::InvalidArgumentError( + absl::StrCat("Invalid mutation type: ", + kv_server::EnumNameKeyValueMutationType( + record->mutation_type()))); } - return absl::OkStatus(); - }); + } + return absl::OkStatus(); + }); + }); benchmark::DoNotOptimize(status); } state.SetItemsProcessed(num_records_read); @@ -292,7 +363,7 @@ int main(int argc, char** argv) { } LOG(INFO) << "Done creating input file: " << GetBlobLocation(); } - RegisterBenchmarks(); + RegisterBenchmarks(*noop_metrics_recorder); ::benchmark::RunSpecifiedBenchmarks(); ::benchmark::Shutdown(); if (absl::GetFlag(FLAGS_create_input_file)) { diff --git a/components/tools/data_loading_analyzer.cc b/components/tools/data_loading_analyzer.cc index d876957e..d6324dbb 100644 --- a/components/tools/data_loading_analyzer.cc +++ b/components/tools/data_loading_analyzer.cc @@ -25,6 +25,7 @@ #include "components/data_server/cache/cache.h" #include "components/data_server/cache/key_value_cache.h" #include "components/data_server/data_loading/data_orchestrator.h" +#include "components/udf/noop_udf_client.h" #include "components/util/platform_initializer.h" #include "glog/logging.h" #include "public/base_types.pb.h" @@ -100,7 +101,7 @@ class ReadonlyStreamReaderFactory : public StreamRecordReaderFactory { absl::Cleanup reader_closer([&reader] { reader.Close(); }); std::string_view raw; while (reader.ReadRecord(raw)) { - auto record = flatbuffers::GetRoot(raw.data()); + auto record = flatbuffers::GetRoot(raw.data()); if (record->logical_commit_time() == 0) { LOG(INFO) << "This is a dummy log line (that should not be called) in " "order to read the record. A logical commit time of 0 is " @@ -135,6 +136,7 @@ std::vector OperationsFromFlag() { } absl::Status InitOnce(Operation operation) { + std::unique_ptr noop_udf_client = NewNoopUdfClient(); std::unique_ptr cache = KeyValueCache::Create(); std::unique_ptr metrics_recorder = TelemetryProvider::GetInstance().CreateMetricsRecorder(); @@ -175,7 +177,6 @@ absl::Status InitOnce(Operation operation) { realtime_option.realtime_notifier = RealtimeNotifier::Create(*metrics_recorder); realtime_options.push_back(std::move(realtime_option)); - maybe_data_orchestrator = DataOrchestrator::TryCreate( { .data_bucket = absl::GetFlag(FLAGS_bucket), @@ -185,6 +186,7 @@ absl::Status InitOnce(Operation operation) { .change_notifier = change_notifier, .delta_stream_reader_factory = *delta_stream_reader_factory, .realtime_options = realtime_options, + .udf_client = *noop_udf_client, }, *metrics_recorder); absl::Time end_time = absl::Now(); diff --git a/components/tools/delta_file_record_change_watcher.cc b/components/tools/delta_file_record_change_watcher.cc index aff79bfc..b23e6f4e 100644 --- a/components/tools/delta_file_record_change_watcher.cc +++ b/components/tools/delta_file_record_change_watcher.cc @@ -17,17 +17,22 @@ #include "absl/flags/flag.h" #include "absl/flags/parse.h" #include "absl/flags/usage.h" +#include "absl/strings/str_join.h" #include "components/data/realtime/delta_file_record_change_notifier.h" #include "components/util/platform_initializer.h" #include "public/constants.h" #include "public/data_loading/data_loading_generated.h" #include "public/data_loading/filename_utils.h" #include "public/data_loading/readers/riegeli_stream_io.h" +#include "public/data_loading/records_utils.h" #include "src/cpp/telemetry/telemetry_provider.h" ABSL_FLAG(std::string, sns_arn, "", "sns_arn"); using kv_server::DeltaFileRecordChangeNotifier; +using kv_server::GetRecordValue; +using kv_server::KeyValueMutationRecord; +using kv_server::Value; using privacy_sandbox::server_common::TelemetryProvider; void Print(std::string string_decoded) { @@ -39,7 +44,8 @@ void Print(std::string string_decoded) { auto record_reader = delta_stream_reader_factory->CreateReader(is); auto result = record_reader->ReadStreamRecords([](std::string_view raw) { - auto record = flatbuffers::GetRoot(raw.data()); + auto record = + flatbuffers::GetRoot(raw.data()); auto recordVerifier = flatbuffers::Verifier( reinterpret_cast(raw.data()), raw.size()); @@ -52,14 +58,25 @@ void Print(std::string string_decoded) { auto update_type = "update"; switch (record->mutation_type()) { - case kv_server::DeltaMutationType::Delete: { + case kv_server::KeyValueMutationType::Delete: { update_type = "delete"; break; } } + auto format_value_func = + [](const KeyValueMutationRecord& record) -> std::string { + if (record.value_type() == Value::String) { + return std::string(GetRecordValue(record)); + } + if (record.value_type() == Value::StringSet) { + return absl::StrJoin( + GetRecordValue>(record), ","); + } + return ""; + }; std::cout << "key: " << record->key()->string_view() << std::endl; - std::cout << "value: " << record->value()->string_view() << std::endl; + std::cout << "value: " << format_value_func(*record) << std::endl; std::cout << "logical_commit_time: " << record->logical_commit_time() << std::endl; std::cout << "update_type: " << update_type << std::endl; diff --git a/components/tools/delta_file_watcher_aws.cc b/components/tools/delta_file_watcher_aws.cc index 17be3499..b4c5415c 100644 --- a/components/tools/delta_file_watcher_aws.cc +++ b/components/tools/delta_file_watcher_aws.cc @@ -19,7 +19,7 @@ #include "absl/flags/usage.h" #include "components/data/blob_storage/blob_storage_client.h" #include "components/data/blob_storage/delta_file_notifier.h" -#include "components/data/common/thread_notifier.h" +#include "components/data/common/thread_manager.h" #include "components/util/platform_initializer.h" #include "src/cpp/telemetry/telemetry_provider.h" diff --git a/components/tools/delta_file_watcher_local.cc b/components/tools/delta_file_watcher_local.cc index db661636..0a92bb64 100644 --- a/components/tools/delta_file_watcher_local.cc +++ b/components/tools/delta_file_watcher_local.cc @@ -24,7 +24,7 @@ #include "components/data/blob_storage/blob_storage_client.h" #include "components/data/blob_storage/delta_file_notifier.h" #include "components/data/common/change_notifier.h" -#include "components/data/common/thread_notifier.h" +#include "components/data/common/thread_manager.h" #include "components/util/platform_initializer.h" #include "src/cpp/telemetry/telemetry_provider.h" diff --git a/components/tools/query_toy.cc b/components/tools/query_toy.cc new file mode 100644 index 00000000..4ee0191d --- /dev/null +++ b/components/tools/query_toy.cc @@ -0,0 +1,146 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This program can be run to query the hard coded database below, ex: +// bazel run components/tools:query_toy -- --query="A UNION B" +// results in: [a,b,c,d] +// Alternatively you can run in interactive, allowing to query multiple times. +// bazel run components/tools:query_toy + +#include + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/flags/usage.h" +#include "absl/strings/str_join.h" +#include "components/query/driver.h" +#include "components/query/scanner.h" + +ABSL_FLAG(std::string, query, "", + "If provided outputs the result to stdout. Does not enter " + "interactive mode. Interactive mode supplies user with repeated " + "query prompts."); + +absl::flat_hash_map> kDb = { + {"A", {"a", "b", "c"}}, + {"B", {"b", "c", "d"}}, + {"C", {"c", "d", "e"}}, + {"D", {"d", "e", "f"}}, +}; + +absl::flat_hash_set kEmptySet; + +template +std::string ToString(const T& set) { + std::vector sorted_set(set.begin(), set.end()); + std::sort(sorted_set.begin(), sorted_set.end()); + return absl::StrCat("[", absl::StrJoin(sorted_set, ","), "]"); +} + +std::string ToString( + const absl::flat_hash_map>& + db) { + // Get an alphabetically sorted list of string keys. + std::vector keys; + keys.reserve(db.size()); // Reserve space to avoid unnecessary reallocations + for (const auto& pair : db) { + keys.push_back(pair.first); + } + std::sort(keys.begin(), keys.end()); + std::string result = "{\n"; + for (const auto& key : keys) { + const auto& it = db.find(key); + if (it != db.end()) { + absl::StrAppend(&result, "\t{", key, ", ", ToString(it->second), "},\n"); + } + } + absl::StrAppend(&result, "}"); + return result; +} + +absl::flat_hash_set ToView( + const absl::flat_hash_set& values) { + absl::flat_hash_set result; + result.reserve(values.size()); + result.insert(values.begin(), values.end()); + return result; +} + +absl::StatusOr> Parse( + kv_server::Driver& driver, std::string query) { + std::istringstream stream(query); + kv_server::Scanner scanner(stream); + kv_server::Parser parse(driver, scanner); + int parse_result = parse(); + auto result = driver.GetResult(); + if (parse_result && result.ok()) { + std::cerr << "Unexpected failed parse result with an OK query result."; + } + return result; +} + +absl::flat_hash_set Lookup(std::string_view key) { + const auto& it = kDb.find(key); + if (it != kDb.end()) { + return ToView(it->second); + } + return kEmptySet; +} + +void ProcessQuery(kv_server::Driver& driver, std::string query) { + const auto result = Parse(driver, query); + if (!result.ok()) { + std::cout << result.status() << std::endl; + return; + } + std::cout << ToString(result.value()) << std::endl; +} + +void PromptForQuery(kv_server::Driver& driver) { + while (true) { + std::cout << ">> "; + std::string query; + std::getline(std::cin, query); + ProcessQuery(driver, query); + } +} + +void SignalHandler(int signal) { + std::cout << " Quitting." << std::endl; + exit(0); +} + +int main(int argc, char* argv[]) { + absl::ParseCommandLine(argc, argv); + kv_server::Driver driver(Lookup); + const std::string query = absl::GetFlag(FLAGS_query); + if (!query.empty()) { + ProcessQuery(driver, query); + return 0; + } + signal(SIGINT, SignalHandler); + signal(SIGQUIT, SignalHandler); + std::cout << "/*" << std::endl << "Sets available to query:" << std::endl; + std::cout << ToString(kDb) << std::endl; + std::cout << "*/" << std::endl; + PromptForQuery(driver); + return 0; +} diff --git a/components/udf/BUILD b/components/udf/BUILD index 928d00bc..717e437f 100644 --- a/components/udf/BUILD +++ b/components/udf/BUILD @@ -22,8 +22,9 @@ package(default_visibility = [ cc_library( name = "code_config", srcs = [ - "code_config.h", + "code_config.cc", ], + hdrs = ["code_config.h"], deps = [ ], ) @@ -38,32 +39,49 @@ cc_library( ], deps = [ ":code_config", - ":get_values_hook_impl", + ":get_values_hook", + ":run_query_hook", "//components/errors:retry", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@control_plane_shared//cc/roma/interface:roma_interface_lib", "@control_plane_shared//cc/roma/roma_service/src:roma_service_lib", - "@google_privacysandbox_servers_common//src/cpp/telemetry", ], ) cc_library( - name = "code_fetcher", + name = "noop_udf_client", srcs = [ - "code_fetcher.cc", + "noop_udf_client.cc", ], hdrs = [ - "code_fetcher.h", + "noop_udf_client.h", ], + visibility = ["//components/tools:__subpackages__"], deps = [ ":code_config", - "//components/internal_lookup:lookup_client_impl", - "//public/udf:constants", - "@com_github_google_glog//:glog", + ":udf_client", + "//components/errors:retry", "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "udf_config_builder", + srcs = [ + "udf_config_builder.cc", + ], + hdrs = [ + "udf_config_builder.h", + ], + deps = [ + ":code_config", + ":get_values_hook", + ":logging_hook", + ":run_query_hook", "@control_plane_shared//cc/roma/interface:roma_interface_lib", + "@control_plane_shared//cc/roma/roma_service/src:roma_service_lib", ], ) @@ -78,7 +96,7 @@ cc_library( deps = [ ":get_values_hook", "//components/data_server/cache", - "//components/internal_lookup:internal_lookup_cc_proto", + "//components/internal_server:internal_lookup_cc_proto", "@com_github_google_glog//:glog", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -88,17 +106,16 @@ cc_library( ) cc_library( - name = "get_values_hook_impl", + name = "get_values_hook", srcs = [ - "get_values_hook_impl.cc", + "get_values_hook.cc", ], hdrs = [ - "get_values_hook_impl.h", + "get_values_hook.h", ], deps = [ - ":get_values_hook", - "//components/internal_lookup:internal_lookup_cc_proto", - "//components/internal_lookup:lookup_client_impl", + "//components/internal_server:internal_lookup_cc_proto", + "//components/internal_server:lookup_client_impl", "@com_github_google_glog//:glog", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", @@ -110,9 +127,30 @@ cc_library( ) cc_library( - name = "get_values_hook", + name = "run_query_hook", + srcs = [ + "run_query_hook.cc", + ], hdrs = [ - "get_values_hook.h", + "run_query_hook.h", + ], + deps = [ + "//components/internal_server:internal_lookup_cc_proto", + "//components/internal_server:run_query_client_impl", + "@com_github_google_glog//:glog", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "logging_hook", + srcs = [ + "logging_hook.h", + ], + deps = [ + "@com_github_google_glog//:glog", ], ) @@ -129,8 +167,10 @@ cc_test( ":code_config", ":mocks", ":udf_client", - "//components/internal_lookup:mocks", - "//components/udf:get_values_hook_impl", + ":udf_config_builder", + "//components/internal_server:mocks", + "//components/udf:get_values_hook", + "//components/udf:run_query_hook", "//public/test_util:proto_matcher", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", @@ -141,33 +181,31 @@ cc_test( ) cc_test( - name = "code_fetcher_test", + name = "get_values_hook_test", size = "small", srcs = [ - "code_fetcher_test.cc", + "get_values_hook_test.cc", ], deps = [ - ":code_fetcher", - "//components/internal_lookup:mocks", + ":get_values_hook", + "//components/internal_server:mocks", "//public/test_util:proto_matcher", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", - "@control_plane_shared//cc/roma/interface:roma_interface_lib", - "@control_plane_shared//cc/roma/roma_service/src:roma_service_lib", ], ) cc_test( - name = "get_values_hook_impl_test", + name = "cache_get_values_hook_test", size = "small", srcs = [ - "get_values_hook_impl_test.cc", + "cache_get_values_hook_test.cc", ], deps = [ - ":get_values_hook_impl", - "//components/internal_lookup:mocks", - "//public/test_util:proto_matcher", + ":cache_get_values_hook", + "//components/data_server/cache", + "//components/data_server/cache:key_value_cache", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", @@ -175,15 +213,15 @@ cc_test( ) cc_test( - name = "cache_get_values_hook_test", + name = "run_query_hook_test", size = "small", srcs = [ - "cache_get_values_hook_test.cc", + "run_query_hook_test.cc", ], deps = [ - ":cache_get_values_hook", - "//components/data_server/cache", - "//components/data_server/cache:key_value_cache", + ":run_query_hook", + "//components/internal_server:mocks", + "//public/test_util:proto_matcher", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", @@ -196,9 +234,8 @@ cc_library( hdrs = ["mocks.h"], deps = [ ":code_config", - ":code_fetcher", ":udf_client", - "//components/internal_lookup:lookup_client_impl", + "//components/internal_server:lookup_client_impl", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@control_plane_shared//cc/roma/interface:roma_interface_lib", diff --git a/components/udf/cache_get_values_hook.cc b/components/udf/cache_get_values_hook.cc index 53daaa75..5154a5ac 100644 --- a/components/udf/cache_get_values_hook.cc +++ b/components/udf/cache_get_values_hook.cc @@ -22,7 +22,7 @@ #include "absl/strings/str_join.h" #include "components/data_server/cache/cache.h" -#include "components/internal_lookup/lookup.pb.h" +#include "components/internal_server/lookup.pb.h" #include "components/udf/get_values_hook.h" #include "glog/logging.h" #include "google/protobuf/util/json_util.h" diff --git a/components/udf/code_config.cc b/components/udf/code_config.cc new file mode 100644 index 00000000..c153ea44 --- /dev/null +++ b/components/udf/code_config.cc @@ -0,0 +1,29 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/udf/code_config.h" + +namespace kv_server { + +bool operator==(const CodeConfig& lhs_config, const CodeConfig& rhs_config) { + return lhs_config.logical_commit_time == rhs_config.logical_commit_time && + lhs_config.udf_handler_name == rhs_config.udf_handler_name && + lhs_config.js == rhs_config.js && lhs_config.wasm == rhs_config.wasm; +} + +bool operator!=(const CodeConfig& lhs_config, const CodeConfig& rhs_config) { + return !operator==(lhs_config, rhs_config); +} + +} // namespace kv_server diff --git a/components/udf/code_config.h b/components/udf/code_config.h index e8eccedd..b02dd087 100644 --- a/components/udf/code_config.h +++ b/components/udf/code_config.h @@ -28,8 +28,12 @@ struct CodeConfig { std::string js; std::string wasm; std::string udf_handler_name; + int64_t logical_commit_time; }; +bool operator==(const CodeConfig& lhs_config, const CodeConfig& rhs_config); +bool operator!=(const CodeConfig& lhs_config, const CodeConfig& rhs_config); + } // namespace kv_server #endif // COMPONENTS_UDF_CODE_CONFIG_H_ diff --git a/components/udf/code_fetcher.cc b/components/udf/code_fetcher.cc deleted file mode 100644 index c81daecb..00000000 --- a/components/udf/code_fetcher.cc +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "components/udf/code_fetcher.h" - -#include -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "cc/roma/interface/roma.h" -#include "components/internal_lookup/lookup.grpc.pb.h" -#include "components/internal_lookup/lookup_client.h" -#include "components/udf/code_config.h" -#include "glog/logging.h" -#include "public/udf/constants.h" - -namespace kv_server { - -constexpr char kId[] = "UDF"; - -class CodeFetcherImpl : public CodeFetcher { - public: - CodeFetcherImpl() = default; - - absl::StatusOr FetchCodeConfig( - const LookupClient& lookup_client) { - absl::StatusOr lookup_response_or_status = - lookup_client.GetValues({kUdfHandlerNameKey, kUdfCodeSnippetKey}); - if (!lookup_response_or_status.ok()) { - return lookup_response_or_status.status(); - } - - auto lookup_response = std::move(lookup_response_or_status).value(); - const auto udf_handler_name_pair = - lookup_response.kv_pairs().find(kUdfHandlerNameKey); - if (udf_handler_name_pair == lookup_response.kv_pairs().end()) { - return absl::NotFoundError("UDF handler name not found"); - } else if (udf_handler_name_pair->second.has_status()) { - return absl::UnknownError( - udf_handler_name_pair->second.status().message()); - } else if (udf_handler_name_pair->second.value().empty()) { - return absl::NotFoundError("UDF handler name value empty"); - } - - const auto udf_code_snippet_pair = - lookup_response.kv_pairs().find(kUdfCodeSnippetKey); - if (udf_code_snippet_pair == lookup_response.kv_pairs().end()) { - return absl::NotFoundError("UDF code snippet not found"); - } else if (udf_code_snippet_pair->second.has_status()) { - return absl::UnknownError( - udf_code_snippet_pair->second.status().message()); - } else if (udf_code_snippet_pair->second.value().empty()) { - return absl::NotFoundError("UDF code snippet value empty"); - } - - return CodeConfig{ - .js = std::move(udf_code_snippet_pair->second.value()), - .udf_handler_name = std::move(udf_handler_name_pair->second.value())}; - } -}; - -std::unique_ptr CodeFetcher::Create() { - return std::make_unique(); -} -} // namespace kv_server diff --git a/components/udf/code_fetcher_test.cc b/components/udf/code_fetcher_test.cc deleted file mode 100644 index 50be7722..00000000 --- a/components/udf/code_fetcher_test.cc +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "components/udf/code_fetcher.h" - -#include -#include -#include - -#include "absl/status/statusor.h" -#include "components/internal_lookup/mocks.h" -#include "gmock/gmock.h" -#include "google/protobuf/text_format.h" -#include "gtest/gtest.h" -#include "roma/interface/roma.h" - -namespace kv_server { -namespace { - -using google::protobuf::TextFormat; -using testing::_; -using testing::Return; - -TEST(CodeFetcherTest, CreatesCodeConfigWithCodeSnippetAndHandlerName) { - MockLookupClient mock_lookup_client; - - std::vector udf_config_keys = {"udf_handler_name", - "udf_code_snippet"}; - InternalLookupResponse response; - TextFormat::ParseFromString(R"pb(kv_pairs { - key: "udf_handler_name" - value { value: "SomeHandler" } - } - kv_pairs { - key: "udf_code_snippet" - value { value: "function SomeUDFCode(){}" } - })pb", - &response); - EXPECT_CALL(mock_lookup_client, GetValues(udf_config_keys)) - .WillOnce(Return(response)); - auto code_fetcher = CodeFetcher::Create(); - auto code_config = code_fetcher->FetchCodeConfig(mock_lookup_client); - EXPECT_TRUE(code_config.ok()); - EXPECT_EQ(code_config->js, "function SomeUDFCode(){}"); - EXPECT_EQ(code_config->udf_handler_name, "SomeHandler"); -} - -TEST(CodeFetcherTest, ErrorOnEmptyUdfCodeSnippet) { - MockLookupClient mock_lookup_client; - - std::vector udf_config_keys = {"udf_handler_name", - "udf_code_snippet"}; - InternalLookupResponse response; - TextFormat::ParseFromString(R"pb(kv_pairs { - key: "udf_handler_name" - value { value: "Something" } - } - kv_pairs { - key: "udf_code_snippet" - value { value: "" } - })pb", - &response); - EXPECT_CALL(mock_lookup_client, GetValues(udf_config_keys)) - .WillOnce(Return(response)); - auto code_fetcher = CodeFetcher::Create(); - auto code_config = code_fetcher->FetchCodeConfig(mock_lookup_client); - EXPECT_EQ(code_config.status(), - absl::NotFoundError("UDF code snippet value empty")); -} - -TEST(CodeFetcherTest, ErrorOnEmptyUdfHandlerNameSnippet) { - MockLookupClient mock_lookup_client; - - std::vector udf_config_keys = {"udf_handler_name", - "udf_code_snippet"}; - InternalLookupResponse response; - TextFormat::ParseFromString(R"pb(kv_pairs { - key: "udf_handler_name" - value { value: "" } - } - kv_pairs { - key: "udf_code_snippet" - value { value: "Something" } - })pb", - &response); - EXPECT_CALL(mock_lookup_client, GetValues(udf_config_keys)) - .WillOnce(Return(response)); - - auto code_fetcher = CodeFetcher::Create(); - auto code_config = code_fetcher->FetchCodeConfig(mock_lookup_client); - EXPECT_EQ(code_config.status(), - absl::NotFoundError("UDF handler name value empty")); -} - -TEST(CodeFetcherTest, ErrorOnUdfHandlerNameWithErrorStatus) { - MockLookupClient mock_lookup_client; - - std::vector udf_config_keys = {"udf_handler_name", - "udf_code_snippet"}; - InternalLookupResponse response; - TextFormat::ParseFromString( - R"pb(kv_pairs { - key: "udf_handler_name" - value { status { code: 2 message: "Some error" } } - } - kv_pairs { - key: "udf_code_snippet" - value { value: "function SomeUDFCode(){}" } - })pb", - &response); - EXPECT_CALL(mock_lookup_client, GetValues(udf_config_keys)) - .WillOnce(Return(response)); - - auto code_fetcher = CodeFetcher::Create(); - auto code_config = code_fetcher->FetchCodeConfig(mock_lookup_client); - EXPECT_EQ(code_config.status(), absl::UnknownError("Some error")); -} - -TEST(CodeFetcherTest, ErrorOnUdfCodeSnippetWithErrorStatus) { - MockLookupClient mock_lookup_client; - - std::vector udf_config_keys = {"udf_handler_name", - "udf_code_snippet"}; - InternalLookupResponse response; - TextFormat::ParseFromString( - R"pb(kv_pairs { - key: "udf_handler_name" - value { value: "Something" } - } - kv_pairs { - key: "udf_code_snippet" - value { status { code: 2 message: "Some error" } } - })pb", - &response); - EXPECT_CALL(mock_lookup_client, GetValues(udf_config_keys)) - .WillOnce(Return(response)); - - auto code_fetcher = CodeFetcher::Create(); - auto code_config = code_fetcher->FetchCodeConfig(mock_lookup_client); - EXPECT_EQ(code_config.status(), absl::UnknownError("Some error")); -} - -TEST(CodeFetcherTest, ErrorOnLookupError) { - MockLookupClient mock_lookup_client; - - std::vector udf_config_keys = {"udf_handler_name", - "udf_code_snippet"}; - EXPECT_CALL(mock_lookup_client, GetValues(udf_config_keys)) - .WillOnce(Return(absl::UnknownError("Some error"))); - - auto code_fetcher = CodeFetcher::Create(); - auto code_config = code_fetcher->FetchCodeConfig(mock_lookup_client); - EXPECT_EQ(code_config.status(), absl::UnknownError("Some error")); -} - -TEST(CodeFetcherTest, ErrorOnNoUdfHandlerName) { - MockLookupClient mock_lookup_client; - - std::vector udf_config_keys = {"udf_handler_name", - "udf_code_snippet"}; - InternalLookupResponse response; - TextFormat::ParseFromString( - R"pb( - kv_pairs { - key: "udf_code_snippet" - value { value: "Something" } - })pb", - &response); - EXPECT_CALL(mock_lookup_client, GetValues(udf_config_keys)) - .WillOnce(Return(response)); - - auto code_fetcher = CodeFetcher::Create(); - auto code_config = code_fetcher->FetchCodeConfig(mock_lookup_client); - EXPECT_EQ(code_config.status(), - absl::NotFoundError("UDF handler name not found")); -} - -} // namespace -} // namespace kv_server diff --git a/components/udf/get_values_hook_impl.cc b/components/udf/get_values_hook.cc similarity index 55% rename from components/udf/get_values_hook_impl.cc rename to components/udf/get_values_hook.cc index c1c973fe..e79614ad 100644 --- a/components/udf/get_values_hook_impl.cc +++ b/components/udf/get_values_hook.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "components/udf/get_values_hook_impl.h" +#include "components/udf/get_values_hook.h" #include #include @@ -22,8 +22,7 @@ #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" -#include "components/internal_lookup/lookup.grpc.pb.h" -#include "components/udf/get_values_hook.h" +#include "components/internal_server/lookup.grpc.pb.h" #include "glog/logging.h" #include "google/protobuf/util/json_util.h" #include "nlohmann/json.hpp" @@ -32,35 +31,26 @@ namespace kv_server { namespace { -constexpr char kGetValuesHookSpan[] = "GetValuesHook"; -constexpr char kProcessClientResponseSpan[] = "ProcessLookupClientResponse"; -constexpr char kProtoToJsonSpan[] = "ProtoToJson"; - using google::protobuf::util::MessageToJsonString; using privacy_sandbox::server_common::GetTracer; class GetValuesHookImpl : public GetValuesHook { public: - explicit GetValuesHookImpl( - const absl::AnyInvocable& get_lookup_client) - : lookup_client_factory_(std::move(get_lookup_client)) {} + explicit GetValuesHookImpl(absl::AnyInvocable()> + lookup_client_supplier) + : lookup_client_supplier_(std::move(lookup_client_supplier)) {} std::string operator()(std::tuple>& input) { + // Lazy load lookup client on first call. + if (lookup_client_ == nullptr) { + lookup_client_ = lookup_client_supplier_(); + } // TODO(b/261181061): Determine where to InitTracer. - - auto span = GetTracer()->StartSpan(kGetValuesHookSpan); - auto scope = opentelemetry::trace::Scope(span); - - LOG(INFO) << "Calling internal lookup client"; + VLOG(9) << "Calling internal lookup client"; absl::StatusOr response_or_status = - lookup_client_factory_().GetValues(std::get<0>(input)); + lookup_client_->GetValues(std::get<0>(input)); - auto process_response_span = - GetTracer()->StartSpan(kProcessClientResponseSpan); - auto process_response_scope = - opentelemetry::trace::Scope(process_response_span); - - LOG(INFO) << "Processing internal lookup response"; + VLOG(9) << "Processing internal lookup response"; if (!response_or_status.ok()) { nlohmann::json status; status["code"] = response_or_status.status().code(); @@ -68,9 +58,6 @@ class GetValuesHookImpl : public GetValuesHook { return status.dump(); } - auto proto_to_json_span = GetTracer()->StartSpan(kProtoToJsonSpan); - auto proto_to_json_scope = opentelemetry::trace::Scope(proto_to_json_span); - std::string kv_pairs_json; MessageToJsonString(response_or_status.value(), &kv_pairs_json); @@ -78,13 +65,18 @@ class GetValuesHookImpl : public GetValuesHook { } private: - const absl::AnyInvocable& lookup_client_factory_; + // `lookup_client_` is lazy loaded because getting one can cause thread + // creation. Lazy load is used to ensure that it only happens after Roma + // forks. + absl::AnyInvocable()> lookup_client_supplier_; + std::unique_ptr lookup_client_; }; } // namespace -std::unique_ptr NewGetValuesHook( - const absl::AnyInvocable& get_lookup_client) { - return std::make_unique(std::move(get_lookup_client)); +std::unique_ptr GetValuesHook::Create( + absl::AnyInvocable()> + lookup_client_supplier) { + return std::make_unique(std::move(lookup_client_supplier)); } } // namespace kv_server diff --git a/components/udf/get_values_hook.h b/components/udf/get_values_hook.h index fbb405fe..aa9c73e7 100644 --- a/components/udf/get_values_hook.h +++ b/components/udf/get_values_hook.h @@ -22,6 +22,9 @@ #include #include +#include "absl/functional/any_invocable.h" +#include "components/internal_server/lookup_client.h" + namespace kv_server { // Functor that acts as a wrapper for the internal lookup client call. @@ -33,6 +36,10 @@ class GetValuesHook { // the internal lookup client. virtual std::string operator()( std::tuple>& input) = 0; + + static std::unique_ptr Create( + absl::AnyInvocable()> + lookup_client_supplier); }; } // namespace kv_server diff --git a/components/udf/get_values_hook_impl_test.cc b/components/udf/get_values_hook_test.cc similarity index 77% rename from components/udf/get_values_hook_impl_test.cc rename to components/udf/get_values_hook_test.cc index 65320d26..833d1086 100644 --- a/components/udf/get_values_hook_impl_test.cc +++ b/components/udf/get_values_hook_test.cc @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "components/udf/get_values_hook_impl.h" +#include "components/udf/get_values_hook.h" #include #include @@ -20,7 +20,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "components/internal_lookup/mocks.h" +#include "components/internal_server/mocks.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" @@ -45,13 +45,15 @@ TEST(GetValuesHookTest, SuccessfullyProcessesValue) { value { value: "value2" } })pb", &lookup_response); - MockLookupClient mock_lookup_client; - EXPECT_CALL(mock_lookup_client, GetValues(keys)) - .WillOnce(Return(lookup_response)); auto input = std::make_tuple(keys); - auto get_values_hook = - NewGetValuesHook([&]() -> LookupClient& { return mock_lookup_client; }); + auto mlc = std::make_unique(); + MockLookupClient* mock_lookup_client = mlc.get(); + auto get_values_hook = GetValuesHook::Create( + [mlc = std::move(mlc)]() mutable { return std::move(mlc); }); + + EXPECT_CALL(*mock_lookup_client, GetValues(keys)) + .WillOnce(Return(lookup_response)); std::string result = (*get_values_hook)(input); nlohmann::json result_json = nlohmann::json::parse(result); @@ -71,14 +73,15 @@ TEST(GetValuesHookTest, SuccessfullyProcessesResultsWithStatus) { value { status { code: 2, message: "Some error" } } })pb", &lookup_response); - MockLookupClient mock_lookup_client; - EXPECT_CALL(mock_lookup_client, GetValues(keys)) + auto mlc = std::make_unique(); + MockLookupClient* mock_lookup_client = mlc.get(); + auto get_values_hook = GetValuesHook::Create( + [mlc = std::move(mlc)]() mutable { return std::move(mlc); }); + + EXPECT_CALL(*mock_lookup_client, GetValues(keys)) .WillOnce(Return(lookup_response)); auto input = std::make_tuple(keys); - auto get_values_hook = - NewGetValuesHook([&]() -> LookupClient& { return mock_lookup_client; }); - std::string result = (*get_values_hook)(input); nlohmann::json expected = R"({"kvPairs":{"key1":{"status":{"code":2,"message":"Some error"}}}})"_json; @@ -87,14 +90,15 @@ TEST(GetValuesHookTest, SuccessfullyProcessesResultsWithStatus) { TEST(GetValuesHookTest, LookupClientReturnsError) { std::vector keys = {"key1"}; - MockLookupClient mock_lookup_client; - EXPECT_CALL(mock_lookup_client, GetValues(keys)) + auto mlc = std::make_unique(); + MockLookupClient* mock_lookup_client = mlc.get(); + auto get_values_hook = GetValuesHook::Create( + [mlc = std::move(mlc)]() mutable { return std::move(mlc); }); + + EXPECT_CALL(*mock_lookup_client, GetValues(keys)) .WillOnce(Return(absl::UnknownError("Some error"))); auto input = std::make_tuple(keys); - auto get_values_hook = - NewGetValuesHook([&]() -> LookupClient& { return mock_lookup_client; }); - std::string result = (*get_values_hook)(input); nlohmann::json expected = R"({"code":2,"message":"Some error"})"_json; EXPECT_EQ(result, expected.dump()); diff --git a/components/udf/logging_hook.h b/components/udf/logging_hook.h new file mode 100644 index 00000000..a7c83c9b --- /dev/null +++ b/components/udf/logging_hook.h @@ -0,0 +1,37 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_UDF_LOGGING_HOOK_H_ +#define COMPONENTS_UDF_LOGGING_HOOK_H_ + +#include +#include + +#include "glog/logging.h" + +namespace kv_server { + +// UDF hook for logging a string. +// TODO(b/285331079): Disable for production builds. +inline std::string LogMessage(std::tuple& input) { + LOG(INFO) << std::get<0>(input); + // void is not allowed as an output type, so return an empty string. + return ""; +} + +} // namespace kv_server + +#endif // COMPONENTS_UDF_LOGGING_HOOK_H_ diff --git a/components/udf/mocks.h b/components/udf/mocks.h index b3eaf8cc..dec69550 100644 --- a/components/udf/mocks.h +++ b/components/udf/mocks.h @@ -22,9 +22,8 @@ #include #include "absl/status/statusor.h" -#include "components/internal_lookup/lookup_client.h" +#include "components/internal_server/lookup_client.h" #include "components/udf/code_config.h" -#include "components/udf/code_fetcher.h" #include "components/udf/udf_client.h" #include "gmock/gmock.h" #include "roma/interface/roma.h" @@ -41,12 +40,6 @@ class MockUdfClient : public UdfClient { (CodeConfig, google::scp::roma::WasmDataType), (override)); }; -class MockCodeFetcher : public CodeFetcher { - public: - MOCK_METHOD((absl::StatusOr), FetchCodeConfig, - (const LookupClient&), (override)); -}; - } // namespace kv_server #endif // COMPONENTS_UDF_MOCKS_H_ diff --git a/components/udf/noop_udf_client.cc b/components/udf/noop_udf_client.cc new file mode 100644 index 00000000..97c381be --- /dev/null +++ b/components/udf/noop_udf_client.cc @@ -0,0 +1,57 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "components/udf/noop_udf_client.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "components/udf/code_config.h" +#include "components/udf/udf_client.h" +#include "roma/config/src/config.h" + +namespace kv_server { + +namespace { +class NoopUdfClientImpl : public UdfClient { + public: + absl::StatusOr ExecuteCode(std::vector keys) const { + return ""; + } + + absl::Status Stop() { return absl::OkStatus(); } + + absl::Status SetCodeObject(CodeConfig code_config) { + return absl::OkStatus(); + } + + absl::Status SetWasmCodeObject( + CodeConfig code_config, + google::scp::roma::WasmDataType wasm_return_type) { + return absl::OkStatus(); + } +}; + +} // namespace + +std::unique_ptr NewNoopUdfClient() { + return std::make_unique(); +} + +} // namespace kv_server diff --git a/components/udf/code_fetcher.h b/components/udf/noop_udf_client.h similarity index 53% rename from components/udf/code_fetcher.h rename to components/udf/noop_udf_client.h index 69785bd7..fe71a61e 100644 --- a/components/udf/code_fetcher.h +++ b/components/udf/noop_udf_client.h @@ -1,5 +1,5 @@ /* - * Copyright 2022 Google LLC + * Copyright 2023 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,28 +14,19 @@ * limitations under the License. */ -#ifndef COMPONENTS_UDF_CODE_FETCHER_H_ -#define COMPONENTS_UDF_CODE_FETCHER_H_ +#ifndef COMPONENTS_UDF_NOOP_UDF_CLIENT_H_ +#define COMPONENTS_UDF_NOOP_UDF_CLIENT_H_ #include -#include "absl/status/status.h" -#include "components/internal_lookup/lookup_client.h" -#include "components/udf/code_config.h" +#include "components/udf/udf_client.h" namespace kv_server { -class CodeFetcher { - public: - virtual ~CodeFetcher() {} - - // Fetches untrusted code for UDF execution from the lookup server. - virtual absl::StatusOr FetchCodeConfig( - const LookupClient& lookup_client) = 0; - - static std::unique_ptr Create(); -}; +// Create a no-op UDF client that doesn't do anything. Useful for certain +// tests/tools that need a udf client, but don't actually want to use it. +std::unique_ptr NewNoopUdfClient(); } // namespace kv_server -#endif // COMPONENTS_UDF_CODE_FETCHER_H_ +#endif // COMPONENTS_UDF_NOOP_UDF_CLIENT_H_ diff --git a/components/udf/run_query_hook.cc b/components/udf/run_query_hook.cc new file mode 100644 index 00000000..e1780937 --- /dev/null +++ b/components/udf/run_query_hook.cc @@ -0,0 +1,77 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/udf/run_query_hook.h" + +#include +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "components/internal_server/run_query_client.h" +#include "glog/logging.h" + +namespace kv_server { +namespace { + +class RunQueryHookImpl : public RunQueryHook { + public: + explicit RunQueryHookImpl( + absl::AnyInvocable()> + query_client_supplier) + : query_client_supplier_(std::move(query_client_supplier)) {} + + // TODO(b/283091615): Add tests. + std::vector operator()(std::tuple& input) { + if (query_client_ == nullptr) { + query_client_ = query_client_supplier_(); + } + // TODO(b/261181061): Determine where to InitTracer. + VLOG(9) << "Calling internal run query client"; + absl::StatusOr response_or_status = + query_client_->RunQuery(std::get<0>(input)); + + VLOG(9) << "Processing internal run query response"; + if (!response_or_status.ok()) { + LOG(ERROR) << "Internal run query returned error: " + << response_or_status.status(); + return std::vector(); + } + std::vector result; + for (auto&& element : + *std::move(response_or_status).value().mutable_elements()) { + result.push_back(std::move(element)); + } + return result; + } + + private: + // `query_client_` is lazy loaded because getting one can cause thread + // creation. Lazy load is used to ensure that it only happens after Roma + // forks. + absl::AnyInvocable()> query_client_supplier_; + std::unique_ptr query_client_; +}; +} // namespace + +std::unique_ptr RunQueryHook::Create( + absl::AnyInvocable()> + query_client_supplier) { + return std::make_unique(std::move(query_client_supplier)); +} + +} // namespace kv_server diff --git a/components/udf/run_query_hook.h b/components/udf/run_query_hook.h new file mode 100644 index 00000000..8598e71d --- /dev/null +++ b/components/udf/run_query_hook.h @@ -0,0 +1,47 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPONENTS_UDF_RUN_QUERY_HOOK_H_ +#define COMPONENTS_UDF_RUN_QUERY_HOOK_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "components/internal_server/run_query_client.h" + +namespace kv_server { + +// Functor that acts as a wrapper for the internal query client call. +class RunQueryHook { + public: + virtual ~RunQueryHook() = default; + + // This is registered with v8 and is exposed to the UDF. Internally, it calls + // the internal query client. + virtual std::vector operator()( + std::tuple& input) = 0; + + static std::unique_ptr Create( + absl::AnyInvocable()> + query_client_supplier); +}; + +} // namespace kv_server + +#endif // COMPONENTS_UDF_RUN_QUERY_HOOK_H_ diff --git a/components/udf/run_query_hook_test.cc b/components/udf/run_query_hook_test.cc new file mode 100644 index 00000000..e884f166 --- /dev/null +++ b/components/udf/run_query_hook_test.cc @@ -0,0 +1,71 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/udf/run_query_hook.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "components/internal_server/mocks.h" +#include "gmock/gmock.h" +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +using google::protobuf::TextFormat; +using testing::_; +using testing::Return; +using testing::UnorderedElementsAre; + +TEST(RunQueryHookTest, SuccessfullyProcessesValue) { + std::string query = "Q"; + InternalRunQueryResponse run_query_response; + TextFormat::ParseFromString(R"pb(elements: "a" elements: "b")pb", + &run_query_response); + auto mrq = std::make_unique(); + MockRunQueryClient* mock_run_query_client = mrq.get(); + EXPECT_CALL(*mock_run_query_client, RunQuery(query)) + .WillOnce(Return(run_query_response)); + + auto input = std::make_tuple(query); + auto run_query_hook = RunQueryHook::Create( + [mrq = std::move(mrq)]() mutable { return std::move(mrq); }); + + std::vector result = (*run_query_hook)(input); + EXPECT_THAT(result, UnorderedElementsAre("a", "b")); +} + +TEST(GetValuesHookTest, RunQueryClientReturnsError) { + std::string query = "Q"; + auto mrq = std::make_unique(); + MockRunQueryClient* mock_run_query_client = mrq.get(); + EXPECT_CALL(*mock_run_query_client, RunQuery(query)) + .WillOnce(Return(absl::UnknownError("Some error"))); + + auto input = std::make_tuple(query); + auto run_query_hook = RunQueryHook::Create( + [mrq = std::move(mrq)]() mutable { return std::move(mrq); }); + + std::vector result = (*run_query_hook)(input); + EXPECT_TRUE(result.empty()); +} + +} // namespace +} // namespace kv_server diff --git a/components/udf/udf_client.cc b/components/udf/udf_client.cc index 9ee27cef..aac62d68 100644 --- a/components/udf/udf_client.cc +++ b/components/udf/udf_client.cc @@ -26,13 +26,9 @@ #include "absl/synchronization/notification.h" #include "absl/time/time.h" #include "components/errors/retry.h" -#include "components/internal_lookup/lookup_client.h" -#include "components/udf/get_values_hook.h" #include "glog/logging.h" #include "roma/config/src/config.h" -#include "roma/config/src/function_binding_object.h" #include "roma/interface/roma.h" -#include "src/cpp/telemetry/telemetry.h" namespace kv_server { @@ -40,19 +36,13 @@ namespace { using google::scp::roma::CodeObject; using google::scp::roma::Config; using google::scp::roma::Execute; -using google::scp::roma::FunctionBindingObject; using google::scp::roma::InvocationRequestStrInput; using google::scp::roma::LoadCodeObj; using google::scp::roma::ResponseObject; using google::scp::roma::RomaInit; using google::scp::roma::RomaStop; using google::scp::roma::WasmDataType; -using privacy_sandbox::server_common::GetTracer; -constexpr char kExecuteCodeSpan[] = "UdfClientExecuteCode"; -constexpr char kUpdateCodeObjectSpan[] = "UdfUpdateCodeObject"; - -constexpr char kGetValuesHookJsName[] = "getValues"; constexpr absl::Duration kCallbackTimeout = absl::Seconds(1); constexpr absl::Duration kCodeUpdateTimeout = absl::Seconds(1); @@ -68,37 +58,40 @@ class UdfClientImpl : public UdfClient { UdfClientImpl() = default; absl::StatusOr ExecuteCode(std::vector keys) const { - auto span = GetTracer()->StartSpan(kExecuteCodeSpan); - auto scope = opentelemetry::trace::Scope(span); - - absl::Status response_status; - std::string result; - absl::Notification notification; + std::shared_ptr response_status = + std::make_shared(); + std::shared_ptr result = std::make_shared(); + std::shared_ptr notification = + std::make_shared(); InvocationRequestStrInput invocation_request = BuildInvocationRequest(std::move(keys)); + VLOG(9) << "Executing UDF"; const auto status = Execute(std::make_unique(invocation_request), - [¬ification, &response_status, &result]( + [notification, response_status, result]( std::unique_ptr> response) { if (response->ok()) { auto& code_response = **response; - result = std::move(code_response.resp); + *result = std::move(code_response.resp); } else { - response_status.Update(std::move(response->status())); + response_status->Update(std::move(response->status())); } - notification.Notify(); + notification->Notify(); }); if (!status.ok()) { - LOG(ERROR) << "Error executing UDF: " << status; + LOG(ERROR) << "Error sending UDF for execution: " << status; return status; } - notification.WaitForNotificationWithTimeout(kCallbackTimeout); - if (!response_status.ok()) { - LOG(ERROR) << "Error executing UDF: " << response_status; - return response_status; + notification->WaitForNotificationWithTimeout(kCallbackTimeout); + if (!notification->HasBeenNotified()) { + return absl::InternalError("Timed out waiting for UDF result."); } - return result; + if (!response_status->ok()) { + LOG(ERROR) << "Error executing UDF: " << *response_status; + return *response_status; + } + return *result; } static absl::Status Init(const Config& config) { return RomaInit(config); } @@ -106,32 +99,43 @@ class UdfClientImpl : public UdfClient { absl::Status Stop() { return RomaStop(); } absl::Status SetCodeObject(CodeConfig code_config) { - auto span = GetTracer()->StartSpan(kUpdateCodeObjectSpan); - auto scope = opentelemetry::trace::Scope(span); - - absl::Status response_status; - absl::Notification notification; + // Only update code if logical commit time is larger. + if (logical_commit_time_ >= code_config.logical_commit_time) { + VLOG(1) << "Not updating code object. logical_commit_time " + << code_config.logical_commit_time + << " too small, should be greater than " << logical_commit_time_; + return absl::OkStatus(); + } + std::shared_ptr response_status = + std::make_shared(); + std::shared_ptr notification = + std::make_shared(); CodeObject code_object = BuildCodeObject(std::move(code_config.js), std::move(code_config.wasm)); absl::Status load_status = LoadCodeObj(std::make_unique(code_object), - [&](std::unique_ptr> resp) { + [notification, response_status]( + std::unique_ptr> resp) { if (!resp->ok()) { - response_status.Update(std::move(resp->status())); + response_status->Update(std::move(resp->status())); } - notification.Notify(); + notification->Notify(); }); if (!load_status.ok()) { LOG(ERROR) << "Error setting UDF Code object: " << load_status; return load_status; } - notification.WaitForNotificationWithTimeout(kCodeUpdateTimeout); - if (!response_status.ok()) { - LOG(ERROR) << "Error setting UDF Code object: " << response_status; - return response_status; + notification->WaitForNotificationWithTimeout(kCodeUpdateTimeout); + if (!notification->HasBeenNotified()) { + return absl::InternalError("Timed out setting UDF code object."); + } + if (!response_status->ok()) { + LOG(ERROR) << "Error setting UDF Code object: " << *response_status; + return *response_status; } handler_name_ = std::move(code_config.udf_handler_name); + logical_commit_time_ = code_config.logical_commit_time; return absl::OkStatus(); } @@ -163,6 +167,7 @@ class UdfClientImpl : public UdfClient { } std::string handler_name_; + int64_t logical_commit_time_ = -1; WasmDataType wasm_return_type_; }; @@ -177,22 +182,4 @@ absl::StatusOr> UdfClient::Create( return std::make_unique(); } -Config UdfClient::ConfigWithGetValuesHook(GetValuesHook& get_values_hook, - const int number_of_workers) { - auto function_object = std::make_unique< - FunctionBindingObject>>(); - function_object->function_name = kGetValuesHookJsName; - // TODO(b/260874774): Investigate other options - function_object->function = - [&get_values_hook]( - std::tuple>& in) -> std::string { - return get_values_hook(in); - }; - - Config config; - config.RegisterFunctionBinding(std::move(function_object)); - config.NumberOfWorkers = number_of_workers; - return config; -} - } // namespace kv_server diff --git a/components/udf/udf_client.h b/components/udf/udf_client.h index 29787c8e..8fe73e83 100644 --- a/components/udf/udf_client.h +++ b/components/udf/udf_client.h @@ -24,7 +24,6 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "components/udf/code_config.h" -#include "components/udf/get_values_hook.h" #include "roma/config/src/config.h" #include "roma/interface/roma.h" @@ -53,12 +52,6 @@ class UdfClient { // Creates a UDF executor. This calls Roma::Init, which forks. static absl::StatusOr> Create( const google::scp::roma::Config& config = google::scp::roma::Config()); - - // Creates a config with the get value hook. Caller needs to make sure - // the hook object stays alive for UDF execution. - // TODO(b/260874772): Pass hook as unique_ptr. - static google::scp::roma::Config ConfigWithGetValuesHook( - GetValuesHook& get_values_hook, int number_of_workers = 0); }; } // namespace kv_server diff --git a/components/udf/udf_client_test.cc b/components/udf/udf_client_test.cc index 5b7cc295..50b1ada8 100644 --- a/components/udf/udf_client_test.cc +++ b/components/udf/udf_client_test.cc @@ -21,10 +21,12 @@ #include #include "absl/status/statusor.h" -#include "components/internal_lookup/mocks.h" +#include "components/internal_server/mocks.h" #include "components/udf/code_config.h" -#include "components/udf/get_values_hook_impl.h" +#include "components/udf/get_values_hook.h" #include "components/udf/mocks.h" +#include "components/udf/run_query_hook.h" +#include "components/udf/udf_config_builder.h" #include "gmock/gmock.h" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" @@ -35,6 +37,7 @@ using google::protobuf::TextFormat; using google::scp::roma::Config; using google::scp::roma::FunctionBindingObject; +using google::scp::roma::FunctionBindingObjectBase; using google::scp::roma::WasmDataType; using testing::_; using testing::Return; @@ -52,11 +55,12 @@ TEST(UdfClientTest, JsCallSucceeds) { auto udf_client = CreateUdfClient(); EXPECT_TRUE(udf_client.ok()); - absl::Status code_obj_status = udf_client.value()->SetCodeObject( - CodeConfig{.js = R"( + absl::Status code_obj_status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( function hello() { return "Hello world!"; } )", - .udf_handler_name = "hello"}); + .udf_handler_name = "hello", + .logical_commit_time = 1}); EXPECT_TRUE(code_obj_status.ok()); absl::StatusOr result = udf_client.value()->ExecuteCode({}); @@ -71,11 +75,12 @@ TEST(UdfClientTest, RepeatedJsCallsSucceed) { auto udf_client = CreateUdfClient(); EXPECT_TRUE(udf_client.ok()); - absl::Status code_obj_status = udf_client.value()->SetCodeObject( - CodeConfig{.js = R"( + absl::Status code_obj_status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( function hello() { return "Hello world!"; } )", - .udf_handler_name = "hello"}); + .udf_handler_name = "hello", + .logical_commit_time = 1}); EXPECT_TRUE(code_obj_status.ok()); absl::StatusOr result1 = udf_client.value()->ExecuteCode({}); @@ -103,6 +108,7 @@ TEST(UdfClientTest, WasmCallSucceeds) { 0x00, 0x20, 0x01, 0x6a, 0x0b}; code_config.wasm.assign(wasm_bin, sizeof(wasm_bin)); code_config.udf_handler_name = "add"; + code_config.logical_commit_time = 1; absl::Status code_obj_status = udf_client.value()->SetWasmCodeObject( std::move(code_config), @@ -129,6 +135,7 @@ TEST(UdfClientTest, WasmFromFileSucceeds) { CodeConfig code_config; code_config.wasm = content; code_config.udf_handler_name = "add"; + code_config.logical_commit_time = 1; absl::Status code_obj_status = udf_client.value()->SetWasmCodeObject( std::move(code_config), /*wasm_return_type=*/WasmDataType::kUint32); @@ -147,11 +154,12 @@ TEST(UdfClientTest, JsEchoCallSucceeds) { auto udf_client = CreateUdfClient(); EXPECT_TRUE(udf_client.ok()); - absl::Status code_obj_status = udf_client.value()->SetCodeObject( - CodeConfig{.js = R"( + absl::Status code_obj_status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( function hello(input) { return "Hello world! " + JSON.stringify(input); } )", - .udf_handler_name = "hello"}); + .udf_handler_name = "hello", + .logical_commit_time = 1}); EXPECT_TRUE(code_obj_status.ok()); absl::StatusOr result = @@ -181,11 +189,12 @@ TEST(UdfClientTest, JsEchoHookCallSucceeds) { UdfClient::Create(config); EXPECT_TRUE(udf_client.ok()); - absl::Status code_obj_status = udf_client.value()->SetCodeObject( - CodeConfig{.js = R"( + absl::Status code_obj_status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( function hello(input) { return "Hello world! " + echo(input); } )", - .udf_handler_name = "hello"}); + .udf_handler_name = "hello", + .logical_commit_time = 1}); EXPECT_TRUE(code_obj_status.ok()); absl::StatusOr result = @@ -198,7 +207,8 @@ function hello(input) { return "Hello world! " + echo(input); } } TEST(UdfClientTest, JsStringInWithGetValuesHookSucceeds) { - MockLookupClient mock_lookup_client; + auto mlc = std::make_unique(); + MockLookupClient* mock_lookup_client = mlc.get(); InternalLookupResponse response; TextFormat::ParseFromString(R"pb(kv_pairs { @@ -206,17 +216,19 @@ TEST(UdfClientTest, JsStringInWithGetValuesHookSucceeds) { value { value: "value1" } })pb", &response); - ON_CALL(mock_lookup_client, GetValues(_)).WillByDefault(Return(response)); + ON_CALL(*mock_lookup_client, GetValues(_)).WillByDefault(Return(response)); + auto get_values_hook = GetValuesHook::Create( + [mlc = std::move(mlc)]() mutable { return std::move(mlc); }); + UdfConfigBuilder config_builder; absl::StatusOr> udf_client = - UdfClient::Create(UdfClient::ConfigWithGetValuesHook( - *NewGetValuesHook( - [&]() -> LookupClient& { return mock_lookup_client; }), - 1)); + UdfClient::Create(config_builder.RegisterGetValuesHook(*get_values_hook) + .SetNumberOfWorkers(1) + .Config()); EXPECT_TRUE(udf_client.ok()); - absl::Status code_obj_status = udf_client.value()->SetCodeObject( - CodeConfig{.js = R"( + absl::Status code_obj_status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( function hello(input) { let kvPairs = JSON.parse(getValues([input])).kvPairs; let output = ""; @@ -228,7 +240,8 @@ function hello(input) { } return output; } )", - .udf_handler_name = "hello"}); + .udf_handler_name = "hello", + .logical_commit_time = 1}); EXPECT_TRUE(code_obj_status.ok()); absl::StatusOr result = @@ -241,7 +254,8 @@ function hello(input) { } TEST(UdfClientTest, JsJSONObjectInWithGetValuesHookSucceeds) { - MockLookupClient mock_lookup_client; + auto mlc = std::make_unique(); + MockLookupClient* mock_lookup_client = mlc.get(); InternalLookupResponse response; TextFormat::ParseFromString(R"pb(kv_pairs { @@ -249,17 +263,19 @@ TEST(UdfClientTest, JsJSONObjectInWithGetValuesHookSucceeds) { value { value: "value1" } })pb", &response); - ON_CALL(mock_lookup_client, GetValues(_)).WillByDefault(Return(response)); + ON_CALL(*mock_lookup_client, GetValues(_)).WillByDefault(Return(response)); + auto get_values_hook = GetValuesHook::Create( + [mlc = std::move(mlc)]() mutable { return std::move(mlc); }); + UdfConfigBuilder config_builder; absl::StatusOr> udf_client = - UdfClient::Create(UdfClient::ConfigWithGetValuesHook( - *NewGetValuesHook( - [&]() -> LookupClient& { return mock_lookup_client; }), - 1)); + UdfClient::Create(config_builder.RegisterGetValuesHook(*get_values_hook) + .SetNumberOfWorkers(1) + .Config()); EXPECT_TRUE(udf_client.ok()); - absl::Status code_obj_status = udf_client.value()->SetCodeObject( - CodeConfig{.js = R"( + absl::Status code_obj_status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( function hello(input) { let keys = input.keys; let kvPairs = JSON.parse(getValues(keys)).kvPairs; @@ -273,7 +289,8 @@ TEST(UdfClientTest, JsJSONObjectInWithGetValuesHookSucceeds) { return output; } )", - .udf_handler_name = "hello"}); + .udf_handler_name = "hello", + .logical_commit_time = 1}); EXPECT_TRUE(code_obj_status.ok()); absl::StatusOr result = @@ -285,23 +302,63 @@ TEST(UdfClientTest, JsJSONObjectInWithGetValuesHookSucceeds) { EXPECT_TRUE(stop.ok()); } +TEST(UdfClientTest, JsJSONObjectInWithRunQueryHookSucceeds) { + auto mrq = std::make_unique(); + MockRunQueryClient* mock_run_query_client = mrq.get(); + + InternalRunQueryResponse response; + TextFormat::ParseFromString(R"pb(elements: "a")pb", &response); + ON_CALL(*mock_run_query_client, RunQuery(_)).WillByDefault(Return(response)); + + auto run_query_hook = RunQueryHook::Create( + [mrq = std::move(mrq)]() mutable { return std::move(mrq); }); + UdfConfigBuilder config_builder; + absl::StatusOr> udf_client = + UdfClient::Create(config_builder.RegisterRunQueryHook(*run_query_hook) + .SetNumberOfWorkers(1) + .Config()); + EXPECT_TRUE(udf_client.ok()); + + absl::Status code_obj_status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( + function hello(input) { + let keys = input.keys; + let queryResultArray = runQuery(keys[0]); + return queryResultArray; + } + )", + .udf_handler_name = "hello", + .logical_commit_time = 1}); + EXPECT_TRUE(code_obj_status.ok()); + + absl::StatusOr result = + udf_client.value()->ExecuteCode({R"({"keys":["key1"]})"}); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(*result, R"(["a"])"); + + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} + TEST(UdfClientTest, UpdatesCodeObjectTwice) { auto udf_client = CreateUdfClient(); EXPECT_TRUE(udf_client.ok()); - auto status = udf_client.value()->SetCodeObject( - CodeConfig{.js = R"( + auto status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( function hello1() { return "1"; } )", - .udf_handler_name = "hello1"}); + .udf_handler_name = "hello1", + .logical_commit_time = 1}); EXPECT_TRUE(status.ok()); - status = udf_client.value()->SetCodeObject( - CodeConfig{.js = R"( + status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( function hello2() { return "2"; } )", - .udf_handler_name = "hello2"}); + .udf_handler_name = "hello2", + .logical_commit_time = 2}); EXPECT_TRUE(status.ok()); absl::StatusOr result = udf_client.value()->ExecuteCode({}); @@ -312,6 +369,63 @@ TEST(UdfClientTest, UpdatesCodeObjectTwice) { EXPECT_TRUE(stop.ok()); } +TEST(UdfClientTest, IgnoresCodeObjectWithSameCommitTime) { + auto udf_client = CreateUdfClient(); + EXPECT_TRUE(udf_client.ok()); + + auto status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( + function hello1() { return "1"; } + )", + .udf_handler_name = "hello1", + .logical_commit_time = 1}); + + EXPECT_TRUE(status.ok()); + status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( + function hello2() { return "2"; } + )", + + .udf_handler_name = "hello2", + .logical_commit_time = 1}); + EXPECT_TRUE(status.ok()); + + absl::StatusOr result = udf_client.value()->ExecuteCode({}); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(*result, R"("1")"); + + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} + +TEST(UdfClientTest, IgnoresCodeObjectWithSmallerCommitTime) { + auto udf_client = CreateUdfClient(); + EXPECT_TRUE(udf_client.ok()); + + auto status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( + function hello1() { return "1"; } + )", + .udf_handler_name = "hello1", + .logical_commit_time = 2}); + + EXPECT_TRUE(status.ok()); + status = + udf_client.value()->SetCodeObject(CodeConfig{.js = R"( + function hello2() { return "2"; } + )", + .udf_handler_name = "hello2", + .logical_commit_time = 1}); + EXPECT_TRUE(status.ok()); + + absl::StatusOr result = udf_client.value()->ExecuteCode({}); + EXPECT_TRUE(result.ok()); + EXPECT_EQ(*result, R"("1")"); + + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} + TEST(UdfClientTest, CodeObjectNotSetError) { auto udf_client = CreateUdfClient(); EXPECT_TRUE(udf_client.ok()); diff --git a/components/udf/udf_config_builder.cc b/components/udf/udf_config_builder.cc new file mode 100644 index 00000000..ec275343 --- /dev/null +++ b/components/udf/udf_config_builder.cc @@ -0,0 +1,86 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "components/udf/udf_config_builder.h" + +#include +#include +#include +#include +#include +#include + +#include "components/udf/get_values_hook.h" +#include "components/udf/logging_hook.h" +#include "components/udf/run_query_hook.h" +#include "roma/config/src/config.h" +#include "roma/config/src/function_binding_object.h" +#include "roma/interface/roma.h" + +namespace kv_server { +using google::scp::roma::Config; +using google::scp::roma::FunctionBindingObject; + +constexpr char kGetValuesHookJsName[] = "getValues"; +constexpr char kRunQueryHookJsName[] = "runQuery"; +constexpr char kLoggingHookJsName[] = "logMessage"; + +UdfConfigBuilder& UdfConfigBuilder::RegisterGetValuesHook( + GetValuesHook& get_values_hook) { + auto get_values_function_object = std::make_unique< + FunctionBindingObject>>(); + get_values_function_object->function_name = kGetValuesHookJsName; + get_values_function_object->function = + [&get_values_hook]( + std::tuple>& in) -> std::string { + return get_values_hook(in); + }; + config_.RegisterFunctionBinding(std::move(get_values_function_object)); + return *this; +} + +UdfConfigBuilder& UdfConfigBuilder::RegisterRunQueryHook( + RunQueryHook& run_query_hook) { + auto run_query_function_object = std::make_unique< + FunctionBindingObject, std::string>>(); + run_query_function_object->function_name = kRunQueryHookJsName; + run_query_function_object->function = + [&run_query_hook]( + std::tuple& in) -> std::vector { + return run_query_hook(in); + }; + config_.RegisterFunctionBinding(std::move(run_query_function_object)); + return *this; +} + +UdfConfigBuilder& UdfConfigBuilder::RegisterLoggingHook() { + auto logging_function_object = + std::make_unique>(); + logging_function_object->function_name = kLoggingHookJsName; + logging_function_object->function = LogMessage; + config_.RegisterFunctionBinding(std::move(logging_function_object)); + return *this; +} + +UdfConfigBuilder& UdfConfigBuilder::SetNumberOfWorkers( + const int number_of_workers) { + config_.NumberOfWorkers = number_of_workers; + return *this; +} + +const google::scp::roma::Config& UdfConfigBuilder::Config() const { + return config_; +} + +} // namespace kv_server diff --git a/components/udf/get_values_hook_impl.h b/components/udf/udf_config_builder.h similarity index 57% rename from components/udf/get_values_hook_impl.h rename to components/udf/udf_config_builder.h index dbb3f711..cebb3afb 100644 --- a/components/udf/get_values_hook_impl.h +++ b/components/udf/udf_config_builder.h @@ -1,6 +1,5 @@ - /* - * Copyright 2022 Google LLC + * Copyright 2023 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,23 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef COMPONENTS_UDF_GET_VALUES_HOOK_IMPL_H_ -#define COMPONENTS_UDF_GET_VALUES_HOOK_IMPL_H_ - #include -#include "absl/functional/any_invocable.h" -#include "components/internal_lookup/lookup_client.h" #include "components/udf/get_values_hook.h" +#include "components/udf/run_query_hook.h" +#include "roma/config/src/config.h" namespace kv_server { -// Create a GetValuesHook that will be registered with v8 and calls a lookup -// client. -std::unique_ptr NewGetValuesHook( - const absl::AnyInvocable& get_lookup_client); +class UdfConfigBuilder { + public: + UdfConfigBuilder& RegisterGetValuesHook(GetValuesHook& get_values_hook); -} // namespace kv_server + UdfConfigBuilder& RegisterRunQueryHook(RunQueryHook& run_query_hook); -#endif // COMPONENTS_UDF_GET_VALUES_HOOK_IMPL_H_ + UdfConfigBuilder& RegisterLoggingHook(); + + UdfConfigBuilder& SetNumberOfWorkers(const int number_of_workers); + + const google::scp::roma::Config& Config() const; + + private: + google::scp::roma::Config config_; +}; +} // namespace kv_server diff --git a/components/util/BUILD b/components/util/BUILD index dd2ef5a1..3066f8c9 100644 --- a/components/util/BUILD +++ b/components/util/BUILD @@ -17,6 +17,7 @@ load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") package(default_visibility = [ "//components:__subpackages__", + "//tools:__subpackages__", ]) cc_library( diff --git a/docs/AWS_Terraform_vars.md b/docs/AWS_Terraform_vars.md new file mode 100644 index 00000000..f3a490ae --- /dev/null +++ b/docs/AWS_Terraform_vars.md @@ -0,0 +1,165 @@ +# AWS Key Value Server Terraform vars documentation + +- **autoscaling_desired_capacity** + + Number of Amazon EC2 instances that should be running in the autoscaling group + +- **autoscaling_max_size** + + Maximum size of the Auto Scaling Group + +- **autoscaling_min_size** + + Minimum size of the Auto Scaling Group + +- **backup_poll_frequency_secs** + + Interval between attempts to check if there are new data files on S3, as a backup to listening + to new data files. + +- **certificate_arn** + + If you want to create a public AWS ACM certificate for a domain from scratch, follow + [these steps to request a public certificate](https://docs.aws.amazon.com/acm/latest/userguide/gs-acm-request-public.html). + If you want to import an existing public certificate into ACM, follow these steps to + [import the certificate](https://docs.aws.amazon.com/acm/latest/userguide/import-certificate.html). + +- **data_loading_num_threads** + + the number of concurrent threads used to read and load a single delta or snapshot file from blob + storage. + +- **enclave_cpu_count** + + Set how many CPUs the server will use. + +- **enclave_memory_mib** + + Set how much RAM the server will use. + +- **environment** + + The value can be any arbitrary unique string (there is a length limit of ~10), and for example, + strings like `staging` and `prod` can be used to represent the environment that the Key/Value + server will run in. + +- **healthcheck_healthy_threshold** + + Consecutive health check successes required to be considered healthy + +- **healthcheck_interval_sec** + + Amount of time between health check intervals in seconds. + +- **healthcheck_unhealthy_threshold** + + Consecutive health check failures required to be considered unhealthy. + +- **instance_ami_id** + + Set the value to the AMI ID that was generated when the image was built. + +- **instance_type** + + Set the instance type. Use instances with at least four vCPUs. Learn more about which types are + supported from the + [AWS article](https://docs.aws.amazon.com/enclaves/latest/user/nitro-enclave.html). + +- **metrics_export_interval_millis** + + Export interval for metrics in milliseconds. + +- **metrics_export_timeout_millis** + + Export timeout for metrics in milliseconds. + +- **mode** + + Set the server mode. The acceptable values are [DSP] or [SSP] + +- **num_shards** + + Total number of shards + +- **prometheus_service_region** + + Specifies which region to find Prometheus service and use. Not all regions have Prometheus + service. (See for + supported regions). If this region does not have Prometheus service, it must be created + beforehand either manually or by deploying this system in that region. At this time Prometheus + service is needed. In the future it can be refactored to become optional. + +- **prometheus_workspace_id** + + Only required if the region does not have its own Amazon Prometheus workspace, in which case an + existing workspace id from another region should be provided. It is expected that the workspace + from that region is created before this terraform file is applied. That can be done by running + the Key Value service terraform file in that region. + +- **realtime_updater_num_threads** + + The number of threads to process real time updates. + +- **region** + + The region that the Key/Value server will operate in. Each terraform file specifies one region. + +- **root_domain** + + Set the root domain for the server. If your domain is managed by + [AWS Route 53](https://aws.amazon.com/route53/), then you can simply set your domain value to + the `root_domain` property in the Terraform configuration that will be described in the next + section. If your domain is not managed by Route 53, and you do not wish to migrate your domain + to Route 53, you can + [delegate subdomain management to Route 53](https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/CreatingNewSubdomain.html). + +- **root_domain_zone_id** + + Set the hosted zone ID. The ID can be found in the details of the hosted zone in Route 53. + +- **route_v1_requests_to_v2** + + Whether to route V1 requests through V2 + +- **s3_delta_file_bucket_name** + + Set a name for the bucket that the server will read data from. The bucket name must be globally + unique. This bucket is different from the one that was manually created for Terraform states + earlier. + +- **s3client_max_connections** + + S3 Client max connections for reading data files. + +- **s3client_max_range_bytes** + + S3 Client max range bytes for reading data files. + +- **server_port** + + Set the port of the EC2 parent instance (that hosts the Nitro Enclave instance). + +- **sqs_cleanup_image_uri** + + The image built previously in the ECR. Example: + `123456789.dkr.ecr.us-east-1.amazonaws.com/sqs_lambda:latest` + +- **sqs_cleanup_schedule** + + How often to clean up SQS + +- **sqs_queue_timeout_secs** + + Clean up queues not updated within the timeout period. + +- **ssh_source_cidr_blocks** + + Source ips allowed to send ssh traffic to the ssh instance. + +- **udf_num_workers** + + Total number of workers for UDF execution + +- **vpc_cidr_block** + + CIDR range for the VPC where KV server will be deployed. diff --git a/docs/CONTRIBUTING.md b/docs/contributing.md similarity index 100% rename from docs/CONTRIBUTING.md rename to docs/contributing.md diff --git a/docs/data_loading_capabilities.md b/docs/data_loading_capabilities.md new file mode 100644 index 00000000..e0d9450d --- /dev/null +++ b/docs/data_loading_capabilities.md @@ -0,0 +1,86 @@ +# Data loading capabilities + +See [Generating and loading data guide](loading_data.md) for more information on the process, tools +and libraries available for generating and loading data into KV servers. + +## Tuning parameters + +At a high level, KV server instances stream data files using a parallel reader capable of reading +different chunks of a single data file concurrently. For AWS S3, the level of concurrency for the +reader and the number of concurrent S3 client connections can be tuned using the following terraform +parameters: + +- [data_loading_num_threads](https://github.com/privacysandbox/fledge-key-value-service/blob/4d6f691b0d12f9604988c14f534f6e91f4025f29/production/terraform/aws/environments/kv_server_variables.tf) + sets the number of concurrent threads used to read a single data file [default: 16]. +- [s3client_max_connections](https://github.com/privacysandbox/fledge-key-value-service/blob/4d6f691b0d12f9604988c14f534f6e91f4025f29/production/terraform/aws/environments/kv_server_variables.tf) + sets the maximum number of concurrent connections used by the S3 client [default: 64]. + +## Benchmarking tool + +The data loading benchmark tool can be used to search for optimal +[tuning parameters](#tuning-parameters) that are best suited to specific hardware, memory and +network specs. To build the benchmarking tool for AWS use the following command (note the +`--//:platform=aws` build flag): + +```sh +builders/tools/bazel-debian run //production/packaging/tools:copy_to_dist --//:instance=local --//:platform=aws +``` + +After building, load the tool into docker as follows: + +```sh +docker load -i dist/tools_binaries_docker_image.tar +``` + +See the [Benchmarking on AWS EC2](#benchmarking-on-aws-ec2) section for some examples of how to run +the tool. + +## Benchmarking on AWS EC2 + +Setup: + +The results below were obtained using the following setup. + +- m5.2xlarge instance (8 vCPUs, 32 GB ram, upto 10Gbps network bandwidth) +- input delta file (4,000,000 records, each record ~512 bytes in size) + +Full benchmark command: + +```sh +AWS_ACCESS_KEY_ID=... +AWS_SECRET_ACCESS_KEY=... +AWS_DEFAULT_REGION=us-east-1 +docker run -it --rm \ + --env AWS_DEFAULT_REGION \ + --env AWS_ACCESS_KEY_ID \ + --env AWS_SECRET_ACCESS_KEY \ + --entrypoint=/tools/benchmarks/data_loading_benchmark \ + bazel/production/packaging/tools:tools_binaries_docker_image \ + --benchmark_time_unit=ms \ + --benchmark_counters_tabular=true \ + --benchmark_filter="BM_DataLoading_MutexCache*" \ + --data_directory=kv-server-gorekore-data-bucket \ + --filename="benchmarking-data" \ + --create_input_file \ + --num_records=4000000 \ + --record_size=512 \ + --args_benchmark_iterations=1 \ + --args_client_max_range_mb=8 \ + --args_client_max_connections=64,128 \ + --args_reader_worker_threads=16,32,64 +``` + +Results: + +| data_loading_num_threads | s3client_max_connections | data loading time | num records per sec | +| ------------------------ | ------------------------ | ----------------- | ------------------- | +| 16 | 64 | 7.92 sec | 505 k/s | +| 32 | 64 | 7.82 sec | 511 k/s | +| 64 | 64 | 7.95 sec | 503 k/s | +| 16 | 128 | 7.46 sec | 536 k/s | +| 32 | 128 | 7.70 sec | 519 k/s | +| 64 | 128 | 8.71 sec | 459 k/s | + +One intepretation of the results above is that, it takes the KV server 7.92 sec to load a delta file +with 4'000'000 records (each 0.5kb in size) at a rate of 505k records per second when the parallel +reader is using 16 threads and the S3 client is configured with 64 concurrent connections. diff --git a/docs/deploying_locally.md b/docs/deploying_locally.md index 3b5791e2..4b9f81b5 100644 --- a/docs/deploying_locally.md +++ b/docs/deploying_locally.md @@ -83,15 +83,7 @@ will correct the symlink target paths. _It must be run after each call compilati ## Start the server ```sh -./bazel-bin/components/data_server/server/server \ - --delta_directory=/tmp/deltas \ - --realtime_directory=/tmp/realtime -``` - -To have server logs and telemetry written to _STDOUT_ you can prepend this flag: - -```sh -GLOG_logtostderr=1 \ +GLOG_alsologtostderr=1 \ ./bazel-bin/components/data_server/server/server \ --delta_directory=/tmp/deltas \ --realtime_directory=/tmp/realtime diff --git a/docs/deploying_on_aws.md b/docs/deploying_on_aws.md index 11acd17a..c85e7a1a 100644 --- a/docs/deploying_on_aws.md +++ b/docs/deploying_on_aws.md @@ -164,58 +164,10 @@ For your Terraform configuration, you can use the template under your environment name such as dev/staging/prod, name the files inside according to the region you want to deploy to, and update the following file content. -In `[[REGION]].tfvars.json`: - -- Environment - - `environment` - The default is `demo`. The value can be any arbitrary unique string, and for - example, strings like `staging` and `prod` can be used to represent the environment that the - Key/Value server will run in. - - `region` - Update the region that the Key/Value server will operate in. The default is - `us-east-1`. -- Network - - `root_domain` - Set the root domain for the server. - - If your domain is managed by [AWS Route 53](https://aws.amazon.com/route53/), then you - can simply set your domain value to the `root_domain` property in the Terraform - configuration that will be described in the next section. If your domain is not managed - by Route 53, and you do not wish to migrate your domain to Route 53, you can - [delegate subdomain management to Route 53](https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/CreatingNewSubdomain.html). - - `root_domain_zone_id` - Set the hosted zone ID. The ID can be found in the details of the - hosted zone in Route 53. - - `certificate_arn` - - If you want to create a public AWS ACM certificate for a domain from scratch, follow - [these steps to request a public certificate](https://docs.aws.amazon.com/acm/latest/userguide/gs-acm-request-public.html). - - If you want to import an existing public certificate into ACM, follow these steps to - [import the certificate](https://docs.aws.amazon.com/acm/latest/userguide/import-certificate.html). -- EC2 instance - - `instance_type` - Set the instance type. Use instances with at least four vCPUs. Learn more - about which types are supported from the - [AWS article](https://docs.aws.amazon.com/enclaves/latest/user/nitro-enclave.html). - - `instance_ami_id` - Set the value to the AMI ID that was generated when the image was built. -- Server configuration - - `mode` - Set the server mode. The acceptable values are "DSP" or "SSP". - - `server_port` - Set the port of the EC2 parent instance (that hosts the Nitro Enclave - instance). - - `enclave_cpu_count` - Set how many CPUs the server will use. - - `enclave_memory_mib` - Set how much RAM the server will use. -- Data storage - - `sqs_cleanup_image_uri` - The image built previously in the ECR. Example: - `123456789.dkr.ecr.us-east-1.amazonaws.com/sqs_lambda:latest` - - `s3_delta_file_bucket_name` - Set a name for the bucket that the server will read data from. - The bucket name must be globally unique. This bucket is different from the one that was - manually created for Terraform states earlier. - - `backup_poll_frequency_secs` - Interval between attempts to check if there are new data - files on S3, as a backup to listening to new data files. -- Peripheral features - - `prometheus_service_region` - Specifies which region to find Prometheus service and use. - [Not all regions have Prometheus service.](See - for supported - regions). If this region does not have Prometheus service, it must be created beforehand - either manually or by deploying this system in that region. At this time Prometheus service - is needed. In the future it can be refactored to become optional. - - `prometheus_workspace_id` - If the target Prometheus service runs in a different region than - this deployment, the workspace id must be specified. - -In `[[REGION]].backend.conf`: +Update the `[[REGION]].tfvars.json` with Terraform variables for your environment. The description +of each variable is described in [AWS Terraform Vars doc](/docs/AWS_Terraform_vars.md). + +Update the `[[REGION]].backend.conf`: - `bucket` - Set the bucket name that Terraform will use. The bucket was created in the previous [Setup S3 bucket for Terraform states](#setup-s3-bucket-for-terraform-states) step. @@ -343,7 +295,7 @@ curl -vX PUT -d "$BODY" ${KV_SERVER_URL}/v2/getvalues Or gRPC (using [grpcurl](https://github.com/fullstorydev/grpcurl)): ```sh -grpcurl --protoset dist/query_api_descriptor_set.pb -d '{"raw_body": {"data": "'"$(echo -n $BODY|base64 -w 0)"'"}}' -plaintext demo.kv-server.your-domain.example:8443 kv_server.v2.KeyValueService/GetValues +grpcurl --protoset dist/query_api_descriptor_set.pb -d '{"raw_body": {"data": "'"$(echo -n $BODY|base64 -w 0)"'"}}' demo.kv-server.your-domain.example:8443 kv_server.v2.KeyValueService/GetValues ``` ## SSH into EC2 diff --git a/docs/developing_the_server.md b/docs/developing_the_server.md index fb30ac27..1640f9d7 100644 --- a/docs/developing_the_server.md +++ b/docs/developing_the_server.md @@ -81,7 +81,7 @@ The data server provides the read API for the KV service. 1. Build the server artifacts and copy them into the `dist/debian/` directory. ```sh - builders/tools/bazel-debian run //production/packaging/aws/data_server:copy_to_dist --//:instance=local --//:platform=aws + builders/tools/bazel-debian run //production/packaging/aws/data_server:copy_to_dist --config local_instance --//:platform=aws ``` 1. Load the image into docker @@ -124,7 +124,7 @@ docker run -it --rm --network host bazel/testing/run_local:envoy_image For example: ```sh -builders/tools/bazel-debian run //components/data_server/server:server --//:instance=local --//:platform=aws -- --environment="dev" +builders/tools/bazel-debian run //components/data_server/server:server --config local_instance --//:platform=aws -- --environment="dev" ``` We are currently developing this server for local testing and for use on AWS Nitro instances diff --git a/docs/generating_udf_files.md b/docs/generating_udf_files.md index 963a418f..cbfc34a7 100644 --- a/docs/generating_udf_files.md +++ b/docs/generating_udf_files.md @@ -3,18 +3,21 @@ # Generating UDF code configs for the server -During server startup, the server reads in the UDF configuration through existing delta/snapshot -files in the bucket. The UDF delta file must be in the delta storage before the server starts. If -not, it will default to a simple pass-through implementation at `public/udf/constants.h`. +The server starts with a simple pass-through implementation at `public/udf/constants.h`. -Currently, the server does not support code updates. +UDF configurations can be updated as the server is running using delta/snapshot files as per the +[data loading guide](generating_udf_files.md). -Depending on the environment, you will need to either include the delta file with the code configs -in the terraform or the local directory. +- To override an existing UDF, the delta/snapshot file must have a + [`DataRecord`](/public/data_loading/data_loading.fbs) with a `UserDefinedFunctionsConfig`. + +- Similar to a `KeyValueMutationRecord`, the `UserDefinedFunctionsConfig` has a + `logical_commit_time`. The UDF will only be updated for configs with a higher + `logical_commit_time` than the existing one. The minimum `logical_commit_time` is 1. Please read through the -[UDF explainer](https://github.com/privacysandbox/fledge-docs/blob/main/key_value_user_defined_functions.md#keyvalue-service-user-defined-functions-udfs) -for requirements and APIs. +[UDF explainer](https://github.com/privacysandbox/fledge-docs/blob/main/key_value_service_user_defined_functions.md#keyvalue-service-user-defined-functions-udfs) +for more requirements and APIs. # Steps for including the UDF delta file @@ -23,13 +26,13 @@ for requirements and APIs. ### Option A. Write a custom UDF Write the UDF according to the -[API in the UDF explainer](https://github.com/privacysandbox/fledge-docs/blob/main/key_value_user_defined_functions.md#apis). +[API in the UDF explainer](https://github.com/privacysandbox/fledge-docs/blob/main/key_value_service_user_defined_functions.md#apis). Note that the UDF should be in JavaScript (and optionally JavaScript + inline WASM). ### Option B. Use the reference UDF -We provide a [simple reference implementation](tools/udf/udf.js): +We provide a [simple reference implementation](/tools/udf/sample_udf/udf.js): - The implementation ignores part of the request, e.g. the `context` field. - For each `keyGroup` in the request, it calls `getValues(keyGroup.keyList)` to retrieve the keys @@ -65,12 +68,7 @@ Tools to generate UDF delta files and test them are in the `tools/udf` directory You can use other options to generate delta files, e.g. using the [`data_cli` tool](./loading_data.md). -The delta file must have the following key value records with an `UPDATE` mutation type: - -| Key | Value | -| ---------------- | --------------------------------------------------------------------- | -| udf_handler_name | Name of the handler function that serves as the execution entry point | -| udf_code_snippet | UDF code snippet that contains the handler | +The delta file must have a `DataRecord` with a `UserDefinedFunctionsConfig` as its record. ### Option 3. Using sample UDF configurations diff --git a/docs/loading_data.md b/docs/loading_data.md index 9ef2546e..d3050222 100644 --- a/docs/loading_data.md +++ b/docs/loading_data.md @@ -35,11 +35,10 @@ Delta filename must conform to the regular expression `DELTA_\d{16}`. See [constants.h](../public/constants.h) for the most up-to-date format. More recent delta files are lexicographically greater than older delta files. Delta files have the following properties: -- Consists of key/value mutation events (updates/deletes) for a fixed time window. +- Consists of key/value mutation events (updates/deletes) for a fixed time window. The events are + in the format of Flatbuffers ([Schema](/public/data_loading/data_loading.fbs)). - Each mutation event is associated with a `logical_commit_timestamp`, larger timestamp indicates a more recent record. -- Mutation events are ordered by their `logical_commit_timestamp` and this order is important at - read time to make sure that mutations are applied correctly. - `logical_commit_timestamp` of the records have no relation with their file's name. It is acceptable to also use timestamps in file names for ordering purposes for your convenience but the system makes no assumption on the relation between the record timestamps and the file names. @@ -115,6 +114,10 @@ Commands: [--input_format] (Optional) Defaults to "CSV". Possible options=(CSV|DELTA) [--output_file] (Optional) Defaults to stdout. Output file to write converted records to. [--output_format] (Optional) Defaults to "DELTA". Possible options=(CSV|DELTA). + [--record_type] (Optional) Defaults to "KEY_VALUE_MUTATION_RECORD". Possible + options=(KEY_VALUE_MUTATION_RECORD|USER_DEFINED_FUNCTIONS_CONFIG). + If reading/writing a UDF config, use "USER_DEFINED_FUNCTIONS_CONFIG". + Examples: ... @@ -145,6 +148,29 @@ As an example, to convert a CSV file to a DELTA file, run the following command: --output_format=DELTA ``` +Here are samples of a valid csv files that can be used as input to the cli: + +```sh +# The following csv example shows csv with simple string value. +key,mutation_type,logical_commit_time,value,value_type +key1,UPDATE,1680815895468055,value1,string +key2,UPDATE,1680815895468056,value2,string +key1,UPDATE,1680815895468057,value11,string +key2,DELETE,1680815895468058,value2,string + +# The following csv example shows csv with set values. +# By default, column delimiter = "," and value delimiter = "|" +key,mutation_type,logical_commit_time,value,value_type +key1,UPDATE,1680815895468055,elem1|elem2,string_set +key2,UPDATE,1680815895468056,elem3|elem4,string_set +key1,UPDATE,1680815895468057,elem6|elem7|elem8,string_set +key2,DELETE,1680815895468058,elem10,string_set +``` + +Note that the csv delimiters for set values can be changed to any character combination, but if the +defaults are not used, then the chosen delimiters should be passed to the data_cli using the +`--csv_column_delimiter` and `--csv_value_delimiter` flags. + And to generate a snapshot from a set of delta files, run the following command (replacing flag values with your own values): @@ -199,6 +225,9 @@ data library. Keep the following things in mind: The server watches an S3 bucket for new files. The bucket name is provided by you in the Terraform config and is globally unique. +> Note: Access control of the S3 bucket is managed by your IAM system on the cloud platform. Make +> sure to set the right permissions. + You can use the AWS CLI to upload the sample data to S3, or you can also use the UI. ```sh @@ -206,7 +235,7 @@ You can use the AWS CLI to upload the sample data to S3, or you can also use the -$ aws s3 cp riegeli_data s3://${S3_BUCKET}/DELTA_001 ``` -> Cauition: The filename must start with `DELTA_` prefix, followed by a 16-digit number. +> Caution: The filename must start with `DELTA_` prefix, followed by a 16-digit number. Confirm that the file is present in the S3 bucket: diff --git a/infrastructure/testing/protocol_testing_helper_server.proto b/infrastructure/testing/protocol_testing_helper_server.proto index 3ff17060..f53f5880 100644 --- a/infrastructure/testing/protocol_testing_helper_server.proto +++ b/infrastructure/testing/protocol_testing_helper_server.proto @@ -32,16 +32,13 @@ service ProtocolTestingHelper { // Given a cleartext http message body in bytes format, wraps it in BinaryHTTP // format. - rpc BHTTPEncapsulate(BHTTPEncapsulateRequest) - returns (BHTTPEncapsulateResponse) {} + rpc BHTTPEncapsulate(BHTTPEncapsulateRequest) returns (BHTTPEncapsulateResponse) {} // Given a BinaryHTTP message, unwraps it into a cleartext message. - rpc BHTTPDecapsulate(BHTTPDecapsulateRequest) - returns (BHTTPDecapsulateResponse) {} + rpc BHTTPDecapsulate(BHTTPDecapsulateRequest) returns (BHTTPDecapsulateResponse) {} // Wraps a byte string with Oblivious HTTP encryption. - rpc OHTTPEncapsulate(OHTTPEncapsulateRequest) - returns (OHTTPEncapsulateResponse) {} + rpc OHTTPEncapsulate(OHTTPEncapsulateRequest) returns (OHTTPEncapsulateResponse) {} // Decrypts a ciphertext encrypted by OHTTP. This is specifically for // decrypting response from the OHTTP gateway (server). That is, if you need @@ -53,8 +50,7 @@ service ProtocolTestingHelper { // OHTTPEncapsulate on this particular helper server process for this RPC to // work. The context_token used as input to this RPC should be from the // response of the earlier call. - rpc OHTTPDecapsulate(OHTTPDecapsulateRequest) - returns (OHTTPDecapsulateResponse) {} + rpc OHTTPDecapsulate(OHTTPDecapsulateRequest) returns (OHTTPDecapsulateResponse) {} } message GetTestConfigRequest {} diff --git a/package.json b/package.json deleted file mode 100644 index 3707a3bc..00000000 --- a/package.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "repository": "https://github.com/privacysandbox/fledge-key-value-service" -} diff --git a/production/packaging/aws/build_and_test b/production/packaging/aws/build_and_test index 66945f2f..5e5629c5 100755 --- a/production/packaging/aws/build_and_test +++ b/production/packaging/aws/build_and_test @@ -60,8 +60,7 @@ while [[ $# -gt 0 ]]; do case "$1" in --with-ami) AMI_REGIONS+=("$2") - shift - shift + shift 2 || usage ;; --no-precommit) BUILD_AND_TEST_ARGS+=("--no-precommit") @@ -72,14 +71,8 @@ while [[ $# -gt 0 ]]; do set -o xtrace shift ;; - -h | --help) - usage 0 - break - ;; - *) - usage - break - ;; + -h | --help) usage 0 ;; + *) usage ;; esac done @@ -89,14 +82,36 @@ function arr_to_string_list() { printf "[%s]" "${joined%,}" } -SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" -readonly SCRIPT_DIR +# Exit 1 on any error before pushing to origin. +function fail() { + printf "\n\n[ERROR]: Failure: %s\n\n" "$@" + sleep 5s # Make sure that stdout has time to be written + exit 1 +} + +printf "==== Sourcing builder.sh =====\n" +# We can't use $WORKSPACE here as it may not be set - it'll be created +# by builder.sh if it's empty. +readonly SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +readonly BUILDER="${SCRIPT_DIR}"/../../../builders/tools/builder.sh +if [[ ! -f ${BUILDER} ]]; then + ERROR_MESSAGE=$(printf "builder.sh file does not exist, expected at: %s\n" "${BUILDER}") + fail "$ERROR_MESSAGE" +fi # shellcheck source=builders/tools/builder.sh -source "${SCRIPT_DIR}"/../../../builders/tools/builder.sh +source "${BUILDER}" || fail "Failed to source builder.sh" readonly DIST="${WORKSPACE}"/dist -"${WORKSPACE}"/production/packaging/build_and_test_all_in_docker "${BUILD_AND_TEST_ARGS[@]}" --instance aws +printf "==== Running build_and_test_all_in_docker =====\n" +if ! [[ -r ${WORKSPACE}/production/packaging/build_and_test_all_in_docker && -x ${WORKSPACE}/production/packaging/build_and_test_all_in_docker ]]; then + printf "build_and_test script not found at location: %s/production/packaging/build_and_test_all_in_docker\n" "${WORKSPACE}" &>/dev/stderr + fail "build_and_test not found" +fi +if ! "${WORKSPACE}"/production/packaging/build_and_test_all_in_docker "${BUILD_AND_TEST_ARGS[@]}" --instance aws; then + fail "Failed to run build_and_test_all_in_docker" +fi +printf "==== Creating dist dir =====\n" mkdir -p "${DIST}"/aws chmod 770 "${DIST}" "${DIST}"/aws @@ -121,8 +136,9 @@ jq --compact-output --raw-output '.Measurements.PCR0' dist/aws/server_enclave_im cat dist/aws/pcr0.json exit 0 " -docker image rm --force ${IMAGE_URI}:"${IMAGE_TAG}" +docker image rm --force ${IMAGE_URI}:"${IMAGE_TAG}" || fail "Unable to remove Docker image" +printf "==== Checking PCR0 =====\n" readonly PCR0_REL_DIR=production/packaging/aws/data_server/nitro-pcr0 readonly PCR0_DIR="${WORKSPACE}"/${PCR0_REL_DIR} readonly PCR0_FILE="${PCR0_DIR}"/${BUILD_ARCH}.json @@ -141,6 +157,7 @@ else cat "${PCR0_FILE}" fi +printf "==== Copying to dist =====\n" builder::cbuild_al2 $" trap _collect_logs EXIT function _collect_logs() { @@ -152,7 +169,13 @@ function _collect_logs() { set -o errexit bazel ${BAZEL_STARTUP_ARGS} run ${BAZEL_EXTRA_ARGS} //production/packaging/aws/data_server/ami:copy_to_dist " -"${WORKSPACE}"/builders/tools/normalize-dist + +printf "==== Running normalize_dist.sh =====\n" +if ! [[ -r ${WORKSPACE}/builders/tools/normalize-dist && -x ${WORKSPACE}/builders/tools/normalize-dist ]]; then + printf "normalize_dist.sh file does not exist, expected at: %s/builders/tools/normalize-dist\n" "${WORKSPACE}" &>/dev/stderr + fail "normalize_dist.sh does not exist" +fi +"${WORKSPACE}"/builders/tools/normalize-dist || { fail "Unable to run normalize-dist"; } if [[ -n ${AMI_REGIONS[0]} ]]; then UTILS_IMAGE="$("${WORKSPACE}"/builders/tools/get-builder-image-tagged --image utils)" diff --git a/production/packaging/aws/data_server/bin/init_server_basic b/production/packaging/aws/data_server/bin/init_server_basic index 4d2cf6c5..3f1365b4 100755 --- a/production/packaging/aws/data_server/bin/init_server_basic +++ b/production/packaging/aws/data_server/bin/init_server_basic @@ -15,6 +15,8 @@ set -o errexit +readonly INTERNAL_SERVER_ADDRESS="unix:///server/socket/internal.sock" + usage() { exitval=${1-1} cat >&2 << USAGE @@ -41,4 +43,4 @@ while [ $# -gt 0 ]; do done # Start the server. -GLOG_logtostderr=1 $PROXY /server/bin/server "$@" +GLOG_logtostderr=1 $PROXY /server/bin/server "$@" --internal_server_address=${INTERNAL_SERVER_ADDRESS} diff --git a/production/packaging/aws/data_server/nitro-pcr0/amd64.json b/production/packaging/aws/data_server/nitro-pcr0/amd64.json index ffb709e6..c4bf4e10 100644 --- a/production/packaging/aws/data_server/nitro-pcr0/amd64.json +++ b/production/packaging/aws/data_server/nitro-pcr0/amd64.json @@ -1 +1 @@ -{"PCR0":"d4042763ce32a2609bf9f2e315050c6d2f3a88ce9fc9cee63d3a8d3f0f35f413473c8d693d738db230cca321643b6c3f"} +{"PCR0":"f3d3cf8a7c9ce4a97f8a079f5d0a7c67385bb319bb9f99f51a5cc8310da6ee7f07594a0e78073ae0d1ab077616d3c18a"} diff --git a/production/packaging/lib_local_server.sh b/production/packaging/lib_local_server.sh index ff9c1667..a2ba5f21 100644 --- a/production/packaging/lib_local_server.sh +++ b/production/packaging/lib_local_server.sh @@ -24,9 +24,13 @@ fi function local_server::_sut_cleanup() { declare -r -i STATUS=$? declare -n _cleanup_args=$1 + declare -i _pid=$2 if [[ ${_cleanup_args[0]} ]]; then docker compose "${_cleanup_args[@]}" down fi + if [[ ${_pid} -gt 0 ]]; then + kill "${_pid}" &>/dev/null + fi return ${STATUS} } @@ -61,8 +65,14 @@ EOF --env-file "${tmp_env}" ) - trap "local_server::_sut_cleanup docker_compose_args && rm -f \${tmp_env@Q}" ERR RETURN + trap "local_server::_sut_cleanup docker_compose_args \${docker_compose_logs_pid} && rm -f \${tmp_env@Q}" ERR RETURN docker compose "${docker_compose_args[@]}" up --quiet-pull --detach + mkdir -p "${WORKSPACE}"/dist/logs + declare logfile + logfile="$(mktemp --tmpdir="${WORKSPACE}/dist/logs" --dry-run "${sut_name}-dcompose-XXXX" --suffix=".log")" + docker compose "${docker_compose_args[@]}" logs --follow >"${logfile}" & + declare -r -i docker_compose_logs_pid=$! + "${WORKSPACE}"/testing/functionaltest/run-tests --sut-name "${sut_name}" } diff --git a/production/terraform/aws/environments/demo/us-east-1.tfvars.json b/production/terraform/aws/environments/demo/us-east-1.tfvars.json index 00911d53..5c1011f6 100644 --- a/production/terraform/aws/environments/demo/us-east-1.tfvars.json +++ b/production/terraform/aws/environments/demo/us-east-1.tfvars.json @@ -22,6 +22,7 @@ "region": "us-east-1", "root_domain": "demo-server.com", "root_domain_zone_id": "zone-id", + "route_v1_requests_to_v2": false, "s3_delta_file_bucket_name": "globally-unique-bucket", "s3client_max_connections": 64, "s3client_max_range_bytes": 8388608, @@ -30,5 +31,6 @@ "sqs_cleanup_schedule": "rate(6 hours)", "sqs_queue_timeout_secs": 86400, "ssh_source_cidr_blocks": ["0.0.0.0/0"], + "udf_num_workers": 2, "vpc_cidr_block": "10.0.0.0/16" } diff --git a/production/terraform/aws/environments/demo/us-west-1.tfvars.json b/production/terraform/aws/environments/demo/us-west-1.tfvars.json index 6082062a..9a369c70 100644 --- a/production/terraform/aws/environments/demo/us-west-1.tfvars.json +++ b/production/terraform/aws/environments/demo/us-west-1.tfvars.json @@ -22,6 +22,7 @@ "region": "us-west-1", "root_domain": "demo-server.com", "root_domain_zone_id": "zone-id", + "route_v1_requests_to_v2": false, "s3_delta_file_bucket_name": "globally-unique-bucket", "s3client_max_connections": 64, "s3client_max_range_bytes": 8388608, @@ -30,5 +31,6 @@ "sqs_cleanup_schedule": "rate(6 hours)", "sqs_queue_timeout_secs": 86400, "ssh_source_cidr_blocks": ["0.0.0.0/0"], + "udf_num_workers": 2, "vpc_cidr_block": "10.0.0.0/16" } diff --git a/production/terraform/aws/environments/kv_server.tf b/production/terraform/aws/environments/kv_server.tf index 83b2bd91..b1e1226a 100644 --- a/production/terraform/aws/environments/kv_server.tf +++ b/production/terraform/aws/environments/kv_server.tf @@ -32,10 +32,11 @@ module "kv_server" { instance_ami_id = var.instance_ami_id # Variables related to server configuration. - mode = var.mode - server_port = var.server_port - enclave_cpu_count = var.enclave_cpu_count - enclave_memory_mib = var.enclave_memory_mib + mode = var.mode + route_v1_requests_to_v2 = var.route_v1_requests_to_v2 + server_port = var.server_port + enclave_cpu_count = var.enclave_cpu_count + enclave_memory_mib = var.enclave_memory_mib # Variables related to autoscaling and load balancing. autoscaling_desired_capacity = var.autoscaling_desired_capacity @@ -73,6 +74,9 @@ module "kv_server" { # Variables related to sharding. num_shards = var.num_shards + + # Variables related to UDF exeuction. + udf_num_workers = var.udf_num_workers } output "kv_server_url" { diff --git a/production/terraform/aws/environments/kv_server_variables.tf b/production/terraform/aws/environments/kv_server_variables.tf index c9e7ba14..e2a75be9 100644 --- a/production/terraform/aws/environments/kv_server_variables.tf +++ b/production/terraform/aws/environments/kv_server_variables.tf @@ -187,3 +187,13 @@ variable "num_shards" { description = "Total number of shards." type = number } + +variable "udf_num_workers" { + description = "Total number of workers for UDF execution." + type = number +} + +variable "route_v1_requests_to_v2" { + description = "Whether to route V1 requests through V2." + type = bool +} diff --git a/production/terraform/aws/modules/kv_server/main.tf b/production/terraform/aws/modules/kv_server/main.tf index 9b009158..f633b6a4 100644 --- a/production/terraform/aws/modules/kv_server/main.tf +++ b/production/terraform/aws/modules/kv_server/main.tf @@ -152,6 +152,8 @@ module "parameter" { s3client_max_connections_parameter_value = var.s3client_max_connections s3client_max_range_bytes_parameter_value = var.s3client_max_range_bytes num_shards_parameter_value = var.num_shards + udf_num_workers_parameter_value = var.udf_num_workers + route_v1_requests_to_v2_parameter_value = var.route_v1_requests_to_v2 } module "security_group_rules" { @@ -193,6 +195,8 @@ module "iam_role_policies" { module.parameter.s3client_max_connections_parameter_arn, module.parameter.s3client_max_range_bytes_parameter_arn, module.parameter.num_shards_parameter_arn, + module.parameter.udf_num_workers_parameter_arn, + module.parameter.route_v1_requests_to_v2_parameter_arn, ] } diff --git a/production/terraform/aws/modules/kv_server/variables.tf b/production/terraform/aws/modules/kv_server/variables.tf index edc78f2f..c74fa224 100644 --- a/production/terraform/aws/modules/kv_server/variables.tf +++ b/production/terraform/aws/modules/kv_server/variables.tf @@ -177,3 +177,13 @@ variable "num_shards" { description = "Number of shards." type = number } + +variable "udf_num_workers" { + description = "Number of workers for UDF execution." + type = number +} + +variable "route_v1_requests_to_v2" { + description = "Whether to route V1 requests through V2." + type = bool +} diff --git a/production/terraform/aws/services/iam_role_policies/main.tf b/production/terraform/aws/services/iam_role_policies/main.tf index cc15a48d..4299f544 100644 --- a/production/terraform/aws/services/iam_role_policies/main.tf +++ b/production/terraform/aws/services/iam_role_policies/main.tf @@ -51,6 +51,12 @@ data "aws_iam_policy_document" "instance_policy_doc" { effect = "Allow" resources = ["*"] } + statement { + sid = "AllowInstancesToDescribeAutoScalingGroups" + actions = ["autoscaling:DescribeAutoScalingGroups"] + effect = "Allow" + resources = ["*"] + } statement { sid = "AllowInstancesToReadParameters" actions = ["ssm:GetParameter"] diff --git a/production/terraform/aws/services/parameter/main.tf b/production/terraform/aws/services/parameter/main.tf index df13a6ea..12a3006f 100644 --- a/production/terraform/aws/services/parameter/main.tf +++ b/production/terraform/aws/services/parameter/main.tf @@ -104,3 +104,17 @@ resource "aws_ssm_parameter" "num_shards_parameter" { value = var.num_shards_parameter_value overwrite = true } + +resource "aws_ssm_parameter" "udf_num_workers_parameter" { + name = "${var.service}-${var.environment}-udf-num-workers" + type = "String" + value = var.udf_num_workers_parameter_value + overwrite = true +} + +resource "aws_ssm_parameter" "route_v1_requests_to_v2_parameter" { + name = "${var.service}-${var.environment}-route-v1-to-v2" + type = "String" + value = var.route_v1_requests_to_v2_parameter_value + overwrite = true +} diff --git a/production/terraform/aws/services/parameter/outputs.tf b/production/terraform/aws/services/parameter/outputs.tf index dbe8872a..3ed372ed 100644 --- a/production/terraform/aws/services/parameter/outputs.tf +++ b/production/terraform/aws/services/parameter/outputs.tf @@ -69,3 +69,11 @@ output "s3client_max_range_bytes_parameter_arn" { output "num_shards_parameter_arn" { value = aws_ssm_parameter.num_shards_parameter.arn } + +output "udf_num_workers_parameter_arn" { + value = aws_ssm_parameter.udf_num_workers_parameter.arn +} + +output "route_v1_requests_to_v2_parameter_arn" { + value = aws_ssm_parameter.route_v1_requests_to_v2_parameter.arn +} diff --git a/production/terraform/aws/services/parameter/variables.tf b/production/terraform/aws/services/parameter/variables.tf index ddf26ee0..a8cc7ccc 100644 --- a/production/terraform/aws/services/parameter/variables.tf +++ b/production/terraform/aws/services/parameter/variables.tf @@ -83,3 +83,13 @@ variable "num_shards_parameter_value" { description = "Total shards numbers." type = number } + +variable "udf_num_workers_parameter_value" { + description = "Total number of workers for UDF execution." + type = number +} + +variable "route_v1_requests_to_v2_parameter_value" { + description = "Whether to route V1 requests through V2." + type = bool +} diff --git a/production/terraform/aws/services/security_group_rules/main.tf b/production/terraform/aws/services/security_group_rules/main.tf index e60bd03c..a528b9d2 100644 --- a/production/terraform/aws/services/security_group_rules/main.tf +++ b/production/terraform/aws/services/security_group_rules/main.tf @@ -139,3 +139,21 @@ resource "aws_security_group_rule" "allow_ssh_instance_to_vpce_ingress" { type = "ingress" source_security_group_id = var.ssh_security_group_id } + +resource "aws_security_group_rule" "allow_ec2_to_ec2_endpoint_egress" { + from_port = 50100 + protocol = "TCP" + security_group_id = var.instances_security_group_id + to_port = 50100 + type = "egress" + source_security_group_id = var.instances_security_group_id +} + +resource "aws_security_group_rule" "allow_ec2_to_ec2_endpoint_ingress" { + from_port = 50100 + protocol = "TCP" + security_group_id = var.instances_security_group_id + to_port = 50100 + type = "ingress" + source_security_group_id = var.instances_security_group_id +} diff --git a/public/BUILD b/public/BUILD index 5ebc4162..7dc9cce4 100644 --- a/public/BUILD +++ b/public/BUILD @@ -28,7 +28,7 @@ proto_library( buf_lint_test( name = "base_types_proto_lint", size = "small", - config = "//public:buf.yaml", + config = "//:buf.yaml", targets = [":base_types_proto"], ) @@ -52,10 +52,6 @@ cc_library( ], ) -exports_files([ - "buf.yaml", -]) - cc_test( name = "constants_test", size = "small", diff --git a/public/data_loading/BUILD b/public/data_loading/BUILD index 873ef9de..19ecc97e 100644 --- a/public/data_loading/BUILD +++ b/public/data_loading/BUILD @@ -43,7 +43,7 @@ proto_library( buf_lint_test( name = "riegeli_metadata_proto_lint", size = "small", - config = "//public:buf.yaml", + config = "//:buf.yaml", targets = [":riegeli_metadata_proto"], ) @@ -53,6 +53,9 @@ cc_library( hdrs = ["records_utils.h"], deps = [ ":data_loading_fbs", + "@com_github_google_glog//:glog", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/public/data_loading/aggregation/BUILD b/public/data_loading/aggregation/BUILD index dc228c75..6e941d84 100644 --- a/public/data_loading/aggregation/BUILD +++ b/public/data_loading/aggregation/BUILD @@ -26,6 +26,7 @@ cc_library( "//public/data_loading/writers:delta_record_stream_writer", "//public/data_loading/writers:delta_record_writer", "@com_github_google_glog//:glog", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/public/data_loading/aggregation/record_aggregator.cc b/public/data_loading/aggregation/record_aggregator.cc index 3ecaa81d..893cf36b 100644 --- a/public/data_loading/aggregation/record_aggregator.cc +++ b/public/data_loading/aggregation/record_aggregator.cc @@ -20,6 +20,7 @@ #include #include +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -115,33 +116,21 @@ class RecordRowCallback { class DeltaRecordCallback : public RecordRowCallback { public: explicit DeltaRecordCallback( - std::function delta_callback) + std::function delta_callback) : delta_callback_(std::move(delta_callback)) {} absl::Status operator()(RecordRow record_row) override { - auto verifier = flatbuffers::Verifier( - reinterpret_cast(record_row.record_blob.data()), - record_row.record_blob.size()); - auto record = - flatbuffers::GetRoot(record_row.record_blob.data()); - if (!record->Verify(verifier)) { - return absl::InvalidArgumentError( - "Record flatbuffer format is not valid."); - } - return delta_callback_(DeltaFileRecordStruct{ - .mutation_type = record->mutation_type(), - .logical_commit_time = record->logical_commit_time(), - .key = record->key()->string_view(), - .value = record->value()->string_view()}); + return DeserializeRecord(record_row.record_blob, delta_callback_); } private: - std::function delta_callback_; + std::function + delta_callback_; }; class ReadRecordsCallback : public RecordRowCallback { public: explicit ReadRecordsCallback( - std::function delta_callback) + std::function delta_callback) : delta_callback_(DeltaRecordCallback(std::move(delta_callback))) {} const int64_t& LastProcessedRecordKey() const { return last_processed_key_; } absl::Status operator()(RecordRow record_row) { @@ -168,11 +157,11 @@ absl::Status CreateRecordsTable(sqlite3* db) { return absl::OkStatus(); } -absl::Status ValidateRecord(const DeltaFileRecordStruct& record) { +absl::Status ValidateRecord(const KeyValueMutationRecordStruct& record) { if (record.key.empty()) { return absl::InvalidArgumentError("Record key must not be empty."); } - if (record.value.empty()) { + if (IsEmptyValue(record.value)) { return absl::InvalidArgumentError("Record value must not be empty."); } return absl::OkStatus(); @@ -279,11 +268,52 @@ RecordAggregator::CreateFileBackedAggregator(std::string_view data_file) { return absl::WrapUnique(new RecordAggregator(std::move(db_owner))); } +absl::StatusOr> +RecordAggregator::MergeSetValueIfRecordExists( + int64_t record_key, const KeyValueMutationRecordStruct& record) { + auto new_values_set = std::get>(record.value); + absl::flat_hash_set merged_values_set; + std::copy(new_values_set.begin(), new_values_set.end(), + std::inserter(merged_values_set, merged_values_set.end())); + auto status = ReadRecord( + record_key, + [&merged_values_set](KeyValueMutationRecordStruct existing_record) { + if (std::holds_alternative>( + existing_record.value)) { + auto existing_values_set = + std::get>(existing_record.value); + std::copy(existing_values_set.begin(), existing_values_set.end(), + std::inserter(merged_values_set, merged_values_set.end())); + } + return absl::OkStatus(); + }); + if (!status.ok()) { + return status; + } + std::vector merged_values_list; + for (auto value : merged_values_set) { + merged_values_list.push_back(std::string(value)); + } + return merged_values_list; +} + absl::Status RecordAggregator::InsertOrUpdateRecord( - int64_t record_key, const DeltaFileRecordStruct& record) { + int64_t record_key, const KeyValueMutationRecordStruct& record) { if (absl::Status status = ValidateRecord(record); !status.ok()) { return status; } + KeyValueMutationRecordStruct mutable_record = record; + std::vector values; + if (std::holds_alternative>( + mutable_record.value)) { + auto maybe_values = MergeSetValueIfRecordExists(record_key, mutable_record); + if (!maybe_values.ok()) { + return maybe_values.status(); + } + values = std::move(*maybe_values); + mutable_record.value = + std::vector(values.begin(), values.end()); + } sqlite3_stmt* insert_stmt; if (absl::Status status = PrepareStatement(kInsertRecordSql, &insert_stmt, db_.get()); @@ -298,14 +328,14 @@ absl::Status RecordAggregator::InsertOrUpdateRecord( return status; } if (absl::Status status = - BindInt64(record.logical_commit_time, kLogicalCommitTimeBindValueIdx, - owned_stmt.get()); + BindInt64(mutable_record.logical_commit_time, + kLogicalCommitTimeBindValueIdx, owned_stmt.get()); !status.ok()) { return status; } - std::string_view record_blob = ToStringView(record.ToFlatBuffer()); if (absl::Status status = - BindBlob(record_blob, kRecordBlobBindValueIdx, owned_stmt.get()); + BindBlob(ToStringView(ToFlatBufferBuilder(mutable_record)), + kRecordBlobBindValueIdx, owned_stmt.get()); !status.ok()) { return status; } @@ -319,7 +349,7 @@ absl::Status RecordAggregator::InsertOrUpdateRecord( absl::Status RecordAggregator::ReadRecord( int64_t record_key, - std::function record_callback) { + std::function record_callback) { sqlite3_stmt* select_stmt; if (absl::Status status = PrepareStatement(kSelectRecordSql, &select_stmt, db_.get()); @@ -347,7 +377,7 @@ absl::Status RecordAggregator::ReadRecord( } absl::Status RecordAggregator::ReadRecords( - std::function record_callback) { + std::function record_callback) { sqlite3_stmt* batch_select_stmt; if (absl::Status status = PrepareStatement(kBatchSelectRecordsSql, &batch_select_stmt, db_.get()); @@ -358,9 +388,9 @@ absl::Status RecordAggregator::ReadRecords( batch_select_stmt, StmtDeleter{}); ReadRecordsCallback callback(std::move(record_callback)); // Loop through all batches of records until we have read all avaiable rows. - // Each batch has a size of `kRecordsQueryBatchSizeBindIdx` and we know we are - // done if we process less than `kRecordsQueryBatchSizeBindIdx` in a batch - // successfuly. + // Each batch has a size of `kRecordsQueryBatchSizeBindIdx` and we know we + // are done if we process less than `kRecordsQueryBatchSizeBindIdx` in a + // batch successfuly. while (true) { if (absl::Status status = ResetPreparedStatement(owned_stmt.get()); !status.ok()) { @@ -378,8 +408,8 @@ absl::Status RecordAggregator::ReadRecords( !status.ok()) { return status; } - // Loop through all rows in a batch until we get a SQLITE_DONE or SQLITE_OK - // meaning that we have finished processing rows. + // Loop through all rows in a batch until we get a SQLITE_DONE or + // SQLITE_OK meaning that we have finished processing rows. int64_t num_processed_records = 0; while (true) { auto result = ProcessRecordRow(callback, owned_stmt.get()); diff --git a/public/data_loading/aggregation/record_aggregator.h b/public/data_loading/aggregation/record_aggregator.h index ff03332f..c102c626 100644 --- a/public/data_loading/aggregation/record_aggregator.h +++ b/public/data_loading/aggregation/record_aggregator.h @@ -20,8 +20,10 @@ #include #include #include +#include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -30,9 +32,10 @@ #include "sqlite3.h" namespace kv_server { -// A `RecordAggregator` aggregates `DeltaFileRecordStruct` records added to an -// aggregator instance from potentially multiple record streams. Records can -// be aggregated by repeatedly calling `InsertOrUpdateRecord(...)` as follows: +// A `RecordAggregator` aggregates `KeyValueMutationRecordStruct` records added +// to an aggregator instance from potentially multiple record streams. Records +// can be aggregated by repeatedly calling `InsertOrUpdateRecord(...)` as +// follows: // //``` // auto record_aggregator = RecordAggregator::CreateInMemoryAggregator(); @@ -42,7 +45,7 @@ namespace kv_server { // for (auto record_stream : records_streams) { // DeltaRecordStreamReader record_reader(record_stream); // absl::Status status = record_reader.ReadRecords( -// [&](const DeltaFileRecordStruct& record) { +// [&](const KeyValueMutationRecordStruct& record) { // return record_aggregator->InsertOrUpdateRecord(GetRecordKey(record), // record); // }); @@ -80,7 +83,7 @@ class RecordAggregator { // - !absl::OkStatus() - if there are any errors. The returned status // contains a detailed error message. absl::Status InsertOrUpdateRecord(int64_t record_key, - const DeltaFileRecordStruct& record); + const KeyValueMutationRecordStruct& record); // Reads a record keyed by `record_key` and calls the provided // `record_callback` function with the record. If no record keyed by // `record_key` exists, then `record_callback` is never called. @@ -91,7 +94,8 @@ class RecordAggregator { // contains a detailed error message. absl::Status ReadRecord( int64_t record_key, - std::function record_callback); + std::function + record_callback); // Reads all records currently in the aggregator. The `record_callback` // function is called exactly once for each record. // @@ -100,7 +104,8 @@ class RecordAggregator { // - !absl::OkStatus() - if there are any errors. The returned status // contains a detailed error message. absl::Status ReadRecords( - std::function record_callback); + std::function + record_callback); // Deletes record keyed by `record_key`. Silently succeeds if the record // does not exist. // @@ -126,6 +131,9 @@ class RecordAggregator { explicit RecordAggregator(std::unique_ptr db) : db_(std::move(db)) {} + absl::StatusOr> MergeSetValueIfRecordExists( + int64_t record_key, const KeyValueMutationRecordStruct& record); + std::unique_ptr db_; }; } // namespace kv_server diff --git a/public/data_loading/aggregation/record_aggregator_benchmarks.cc b/public/data_loading/aggregation/record_aggregator_benchmarks.cc index e9d4b81f..cda3a28d 100644 --- a/public/data_loading/aggregation/record_aggregator_benchmarks.cc +++ b/public/data_loading/aggregation/record_aggregator_benchmarks.cc @@ -19,8 +19,8 @@ #include "public/data_loading/aggregation/record_aggregator.h" #include "public/data_loading/records_utils.h" -using kv_server::DeltaFileRecordStruct; -using kv_server::DeltaMutationType; +using kv_server::KeyValueMutationRecordStruct; +using kv_server::KeyValueMutationType; using kv_server::RecordAggregator; static std::string GenerateRecordValue(int64_t char_count) { @@ -30,9 +30,10 @@ static std::string GenerateRecordValue(int64_t char_count) { static void BM_InMemoryRecordAggregator_InsertRecord(benchmark::State& state) { auto record_aggregator = RecordAggregator::CreateInMemoryAggregator(); std::string record_value = GenerateRecordValue(state.range(0)); - DeltaFileRecordStruct record{.mutation_type = DeltaMutationType::Update, - .logical_commit_time = 1234567890, - .value = record_value}; + KeyValueMutationRecordStruct record{ + .mutation_type = KeyValueMutationType::Update, + .logical_commit_time = 1234567890, + .value = record_value}; for (auto _ : state) { state.PauseTiming(); std::string record_key = absl::StrCat("key", std::rand() % 10'000); diff --git a/public/data_loading/aggregation/record_aggregator_test.cc b/public/data_loading/aggregation/record_aggregator_test.cc index d4afb9f1..273b3eaa 100644 --- a/public/data_loading/aggregation/record_aggregator_test.cc +++ b/public/data_loading/aggregation/record_aggregator_test.cc @@ -31,16 +31,18 @@ namespace kv_server { namespace { -size_t GetRecordKey(const DeltaFileRecordStruct& record) { +size_t GetRecordKey(const KeyValueMutationRecordStruct& record) { return absl::HashOf(record.key); } -DeltaFileRecordStruct GetDeltaRecord(std::string_view key = "key") { - DeltaFileRecordStruct record; +KeyValueMutationRecordStruct GetDeltaRecord( + std::string_view key = "key", + KeyValueMutationRecordValueT value = "value") { + KeyValueMutationRecordStruct record; record.key = key; - record.value = "value"; + record.value = value; record.logical_commit_time = 1234567890; - record.mutation_type = DeltaMutationType::Update; + record.mutation_type = KeyValueMutationType::Update; return record; } @@ -75,10 +77,11 @@ TEST_P(RecordAggregatorTest, ValidateReadRecord) { status = (*record_aggregator)->InsertOrUpdateRecord(GetRecordKey(record), record); EXPECT_TRUE(status.ok()) << status; - testing::MockFunction record_callback1; + testing::MockFunction + record_callback1; EXPECT_CALL(record_callback1, Call) .Times(1) - .WillRepeatedly([](DeltaFileRecordStruct record) { + .WillRepeatedly([](KeyValueMutationRecordStruct record) { EXPECT_EQ(record, GetDeltaRecord()); return absl::OkStatus(); }); @@ -87,7 +90,8 @@ TEST_P(RecordAggregatorTest, ValidateReadRecord) { ->ReadRecord(GetRecordKey(record), record_callback1.AsStdFunction()); EXPECT_TRUE(status.ok()) << status; // We don't expect calls to our callback for records that do not exist. - testing::MockFunction record_callback2; + testing::MockFunction + record_callback2; EXPECT_CALL(record_callback2, Call).Times(0); status = (*record_aggregator) ->ReadRecord(std::hash{}("non-existing-record-key"), @@ -103,10 +107,11 @@ TEST_P(RecordAggregatorTest, ValidateDeleteRecord) { status = (*record_aggregator)->InsertOrUpdateRecord(GetRecordKey(record), record); EXPECT_TRUE(status.ok()) << status; - testing::MockFunction record_callback1; + testing::MockFunction + record_callback1; EXPECT_CALL(record_callback1, Call) .Times(1) - .WillRepeatedly([](DeltaFileRecordStruct record) { + .WillRepeatedly([](KeyValueMutationRecordStruct record) { EXPECT_EQ(record, GetDeltaRecord()); return absl::OkStatus(); }); @@ -116,7 +121,8 @@ TEST_P(RecordAggregatorTest, ValidateDeleteRecord) { EXPECT_TRUE(status.ok()) << status; status = (*record_aggregator)->DeleteRecord(GetRecordKey(record)); EXPECT_TRUE(status.ok()) << status; - testing::MockFunction record_callback2; + testing::MockFunction + record_callback2; EXPECT_CALL(record_callback2, Call).Times(0); status = (*record_aggregator) @@ -134,7 +140,8 @@ TEST_P(RecordAggregatorTest, ValidateDeleteRecords) { EXPECT_TRUE(status.ok()) << status; status = (*record_aggregator)->DeleteRecords(); EXPECT_TRUE(status.ok()) << status; - testing::MockFunction record_callback; + testing::MockFunction + record_callback; EXPECT_CALL(record_callback, Call).Times(0); status = (*record_aggregator) @@ -150,10 +157,11 @@ TEST_P(RecordAggregatorTest, ValidateInsertingNonExistingRecord) { status = (*record_aggregator)->InsertOrUpdateRecord(GetRecordKey(record), record); EXPECT_TRUE(status.ok()) << status; - testing::MockFunction record_callback; + testing::MockFunction + record_callback; EXPECT_CALL(record_callback, Call) .Times(1) - .WillRepeatedly([](DeltaFileRecordStruct record) { + .WillRepeatedly([](KeyValueMutationRecordStruct record) { EXPECT_EQ(record, GetDeltaRecord()); return absl::OkStatus(); }); @@ -174,14 +182,16 @@ TEST_P(RecordAggregatorTest, ValidateInsertingMoreRecentRecord) { // Update record to be more recent and verify that updates are reflected in // stored record. record.logical_commit_time = record.logical_commit_time + 1; - std::string updated_value = absl::StrCat("Updated ", record.value); + std::string updated_value = + absl::StrCat("Updated ", std::get(record.value)); record.value = updated_value; status = (*record_aggregator)->InsertOrUpdateRecord(GetRecordKey(record), record); - testing::MockFunction record_callback; + testing::MockFunction + record_callback; EXPECT_CALL(record_callback, Call) .Times(1) - .WillRepeatedly([&](DeltaFileRecordStruct existing_record) { + .WillRepeatedly([&](KeyValueMutationRecordStruct existing_record) { EXPECT_EQ(existing_record, record); return absl::OkStatus(); }); @@ -201,14 +211,16 @@ TEST_P(RecordAggregatorTest, ValidateInsertingOlderRecord) { EXPECT_TRUE(status.ok()) << status; // Update record to be older and verify that stored record is not updated. record.logical_commit_time = record.logical_commit_time - 1; - std::string updated_value = absl::StrCat("Updated ", record.value); + std::string updated_value = + absl::StrCat("Updated ", std::get(record.value)); record.value = updated_value; status = (*record_aggregator)->InsertOrUpdateRecord(GetRecordKey(record), record); - testing::MockFunction record_callback; + testing::MockFunction + record_callback; EXPECT_CALL(record_callback, Call) .Times(1) - .WillRepeatedly([](DeltaFileRecordStruct existing_record) { + .WillRepeatedly([](KeyValueMutationRecordStruct existing_record) { EXPECT_EQ(existing_record, GetDeltaRecord()); return absl::OkStatus(); }); @@ -229,14 +241,16 @@ TEST_P(RecordAggregatorTest, ValidateInsertingUpdatedRecordWithSameTimestamp) { // Update record and verify that stored record is updated. // Since updated record has the same timestamp, new values should be // reflected in store. - std::string updated_value = absl::StrCat("Updated ", record.value); + std::string updated_value = + absl::StrCat("Updated ", std::get(record.value)); record.value = updated_value; status = (*record_aggregator)->InsertOrUpdateRecord(GetRecordKey(record), record); - testing::MockFunction record_callback; + testing::MockFunction + record_callback; EXPECT_CALL(record_callback, Call) .Times(1) - .WillRepeatedly([&](DeltaFileRecordStruct existing_record) { + .WillRepeatedly([&](KeyValueMutationRecordStruct existing_record) { EXPECT_EQ(existing_record, record); return absl::OkStatus(); }); @@ -254,10 +268,11 @@ TEST_P(RecordAggregatorTest, ValidateInsertingMultipleRecords) { status = (*record_aggregator)->InsertOrUpdateRecord(GetRecordKey(record), record); EXPECT_TRUE(status.ok()) << status; - testing::MockFunction record_callback1; + testing::MockFunction + record_callback1; EXPECT_CALL(record_callback1, Call) .Times(1) - .WillRepeatedly([](DeltaFileRecordStruct existing_record) { + .WillRepeatedly([](KeyValueMutationRecordStruct existing_record) { EXPECT_EQ(existing_record, GetDeltaRecord()); return absl::OkStatus(); }); @@ -270,10 +285,11 @@ TEST_P(RecordAggregatorTest, ValidateInsertingMultipleRecords) { status = (*record_aggregator)->InsertOrUpdateRecord(GetRecordKey(record), record); EXPECT_TRUE(status.ok()) << status; - testing::MockFunction record_callback2; + testing::MockFunction + record_callback2; EXPECT_CALL(record_callback2, Call) .Times(1) - .WillRepeatedly([&](DeltaFileRecordStruct existing_record) { + .WillRepeatedly([&](KeyValueMutationRecordStruct existing_record) { EXPECT_EQ(existing_record, record); return absl::OkStatus(); }); @@ -286,7 +302,7 @@ TEST_P(RecordAggregatorTest, ValidateInsertingInvalidRecords) { auto record_aggregator = RecordAggregatorTest::CreateAggregator(); auto status = (*record_aggregator)->DeleteRecords(); EXPECT_TRUE(status.ok()) << status; - DeltaFileRecordStruct record; + KeyValueMutationRecordStruct record; status = (*record_aggregator) ->InsertOrUpdateRecord(std::hash{}("key1"), record); EXPECT_FALSE(status.ok()) << status; @@ -308,7 +324,8 @@ TEST_P(RecordAggregatorTest, ValidateReadingRecords) { EXPECT_TRUE(status.ok()) << status; constexpr std::array kRecordKeys = { "key1", "key2", "key3", "key4", "key5"}; - testing::MockFunction record_callback; + testing::MockFunction + record_callback; for (std::string_view key : kRecordKeys) { auto record = GetDeltaRecord(key); auto status = (*record_aggregator) @@ -318,7 +335,8 @@ TEST_P(RecordAggregatorTest, ValidateReadingRecords) { // inserted records and each call should be with a record that matches an // inserted record. EXPECT_CALL(record_callback, Call(record)) - .WillOnce([](DeltaFileRecordStruct) { return absl::OkStatus(); }); + .WillOnce( + [](KeyValueMutationRecordStruct) { return absl::OkStatus(); }); } EXPECT_TRUE( (*record_aggregator)->ReadRecords(record_callback.AsStdFunction()).ok()); @@ -328,7 +346,8 @@ TEST_P(RecordAggregatorTest, ValidateReadingRecordsFromEmptyAggregator) { auto record_aggregator = RecordAggregatorTest::CreateAggregator(); auto status = (*record_aggregator)->DeleteRecords(); EXPECT_TRUE(status.ok()) << status; - testing::MockFunction record_callback; + testing::MockFunction + record_callback; EXPECT_CALL(record_callback, Call).Times(0); EXPECT_TRUE( (*record_aggregator)->ReadRecords(record_callback.AsStdFunction()).ok()); @@ -342,10 +361,11 @@ TEST_P(RecordAggregatorTest, ValidateReadingRecordsWhenCallbackFails) { status = (*record_aggregator)->InsertOrUpdateRecord(GetRecordKey(record), record); EXPECT_TRUE(status.ok()) << status; - testing::MockFunction record_callback; + testing::MockFunction + record_callback; EXPECT_CALL(record_callback, Call) .Times(1) - .WillOnce([&](DeltaFileRecordStruct record) { + .WillOnce([&](KeyValueMutationRecordStruct record) { return absl::InvalidArgumentError("Callback failed."); }); status = (*record_aggregator)->ReadRecords(record_callback.AsStdFunction()); @@ -354,5 +374,33 @@ TEST_P(RecordAggregatorTest, ValidateReadingRecordsWhenCallbackFails) { EXPECT_STREQ(status.message().data(), "Callback failed."); } +TEST_P(RecordAggregatorTest, ValidateAggregatingRecordsWithSetValues) { + auto record_aggregator = RecordAggregatorTest::CreateAggregator(); + auto status = (*record_aggregator)->DeleteRecords(); + EXPECT_TRUE(status.ok()) << status; + testing::MockFunction + record_callback; + auto record1 = GetDeltaRecord( + "key1", + std::vector{"value1", "value2", "value3", "value4"}); + status = (*record_aggregator) + ->InsertOrUpdateRecord(GetRecordKey(record1), record1); + EXPECT_TRUE(status.ok()) << status; + auto record2 = GetDeltaRecord( + "key1", std::vector{"value3", "value4", "value5"}); + status = (*record_aggregator) + ->InsertOrUpdateRecord(GetRecordKey(record2), record2); + EXPECT_TRUE(status.ok()) << status; + EXPECT_CALL(record_callback, Call) + .WillOnce([](KeyValueMutationRecordStruct record) { + EXPECT_THAT(std::get>(record.value), + testing::UnorderedElementsAre("value1", "value2", "value3", + "value4", "value5")); + return absl::OkStatus(); + }); + EXPECT_TRUE( + (*record_aggregator)->ReadRecords(record_callback.AsStdFunction()).ok()); +} + } // namespace } // namespace kv_server diff --git a/public/data_loading/csv/BUILD b/public/data_loading/csv/BUILD index ea648d14..43b2f194 100644 --- a/public/data_loading/csv/BUILD +++ b/public/data_loading/csv/BUILD @@ -35,6 +35,7 @@ cc_library( "//public/data_loading/writers:delta_record_writer", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_riegeli//riegeli/bytes:ostream_writer", "@com_google_riegeli//riegeli/csv:csv_record", "@com_google_riegeli//riegeli/csv:csv_writer", diff --git a/public/data_loading/csv/constants.h b/public/data_loading/csv/constants.h index 26085cfe..8b722413 100644 --- a/public/data_loading/csv/constants.h +++ b/public/data_loading/csv/constants.h @@ -30,6 +30,28 @@ inline constexpr std::string_view kLogicalCommitTimeColumn = "logical_commit_time"; inline constexpr std::string_view kKeyColumn = "key"; inline constexpr std::string_view kValueColumn = "value"; +inline constexpr std::string_view kValueTypeColumn = "value_type"; +inline constexpr std::string_view kValueTypeString = "string"; +inline constexpr std::string_view kValueTypeStringSet = "string_set"; + +inline constexpr std::string_view kRecordTypeColumn = "record_type"; +inline constexpr std::string_view kRecordTypeKVMutation = "key_value_mutation"; +inline constexpr std::string_view kRecordTypeUdfConfig = + "user_defined_functions_config"; + +inline constexpr std::string_view kCodeSnippetColumn = "code_snippet"; +inline constexpr std::string_view kHandlerNameColumn = "handler_name"; +inline constexpr std::string_view kLanguageColumn = "language"; +inline constexpr std::string_view kLanguageJavascript = "javascript"; + +inline constexpr std::array kKeyValueMutationRecordHeader = + {kKeyColumn, kLogicalCommitTimeColumn, kMutationTypeColumn, kValueColumn, + kValueTypeColumn}; + +inline constexpr std::array + kUserDefinedFunctionsConfigHeader = {kCodeSnippetColumn, kHandlerNameColumn, + kLogicalCommitTimeColumn, + kLanguageColumn}; } // namespace kv_server diff --git a/public/data_loading/csv/csv_delta_record_stream_reader.cc b/public/data_loading/csv/csv_delta_record_stream_reader.cc index 187c1290..4af6dcef 100644 --- a/public/data_loading/csv/csv_delta_record_stream_reader.cc +++ b/public/data_loading/csv/csv_delta_record_stream_reader.cc @@ -17,6 +17,10 @@ #include "public/data_loading/csv/csv_delta_record_stream_reader.h" #include "absl/strings/ascii.h" +#include "absl/strings/match.h" +#include "absl/strings/str_split.h" +#include "glog/logging.h" +#include "public/data_loading/records_utils.h" namespace kv_server { namespace { @@ -29,39 +33,108 @@ absl::StatusOr GetLogicalCommitTime( "Cannot convert timestamp:", logical_commit_time, " to a number.")); } -absl::StatusOr GetDeltaMutationType( +absl::StatusOr GetDeltaMutationType( absl::string_view mutation_type) { std::string mt_lower = absl::AsciiStrToLower(mutation_type); if (mt_lower == kUpdateMutationType) { - return DeltaMutationType::Update; + return KeyValueMutationType::Update; } if (mt_lower == kDeleteMutationType) { - return DeltaMutationType::Delete; + return KeyValueMutationType::Delete; } return absl::InvalidArgumentError( absl::StrCat("Unknown mutation type:", mutation_type)); } -} // namespace -namespace internal { -absl::StatusOr MakeDeltaFileRecordStruct( - const riegeli::CsvRecord& csv_record) { - DeltaFileRecordStruct record; +absl::StatusOr GetRecordValue( + const riegeli::CsvRecord& csv_record, char value_separator) { + auto type = absl::AsciiStrToLower(csv_record[kValueTypeColumn]); + if (kValueTypeString == type) { + return csv_record[kValueColumn]; + } + if (kValueTypeStringSet == type) { + return absl::StrSplit(csv_record[kValueColumn], value_separator); + } + return absl::InvalidArgumentError( + absl::StrCat("Value type: ", type, " is not supported")); +} + +absl::StatusOr MakeDeltaFileRecordStructWithKVMutation( + const riegeli::CsvRecord& csv_record, char value_separator) { + KeyValueMutationRecordStruct record; record.key = csv_record[kKeyColumn]; - record.value = csv_record[kValueColumn]; + auto value = GetRecordValue(csv_record, value_separator); + if (!value.ok()) { + return value.status(); + } + record.value = *value; absl::StatusOr commit_time = GetLogicalCommitTime(csv_record[kLogicalCommitTimeColumn]); if (!commit_time.ok()) { return commit_time.status(); } record.logical_commit_time = *commit_time; - absl::StatusOr mutation_type = + absl::StatusOr mutation_type = GetDeltaMutationType(csv_record[kMutationTypeColumn]); if (!mutation_type.ok()) { return mutation_type.status(); } record.mutation_type = *mutation_type; - return record; + + DataRecordStruct data_record; + data_record.record = record; + return data_record; +} + +absl::StatusOr GetUdfLanguage( + const riegeli::CsvRecord& csv_record) { + auto language = absl::AsciiStrToLower(csv_record[kLanguageColumn]); + if (kLanguageJavascript == language) { + return UserDefinedFunctionsLanguage::Javascript; + } + return absl::InvalidArgumentError( + absl::StrCat("Language: ", language, " is not supported.")); +} + +absl::StatusOr MakeDeltaFileRecordStructWithUdfConfig( + const riegeli::CsvRecord& csv_record) { + UserDefinedFunctionsConfigStruct udf_config; + udf_config.code_snippet = csv_record[kCodeSnippetColumn]; + udf_config.handler_name = csv_record[kHandlerNameColumn]; + + absl::StatusOr commit_time = + GetLogicalCommitTime(csv_record[kLogicalCommitTimeColumn]); + if (!commit_time.ok()) { + return commit_time.status(); + } + udf_config.logical_commit_time = *commit_time; + + auto language = GetUdfLanguage(csv_record); + if (!language.ok()) { + return language.status(); + } + udf_config.language = *language; + + DataRecordStruct data_record; + data_record.record = udf_config; + return data_record; +} + +} // namespace + +namespace internal { +absl::StatusOr MakeDeltaFileRecordStruct( + const riegeli::CsvRecord& csv_record, const DataRecordType& record_type, + char value_separator) { + switch (record_type) { + case DataRecordType::kKeyValueMutationRecord: + return MakeDeltaFileRecordStructWithKVMutation(csv_record, + value_separator); + case DataRecordType::kUserDefinedFunctionsConfig: + return MakeDeltaFileRecordStructWithUdfConfig(csv_record); + default: + return absl::InvalidArgumentError("Invalid record type."); + } } } // namespace internal diff --git a/public/data_loading/csv/csv_delta_record_stream_reader.h b/public/data_loading/csv/csv_delta_record_stream_reader.h index b34a032a..7af9ca06 100644 --- a/public/data_loading/csv/csv_delta_record_stream_reader.h +++ b/public/data_loading/csv/csv_delta_record_stream_reader.h @@ -28,33 +28,48 @@ namespace kv_server { -// A `CsvDeltaRecordStreamReader` reads CSV records as `DeltaFileRecordStruct` -// records from a `std::iostream` or `std::istream` with CSV formatted data. +// A `CsvDeltaRecordStreamReader` reads CSV records as +// `DataRecordStruct` records from a `std::iostream` or +// `std::istream` with CSV formatted data. // // A `CsvDeltaRecordStreamReader` can be used to read records as follows: // ``` // std::ifstream csv_file(my_filename); // CsvDeltaRecordStreamReader record_reader(csv_file); // absl::Status status = record_reader.ReadRecords( -// [](const DeltaFileRecordStruct& record) { +// [](const DataRecordStruct& record) { // UseRecord(record); // return absl::OkStatus(); // } // ); // ``` -// The default delimiter is assumed to be a ',' and records are assumed to have -// the following fields: -// `header = ["mutation_type", "logical_commit_time", "key", "value"]` -// These defaults can be overriden by specifying `Options` when initializing the -// record reader. +// +// The record reader has the following default options, which can be overriden +// by specifying `Options` when initializing the record reader. +// +// - `record_type`: +// If DataRecordType::kKeyValueMutationRecord, records are assumed to be key +// value mutation records with the following header: ["mutation_type", +// "logical_commit_time", "key", "value", "value_type"]`. If +// DataRecordType::kUserDefinedFunctionsConfig, records are assumed to be +// user-defined function configs with the following header: +// `["code_snippet", "handler_name", "language", "logical_commit_time"]`. +// Default `DataRecordType::kKeyValueMutationRecord`. +// +// - `field_separator`: CSV delimiter +// Default ','. +// +// - `value_separator`: For set values, the delimiter for values in a set. +// Default `|`. + template class CsvDeltaRecordStreamReader : public DeltaRecordReader { public: struct Options { char field_separator = ','; - std::vector header = {kKeyColumn, - kLogicalCommitTimeColumn, - kMutationTypeColumn, kValueColumn}; + // Used as a separator for set value elements. + char value_separator = '|'; + DataRecordType record_type = DataRecordType::kKeyValueMutationRecord; }; CsvDeltaRecordStreamReader(SrcStreamT& src_stream, @@ -64,9 +79,8 @@ class CsvDeltaRecordStreamReader : public DeltaRecordReader { CsvDeltaRecordStreamReader& operator=(const CsvDeltaRecordStreamReader&) = delete; - absl::Status ReadRecords( - const std::function& record_callback) - override; + absl::Status ReadRecords(const std::function& + record_callback) override; bool IsOpen() const override { return record_reader_.is_open(); }; absl::Status Status() const override { return record_reader_.status(); } @@ -76,15 +90,30 @@ class CsvDeltaRecordStreamReader : public DeltaRecordReader { }; namespace internal { -absl::StatusOr MakeDeltaFileRecordStruct( - const riegeli::CsvRecord& csv_record); +absl::StatusOr MakeDeltaFileRecordStruct( + const riegeli::CsvRecord& csv_record, const DataRecordType& record_type, + char value_separator); template riegeli::CsvReaderBase::Options GetRecordReaderOptions( const typename CsvDeltaRecordStreamReader::Options& options) { riegeli::CsvReaderBase::Options reader_options; reader_options.set_field_separator(options.field_separator); - reader_options.set_required_header(options.header); + + std::vector header; + switch (options.record_type) { + case DataRecordType::kKeyValueMutationRecord: + header = + std::vector(kKeyValueMutationRecordHeader.begin(), + kKeyValueMutationRecordHeader.end()); + break; + case DataRecordType::kUserDefinedFunctionsConfig: + header = std::vector( + kUserDefinedFunctionsConfigHeader.begin(), + kUserDefinedFunctionsConfigHeader.end()); + break; + } + reader_options.set_required_header(std::move(header)); return reader_options; } } // namespace internal @@ -99,12 +128,13 @@ CsvDeltaRecordStreamReader::CsvDeltaRecordStreamReader( template absl::Status CsvDeltaRecordStreamReader::ReadRecords( - const std::function& record_callback) { + const std::function& record_callback) { riegeli::CsvRecord csv_record; absl::Status overall_status; while (record_reader_.ReadRecord(csv_record)) { - absl::StatusOr delta_record = - internal::MakeDeltaFileRecordStruct(csv_record); + absl::StatusOr delta_record = + internal::MakeDeltaFileRecordStruct(csv_record, options_.record_type, + options_.value_separator); if (!delta_record.ok()) { overall_status.Update(delta_record.status()); continue; diff --git a/public/data_loading/csv/csv_delta_record_stream_reader_test.cc b/public/data_loading/csv/csv_delta_record_stream_reader_test.cc index 8d35b357..2e11f7cb 100644 --- a/public/data_loading/csv/csv_delta_record_stream_reader_test.cc +++ b/public/data_loading/csv/csv_delta_record_stream_reader_test.cc @@ -18,6 +18,7 @@ #include +#include "gmock/gmock.h" #include "gtest/gtest.h" #include "public/data_loading/csv/csv_delta_record_stream_writer.h" #include "public/data_loading/records_utils.h" @@ -25,38 +26,59 @@ namespace kv_server { namespace { -DeltaFileRecordStruct GetDeltaRecord() { - DeltaFileRecordStruct record; +using testing::UnorderedElementsAre; + +KeyValueMutationRecordStruct GetKVMutationRecord( + KeyValueMutationRecordValueT value = "value") { + KeyValueMutationRecordStruct record; record.key = "key"; - record.value = "value"; + record.value = value; record.logical_commit_time = 1234567890; - record.mutation_type = DeltaMutationType::Update; + record.mutation_type = KeyValueMutationType::Update; return record; } -TEST(CsvDeltaRecordStreamReaderTest, ValidateReadingAndWritingRecords) { +UserDefinedFunctionsConfigStruct GetUserDefinedFunctionsConfig() { + UserDefinedFunctionsConfigStruct udf_config_record; + udf_config_record.language = UserDefinedFunctionsLanguage::Javascript; + udf_config_record.code_snippet = "function hello(){}"; + udf_config_record.handler_name = "hello"; + udf_config_record.logical_commit_time = 1234567890; + return udf_config_record; +} + +DataRecordStruct GetDataRecord(const RecordT& record) { + DataRecordStruct data_record; + data_record.record = record; + return data_record; +} + +TEST(CsvDeltaRecordStreamReaderTest, + ValidateReadingAndWriting_KVMutation_StringValues_Success) { std::stringstream string_stream; CsvDeltaRecordStreamWriter record_writer(string_stream); - EXPECT_TRUE(record_writer.WriteRecord(GetDeltaRecord()).ok()); + DataRecordStruct expected = GetDataRecord(GetKVMutationRecord()); + EXPECT_TRUE(record_writer.WriteRecord(expected).ok()); EXPECT_TRUE(record_writer.Flush().ok()); CsvDeltaRecordStreamReader record_reader(string_stream); EXPECT_TRUE(record_reader - .ReadRecords([](DeltaFileRecordStruct record) { - EXPECT_EQ(record, GetDeltaRecord()); + .ReadRecords([&expected](DataRecordStruct record) { + EXPECT_EQ(record, expected); return absl::OkStatus(); }) .ok()); } TEST(CsvDeltaRecordStreamReaderTest, - ValidateReadingCsvRecordsWithInvalidTimestamps) { - const char invalid_data[] = R"csv(key,value,mutation_type,logical_commit_time - key,value,Update,invalid_time)csv"; + ValidateReadingCsvRecords_KVMutation_InvalidTimestamps_Failure) { + const char invalid_data[] = + R"csv(key,value,value_type,mutation_type,logical_commit_time + key,value,string,Update,invalid_time)csv"; std::stringstream csv_stream; csv_stream.str(invalid_data); CsvDeltaRecordStreamReader record_reader(csv_stream); absl::Status status = record_reader.ReadRecords( - [](const DeltaFileRecordStruct&) { return absl::OkStatus(); }); + [](const DataRecordStruct&) { return absl::OkStatus(); }); EXPECT_FALSE(status.ok()) << status; EXPECT_STREQ(std::string(status.message()).c_str(), "Cannot convert timestamp:invalid_time to a number.") @@ -65,14 +87,15 @@ TEST(CsvDeltaRecordStreamReaderTest, } TEST(CsvDeltaRecordStreamReaderTest, - ValidateReadingCsvRecordsWithInvalidMutation) { - const char invalid_data[] = R"csv(key,value,mutation_type,logical_commit_time - key,value,invalid_mutation,1000000)csv"; + ValidateReadingCsvRecords_KVMutation_InvalidMutation_Failure) { + const char invalid_data[] = + R"csv(key,value,value_type,mutation_type,logical_commit_time + key,value,string,invalid_mutation,1000000)csv"; std::stringstream csv_stream; csv_stream.str(invalid_data); CsvDeltaRecordStreamReader record_reader(csv_stream); absl::Status status = record_reader.ReadRecords( - [](DeltaFileRecordStruct) { return absl::OkStatus(); }); + [](DataRecordStruct) { return absl::OkStatus(); }); EXPECT_FALSE(status.ok()) << status; EXPECT_STREQ(std::string(status.message()).c_str(), "Unknown mutation type:invalid_mutation") @@ -80,5 +103,127 @@ TEST(CsvDeltaRecordStreamReaderTest, EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument) << status; } +TEST(CsvDeltaRecordStreamReaderTest, + ValidateReadingAndWriting_KVMutation_SetValues_Success) { + const std::vector values{ + "elem1", + "elem2", + "elem3", + }; + std::stringstream string_stream; + CsvDeltaRecordStreamWriter record_writer(string_stream); + + DataRecordStruct expected = GetDataRecord(GetKVMutationRecord(values)); + auto status = record_writer.WriteRecord(expected); + EXPECT_TRUE(status.ok()) << status; + status = record_writer.Flush(); + EXPECT_TRUE(status.ok()) << status; + CsvDeltaRecordStreamReader record_reader(string_stream); + status = record_reader.ReadRecords([&expected](DataRecordStruct record) { + EXPECT_EQ(record, expected); + return absl::OkStatus(); + }); + EXPECT_TRUE(status.ok()) << status; +} + +TEST(CsvDeltaRecordStreamReaderTest, + ReadingCsvRecords_KvMutation_UdfConfigHeader_Failure) { + std::stringstream string_stream; + CsvDeltaRecordStreamWriter record_writer(string_stream); + DataRecordStruct expected = GetDataRecord(GetKVMutationRecord()); + EXPECT_TRUE(record_writer.WriteRecord(expected).ok()); + EXPECT_TRUE(record_writer.Flush().ok()); + CsvDeltaRecordStreamReader record_reader( + string_stream, + CsvDeltaRecordStreamReader::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + const auto status = record_reader.ReadRecords( + [](DataRecordStruct) { return absl::OkStatus(); }); + EXPECT_FALSE(status.ok()) << status; + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument) << status; +} + +TEST(CsvDeltaRecordStreamReaderTest, + ValidateReadingAndWriting_UdfConfig_Success) { + std::stringstream string_stream; + CsvDeltaRecordStreamWriter record_writer( + string_stream, + CsvDeltaRecordStreamWriter::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + DataRecordStruct expected = GetDataRecord(GetUserDefinedFunctionsConfig()); + EXPECT_TRUE(record_writer.WriteRecord(expected).ok()); + EXPECT_TRUE(record_writer.Flush().ok()); + CsvDeltaRecordStreamReader record_reader( + string_stream, + CsvDeltaRecordStreamReader::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + EXPECT_TRUE(record_reader + .ReadRecords([&expected](DataRecordStruct record) { + EXPECT_EQ(record, expected); + return absl::OkStatus(); + }) + .ok()); +} + +TEST(CsvDeltaRecordStreamReaderTest, + ReadingAndWriting_UdfConfig_KvMutationHeader_Failure) { + std::stringstream string_stream; + CsvDeltaRecordStreamWriter record_writer( + string_stream, + CsvDeltaRecordStreamWriter::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + DataRecordStruct expected = GetDataRecord(GetUserDefinedFunctionsConfig()); + EXPECT_TRUE(record_writer.WriteRecord(expected).ok()); + EXPECT_TRUE(record_writer.Flush().ok()); + CsvDeltaRecordStreamReader record_reader( + string_stream, + CsvDeltaRecordStreamReader::Options{ + .record_type = DataRecordType::kKeyValueMutationRecord}); + const auto status = record_reader.ReadRecords( + [](DataRecordStruct) { return absl::OkStatus(); }); + EXPECT_FALSE(status.ok()) << status; + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument) << status; +} + +TEST(CsvDeltaRecordStreamReaderTest, + ValidateReadingCsvRecords_UdfConfig_InvalidTimestamps_Failure) { + const char invalid_data[] = + R"csv(code_snippet,handler_name,logical_commit_time,language + function hello(){},hello,invalid_time,javascript)csv"; + std::stringstream csv_stream; + csv_stream.str(invalid_data); + CsvDeltaRecordStreamReader record_reader( + csv_stream, + CsvDeltaRecordStreamReader::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + absl::Status status = record_reader.ReadRecords( + [](const DataRecordStruct&) { return absl::OkStatus(); }); + EXPECT_FALSE(status.ok()) << status; + EXPECT_STREQ(std::string(status.message()).c_str(), + "Cannot convert timestamp:invalid_time to a number.") + << status; + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument) << status; +} + +TEST(CsvDeltaRecordStreamReaderTest, + ValidateReadingCsvRecords_UdfConfig_InvalidLanguage_Failure) { + const char invalid_data[] = + R"csv(code_snippet,handler_name,logical_commit_time,language + function hello(){},hello,1000000,invalid_language)csv"; + std::stringstream csv_stream; + csv_stream.str(invalid_data); + CsvDeltaRecordStreamReader record_reader( + csv_stream, + CsvDeltaRecordStreamReader::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + absl::Status status = record_reader.ReadRecords( + [](const DataRecordStruct&) { return absl::OkStatus(); }); + EXPECT_FALSE(status.ok()) << status; + EXPECT_STREQ(std::string(status.message()).c_str(), + "Language: invalid_language is not supported.") + << status; + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument) << status; +} + } // namespace } // namespace kv_server diff --git a/public/data_loading/csv/csv_delta_record_stream_writer.cc b/public/data_loading/csv/csv_delta_record_stream_writer.cc index 9aa52b33..b4f4b073 100644 --- a/public/data_loading/csv/csv_delta_record_stream_writer.cc +++ b/public/data_loading/csv/csv_delta_record_stream_writer.cc @@ -16,35 +16,80 @@ #include "public/data_loading/csv/csv_delta_record_stream_writer.h" -#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "glog/logging.h" #include "public/data_loading/data_loading_generated.h" namespace kv_server { namespace { + +struct ValueStruct { + std::string value_type; + std::string value; +}; + +struct RecordStruct { + std::string record_type; + KeyValueMutationRecordStruct kv_mutation_record; + UserDefinedFunctionsConfigStruct udf_config; +}; + absl::StatusOr GetMutationType( - const DeltaFileRecordStruct& record) { + const KeyValueMutationRecordStruct& record) { switch (record.mutation_type) { - case DeltaMutationType::Update: + case KeyValueMutationType::Update: return kUpdateMutationType; - case DeltaMutationType::Delete: + case KeyValueMutationType::Delete: return kDeleteMutationType; default: return absl::InvalidArgumentError( absl::StrCat("Invalid mutation type: ", - EnumNameDeltaMutationType(record.mutation_type))); + EnumNameKeyValueMutationType(record.mutation_type))); } } -} // namespace -namespace internal { -absl::StatusOr MakeCsvRecord( - const DeltaFileRecordStruct& record, - const std::vector& header) { - // TODO: Consider using ctor with fields for performance gain if order is - // known ahead of time. - riegeli::CsvRecord csv_record(header); +absl::StatusOr GetRecordValue( + const KeyValueMutationRecordValueT& value, + std::string_view value_separator) { + return std::visit( + [value_separator](auto&& arg) -> absl::StatusOr { + using VariantT = std::decay_t; + if constexpr (std::is_same_v) { + return ValueStruct{ + .value_type = std::string(kValueTypeString), + .value = std::string(arg), + }; + } + if constexpr (std::is_same_v>) { + return ValueStruct{ + .value_type = std::string(kValueTypeStringSet), + .value = absl::StrJoin(arg, value_separator), + }; + } + return absl::InvalidArgumentError("Value must be set."); + }, + value); +} + +absl::StatusOr MakeCsvRecordWithKVMutation( + const DataRecordStruct& data_record, char value_separator) { + if (!std::holds_alternative( + data_record.record)) { + return absl::InvalidArgumentError( + "DataRecord must contain a KeyValueMutationRecord."); + } + const auto record = + std::get(data_record.record); + + riegeli::CsvRecord csv_record(kKeyValueMutationRecordHeader); csv_record[kKeyColumn] = record.key; - csv_record[kValueColumn] = record.value; + absl::StatusOr value = + GetRecordValue(record.value, std::string(1, value_separator)); + if (!value.ok()) { + return value.status(); + } + csv_record[kValueColumn] = value->value; + csv_record[kValueTypeColumn] = value->value_type; absl::StatusOr mutation_type = GetMutationType(record); if (!mutation_type.ok()) { return mutation_type.status(); @@ -54,6 +99,59 @@ absl::StatusOr MakeCsvRecord( absl::StrCat(record.logical_commit_time); return csv_record; } + +absl::StatusOr GetUdfLanguage( + const UserDefinedFunctionsConfigStruct& udf_config) { + switch (udf_config.language) { + case UserDefinedFunctionsLanguage::Javascript: + return kLanguageJavascript; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Invalid UDF language: ", + EnumNameUserDefinedFunctionsLanguage(udf_config.language))); + } +} + +absl::StatusOr MakeCsvRecordWithUdfConfig( + const DataRecordStruct& data_record) { + if (!std::holds_alternative( + data_record.record)) { + return absl::InvalidArgumentError( + "DataRecord must contain a UserDefinedFunctionsConfig."); + } + const auto udf_config = + std::get(data_record.record); + + riegeli::CsvRecord csv_record(kUserDefinedFunctionsConfigHeader); + csv_record[kCodeSnippetColumn] = udf_config.code_snippet; + csv_record[kHandlerNameColumn] = udf_config.handler_name; + csv_record[kLogicalCommitTimeColumn] = + absl::StrCat(udf_config.logical_commit_time); + auto udf_language = GetUdfLanguage(udf_config); + if (!udf_language.ok()) { + return udf_language.status(); + } + csv_record[kLanguageColumn] = *udf_language; + return csv_record; +} + +} // namespace + +namespace internal { +absl::StatusOr MakeCsvRecord( + const DataRecordStruct& data_record, const DataRecordType& record_type, + char value_separator) { + // TODO: Consider using ctor with fields for performance gain if order is + // known ahead of time. + switch (record_type) { + case DataRecordType::kKeyValueMutationRecord: + return MakeCsvRecordWithKVMutation(data_record, value_separator); + case DataRecordType::kUserDefinedFunctionsConfig: + return MakeCsvRecordWithUdfConfig(data_record); + default: + return absl::InvalidArgumentError("Invalid record type."); + } +} } // namespace internal } // namespace kv_server diff --git a/public/data_loading/csv/csv_delta_record_stream_writer.h b/public/data_loading/csv/csv_delta_record_stream_writer.h index 7b5d7b3d..2fd31f56 100644 --- a/public/data_loading/csv/csv_delta_record_stream_writer.h +++ b/public/data_loading/csv/csv_delta_record_stream_writer.h @@ -30,27 +30,46 @@ namespace kv_server { -// A `CsvDeltaRecordStreamWriter` writes `DeltaFileRecordStruct` records as CSV -// records to a `std::iostream` or `std::ostream.` or other subclasses of these -// two streams. +// A `CsvDeltaRecordStreamWriter` writes `DataRecordStruct` records +// as CSV records to a `std::iostream` or `std::ostream.` or other subclasses of +// these two streams. // // A `CsvDeltaRecordStreamWriter` can be used to write CSV records as follows: // ``` // std::stringstream ostream; // CsvDeltaRecordStreamWriter record_writer(ostream); -// DeltaFileRecordStruct record = ...; +// DataRecordStruct record = ...; // if (absl::Status status = record_writer.WriteRecord(); !status.ok()) { // LOG(ERROR) << "Failed to write record: " << status; // } // ``` +// +// The record writer has the following default options, which can be overriden +// by specifying `Options` when initializing the record writer. +// +// - `record_type`: +// If DataRecordType::kKeyValueMutationRecord, records are assumed to be key +// value mutation records with the following header: ["mutation_type", +// "logical_commit_time", "key", "value", "value_type"]`. If +// DataRecordType::kUserDefinedFunctionsConfig, records are assumed to be +// user-defined function configs with the following header: +// `["code_snippet", "handler_name", "language", "logical_commit_time"]`. +// Default `DataRecordType::kKeyValueMutationRecord`. +// +// - `field_separator`: CSV delimiter +// Default ','. +// +// - `value_separator`: For set values, the delimiter for values in a set. +// Default `|`. + template class CsvDeltaRecordStreamWriter : public DeltaRecordWriter { public: struct Options : public DeltaRecordWriter::Options { char field_separator = ','; - std::vector header = {kKeyColumn, - kLogicalCommitTimeColumn, - kMutationTypeColumn, kValueColumn}; + // Used as a separator for set value elements. + char value_separator = '|'; + DataRecordType record_type = DataRecordType::kKeyValueMutationRecord; }; CsvDeltaRecordStreamWriter(DestStreamT& dest_stream, @@ -60,7 +79,7 @@ class CsvDeltaRecordStreamWriter : public DeltaRecordWriter { CsvDeltaRecordStreamWriter& operator=(const CsvDeltaRecordStreamWriter&) = delete; - absl::Status WriteRecord(const DeltaFileRecordStruct& record) override; + absl::Status WriteRecord(const DataRecordStruct& record) override; absl::Status Flush() override; const Options& GetOptions() const override { return options_; } void Close() override { record_writer_.Close(); } @@ -74,15 +93,28 @@ class CsvDeltaRecordStreamWriter : public DeltaRecordWriter { namespace internal { absl::StatusOr MakeCsvRecord( - const DeltaFileRecordStruct& record, - const std::vector& header); + const DataRecordStruct& data_record, const DataRecordType& record_type, + char value_separator); template riegeli::CsvWriterBase::Options GetRecordWriterOptions( const typename CsvDeltaRecordStreamWriter::Options& options) { riegeli::CsvWriterBase::Options writer_options; writer_options.set_field_separator(options.field_separator); - writer_options.set_header(options.header); + std::vector header; + switch (options.record_type) { + case DataRecordType::kKeyValueMutationRecord: + header = + std::vector(kKeyValueMutationRecordHeader.begin(), + kKeyValueMutationRecordHeader.end()); + break; + case DataRecordType::kUserDefinedFunctionsConfig: + header = std::vector( + kUserDefinedFunctionsConfigHeader.begin(), + kUserDefinedFunctionsConfigHeader.end()); + break; + } + writer_options.set_header(std::move(header)); return writer_options; } } // namespace internal @@ -97,14 +129,14 @@ CsvDeltaRecordStreamWriter::CsvDeltaRecordStreamWriter( template absl::Status CsvDeltaRecordStreamWriter::WriteRecord( - const DeltaFileRecordStruct& record) { - absl::StatusOr csv_record = - internal::MakeCsvRecord(record, options_.header); + const DataRecordStruct& data_record) { + absl::StatusOr csv_record = internal::MakeCsvRecord( + data_record, options_.record_type, options_.value_separator); if (!csv_record.ok()) { return csv_record.status(); } if (!record_writer_.WriteRecord(*csv_record) && options_.recovery_function) { - options_.recovery_function(record); + options_.recovery_function(data_record); } return record_writer_.status(); } diff --git a/public/data_loading/csv/csv_delta_record_stream_writer_test.cc b/public/data_loading/csv/csv_delta_record_stream_writer_test.cc index 79549c7c..5f649ede 100644 --- a/public/data_loading/csv/csv_delta_record_stream_writer_test.cc +++ b/public/data_loading/csv/csv_delta_record_stream_writer_test.cc @@ -25,67 +25,177 @@ namespace kv_server { namespace { -DeltaFileRecordStruct GetDeltaRecord() { - DeltaFileRecordStruct record; +KeyValueMutationRecordStruct GetKVMutationRecord( + KeyValueMutationRecordValueT value = "value") { + KeyValueMutationRecordStruct record; record.key = "key"; - record.value = "value"; + record.value = value; record.logical_commit_time = 1234567890; - record.mutation_type = DeltaMutationType::Update; + record.mutation_type = KeyValueMutationType::Update; return record; } -TEST(CsvDeltaRecordStreamWriterTest, ValidateWritingCsvRecordFromDelta) { +UserDefinedFunctionsConfigStruct GetUserDefinedFunctionsConfig() { + UserDefinedFunctionsConfigStruct udf_config_record; + udf_config_record.language = UserDefinedFunctionsLanguage::Javascript; + udf_config_record.code_snippet = "function hello(){}"; + udf_config_record.handler_name = "hello"; + udf_config_record.logical_commit_time = 1234567890; + return udf_config_record; +} + +DataRecordStruct GetDataRecord(const RecordT& record) { + DataRecordStruct data_record; + data_record.record = record; + return data_record; +} + +TEST(CsvDeltaRecordStreamWriterTest, + ValidateWritingCsvRecord_KVMutation_StringValue_Success) { + std::stringstream string_stream; + CsvDeltaRecordStreamWriter record_writer(string_stream); + + DataRecordStruct expected = GetDataRecord(GetKVMutationRecord()); + EXPECT_TRUE(record_writer.WriteRecord(expected).ok()); + EXPECT_TRUE(record_writer.Flush().ok()); + CsvDeltaRecordStreamReader record_reader(string_stream); + EXPECT_TRUE(record_reader + .ReadRecords([&expected](DataRecordStruct record) { + EXPECT_EQ(record, expected); + return absl::OkStatus(); + }) + .ok()); +} + +TEST(CsvDeltaRecordStreamWriterTest, + ValidateWritingCsvRecord_KVMutation_SetValue_Success) { + const std::vector values{ + "elem1", + "elem2", + "elem3", + }; std::stringstream string_stream; CsvDeltaRecordStreamWriter record_writer(string_stream); - EXPECT_TRUE(record_writer.WriteRecord(GetDeltaRecord()).ok()); + + DataRecordStruct expected = GetDataRecord(GetKVMutationRecord(values)); + EXPECT_TRUE(record_writer.WriteRecord(expected).ok()); EXPECT_TRUE(record_writer.Flush().ok()); CsvDeltaRecordStreamReader record_reader(string_stream); EXPECT_TRUE(record_reader - .ReadRecords([](DeltaFileRecordStruct record) { - EXPECT_EQ(record, GetDeltaRecord()); + .ReadRecords([&expected](DataRecordStruct record) { + EXPECT_EQ(record, expected); return absl::OkStatus(); }) .ok()); } +TEST(CsvDeltaRecordStreamWriterTest, + WritingCsvRecord_KvMutation_UdfConfigHeader_Fails) { + std::stringstream string_stream; + + CsvDeltaRecordStreamWriter record_writer( + string_stream, + CsvDeltaRecordStreamWriter::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + + DataRecordStruct expected = GetDataRecord(GetKVMutationRecord()); + EXPECT_FALSE(record_writer.WriteRecord(expected).ok()); + record_writer.Close(); +} + +TEST(CsvDeltaRecordStreamWriterTest, + ValidateWritingCsvRecord_UdfConfig_Success) { + std::stringstream string_stream; + + CsvDeltaRecordStreamWriter record_writer( + string_stream, + CsvDeltaRecordStreamWriter::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + + DataRecordStruct expected = GetDataRecord(GetUserDefinedFunctionsConfig()); + EXPECT_TRUE(record_writer.WriteRecord(expected).ok()); + EXPECT_TRUE(record_writer.Flush().ok()); + CsvDeltaRecordStreamReader record_reader( + string_stream, + CsvDeltaRecordStreamReader::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + EXPECT_TRUE(record_reader + .ReadRecords([&expected](DataRecordStruct record) { + EXPECT_EQ(record, expected); + return absl::OkStatus(); + }) + .ok()); +} + +TEST(CsvDeltaRecordStreamWriterTest, + WritingCsvRecord_UdfConfig_UdfLanguageUnknown_Fails) { + std::stringstream string_stream; + + CsvDeltaRecordStreamWriter record_writer( + string_stream, + CsvDeltaRecordStreamWriter::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + + UserDefinedFunctionsConfigStruct udf_config; + DataRecordStruct data_record = GetDataRecord(udf_config); + const auto status = record_writer.WriteRecord(data_record); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.message(), "Invalid UDF language: "); + record_writer.Close(); +} + +TEST(CsvDeltaRecordStreamWriterTest, + WritingCsvRecord_UdfConfig_KvMutationHeader_Fails) { + std::stringstream string_stream; + + CsvDeltaRecordStreamWriter record_writer( + string_stream, + CsvDeltaRecordStreamWriter::Options{ + .record_type = DataRecordType::kKeyValueMutationRecord}); + + DataRecordStruct expected = GetDataRecord(GetUserDefinedFunctionsConfig()); + EXPECT_FALSE(record_writer.WriteRecord(expected).ok()); + record_writer.Close(); +} + TEST(CsvDeltaRecordStreamWriterTest, ValidateThatWritingUsingClosedWriterFails) { std::stringstream string_stream; CsvDeltaRecordStreamWriter record_writer(string_stream); record_writer.Close(); - EXPECT_FALSE(record_writer.WriteRecord(GetDeltaRecord()).ok()); + EXPECT_FALSE( + record_writer.WriteRecord(GetDataRecord(GetKVMutationRecord())).ok()); } TEST(CsvDeltaRecordStreamWriterTest, ValidateThatFailedRecordsAreRecoverable) { - DeltaFileRecordStruct recovered_record; + DataRecordStruct recovered_record; CsvDeltaRecordStreamWriter::Options options; options.recovery_function = - [&recovered_record](DeltaFileRecordStruct failed_record) { + [&recovered_record](DataRecordStruct failed_record) { recovered_record = failed_record; }; std::stringstream string_stream; CsvDeltaRecordStreamWriter record_writer(string_stream, options); record_writer.Close(); - EXPECT_NE(recovered_record, GetDeltaRecord()); - EXPECT_FALSE(record_writer.WriteRecord(GetDeltaRecord()).ok()); - EXPECT_EQ(recovered_record, GetDeltaRecord()); + + DataRecordStruct empty_record; + EXPECT_EQ(recovered_record, empty_record); + + // Writer is closed, so writing fails. + DataRecordStruct expected = GetDataRecord(GetKVMutationRecord()); + EXPECT_FALSE(record_writer.WriteRecord(expected).ok()); + + EXPECT_EQ(recovered_record, expected); } -TEST(CsvDeltaRecordStreamWriterTest, ValidateWritingDefaultRecord) { +TEST(CsvDeltaRecordStreamWriterTest, ValidateWritingDefaultRecordFails) { std::stringstream string_stream; CsvDeltaRecordStreamWriter record_writer(string_stream); - EXPECT_TRUE(record_writer.WriteRecord(DeltaFileRecordStruct{}).ok()); - EXPECT_TRUE(record_writer.Flush().ok()); - CsvDeltaRecordStreamReader record_reader(string_stream); - EXPECT_TRUE(record_reader - .ReadRecords([](DeltaFileRecordStruct record) { - EXPECT_EQ(record, DeltaFileRecordStruct{}); - return absl::OkStatus(); - }) - .ok()); + auto status = record_writer.WriteRecord(DataRecordStruct{}); + EXPECT_FALSE(status.ok()) << status; } -TEST(CsvDeltaRecordStreamWriterTest, ValidateClosingRecordWriter) { +TEST(CsvDeltaRecordStreamWriterTest, ValidateClosingRecordWriterSucceeds) { std::stringstream string_stream; CsvDeltaRecordStreamWriter record_writer(string_stream); record_writer.Close(); diff --git a/public/data_loading/data_loading.fbs b/public/data_loading/data_loading.fbs index 8487637b..e40e22c4 100644 --- a/public/data_loading/data_loading.fbs +++ b/public/data_loading/data_loading.fbs @@ -2,22 +2,59 @@ namespace kv_server; -enum DeltaMutationType:byte { Update = 0, Delete = 1 } +enum KeyValueMutationType:byte { Update = 0, Delete = 1 } -table DeltaFileRecord { +table String { value:string; } +// For set values: +// (1) `Update` mutation creates the set if one doesn't exist, +// otherwise inserts the elements into the existing set. +// (2) `Delete` mutation removes the elements from existing set. +table StringSet { value:[string]; } +union Value { String, StringSet } + +table KeyValueMutationRecord { // Required. For updates, the value will overwrite the previous value, if any. - mutation_type: DeltaMutationType; + mutation_type: KeyValueMutationType; // Required. Used to represent the commit time of the record. In cases where 2 // records of the same key are compared, the one with a larger logical time - // is considered newer. There is no constraints as of 2022 Q3 on what format - // the time must be other than that a larger number represents a newer - // timestamp. + // is considered newer. There is no constraints on what format the time must + // be other than that a larger number represents a newer timestamp. For sets, + // all elements will have the same timestamp. logical_commit_time:int64; // Required. key:string; // Required. - value:string; + value:Value; } + + +enum UserDefinedFunctionsLanguage:byte { Javascript = 0 } + +table UserDefinedFunctionsConfig { + // Required. Language of the user-defined function. + language:UserDefinedFunctionsLanguage; + + // Required. Code snippet containing the user-defined function. + code_snippet:string; + + // Required. Handler name is the entry point for user-defined function + // execution. + handler_name:string; + + // Required. Used to represent the commit time of the record. In cases where 2 + // records of the same key are compared, the one with a larger logical time + // is considered newer. There is no constraints on what format the time must + // be other than that a larger number represents a newer timestamp. + logical_commit_time:int64; +} + +union Record { KeyValueMutationRecord, UserDefinedFunctionsConfig } + +table DataRecord { + record:Record; +} + +root_type DataRecord; diff --git a/public/data_loading/readers/delta_record_reader.h b/public/data_loading/readers/delta_record_reader.h index 78a04649..337d51c8 100644 --- a/public/data_loading/readers/delta_record_reader.h +++ b/public/data_loading/readers/delta_record_reader.h @@ -32,7 +32,7 @@ namespace kv_server { // ``` // DeltaRecordReader record_reader = ... // absl::Status status = record_reader.ReadRecords( -// [](const DeltaFileRecordStruct& record) { +// [](const DataRecordStruct& record) { // UseRecord(record); // return absl::OkStatus(); // } @@ -44,11 +44,10 @@ namespace kv_server { class DeltaRecordReader { public: virtual ~DeltaRecordReader() = default; - // Reads `DeltaFileRecordStruct` records from the underlying record source and - // passes them to `record_callback` function. + // Reads `DataRecordStruct` records from the underlying record + // source and passes them to `record_callback` function. virtual absl::Status ReadRecords( - const std::function& - record_callback) = 0; + const std::function& record_callback) = 0; // Returns true if the reader is open for reading records. virtual bool IsOpen() const = 0; // Returns status of the `DeltaRecordReader`. diff --git a/public/data_loading/readers/delta_record_stream_reader.h b/public/data_loading/readers/delta_record_stream_reader.h index 7b592e38..3a639b93 100644 --- a/public/data_loading/readers/delta_record_stream_reader.h +++ b/public/data_loading/readers/delta_record_stream_reader.h @@ -24,15 +24,15 @@ namespace kv_server { -// A `DeltaRecordStreamReader` reads records as `DeltaFileRecordStruct`s from a -// delta record input stream source. +// A `DeltaRecordStreamReader` reads records as `DataRecordStruct`s +// from a delta record input stream source. // // A `DeltaRecordStreamReader` can be used to read records as follows: // ``` // std::ifstream delta_file(my_filename); // DeltaRecordStreamReader record_reader(delta_file); // absl::Status status = record_reader.ReadRecords( -// [](const DeltaFileRecordStruct& record) { +// [](const DataRecordStruct& record) { // UseRecord(record); // return absl::OkStatus(); // } @@ -52,9 +52,8 @@ class DeltaRecordStreamReader : public DeltaRecordReader { DeltaRecordStreamReader(const DeltaRecordStreamReader&) = delete; DeltaRecordStreamReader& operator=(const DeltaRecordStreamReader&) = delete; - absl::Status ReadRecords( - const std::function& record_callback) - override; + absl::Status ReadRecords(const std::function& + record_callback) override; bool IsOpen() const override { return stream_reader_.IsOpen(); }; absl::Status Status() const override { return stream_reader_.Status(); } absl::StatusOr ReadMetadata() { @@ -67,16 +66,11 @@ class DeltaRecordStreamReader : public DeltaRecordReader { template absl::Status DeltaRecordStreamReader::ReadRecords( - const std::function& record_callback) { - return stream_reader_.ReadStreamRecords([&](std::string_view record_string) { - auto fbs_record = - flatbuffers::GetRoot(record_string.data()); - return record_callback(DeltaFileRecordStruct{ - .mutation_type = fbs_record->mutation_type(), - .logical_commit_time = fbs_record->logical_commit_time(), - .key = fbs_record->key()->string_view(), - .value = fbs_record->value()->string_view()}); - }); + const std::function& record_callback) { + return stream_reader_.ReadStreamRecords( + [&record_callback](std::string_view record_string) { + return DeserializeDataRecord(record_string, record_callback); + }); } } // namespace kv_server diff --git a/public/data_loading/readers/delta_record_stream_reader_test.cc b/public/data_loading/readers/delta_record_stream_reader_test.cc index 9764ce66..35c9d07e 100644 --- a/public/data_loading/readers/delta_record_stream_reader_test.cc +++ b/public/data_loading/readers/delta_record_stream_reader_test.cc @@ -29,47 +29,109 @@ KVFileMetadata GetMetadata() { return metadata; } -DeltaFileRecordStruct GetDeltaRecord() { - DeltaFileRecordStruct record; +KeyValueMutationRecordStruct GetKVMutationRecord() { + KeyValueMutationRecordStruct record; record.key = "key"; record.value = "value"; record.logical_commit_time = 1234567890; - record.mutation_type = DeltaMutationType::Update; + record.mutation_type = KeyValueMutationType::Update; return record; } -TEST(DeltaRecordStreamReaderTest, ValidateReadingRecords) { +UserDefinedFunctionsConfigStruct GetUserDefinedFunctionsConfig() { + UserDefinedFunctionsConfigStruct udf_config_record; + udf_config_record.language = UserDefinedFunctionsLanguage::Javascript; + udf_config_record.code_snippet = "function hello(){}"; + udf_config_record.handler_name = "hello"; + udf_config_record.logical_commit_time = 1234567890; + return udf_config_record; +} + +DataRecordStruct GetDataRecord(const RecordT& record) { + DataRecordStruct data_record; + data_record.record = record; + return data_record; +} + +TEST(DeltaRecordStreamReaderTest, KVRecord_ValidateReadingRecords) { std::stringstream string_stream; auto record_writer = DeltaRecordStreamWriter<>::Create( string_stream, DeltaRecordWriter::Options{.metadata = GetMetadata()}); EXPECT_TRUE(record_writer.ok()); - EXPECT_TRUE((*record_writer)->WriteRecord(GetDeltaRecord()).ok()); + + DataRecordStruct expected = GetDataRecord(GetKVMutationRecord()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); (*record_writer)->Close(); DeltaRecordStreamReader record_reader(string_stream); EXPECT_TRUE(record_reader - .ReadRecords([](DeltaFileRecordStruct record) { - EXPECT_EQ(record, GetDeltaRecord()); + .ReadRecords([&expected](DataRecordStruct data_record) { + EXPECT_EQ(data_record, expected); return absl::OkStatus(); }) .ok()); } -TEST(DeltaRecordStreamReaderTest, ValidateReadingRecordCallsRecordCallback) { +TEST(DeltaRecordStreamReaderTest, + KVRecord_ValidateReadingRecordCallsRecordCallback) { std::stringstream string_stream; auto record_writer = DeltaRecordStreamWriter<>::Create( string_stream, DeltaRecordWriter::Options{.metadata = GetMetadata()}); EXPECT_TRUE(record_writer.ok()); - EXPECT_TRUE((*record_writer)->WriteRecord(GetDeltaRecord()).ok()); - EXPECT_TRUE((*record_writer)->WriteRecord(GetDeltaRecord()).ok()); - EXPECT_TRUE((*record_writer)->WriteRecord(GetDeltaRecord()).ok()); - EXPECT_TRUE((*record_writer)->WriteRecord(GetDeltaRecord()).ok()); + + DataRecordStruct expected = GetDataRecord(GetKVMutationRecord()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); + (*record_writer)->Close(); + DeltaRecordStreamReader record_reader(string_stream); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .Times(4) + .WillRepeatedly([&expected](DataRecordStruct record) { + EXPECT_EQ(record, expected); + return absl::OkStatus(); + }); + EXPECT_TRUE(record_reader.ReadRecords(record_callback.AsStdFunction()).ok()); +} + +TEST(DeltaRecordStreamReaderTest, UdfConfig_ValidateReadingRecords) { + std::stringstream string_stream; + auto record_writer = DeltaRecordStreamWriter<>::Create( + string_stream, DeltaRecordWriter::Options{.metadata = GetMetadata()}); + EXPECT_TRUE(record_writer.ok()); + + DataRecordStruct expected = GetDataRecord(GetUserDefinedFunctionsConfig()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); + (*record_writer)->Close(); + DeltaRecordStreamReader record_reader(string_stream); + EXPECT_TRUE(record_reader + .ReadRecords([&expected](DataRecordStruct data_record) { + EXPECT_EQ(data_record, expected); + return absl::OkStatus(); + }) + .ok()); +} + +TEST(DeltaRecordStreamReaderTest, + UdfConfig_ValidateReadingRecordCallsRecordCallback) { + std::stringstream string_stream; + auto record_writer = DeltaRecordStreamWriter<>::Create( + string_stream, DeltaRecordWriter::Options{.metadata = GetMetadata()}); + EXPECT_TRUE(record_writer.ok()); + + DataRecordStruct expected = GetDataRecord(GetUserDefinedFunctionsConfig()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()); (*record_writer)->Close(); DeltaRecordStreamReader record_reader(string_stream); - testing::MockFunction record_callback; + testing::MockFunction record_callback; EXPECT_CALL(record_callback, Call) .Times(4) - .WillRepeatedly([](DeltaFileRecordStruct record) { - EXPECT_EQ(record, GetDeltaRecord()); + .WillRepeatedly([&expected](const DataRecordStruct& data_record) { + EXPECT_EQ(data_record, expected); return absl::OkStatus(); }); EXPECT_TRUE(record_reader.ReadRecords(record_callback.AsStdFunction()).ok()); diff --git a/public/data_loading/records_utils.cc b/public/data_loading/records_utils.cc index 91ad3f11..bfdc6f8b 100644 --- a/public/data_loading/records_utils.cc +++ b/public/data_loading/records_utils.cc @@ -14,38 +14,296 @@ #include "public/data_loading/records_utils.h" +#include + +#include "absl/status/statusor.h" +#include "glog/logging.h" + namespace kv_server { namespace { // An arbitrary small number in case the flat buffer needs some space for // overheads. constexpr int kOverheadSize = 10; -} // namespace -std::string_view ToStringView(const flatbuffers::FlatBufferBuilder& fb_buffer) { - return std::string_view(reinterpret_cast(fb_buffer.GetBufferPointer()), - fb_buffer.GetSize()); +struct ValueUnion { + Value value_type; + flatbuffers::Offset value; +}; + +struct RecordUnion { + Record record_type; + flatbuffers::Offset record; +}; + +ValueUnion BuildValueUnion(const KeyValueMutationRecordValueT& value, + flatbuffers::FlatBufferBuilder& builder) { + return std::visit( + [&builder](auto&& arg) { + using VariantT = std::decay_t; + if constexpr (std::is_same_v) { + return ValueUnion{ + .value_type = Value::String, + .value = CreateStringDirect(builder, arg.data()).Union(), + }; + } + if constexpr (std::is_same_v>) { + auto values_offset = builder.CreateVectorOfStrings(arg); + return ValueUnion{ + .value_type = Value::StringSet, + .value = CreateStringSet(builder, values_offset).Union(), + }; + } + return ValueUnion{ + .value_type = Value::NONE, + .value = flatbuffers::Offset(), + }; + }, + value); } -flatbuffers::FlatBufferBuilder DeltaFileRecordStruct::ToFlatBuffer() const { - flatbuffers::FlatBufferBuilder builder(key.size() + value.size() + - sizeof(logical_commit_time) + - sizeof(mutation_type) + kOverheadSize); - const auto record = CreateDeltaFileRecordDirect( - builder, mutation_type, logical_commit_time, key.data(), value.data()); - builder.Finish(record); - return builder; +flatbuffers::Offset KeyValueMutationFromStruct( + flatbuffers::FlatBufferBuilder& builder, + const KeyValueMutationRecordStruct& record) { + auto fb_value = BuildValueUnion(record.value, builder); + return CreateKeyValueMutationRecordDirect( + builder, record.mutation_type, record.logical_commit_time, + record.key.data(), fb_value.value_type, fb_value.value); +} + +flatbuffers::Offset UdfConfigFromStruct( + flatbuffers::FlatBufferBuilder& builder, + const UserDefinedFunctionsConfigStruct& udf_config_struct) { + return CreateUserDefinedFunctionsConfigDirect( + builder, udf_config_struct.language, + udf_config_struct.code_snippet.data(), + udf_config_struct.handler_name.data(), + udf_config_struct.logical_commit_time); +} + +RecordUnion BuildRecordUnion(const RecordT& record, + flatbuffers::FlatBufferBuilder& builder) { + return std::visit( + [&builder](auto&& arg) { + using VariantT = std::decay_t; + if constexpr (std::is_same_v) { + return RecordUnion{ + .record_type = Record::KeyValueMutationRecord, + .record = KeyValueMutationFromStruct(builder, arg).Union(), + }; + } + if constexpr (std::is_same_v) { + return RecordUnion{ + .record_type = Record::UserDefinedFunctionsConfig, + .record = UdfConfigFromStruct(builder, arg).Union(), + }; + } + return RecordUnion{ + .record_type = Record::NONE, + .record = flatbuffers::Offset(), + }; + }, + record); +} + +template +absl::StatusOr DeserializeAndVerifyRecord( + std::string_view record_bytes) { + auto fbs_record = flatbuffers::GetRoot(record_bytes.data()); + auto record_verifier = flatbuffers::Verifier( + reinterpret_cast(record_bytes.data()), + record_bytes.size(), flatbuffers::Verifier::Options{}); + if (!fbs_record->Verify(record_verifier)) { + // TODO(b/239061954): Publish metrics for alerting + return absl::InvalidArgumentError("Invalid flatbuffer bytes."); + } + return fbs_record; +} + +KeyValueMutationRecordValueT GetRecordStructValue( + const KeyValueMutationRecord& fbs_record) { + KeyValueMutationRecordValueT value; + if (fbs_record.value_type() == Value::String) { + value = GetRecordValue(fbs_record); + } + if (fbs_record.value_type() == Value::StringSet) { + value = GetRecordValue>(fbs_record); + } + return value; } -bool operator==(const DeltaFileRecordStruct& lhs_record, - const DeltaFileRecordStruct& rhs_record) { +RecordT GetRecordStruct(const DataRecord& data_record) { + RecordT record; + if (data_record.record_type() == Record::KeyValueMutationRecord) { + record = GetTypedRecordStruct(data_record); + } + if (data_record.record_type() == Record::UserDefinedFunctionsConfig) { + record = + GetTypedRecordStruct(data_record); + } + return record; +} + +} // namespace + +bool operator==(const KeyValueMutationRecordStruct& lhs_record, + const KeyValueMutationRecordStruct& rhs_record) { return lhs_record.logical_commit_time == rhs_record.logical_commit_time && lhs_record.mutation_type == rhs_record.mutation_type && lhs_record.key == rhs_record.key && lhs_record.value == rhs_record.value; } -bool operator!=(const DeltaFileRecordStruct& lhs_record, - const DeltaFileRecordStruct& rhs_record) { +bool operator!=(const KeyValueMutationRecordStruct& lhs_record, + const KeyValueMutationRecordStruct& rhs_record) { + return !operator==(lhs_record, rhs_record); +} + +bool operator==(const UserDefinedFunctionsConfigStruct& lhs_record, + const UserDefinedFunctionsConfigStruct& rhs_record) { + return lhs_record.logical_commit_time == rhs_record.logical_commit_time && + lhs_record.language == rhs_record.language && + lhs_record.code_snippet == rhs_record.code_snippet && + lhs_record.handler_name == rhs_record.handler_name; +} + +bool operator!=(const UserDefinedFunctionsConfigStruct& lhs_record, + const UserDefinedFunctionsConfigStruct& rhs_record) { + return !operator==(lhs_record, rhs_record); +} + +bool operator==(const DataRecordStruct& lhs_record, + const DataRecordStruct& rhs_record) { + return lhs_record.record == rhs_record.record; +} + +bool operator!=(const DataRecordStruct& lhs_record, + const DataRecordStruct& rhs_record) { return !operator==(lhs_record, rhs_record); } + +bool IsEmptyValue(const KeyValueMutationRecordValueT& value) { + return value.index() == 0; +} + +flatbuffers::FlatBufferBuilder ToFlatBufferBuilder( + const KeyValueMutationRecordStruct& record) { + flatbuffers::FlatBufferBuilder builder; + const auto fbs_record = KeyValueMutationFromStruct(builder, record); + builder.Finish(fbs_record); + return builder; +} + +flatbuffers::FlatBufferBuilder ToFlatBufferBuilder( + const DataRecordStruct& data_record) { + flatbuffers::FlatBufferBuilder builder; + auto kv_fbs_record = BuildRecordUnion(data_record.record, builder); + const auto fbs_record = CreateDataRecord(builder, kv_fbs_record.record_type, + kv_fbs_record.record); + builder.Finish(fbs_record); + return builder; +} + +std::string_view ToStringView(const flatbuffers::FlatBufferBuilder& fb_buffer) { + return std::string_view( + reinterpret_cast(fb_buffer.GetBufferPointer()), + fb_buffer.GetSize()); +} + +absl::Status DeserializeRecord( + std::string_view record_bytes, + const std::function& + record_callback) { + auto fbs_record = + DeserializeAndVerifyRecord(record_bytes); + if (!fbs_record.ok()) { + return fbs_record.status(); + } + if (fbs_record.value()->value_type() == Value::NONE) { + return absl::InvalidArgumentError("Record value is not set."); + } + return record_callback(**fbs_record); +} + +absl::Status DeserializeRecord( + std::string_view record_bytes, + const std::function& + record_callback) { + return DeserializeRecord( + record_bytes, + [&record_callback](const KeyValueMutationRecord& fbs_record) { + KeyValueMutationRecordStruct record_struct; + record_struct.key = fbs_record.key()->string_view(); + record_struct.logical_commit_time = fbs_record.logical_commit_time(); + record_struct.mutation_type = fbs_record.mutation_type(); + record_struct.value = GetRecordStructValue(fbs_record); + return record_callback(record_struct); + }); +} + +absl::Status DeserializeDataRecord( + std::string_view record_bytes, + const std::function& record_callback) { + auto fbs_record = DeserializeAndVerifyRecord(record_bytes); + if (!fbs_record.ok()) { + return fbs_record.status(); + } + // TODO(b/269472380): Add data validation. Not + // necessarily here. + return record_callback(**fbs_record); +} + +absl::Status DeserializeDataRecord( + std::string_view record_bytes, + const std::function& + record_callback) { + return DeserializeDataRecord( + record_bytes, [&record_callback](const DataRecord& fbs_record) { + DataRecordStruct data_struct; + data_struct.record = GetRecordStruct(fbs_record); + return record_callback(data_struct); + }); +} + +template <> +std::string_view GetRecordValue(const KeyValueMutationRecord& record) { + return record.value_as_String()->value()->string_view(); +} + +template <> +std::vector GetRecordValue( + const KeyValueMutationRecord& record) { + std::vector values; + for (const auto val : *record.value_as_StringSet()->value()) { + values.push_back(val->string_view()); + } + return values; +} + +template <> +KeyValueMutationRecordStruct GetTypedRecordStruct( + const DataRecord& data_record) { + KeyValueMutationRecordStruct kv_mutation_struct; + const auto* kv_mutation_record = + data_record.record_as_KeyValueMutationRecord(); + kv_mutation_struct.key = kv_mutation_record->key()->string_view(); + kv_mutation_struct.logical_commit_time = + kv_mutation_record->logical_commit_time(); + kv_mutation_struct.mutation_type = kv_mutation_record->mutation_type(); + kv_mutation_struct.value = GetRecordStructValue(*kv_mutation_record); + return kv_mutation_struct; +} + +template <> +UserDefinedFunctionsConfigStruct GetTypedRecordStruct( + const DataRecord& data_record) { + UserDefinedFunctionsConfigStruct udf_config_struct; + const auto* udf_config = data_record.record_as_UserDefinedFunctionsConfig(); + udf_config_struct.language = udf_config->language(); + udf_config_struct.logical_commit_time = udf_config->logical_commit_time(); + udf_config_struct.code_snippet = udf_config->code_snippet()->string_view(); + udf_config_struct.handler_name = udf_config->handler_name()->string_view(); + return udf_config_struct; +} + } // namespace kv_server diff --git a/public/data_loading/records_utils.h b/public/data_loading/records_utils.h index cabe8bcc..931e4bb4 100644 --- a/public/data_loading/records_utils.h +++ b/public/data_loading/records_utils.h @@ -19,27 +19,139 @@ #include #include +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "public/data_loading/data_loading_generated.h" namespace kv_server { -std::string_view ToStringView(const flatbuffers::FlatBufferBuilder& fb_buffer); +enum class DataRecordType : int { + kKeyValueMutationRecord, + kUserDefinedFunctionsConfig +}; + +using KeyValueMutationRecordValueT = + std::variant>; -struct DeltaFileRecordStruct { - kv_server::DeltaMutationType mutation_type; +struct KeyValueMutationRecordStruct { + KeyValueMutationType mutation_type; int64_t logical_commit_time; std::string_view key; - std::string_view value; + KeyValueMutationRecordValueT value; +}; - flatbuffers::FlatBufferBuilder ToFlatBuffer() const; +struct UserDefinedFunctionsConfigStruct { + UserDefinedFunctionsLanguage language; + std::string_view code_snippet; + std::string_view handler_name; + int64_t logical_commit_time; }; -bool operator==(const DeltaFileRecordStruct& lhs_record, - const DeltaFileRecordStruct& rhs_record); +using RecordT = std::variant; + +struct DataRecordStruct { + RecordT record; +}; + +bool operator==(const KeyValueMutationRecordStruct& lhs_record, + const KeyValueMutationRecordStruct& rhs_record); +bool operator!=(const KeyValueMutationRecordStruct& lhs_record, + const KeyValueMutationRecordStruct& rhs_record); + +bool operator==(const UserDefinedFunctionsConfigStruct& lhs_record, + const UserDefinedFunctionsConfigStruct& rhs_record); +bool operator!=(const UserDefinedFunctionsConfigStruct& lhs_record, + const UserDefinedFunctionsConfigStruct& rhs_record); + +bool operator==(const DataRecordStruct& lhs_record, + const DataRecordStruct& rhs_record); +bool operator!=(const DataRecordStruct& lhs_record, + const DataRecordStruct& rhs_record); + +// Returns true if the value has been default initialized and no variant was +// set. +bool IsEmptyValue(const KeyValueMutationRecordValueT& value); + +// Casts the flat buffer `record_buffer` into a string representation. +std::string_view ToStringView( + const flatbuffers::FlatBufferBuilder& record_buffer); + +// Serializes the record struct to a flat buffer builder using format defined by +// `data_loading.fbs:KeyValueMutationRecord` table. +flatbuffers::FlatBufferBuilder ToFlatBufferBuilder( + const KeyValueMutationRecordStruct& record); + +// Serializes the file record struct to a flat buffer builder using format +// defined by `data_loading.fbs:DataRecord` table. +flatbuffers::FlatBufferBuilder ToFlatBufferBuilder( + const DataRecordStruct& data_record); + +// Deserializes "data_loading.fbs:KeyValueMutationRecord" raw flatbuffer record +// bytes and calls `record_callback` with the resulting `KeyValueMutationRecord` +// object. +// Returns `absl::InvalidArgumentError` if deserilization fails, otherwise +// returns the result of calling `record_callback`. +absl::Status DeserializeRecord( + std::string_view record_bytes, + const std::function& + record_callback); + +// Deserializes "data_loading.fbs:KeyValueMutationRecord" raw flatbuffer record +// bytes and calls `record_callback` with the resulting +// `KeyValueMutationRecordStruct` object. +// Returns `absl::InvalidArgumentError` if deserilization fails, otherwise +// returns the result of calling `record_callback`. +absl::Status DeserializeRecord( + std::string_view record_bytes, + const std::function& + record_callback); + +// Deserializes "data_loading.fbs:DataRecord" raw flatbuffer record +// bytes and calls `record_callback` with the resulting `DataRecord` +// object. +// Returns `absl::InvalidArgumentError` if deserialization fails, otherwise +// returns the result of calling `record_callback`. +absl::Status DeserializeDataRecord( + std::string_view record_bytes, + const std::function& record_callback); + +// Deserializes "data_loading.fbs:DataRecord" raw flatbuffer record +// bytes and calls `record_callback` with the resulting +// `DataRecordStruct` object. +// Returns `absl::InvalidArgumentError` if deserilization fails, otherwise +// returns the result of calling `record_callback`. +absl::Status DeserializeDataRecord( + std::string_view record_bytes, + const std::function& + record_callback); + +// Utility function to get the union value set on the `record`. Must +// be called after checking the type of the union value using +// `record.value_type()` function. +template +ValueT GetRecordValue(const KeyValueMutationRecord& record); +template <> +std::string_view GetRecordValue(const KeyValueMutationRecord& record); +template <> +std::vector GetRecordValue( + const KeyValueMutationRecord& record); -bool operator!=(const DeltaFileRecordStruct& lhs_record, - const DeltaFileRecordStruct& rhs_record); +// Utility function to get the union record set on the `data_record`. Must +// be called after checking the type of the union record using +// `data_record.record_type()` function. +template +RecordT GetTypedRecordStruct(const DataRecord& data_record); +template <> +KeyValueMutationRecordStruct GetTypedRecordStruct( + const DataRecord& data_record); +template <> +UserDefinedFunctionsConfigStruct GetTypedRecordStruct( + const DataRecord& data_record); } // namespace kv_server diff --git a/public/data_loading/records_utils_test.cc b/public/data_loading/records_utils_test.cc index 0b0c1ce1..66ba3ffb 100644 --- a/public/data_loading/records_utils_test.cc +++ b/public/data_loading/records_utils_test.cc @@ -15,23 +15,248 @@ #include "public/data_loading/records_utils.h" #include "absl/hash/hash_testing.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" namespace kv_server { namespace { -DeltaFileRecordStruct GetDeltaRecord() { - DeltaFileRecordStruct record; +KeyValueMutationRecordStruct GetKeyValueMutationRecord( + KeyValueMutationRecordValueT value = "value") { + KeyValueMutationRecordStruct record; record.key = "key"; - record.value = "value"; + record.value = value; record.logical_commit_time = 1234567890; - record.mutation_type = DeltaMutationType::Update; + record.mutation_type = KeyValueMutationType::Update; return record; } -TEST(DeltaFileRecordStructTest, ValidateEqualsOperator) { - EXPECT_EQ(GetDeltaRecord(), GetDeltaRecord()); - EXPECT_NE(GetDeltaRecord(), DeltaFileRecordStruct{.key = "key1"}); +UserDefinedFunctionsConfigStruct GetUdfConfigStruct( + std::string_view code_snippet = "function my_handler(){}") { + UserDefinedFunctionsConfigStruct udf_config_struct; + udf_config_struct.language = UserDefinedFunctionsLanguage::Javascript; + udf_config_struct.code_snippet = code_snippet; + udf_config_struct.handler_name = "my_handler"; + udf_config_struct.logical_commit_time = 1234567890; + return udf_config_struct; } + +DataRecordStruct GetDataRecord(RecordT record) { + DataRecordStruct data_record_struct; + data_record_struct.record = record; + return data_record_struct; +} + +TEST(KeyValueMutationRecordStructTest, ValidateEqualsOperator) { + EXPECT_EQ(GetKeyValueMutationRecord(), GetKeyValueMutationRecord()); + EXPECT_NE(GetKeyValueMutationRecord("value1"), + GetKeyValueMutationRecord("value2")); + std::vector values1{"value1", "value2"}; + EXPECT_EQ(GetKeyValueMutationRecord(values1), + GetKeyValueMutationRecord(values1)); + std::vector values2{"value3", "value4"}; + EXPECT_NE(GetKeyValueMutationRecord(values1), + GetKeyValueMutationRecord(values2)); +} + +TEST(KeyValueMutationRecordStructTest, VerifyRecordStructValueIsEmpty) { + KeyValueMutationRecordValueT value; + EXPECT_TRUE(IsEmptyValue(value)); + value = "test"; + EXPECT_FALSE(IsEmptyValue(value)); + value = std::vector{"test1", "test2"}; + EXPECT_FALSE(IsEmptyValue(value)); +} + +TEST(UdfConfigStructTest, ValidateEqualsOperator) { + EXPECT_EQ(GetUdfConfigStruct(), GetUdfConfigStruct()); + EXPECT_NE(GetUdfConfigStruct("code_snippet1"), + GetUdfConfigStruct("code_snippet2")); +} + +TEST(DataRecordStructTest, ValidateEqualsOperator) { + RecordT record; + EXPECT_EQ(GetDataRecord(record), GetDataRecord(record)); + + EXPECT_EQ(GetDataRecord(GetKeyValueMutationRecord()), + GetDataRecord(GetKeyValueMutationRecord())); + EXPECT_NE(GetDataRecord(GetKeyValueMutationRecord("value1")), + GetDataRecord(GetKeyValueMutationRecord("value2"))); + std::vector values1{"value1", "value2"}; + EXPECT_EQ(GetDataRecord(GetKeyValueMutationRecord(values1)), + GetDataRecord(GetKeyValueMutationRecord(values1))); + std::vector values2{"value3", "value4"}; + EXPECT_NE(GetDataRecord(GetKeyValueMutationRecord(values1)), + GetDataRecord(GetKeyValueMutationRecord(values2))); + + EXPECT_EQ(GetDataRecord(GetUdfConfigStruct()), + GetDataRecord(GetUdfConfigStruct())); + EXPECT_NE(GetDataRecord(GetUdfConfigStruct("code_snippet1")), + GetDataRecord(GetUdfConfigStruct("code_snippet2"))); +} + +class RecordValueTest + : public testing::TestWithParam { + protected: + KeyValueMutationRecordValueT GetValue() { return GetParam(); } +}; + +void ExpectEqual(const KeyValueMutationRecordStruct& record, + const KeyValueMutationRecord& fbs_record) { + EXPECT_EQ(record.key, fbs_record.key()->string_view()); + EXPECT_EQ(record.logical_commit_time, fbs_record.logical_commit_time()); + EXPECT_EQ(record.mutation_type, fbs_record.mutation_type()); + if (fbs_record.value_type() == Value::String) { + EXPECT_EQ(std::get(record.value), + GetRecordValue(fbs_record)); + } + if (fbs_record.value_type() == Value::StringSet) { + EXPECT_THAT(std::get>(record.value), + testing::ContainerEq( + GetRecordValue>(fbs_record))); + } +} + +void ExpectEqual(const UserDefinedFunctionsConfigStruct& record, + const UserDefinedFunctionsConfig& fbs_record) { + EXPECT_EQ(record.language, fbs_record.language()); + EXPECT_EQ(record.logical_commit_time, fbs_record.logical_commit_time()); + EXPECT_EQ(record.code_snippet, fbs_record.code_snippet()->string_view()); + EXPECT_EQ(record.handler_name, fbs_record.handler_name()->string_view()); +} + +void ExpectEqual(const DataRecordStruct& record, const DataRecord& fbs_record) { + if (fbs_record.record_type() == Record::KeyValueMutationRecord) { + ExpectEqual(std::get(record.record), + *fbs_record.record_as_KeyValueMutationRecord()); + } + if (fbs_record.record_type() == Record::UserDefinedFunctionsConfig) { + ExpectEqual(std::get(record.record), + *fbs_record.record_as_UserDefinedFunctionsConfig()); + } +} + +INSTANTIATE_TEST_SUITE_P(RecordValueType, RecordValueTest, + testing::Values("value1", + std::vector{ + "value1", "value2"})); +TEST_P(RecordValueTest, VerifyDeserializeRecordToFbsRecord) { + auto record = GetKeyValueMutationRecord(GetValue()); + testing::MockFunction + record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&record](const KeyValueMutationRecord& fbs_record) { + ExpectEqual(record, fbs_record); + return absl::OkStatus(); + }); + auto status = DeserializeRecord(ToStringView(ToFlatBufferBuilder(record)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST_P(RecordValueTest, VerifyDeserializeRecordToRecordStruct) { + auto record = GetKeyValueMutationRecord(GetValue()); + testing::MockFunction + record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&record](const KeyValueMutationRecordStruct& actual_record) { + EXPECT_EQ(record, actual_record); + return absl::OkStatus(); + }); + auto status = DeserializeRecord(ToStringView(ToFlatBufferBuilder(record)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST(DataRecordTest, + DeserializeDataRecord_ToFbsRecord_KVMutation_StringValue_Success) { + auto data_record_struct = GetDataRecord(GetKeyValueMutationRecord("value")); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecord& data_record_fbs) { + ExpectEqual(data_record_struct, data_record_fbs); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST(DataRecordTest, + DeserializeDataRecord_ToFbsRecord_KVMutation_StringVectorValue_Success) { + std::vector values({"value1", "value2"}); + auto data_record_struct = GetDataRecord(GetKeyValueMutationRecord(values)); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecord& data_record_fbs) { + ExpectEqual(data_record_struct, data_record_fbs); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST(DataRecordTest, + DeserializeDataRecord_ToStruct_KVMutation_StringValue_Success) { + auto data_record_struct = GetDataRecord(GetKeyValueMutationRecord("value")); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecordStruct& actual_record) { + EXPECT_EQ(data_record_struct, actual_record); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST(DataRecordTest, + DeserializeDataRecord_ToStruct_KVMutation_VectorStringValue_Success) { + std::vector values({"value1", "value2"}); + auto data_record_struct = GetDataRecord(GetKeyValueMutationRecord(values)); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecordStruct& actual_record) { + EXPECT_EQ(data_record_struct, actual_record); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST(DataRecordTest, DeserializeDataRecord_ToFbsRecord_UdfConfig_Success) { + auto data_record_struct = GetDataRecord(GetUdfConfigStruct()); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecord& data_record_fbs) { + ExpectEqual(data_record_struct, data_record_fbs); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST(DataRecordTest, DeserializeDataRecord_ToStruct_UdfConfig_Success) { + auto data_record_struct = GetDataRecord(GetUdfConfigStruct()); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .WillOnce([&data_record_struct](const DataRecordStruct& actual_record) { + EXPECT_EQ(data_record_struct, actual_record); + return absl::OkStatus(); + }); + auto status = DeserializeDataRecord( + ToStringView(ToFlatBufferBuilder(data_record_struct)), + record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + } // namespace } // namespace kv_server diff --git a/public/data_loading/riegeli_metadata.proto b/public/data_loading/riegeli_metadata.proto index 409e7b78..2a17dfb0 100644 --- a/public/data_loading/riegeli_metadata.proto +++ b/public/data_loading/riegeli_metadata.proto @@ -22,7 +22,7 @@ import "riegeli/records/records_metadata.proto"; // Metadata specific to DELTA files. message DeltaMetadata {} -//Metadata specific to SNAPSHOT files. +// Metadata specific to SNAPSHOT files. message SnapshotMetadata { // [Required] // (1) Name of the previous snapshot file used to generate this snapshot or @@ -34,6 +34,26 @@ message SnapshotMetadata { optional string ending_delta_file = 2; } +// Defines boundaries for a logical shard in a data file. All shard records are +// strictly contained by the byte range specified by `start_index` and +// `end_index`. +message Shard { + // Byte index where the shard starts. + optional int64 start_index = 1; + + // Byte index where the shard ends. + optional int64 end_index = 2; +} + +// Metadata specifying whether the file is sharded and if so, shard boundaries. +message ShardsMetadata { + // Whether the file is logically sharded or not. + optional bool is_sharded_file = 1; + + // Metadata about logical shards in the file. + repeated Shard shards = 2; +} + // All K/V server metadata related to one riegeli file. message KVFileMetadata { // All records in one file are from this namespace. @@ -43,6 +63,8 @@ message KVFileMetadata { DeltaMetadata delta = 2; SnapshotMetadata snapshot = 3; } + + optional ShardsMetadata shards_metadata = 4; } extend riegeli.RecordsMetadata { diff --git a/public/data_loading/writers/BUILD b/public/data_loading/writers/BUILD index bb4a95dd..66041da5 100644 --- a/public/data_loading/writers/BUILD +++ b/public/data_loading/writers/BUILD @@ -79,3 +79,29 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "sharded_record_buffer", + srcs = ["sharded_record_buffer.cc"], + hdrs = ["sharded_record_buffer.h"], + deps = [ + "//public/data_loading:records_utils", + "//public/sharding:sharding_function", + "@com_github_google_glog//:glog", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_riegeli//riegeli/bytes:ostream_writer", + "@com_google_riegeli//riegeli/records:record_writer", + ], +) + +cc_test( + name = "sharded_record_buffer_test", + srcs = ["sharded_record_buffer_test.cc"], + deps = [ + ":sharded_record_buffer", + "//public/data_loading/readers:delta_record_stream_reader", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/public/data_loading/writers/delta_record_stream_writer.h b/public/data_loading/writers/delta_record_stream_writer.h index e38433c4..df08766e 100644 --- a/public/data_loading/writers/delta_record_stream_writer.h +++ b/public/data_loading/writers/delta_record_stream_writer.h @@ -41,7 +41,7 @@ class DeltaRecordStreamWriter : public DeltaRecordWriter { static absl::StatusOr> Create( DestStreamT& dest_stream, Options options); - absl::Status WriteRecord(const DeltaFileRecordStruct& record) override; + absl::Status WriteRecord(const DataRecordStruct& data_record) override; const Options& GetOptions() const override { return options_; } absl::Status Flush() override; void Close() override { record_writer_->Close(); } @@ -78,10 +78,11 @@ DeltaRecordStreamWriter::Create(DestStreamT& dest_stream, template absl::Status DeltaRecordStreamWriter::WriteRecord( - const DeltaFileRecordStruct& record) { - if (!record_writer_->WriteRecord(ToStringView(record.ToFlatBuffer())) && + const DataRecordStruct& data_record) { + if (!record_writer_->WriteRecord( + ToStringView(ToFlatBufferBuilder(data_record))) && options_.recovery_function) { - options_.recovery_function(record); + options_.recovery_function(data_record); } return record_writer_->status(); } diff --git a/public/data_loading/writers/delta_record_stream_writer_test.cc b/public/data_loading/writers/delta_record_stream_writer_test.cc index 49a01010..c0b3edf9 100644 --- a/public/data_loading/writers/delta_record_stream_writer_test.cc +++ b/public/data_loading/writers/delta_record_stream_writer_test.cc @@ -32,12 +32,36 @@ KVFileMetadata GetMetadata() { return metadata; } -DeltaFileRecordStruct GetDeltaRecord() { - DeltaFileRecordStruct record; +KeyValueMutationRecordStruct GetKeyValueMutationRecord() { + KeyValueMutationRecordStruct kv_mutation_record; + kv_mutation_record.key = "key"; + kv_mutation_record.value = "value"; + kv_mutation_record.logical_commit_time = 1234567890; + kv_mutation_record.mutation_type = KeyValueMutationType::Update; + return kv_mutation_record; +} + +UserDefinedFunctionsConfigStruct GetUserDefinedFunctionsConfig() { + UserDefinedFunctionsConfigStruct udf_config_record; + udf_config_record.language = UserDefinedFunctionsLanguage::Javascript; + udf_config_record.code_snippet = "function hello(){}"; + udf_config_record.handler_name = "hello"; + udf_config_record.logical_commit_time = 1234567890; + return udf_config_record; +} + +DataRecordStruct GetDataRecord(const RecordT& record) { + DataRecordStruct data_record; + data_record.record = record; + return data_record; +} + +KeyValueMutationRecordStruct GetDeltaSetRecord() { + KeyValueMutationRecordStruct record; record.key = "key"; - record.value = "value"; + record.value = std::vector{"v1", "v2"}; record.logical_commit_time = 1234567890; - record.mutation_type = DeltaMutationType::Update; + record.mutation_type = KeyValueMutationType::Update; return record; } @@ -59,11 +83,73 @@ INSTANTIATE_TEST_SUITE_P( DeltaRecordWriter::Options{.enable_compression = true, .metadata = GetMetadata()})); -TEST_P(DeltaRecordStreamWriterTest, ValidateWritingAndReadingDeltaStream) { +TEST_P(DeltaRecordStreamWriterTest, + ValidateWritingAndReadingWithKVMutationDeltaStream) { std::stringstream string_stream; auto record_writer = CreateDeltaRecordStreamWriter(string_stream); EXPECT_TRUE(record_writer.ok()); - EXPECT_TRUE((*record_writer)->WriteRecord(GetDeltaRecord()).ok()) + + DataRecordStruct expected = GetDataRecord(GetKeyValueMutationRecord()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()) + << "Failed to write delta record."; + (*record_writer)->Close(); + EXPECT_FALSE((*record_writer)->IsOpen()); + auto stream_reader_factory = + StreamRecordReaderFactory::Create(); + auto stream_reader = stream_reader_factory->CreateReader(string_stream); + absl::StatusOr metadata = stream_reader->GetKVFileMetadata(); + EXPECT_TRUE(metadata.ok()) << "Failed to read metadata"; + EXPECT_TRUE( + stream_reader + ->ReadStreamRecords( + [&expected](std::string_view record_string) -> absl::Status { + return DeserializeDataRecord( + record_string, [&expected](DataRecordStruct record) { + EXPECT_EQ(record, expected); + return absl::OkStatus(); + }); + }) + .ok()); +} + +TEST_P(DeltaRecordStreamWriterTest, + ValidateWritingAndReadingWithUdfConfigDeltaStream) { + std::stringstream string_stream; + auto record_writer = CreateDeltaRecordStreamWriter(string_stream); + EXPECT_TRUE(record_writer.ok()); + + DataRecordStruct expected = GetDataRecord(GetUserDefinedFunctionsConfig()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()) + << "Failed to write delta record."; + (*record_writer)->Close(); + EXPECT_FALSE((*record_writer)->IsOpen()); + auto stream_reader_factory = + StreamRecordReaderFactory::Create(); + auto stream_reader = stream_reader_factory->CreateReader(string_stream); + absl::StatusOr metadata = stream_reader->GetKVFileMetadata(); + EXPECT_TRUE(metadata.ok()) << "Failed to read metadata"; + EXPECT_TRUE( + stream_reader + ->ReadStreamRecords( + [&expected](std::string_view record_string) -> absl::Status { + return DeserializeDataRecord( + record_string, + [&expected](const DataRecordStruct& data_record) { + EXPECT_EQ(data_record, expected); + return absl::OkStatus(); + }); + }) + .ok()); +} + +TEST_P(DeltaRecordStreamWriterTest, + ValidateWritingAndReadingDeltaStreamForSet) { + std::stringstream string_stream; + auto record_writer = CreateDeltaRecordStreamWriter(string_stream); + EXPECT_TRUE(record_writer.ok()); + + DataRecordStruct expected = GetDataRecord(GetDeltaSetRecord()); + EXPECT_TRUE((*record_writer)->WriteRecord(expected).ok()) << "Failed to write delta record."; (*record_writer)->Close(); EXPECT_FALSE((*record_writer)->IsOpen()); @@ -72,21 +158,17 @@ TEST_P(DeltaRecordStreamWriterTest, ValidateWritingAndReadingDeltaStream) { auto stream_reader = stream_reader_factory->CreateReader(string_stream); absl::StatusOr metadata = stream_reader->GetKVFileMetadata(); EXPECT_TRUE(metadata.ok()) << "Failed to read metadata"; - EXPECT_TRUE(stream_reader - ->ReadStreamRecords( - [](std::string_view record_string) -> absl::Status { - DeltaFileRecordStruct record; - auto fbs_record = flatbuffers::GetRoot( - record_string.data()); - record.key = fbs_record->key()->string_view(); - record.value = fbs_record->value()->string_view(); - record.mutation_type = fbs_record->mutation_type(); - record.logical_commit_time = - fbs_record->logical_commit_time(); - EXPECT_EQ(record, GetDeltaRecord()); - return absl::OkStatus(); - }) - .ok()); + EXPECT_TRUE( + stream_reader + ->ReadStreamRecords( + [&expected](std::string_view record_string) -> absl::Status { + return DeserializeDataRecord( + record_string, [&expected](DataRecordStruct record) { + EXPECT_EQ(record, expected); + return absl::OkStatus(); + }); + }) + .ok()); } TEST(DeltaRecordStreamWriterTest, ValidateWritingFailsAfterClose) { @@ -96,7 +178,8 @@ TEST(DeltaRecordStreamWriterTest, ValidateWritingFailsAfterClose) { string_stream, std::move(options)); EXPECT_TRUE(record_writer.ok()); (*record_writer)->Close(); - auto status = (*record_writer)->WriteRecord(GetDeltaRecord()); + auto status = + (*record_writer)->WriteRecord(GetDataRecord(GetKeyValueMutationRecord())); EXPECT_FALSE(status.ok()); } diff --git a/public/data_loading/writers/delta_record_writer.h b/public/data_loading/writers/delta_record_writer.h index f85cd733..ed579b2f 100644 --- a/public/data_loading/writers/delta_record_writer.h +++ b/public/data_loading/writers/delta_record_writer.h @@ -35,7 +35,7 @@ namespace kv_server { // ``` // DeltaRecordWriter record_writer = ... // while(more records to write) { -// DeltaFileRecordStruct record = ... +// DataRecordStruct record = ... // if (absl::Status status = record_writer.WriteRecord(record); !status.ok()) // { // LOG(WARN) << "Failed to write record."; @@ -52,15 +52,16 @@ class DeltaRecordWriter { bool enable_compression; // If writing a record fails, this function will be called with the failed // record. - std::function recovery_function; + std::function recovery_function; // Metadata required for delta files. KVFileMetadata metadata; }; virtual ~DeltaRecordWriter() = default; - // Writes a `DeltaFileRecordStruct` record to the underlying destination. - virtual absl::Status WriteRecord(const DeltaFileRecordStruct& record) = 0; + // Writes a `DataRecordStruct` record to the underlying + // destination. + virtual absl::Status WriteRecord(const DataRecordStruct& data_record) = 0; // Flushes any written data to the underlying destination and makes it // visible outside the writing process. `Flush()` is different from // `Close()` in that it allows for more records to be written after some data diff --git a/public/data_loading/writers/sharded_record_buffer.cc b/public/data_loading/writers/sharded_record_buffer.cc new file mode 100644 index 00000000..bf395e4c --- /dev/null +++ b/public/data_loading/writers/sharded_record_buffer.cc @@ -0,0 +1,143 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "public/data_loading/writers/sharded_record_buffer.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" +#include "riegeli/bytes/ostream_writer.h" +#include "riegeli/records/record_writer.h" + +namespace kv_server { +namespace { + +class RecordBufferImpl : public RecordBuffer { + public: + ~RecordBufferImpl() { record_writer_.Close(); } + + absl::Status AddRecord(const DataRecordStruct& record) override { + if (!record_writer_.WriteRecord( + ToStringView(ToFlatBufferBuilder(record)))) { + return record_writer_.status(); + } + return absl::OkStatus(); + } + + absl::Status Flush() override { + if (!record_writer_.Flush()) { + return record_writer_.status(); + } + return absl::OkStatus(); + } + + std::istream* RecordStream() override { return record_stream_.get(); } + + static std::unique_ptr Create() { + auto record_stream = std::make_unique(); + riegeli::RecordWriterBase::Options options; + options.set_uncompressed(); + auto record_writer = + riegeli::RecordWriter>( + riegeli::OStreamWriter(record_stream.get()), options); + return absl::WrapUnique(new RecordBufferImpl(std::move(record_stream), + std::move(record_writer))); + } + + private: + RecordBufferImpl( + std::unique_ptr record_stream, + riegeli::RecordWriter> + record_writer) + : record_stream_(std::move(record_stream)), + record_writer_(std::move(record_writer)) {} + + std::unique_ptr record_stream_; + riegeli::RecordWriter> + record_writer_; +}; + +absl::Status IsWithinBounds(int shard_id, int num_shards) { + if (shard_id < 0 || shard_id >= num_shards) { + return absl::InvalidArgumentError(absl::StrFormat( + "Shard id: %d is out of range: [%d, %d)", shard_id, 0, num_shards)); + } + return absl::OkStatus(); +} + +} // namespace + +ShardedRecordBuffer::ShardedRecordBuffer( + ShardingFunction sharding_func, + std::vector> shard_buffers) + : sharding_func_(std::move(sharding_func)), + shard_buffers_(std::move(shard_buffers)) {} + +absl::StatusOr> +ShardedRecordBuffer::Create(int num_shards, ShardingFunction sharding_func) { + if (num_shards <= 0) { + return absl::InvalidArgumentError(absl::StrFormat( + "Number of shards: %d must be greater than 0", num_shards)); + } + std::vector> shard_stores; + shard_stores.reserve(num_shards); + for (int shard_id = 0; shard_id < num_shards; shard_id++) { + shard_stores.push_back(RecordBufferImpl::Create()); + } + return absl::WrapUnique(new ShardedRecordBuffer(std::move(sharding_func), + std::move(shard_stores))); +} + +absl::StatusOr ShardedRecordBuffer::GetShardRecordStream( + int shard_id) { + if (auto status = IsWithinBounds(shard_id, shard_buffers_.size()); + !status.ok()) { + return status; + } + return shard_buffers_[shard_id]->RecordStream(); +} + +absl::Status ShardedRecordBuffer::AddRecord( + const DataRecordStruct& data_record) { + if (std::holds_alternative( + data_record.record)) { + auto kv_record = std::get(data_record.record); + auto shard_id = + sharding_func_.GetShardNumForKey(kv_record.key, shard_buffers_.size()); + return shard_buffers_[shard_id]->AddRecord(data_record); + } + return absl::OkStatus(); +} + +absl::Status ShardedRecordBuffer::Flush(int shard_id) { + if (shard_id < 0) { + for (const auto& buffer : shard_buffers_) { + if (auto status = buffer->Flush(); !status.ok()) { + return status; + } + } + return absl::OkStatus(); + } + if (auto status = IsWithinBounds(shard_id, shard_buffers_.size()); + !status.ok()) { + return status; + } + return shard_buffers_[shard_id]->Flush(); +} + +} // namespace kv_server diff --git a/public/data_loading/writers/sharded_record_buffer.h b/public/data_loading/writers/sharded_record_buffer.h new file mode 100644 index 00000000..42a41da5 --- /dev/null +++ b/public/data_loading/writers/sharded_record_buffer.h @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PUBLIC_DATA_LOADING_WRITERS_SHARDED_RECORD_BUFFER_H_ +#define PUBLIC_DATA_LOADING_WRITERS_SHARDED_RECORD_BUFFER_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "public/data_loading/records_utils.h" +#include "public/sharding/sharding_function.h" + +namespace kv_server { + +// A `RecordBuffer` buffers `DataRecordStruct` records serialized as +// `data_loading.fbs:DataRecord` flatbuffers. Records can be read +// out of the buffer from `RecordStream()`. +class RecordBuffer { + public: + virtual ~RecordBuffer() = default; + static std::unique_ptr Create(); + // Returns an error status if adding a record to the buffer fails for some + // reason. + virtual absl::Status AddRecord(const DataRecordStruct& record) = 0; + // Flushes buffered records so that they are visible for reading via + // `RecordStream()`. + virtual absl::Status Flush() = 0; + // Call `Flush()` to guarantee that all buffered records are visible before + // reading. + virtual std::istream* RecordStream() = 0; +}; + +// A `ShardedRecordBuffer` buffers `DataRecordStruct` records +// serialized as `data_loading.fbs:DataRecord` flatbuffers in +// separate sharded streams +class ShardedRecordBuffer { + public: + ~ShardedRecordBuffer() = default; + ShardedRecordBuffer(const ShardedRecordBuffer&) = delete; + ShardedRecordBuffer& operator=(const ShardedRecordBuffer&) = delete; + + static absl::StatusOr> Create( + int num_shards, ShardingFunction sharding_func = ShardingFunction("")); + absl::StatusOr GetShardRecordStream(int shard_id); + absl::Status AddRecord(const DataRecordStruct& record); + // Flushes buffered records so that they are visible for reading via + // `RecordStream()`. Specify a `shard_id` to flush records buffered for a + // specific shard or -1 to flush all buffered records. + absl::Status Flush(int shard_id = -1); + + private: + ShardedRecordBuffer(ShardingFunction sharding_func, + std::vector> shard_buffers); + ShardingFunction sharding_func_; + std::vector> shard_buffers_; +}; + +} // namespace kv_server + +#endif // PUBLIC_DATA_LOADING_WRITERS_SHARDED_RECORD_BUFFER_H_ diff --git a/public/data_loading/writers/sharded_record_buffer_test.cc b/public/data_loading/writers/sharded_record_buffer_test.cc new file mode 100644 index 00000000..c85c090f --- /dev/null +++ b/public/data_loading/writers/sharded_record_buffer_test.cc @@ -0,0 +1,137 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "public/data_loading/writers/sharded_record_buffer.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "public/data_loading/readers/delta_record_stream_reader.h" +#include "public/sharding/sharding_function.h" + +namespace kv_server { +namespace { + +KeyValueMutationRecordStruct GetKVMutationRecord(std::string_view key) { + return KeyValueMutationRecordStruct{ + .mutation_type = KeyValueMutationType::Update, + .logical_commit_time = 1234567890, + .key = key, + .value = "value", + }; +} + +DataRecordStruct GetDataRecord(const RecordT& record) { + DataRecordStruct data_record; + data_record.record = record; + return data_record; +} + +void ValidateRecordStream(std::istream& record_stream) { + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call).Times(0); + DeltaRecordStreamReader record_reader(record_stream); + auto status = record_reader.ReadRecords(record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +void ValidateRecordStream(const std::vector keys, + std::istream& record_stream) { + testing::MockFunction record_callback; + if (keys.size() == 1) { + EXPECT_CALL(record_callback, Call) + .Times(1) + .WillOnce([&keys](DataRecordStruct data_record) { + if (std::holds_alternative( + data_record.record)) { + auto kv_record = + std::get(data_record.record); + EXPECT_EQ(kv_record.key, keys[0]); + } + return absl::OkStatus(); + }); + } else { + EXPECT_CALL(record_callback, Call) + .Times(keys.size()) + .WillRepeatedly([&keys](DataRecordStruct data_record) { + if (std::holds_alternative( + data_record.record)) { + auto kv_record = + std::get(data_record.record); + auto key_iter = std::find(keys.begin(), keys.end(), kv_record.key); + EXPECT_TRUE(key_iter != keys.end()); + } + return absl::OkStatus(); + }); + } + DeltaRecordStreamReader record_reader(record_stream); + auto status = record_reader.ReadRecords(record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST(ShardedRecordBufferTest, ValidateCreatingBuffer) { + auto buffer = ShardedRecordBuffer::Create(-5); + EXPECT_FALSE(buffer.ok()) << buffer.status(); + EXPECT_EQ(buffer.status().code(), absl::StatusCode::kInvalidArgument); + buffer = ShardedRecordBuffer::Create(7); + EXPECT_TRUE(buffer.ok()) << buffer.status(); +} + +TEST(ShardedRecordBufferTest, ValidateAddingAndReadingRecords) { + int num_shards = 7; + auto keys = std::vector{ + "key1", "key2", "key3", "key4", "key5", "key6", "key7", + }; + ShardingFunction sharding_func(/*seed=*/""); + auto record_buffer = ShardedRecordBuffer::Create(num_shards, sharding_func); + EXPECT_TRUE(record_buffer.ok()) << record_buffer.status(); + for (const auto& key : keys) { + auto status = + (*record_buffer)->AddRecord(GetDataRecord(GetKVMutationRecord(key))); + EXPECT_TRUE(status.ok()) << status; + } + auto status = (*record_buffer)->Flush(); + EXPECT_TRUE(status.ok()) << status; + + // shard 2 and 4 have no records. + for (auto shard_id : std::vector{2, 4}) { + auto shard_stream = (*record_buffer)->GetShardRecordStream(shard_id); + EXPECT_TRUE(shard_stream.ok()) << shard_stream.status(); + ValidateRecordStream(**shard_stream); + } + // {key1,key5}=5, {key2,key7}=6, key3=1, key4=0, key6=3 + auto shard_stream = (*record_buffer)->GetShardRecordStream(5); + EXPECT_TRUE(shard_stream.ok()) << shard_stream.status(); + ValidateRecordStream({"key1", "key5"}, **shard_stream); + + shard_stream = (*record_buffer)->GetShardRecordStream(6); + EXPECT_TRUE(shard_stream.ok()) << shard_stream.status(); + ValidateRecordStream({"key2", "key7"}, **shard_stream); + + shard_stream = (*record_buffer)->GetShardRecordStream(1); + EXPECT_TRUE(shard_stream.ok()) << shard_stream.status(); + ValidateRecordStream({"key3"}, **shard_stream); + + shard_stream = (*record_buffer)->GetShardRecordStream(0); + EXPECT_TRUE(shard_stream.ok()) << shard_stream.status(); + ValidateRecordStream({"key4"}, **shard_stream); + + shard_stream = (*record_buffer)->GetShardRecordStream(3); + EXPECT_TRUE(shard_stream.ok()) << shard_stream.status(); + ValidateRecordStream({"key6"}, **shard_stream); +} + +} // namespace +} // namespace kv_server diff --git a/public/data_loading/writers/snapshot_stream_writer.h b/public/data_loading/writers/snapshot_stream_writer.h index 9db11b6b..68054194 100644 --- a/public/data_loading/writers/snapshot_stream_writer.h +++ b/public/data_loading/writers/snapshot_stream_writer.h @@ -37,7 +37,7 @@ namespace kv_server { -// A `SnapshotStreamWriter` writes `DeltaFileRecordStruct` records to a +// A `SnapshotStreamWriter` writes `DataRecordStruct` records to a // destination snapshot stream. The `SnapshotStreamWriter` can be used to: // (1) merge multiple delta files into a single snapshot file or // (2) merge a base snapshot file with multiple delta files into a single @@ -85,9 +85,9 @@ class SnapshotStreamWriter { static absl::StatusOr> Create( Options options, DestStreamT& dest_snapshot_stream); - absl::Status WriteRecord(const DeltaFileRecordStruct& record); - // Writes `DeltaFileRecordStruct` records from `src_stream` to the output - // snapshot stream, `dest_snapshot_stream`. Valid source streams can be + absl::Status WriteRecord(const DataRecordStruct& record); + // Writes `DataRecordStruct` records from `src_stream` to the + // output snapshot stream, `dest_snapshot_stream`. Valid source streams can be // snapshot files generated using `SnapshotStreamWriter` instances or // delta files generated using `DeltaRecordStreamWriter` instances. template @@ -109,7 +109,7 @@ class SnapshotStreamWriter { std::unique_ptr> record_writer, std::unique_ptr record_aggregator, Options options); - absl::Status InsertOrUpdateRecord(const DeltaFileRecordStruct& record); + absl::Status InsertOrUpdateRecord(const DataRecordStruct& record); template absl::Status InsertOrUpdateRecords(SrcStreamT& src_stream); static absl::StatusOr> @@ -123,6 +123,7 @@ class SnapshotStreamWriter { std::unique_ptr record_aggregator_; Options options_; bool is_finalized_ = false; + std::unique_ptr udf_config_; }; template @@ -169,9 +170,28 @@ SnapshotStreamWriter::CreateDeltaRecordWriterOptions( // TODO: Think about the best way to handle failed records. Should this be // exposed as a field of `SnapshotStreamWriter::Options`? .recovery_function = - [](const DeltaFileRecordStruct& record) { - LOG(ERROR) << "Failed to write record to snapshot stream. (key: " - << record.key << ")"; + [](const DataRecordStruct& data_record) { + if (std::holds_alternative( + data_record.record)) { + LOG(ERROR) << "Failed to write record to snapshot stream. (key: " + << std::get( + data_record.record) + .key + << ")"; + return; + } + if (std::holds_alternative( + data_record.record)) { + LOG(ERROR) << "Failed to write record to snapshot stream. " + "(udf_code_snippet: " + << std::get( + data_record.record) + .code_snippet + << ")"; + return; + } + LOG(ERROR) << "Failed to write record to snapshot stream. " + "No KeyValueMutation or UdfConfig specified. "; }, .metadata = options.metadata, }; @@ -188,9 +208,25 @@ SnapshotStreamWriter::CreateRecordAggregator( template absl::Status SnapshotStreamWriter::InsertOrUpdateRecord( - const DeltaFileRecordStruct& record) { - return record_aggregator_->InsertOrUpdateRecord(absl::HashOf(record.key), - record); + const DataRecordStruct& data_record) { + if (std::holds_alternative( + data_record.record)) { + auto kv_record = std::get(data_record.record); + return record_aggregator_->InsertOrUpdateRecord(absl::HashOf(kv_record.key), + kv_record); + } + if (std::holds_alternative( + data_record.record)) { + auto udf_config = + std::get(data_record.record); + if (udf_config_ == nullptr || + udf_config_->logical_commit_time < udf_config.logical_commit_time) { + udf_config_ = + std::make_unique(udf_config); + } + return absl::OkStatus(); + } + return absl::OkStatus(); } template @@ -203,17 +239,17 @@ absl::Status SnapshotStreamWriter::InsertOrUpdateRecords( return metadata.status(); } return record_reader.ReadRecords( - [this](auto record) { return InsertOrUpdateRecord(record); }); + [this](auto data_record) { return InsertOrUpdateRecord(data_record); }); } template absl::Status SnapshotStreamWriter::WriteRecord( - const DeltaFileRecordStruct& record) { + const DataRecordStruct& data_record) { if (is_finalized_) { return absl::FailedPreconditionError( "Cannot write records after finalizing the snapshot."); } - return InsertOrUpdateRecord(record); + return InsertOrUpdateRecord(data_record); } template @@ -233,16 +269,27 @@ absl::Status SnapshotStreamWriter::Finalize() { return absl::OkStatus(); } if (absl::Status status = record_aggregator_->ReadRecords( - [record_writer = record_writer_.get()](auto record) { + [record_writer = record_writer_.get()]( + KeyValueMutationRecordStruct kv_mutation_record) { // By definition, snapshots do NOT contain DELETE mutations. - if (record.mutation_type == DeltaMutationType::Delete) { + if (kv_mutation_record.mutation_type == + KeyValueMutationType::Delete) { return absl::OkStatus(); } - return record_writer->WriteRecord(std::move(record)); + DataRecordStruct data_record; + data_record.record = std::move(kv_mutation_record); + return record_writer->WriteRecord(data_record); }); !status.ok()) { return status; } + if (udf_config_ != nullptr) { + if (absl::Status status = record_writer_->WriteRecord( + DataRecordStruct{.record = *udf_config_}); + !status.ok()) { + return status; + } + } if (absl::Status status = record_writer_->Flush(); !status.ok()) { return status; } @@ -275,7 +322,8 @@ SnapshotStreamWriter::ValidateRequiredSnapshotMetadata( !IsSnapshotFilename(metadata.snapshot().starting_file())) { return absl::InvalidArgumentError(absl::StrCat( "Snapshot starting filename: ", metadata.snapshot().starting_file(), - " must either be a valid delta filename or valid snapshot filename.")); + " must either be a valid delta filename or valid snapshot " + "filename.")); } return absl::OkStatus(); } diff --git a/public/data_loading/writers/snapshot_stream_writer_test.cc b/public/data_loading/writers/snapshot_stream_writer_test.cc index 46bdf307..d2ac6716 100644 --- a/public/data_loading/writers/snapshot_stream_writer_test.cc +++ b/public/data_loading/writers/snapshot_stream_writer_test.cc @@ -46,15 +46,30 @@ KVFileMetadata GetSnapshotMetadata() { return metadata; } -DeltaFileRecordStruct GetDeltaRecord(std::string_view key = "key") { - DeltaFileRecordStruct record; +KeyValueMutationRecordStruct GetKVMutationRecord(std::string_view key = "key") { + KeyValueMutationRecordStruct record; record.key = key; record.value = "value"; record.logical_commit_time = 1234567890; - record.mutation_type = DeltaMutationType::Update; + record.mutation_type = KeyValueMutationType::Update; return record; } +UserDefinedFunctionsConfigStruct GetUserDefinedFunctionsConfig() { + UserDefinedFunctionsConfigStruct udf_config_record; + udf_config_record.language = UserDefinedFunctionsLanguage::Javascript; + udf_config_record.code_snippet = "function hello(){}"; + udf_config_record.handler_name = "hello"; + udf_config_record.logical_commit_time = 1234567890; + return udf_config_record; +} + +DataRecordStruct GetDataRecord(const RecordT& record) { + DataRecordStruct data_record; + data_record.record = record; + return data_record; +} + std::filesystem::path GetRecordAggregatorDbFile() { std::filesystem::path full_path(std::filesystem::temp_directory_path()); full_path /= absl::StrFormat("RecordAggregator.%d.db", std::rand()); @@ -95,22 +110,23 @@ TEST_P(SnapshotStreamWriterTest, ValidateThatRecordsAreDedupedInSnapshot) { auto snapshot_writer = SnapshotStreamWriterTest::CreateSnapshotWriter(dest_stream); EXPECT_TRUE(snapshot_writer.ok()) << snapshot_writer.status(); - auto record = GetDeltaRecord(); + auto data_record = GetDataRecord(GetKVMutationRecord()); // Write the same record to snapshot 3 times. - std::vector records{record, record, record}; - for (const auto& recd : records) { + std::vector data_records{data_record, data_record, data_record}; + for (const auto& recd : data_records) { auto status = (*snapshot_writer)->WriteRecord(recd); EXPECT_TRUE(status.ok()) << status; } auto status = (*snapshot_writer)->Finalize(); EXPECT_TRUE(status.ok()) << status; + DeltaRecordStreamReader record_reader(dest_stream); - testing::MockFunction record_callback; + testing::MockFunction record_callback; // We expect one call to record_callback because records will be deduped in // the snapshot stream. - EXPECT_CALL(record_callback, Call(record)) + EXPECT_CALL(record_callback, Call(data_record)) .Times(1) - .WillOnce([](DeltaFileRecordStruct) { return absl::OkStatus(); }); + .WillOnce([](DataRecordStruct) { return absl::OkStatus(); }); status = record_reader.ReadRecords(record_callback.AsStdFunction()); EXPECT_TRUE(status.ok()) << status; } @@ -120,19 +136,22 @@ TEST_P(SnapshotStreamWriterTest, ValidateThatDeletedRecordsAreNotInSnapshot) { auto snapshot_writer = SnapshotStreamWriterTest::CreateSnapshotWriter(dest_stream); EXPECT_TRUE(snapshot_writer.ok()) << snapshot_writer.status(); - auto record = GetDeltaRecord(); - auto status = (*snapshot_writer)->WriteRecord(record); + auto kv_record = GetKVMutationRecord(); + auto data_record = GetDataRecord(kv_record); + auto status = (*snapshot_writer)->WriteRecord(data_record); EXPECT_TRUE(status.ok()) << status; // Delete the record written above. - record.mutation_type = DeltaMutationType::Delete; - record.logical_commit_time++; - status = (*snapshot_writer)->WriteRecord(record); + kv_record.mutation_type = KeyValueMutationType::Delete; + kv_record.logical_commit_time++; + auto data_record_with_deletion = GetDataRecord(kv_record); + status = (*snapshot_writer)->WriteRecord(data_record_with_deletion); EXPECT_TRUE(status.ok()) << status; status = (*snapshot_writer)->Finalize(); EXPECT_TRUE(status.ok()) << status; + DeltaRecordStreamReader record_reader(dest_stream); - testing::MockFunction record_callback; - EXPECT_CALL(record_callback, Call(record)).Times(0); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call(data_record)).Times(0); status = record_reader.ReadRecords(record_callback.AsStdFunction()); EXPECT_TRUE(status.ok()) << status; } @@ -143,22 +162,25 @@ TEST_P(SnapshotStreamWriterTest, auto snapshot_writer = SnapshotStreamWriterTest::CreateSnapshotWriter(dest_stream); EXPECT_TRUE(snapshot_writer.ok()) << snapshot_writer.status(); - auto record = GetDeltaRecord(); - auto status = (*snapshot_writer)->WriteRecord(record); + auto kv_record = GetKVMutationRecord(); + auto data_record = GetDataRecord(kv_record); + auto status = (*snapshot_writer)->WriteRecord(data_record); EXPECT_TRUE(status.ok()) << status; // Update the record written above. - std::string value = absl::StrCat(record.value, "-updated"); - record.value = value; - record.logical_commit_time++; - status = (*snapshot_writer)->WriteRecord(record); + std::string value = + absl::StrCat(std::get(kv_record.value), "-updated"); + kv_record.value = value; + kv_record.logical_commit_time++; + status = (*snapshot_writer)->WriteRecord(data_record); EXPECT_TRUE(status.ok()) << status; status = (*snapshot_writer)->Finalize(); EXPECT_TRUE(status.ok()) << status; + DeltaRecordStreamReader record_reader(dest_stream); - testing::MockFunction record_callback; - EXPECT_CALL(record_callback, Call(record)) + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call(data_record)) .Times(1) - .WillOnce([](DeltaFileRecordStruct) { return absl::OkStatus(); }); + .WillOnce([](DataRecordStruct) { return absl::OkStatus(); }); status = record_reader.ReadRecords(record_callback.AsStdFunction()); EXPECT_TRUE(status.ok()) << status; } @@ -169,39 +191,43 @@ TEST_P(SnapshotStreamWriterTest, auto snapshot_writer = SnapshotStreamWriterTest::CreateSnapshotWriter(dest_stream); EXPECT_TRUE(snapshot_writer.ok()) << snapshot_writer.status(); - auto record = GetDeltaRecord(); - auto status = (*snapshot_writer)->WriteRecord(record); + auto kv_record = GetKVMutationRecord(); + auto data_record = GetDataRecord(kv_record); + auto status = (*snapshot_writer)->WriteRecord(data_record); EXPECT_TRUE(status.ok()) << status; // Update the record written above, but also with an older // logical_commit_time. - std::string value = absl::StrCat(record.value, "-updated"); - record.value = value; - record.logical_commit_time--; - status = (*snapshot_writer)->WriteRecord(record); + std::string value = + absl::StrCat(std::get(kv_record.value), "-updated"); + kv_record.value = value; + kv_record.logical_commit_time--; + auto old_data_record = GetDataRecord(kv_record); + status = (*snapshot_writer)->WriteRecord(data_record); EXPECT_TRUE(status.ok()) << status; status = (*snapshot_writer)->Finalize(); EXPECT_TRUE(status.ok()) << status; + DeltaRecordStreamReader record_reader(dest_stream); - testing::MockFunction record_callback; - EXPECT_CALL(record_callback, Call(GetDeltaRecord())) + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call(data_record)) .Times(1) - .WillOnce([](DeltaFileRecordStruct) { return absl::OkStatus(); }); + .WillOnce([](DataRecordStruct) { return absl::OkStatus(); }); status = record_reader.ReadRecords(record_callback.AsStdFunction()); EXPECT_TRUE(status.ok()) << status; } TEST_P(SnapshotStreamWriterTest, ValidateWritingMultipleRecordsUsingASrcStream) { - testing::MockFunction record_callback; + testing::MockFunction record_callback; std::stringstream src_stream; auto record_writer = DeltaRecordStreamWriter<>::Create( src_stream, DeltaRecordWriter::Options{.metadata = GetMetadata()}); EXPECT_TRUE(record_writer.ok()); for (std::string_view key : std::vector{"key1", "key2", "key3", "key4"}) { - auto record = GetDeltaRecord(key); - EXPECT_TRUE((*record_writer)->WriteRecord(record).ok()); - EXPECT_CALL(record_callback, Call(record)) - .WillOnce([](DeltaFileRecordStruct) { return absl::OkStatus(); }); + auto data_record = GetDataRecord(GetKVMutationRecord(key)); + EXPECT_TRUE((*record_writer)->WriteRecord(data_record).ok()); + EXPECT_CALL(record_callback, Call(data_record)) + .WillOnce([](DataRecordStruct) { return absl::OkStatus(); }); } (*record_writer)->Close(); std::stringstream dest_stream; @@ -232,6 +258,89 @@ TEST_P(SnapshotStreamWriterTest, GetSnapshotMetadata(), *metadata)); } +TEST_P(SnapshotStreamWriterTest, UdfConfig_DedupedInSnapshot) { + std::stringstream dest_stream; + auto snapshot_writer = + SnapshotStreamWriterTest::CreateSnapshotWriter(dest_stream); + EXPECT_TRUE(snapshot_writer.ok()) << snapshot_writer.status(); + auto data_record = GetDataRecord(GetUserDefinedFunctionsConfig()); + // Write the same record to snapshot 3 times. + std::vector data_records{data_record, data_record, data_record}; + for (const auto& recd : data_records) { + auto status = (*snapshot_writer)->WriteRecord(recd); + EXPECT_TRUE(status.ok()) << status; + } + auto status = (*snapshot_writer)->Finalize(); + EXPECT_TRUE(status.ok()) << status; + + DeltaRecordStreamReader record_reader(dest_stream); + testing::MockFunction record_callback; + // We expect one call to record_callback because records will be deduped in + // the snapshot stream. + EXPECT_CALL(record_callback, Call(data_record)) + .Times(1) + .WillOnce([](DataRecordStruct) { return absl::OkStatus(); }); + status = record_reader.ReadRecords(record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST_P(SnapshotStreamWriterTest, + UdfConfig_UpdatesWithLargestCommitTimestampInSnapshot) { + std::stringstream dest_stream; + auto snapshot_writer = + SnapshotStreamWriterTest::CreateSnapshotWriter(dest_stream); + EXPECT_TRUE(snapshot_writer.ok()) << snapshot_writer.status(); + auto udf_config = GetUserDefinedFunctionsConfig(); + auto data_record = GetDataRecord(udf_config); + auto status = (*snapshot_writer)->WriteRecord(data_record); + EXPECT_TRUE(status.ok()) << status; + // Update the udf config written above. + std::string handler_name = absl::StrCat(udf_config.handler_name, "-updated"); + udf_config.handler_name = handler_name; + udf_config.logical_commit_time++; + status = (*snapshot_writer)->WriteRecord(data_record); + EXPECT_TRUE(status.ok()) << status; + status = (*snapshot_writer)->Finalize(); + EXPECT_TRUE(status.ok()) << status; + + DeltaRecordStreamReader record_reader(dest_stream); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call(data_record)) + .Times(1) + .WillOnce([](DataRecordStruct) { return absl::OkStatus(); }); + status = record_reader.ReadRecords(record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + +TEST_P(SnapshotStreamWriterTest, + UdfConfig_IgnoresSameCommitTimestampInSnapshot) { + std::stringstream dest_stream; + auto snapshot_writer = + SnapshotStreamWriterTest::CreateSnapshotWriter(dest_stream); + EXPECT_TRUE(snapshot_writer.ok()) << snapshot_writer.status(); + auto udf_config = GetUserDefinedFunctionsConfig(); + auto data_record = GetDataRecord(udf_config); + auto status = (*snapshot_writer)->WriteRecord(data_record); + EXPECT_TRUE(status.ok()) << status; + // Update the udf config written above. + auto old_data_record = GetDataRecord(udf_config); + std::string handler_name = absl::StrCat(udf_config.handler_name, "-updated"); + udf_config.handler_name = handler_name; + udf_config.logical_commit_time; + status = (*snapshot_writer)->WriteRecord(old_data_record); + EXPECT_TRUE(status.ok()) << status; + status = (*snapshot_writer)->Finalize(); + EXPECT_TRUE(status.ok()) << status; + + DeltaRecordStreamReader record_reader(dest_stream); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call(data_record)) + .Times(1) + .WillOnce([](DataRecordStruct) { return absl::OkStatus(); }); + status = record_reader.ReadRecords(record_callback.AsStdFunction()); + EXPECT_TRUE(status.ok()) << status; +} + TEST(SnapshotStreamWriterTest, ValidateCreatingSnapshotWriterWithValidMetadata) { std::stringstream dest_stream; diff --git a/public/query/BUILD b/public/query/BUILD index 9db1d1c0..c02a174d 100644 --- a/public/query/BUILD +++ b/public/query/BUILD @@ -32,7 +32,7 @@ proto_library( buf_lint_test( name = "get_values_proto_lint", size = "small", - config = "//public:buf.yaml", + config = "//:buf.yaml", targets = [ ":get_values_proto", ], diff --git a/public/query/get_values.proto b/public/query/get_values.proto index 51a4bcde..5370f954 100644 --- a/public/query/get_values.proto +++ b/public/query/get_values.proto @@ -25,9 +25,7 @@ service KeyValueService { // V1 Query API as described in // https://github.com/WICG/turtledove/blob/main/FLEDGE_Key_Value_Server_API.md rpc GetValues(GetValuesRequest) returns (GetValuesResponse) { - option (google.api.http) = { - get: "/v1/getvalues" - }; + option (google.api.http) = {get: "/v1/getvalues"}; } } diff --git a/public/query/v2/BUILD b/public/query/v2/BUILD index fb923b1d..97001876 100644 --- a/public/query/v2/BUILD +++ b/public/query/v2/BUILD @@ -22,7 +22,7 @@ package(default_visibility = ["//visibility:public"]) buf_lint_test( name = "get_values_proto_lint", size = "small", - config = "//public:buf.yaml", + config = "//:buf.yaml", targets = [ ":get_values_v2_proto", ], diff --git a/public/query/v2/get_values_v2.proto b/public/query/v2/get_values_v2.proto index fe4bb640..001d3bf8 100644 --- a/public/query/v2/get_values_v2.proto +++ b/public/query/v2/get_values_v2.proto @@ -43,8 +43,7 @@ service KeyValueService { // The response will be a binary Http response. The response will return 200 // code as long as the binary Http response can be encoded. The actual status // code of the processing can be extracted from the binary Http response. - rpc BinaryHttpGetValues(BinaryHttpGetValuesRequest) - returns (google.api.HttpBody) { + rpc BinaryHttpGetValues(BinaryHttpGetValuesRequest) returns (google.api.HttpBody) { option (google.api.http) = { post: "/v2/bhttp_getvalues" body: "raw_body" @@ -52,8 +51,7 @@ service KeyValueService { } // V2 GetValues API based on the Oblivious HTTP protocol. - rpc ObliviousGetValues(ObliviousGetValuesRequest) - returns (google.api.HttpBody) { + rpc ObliviousGetValues(ObliviousGetValuesRequest) returns (google.api.HttpBody) { option (google.api.http) = { post: "/v2/oblivious_getvalues" body: "raw_body" diff --git a/public/sharding/BUILD b/public/sharding/BUILD new file mode 100644 index 00000000..d153b166 --- /dev/null +++ b/public/sharding/BUILD @@ -0,0 +1,38 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "sharding_function", + srcs = ["sharding_function.cc"], + hdrs = ["sharding_function.h"], + deps = [ + "@distributed_point_functions//pir/hashing:sha256_hash_family", + ], +) + +cc_test( + name = "sharding_function_test", + size = "small", + srcs = [ + "sharding_function_test.cc", + ], + deps = [ + ":sharding_function", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/public/sharding/sharding_function.cc b/public/sharding/sharding_function.cc new file mode 100644 index 00000000..d59b8cc9 --- /dev/null +++ b/public/sharding/sharding_function.cc @@ -0,0 +1,29 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "public/sharding/sharding_function.h" + +namespace kv_server { + +ShardingFunction::ShardingFunction(std::string seed) + : hash_function_(std::move(seed)) {} + +int ShardingFunction::GetShardNumForKey(std::string_view key, + int num_shards) const { + return hash_function_(key, num_shards); +} + +} // namespace kv_server diff --git a/public/sharding/sharding_function.h b/public/sharding/sharding_function.h new file mode 100644 index 00000000..0be0137f --- /dev/null +++ b/public/sharding/sharding_function.h @@ -0,0 +1,40 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PUBLIC_SHARDING_SHARDING_FUNCTION_H_ +#define PUBLIC_SHARDING_SHARDING_FUNCTION_H_ + +#include +#include + +#include "pir/hashing/sha256_hash_family.h" + +namespace kv_server { + +// Sharding function to assign different keys to shard numbers within the range +// [0, `num_shards`). +class ShardingFunction { + public: + explicit ShardingFunction(std::string seed); + int GetShardNumForKey(std::string_view key, int num_shards) const; + + private: + distributed_point_functions::SHA256HashFunction hash_function_; +}; + +} // namespace kv_server + +#endif // PUBLIC_SHARDING_SHARDING_FUNCTION_H_ diff --git a/public/sharding/sharding_function_test.cc b/public/sharding/sharding_function_test.cc new file mode 100644 index 00000000..122bf80e --- /dev/null +++ b/public/sharding/sharding_function_test.cc @@ -0,0 +1,32 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "public/sharding/sharding_function.h" + +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +TEST(ShardingFunctionTest, VerifyAssigningKeysToShards) { + ShardingFunction func(""); + EXPECT_EQ(5, func.GetShardNumForKey("key1", 7)); + EXPECT_EQ(6, func.GetShardNumForKey("key2", 7)); + EXPECT_EQ(1, func.GetShardNumForKey("key3", 7)); +} + +} // namespace +} // namespace kv_server diff --git a/public/udf/constants.h b/public/udf/constants.h index ba92c7f3..9fa36e45 100644 --- a/public/udf/constants.h +++ b/public/udf/constants.h @@ -17,25 +17,32 @@ namespace kv_server { constexpr char kUdfCodeSnippetKey[] = "udf_code_snippet"; constexpr char kUdfHandlerNameKey[] = "udf_handler_name"; +constexpr int64_t kDefaultLogicalCommitTime = 0; + constexpr char kDefaultUdfCodeSnippet[] = R"( - function HandleRequest(input) { - const keyGroupOutputs = []; - for (const keyGroup of input.keyGroups) { - const keyGroupOutput = {}; - keyGroupOutput.tags = keyGroup.tags; + function HandleRequest(input) { + const keyGroupOutputs = []; + for (const keyGroup of input.keyGroups) { + const keyGroupOutput = {}; + keyGroupOutput.tags = keyGroup.tags; - const kvPairs = JSON.parse(getValues(keyGroup.keyList)).kvPairs; - const keyValuesOutput = {}; - for (const key in kvPairs) { - if (kvPairs[key].hasOwnProperty("value")) { - keyValuesOutput[key] = { "value": kvPairs[key].value }; + const getValuesResult = JSON.parse(getValues(keyGroup.keyList)); + // getValuesResult returns "kvPairs" when successful and "code" on failure. + // Ignore failures and only add successful getValuesResult lookups to output. + if (getValuesResult.hasOwnProperty("kvPairs")) { + const kvPairs = getValuesResult.kvPairs; + const keyValuesOutput = {}; + for (const key in kvPairs) { + if (kvPairs[key].hasOwnProperty("value")) { + keyValuesOutput[key] = { "value": kvPairs[key].value }; + } + } + keyGroupOutput.keyValues = keyValuesOutput; + keyGroupOutputs.push(keyGroupOutput); } } - keyGroupOutput.keyValues = keyValuesOutput; - keyGroupOutputs.push(keyGroupOutput); + return {keyGroupOutputs, udfOutputApiVersion: 1}; } - return {keyGroupOutputs, udfOutputApiVersion: 1}; -} )"; constexpr char kDefaultUdfHandlerName[] = "HandleRequest"; diff --git a/testing/functionaltest/suts/baseline-dsp/data/kvserver/DELTA_0000000000000001.csv b/testing/functionaltest/suts/baseline-dsp/data/kvserver/DELTA_0000000000000001.csv index 73cb9ba9..e9da496a 100644 --- a/testing/functionaltest/suts/baseline-dsp/data/kvserver/DELTA_0000000000000001.csv +++ b/testing/functionaltest/suts/baseline-dsp/data/kvserver/DELTA_0000000000000001.csv @@ -1,3 +1,3 @@ -key,logical_commit_time,mutation_type,value -abc1,1,UPDATE,AAAAAABBBBBBCCCCCC111111 -abc2,2,UPDATE,AAAAAABBBBBBCCCCCC222222 +key,logical_commit_time,mutation_type,value,value_type +abc1,1,UPDATE,AAAAAABBBBBBCCCCCC111111,string +abc2,2,UPDATE,AAAAAABBBBBBCCCCCC222222,string diff --git a/testing/functionaltest/suts/baseline-dsp/data/kvserver/DELTA_0000000000000002.csv b/testing/functionaltest/suts/baseline-dsp/data/kvserver/DELTA_0000000000000002.csv index e7bc2346..31c0ab6d 100644 --- a/testing/functionaltest/suts/baseline-dsp/data/kvserver/DELTA_0000000000000002.csv +++ b/testing/functionaltest/suts/baseline-dsp/data/kvserver/DELTA_0000000000000002.csv @@ -1,5 +1,5 @@ -key,logical_commit_time,mutation_type,value -foo1,12,UPDATE,AAAAAAAAAAAAAAAAAA -foo2,12,UPDATE,BBBBBBBBBBBBBBBBBB -foo3,12,UPDATE,cccccccccccccccccc -foo4,12,UPDATE,DDDDDDDDDDDDDDDDDD +key,logical_commit_time,mutation_type,value,value_type +foo1,12,UPDATE,AAAAAAAAAAAAAAAAAA,string +foo2,12,UPDATE,BBBBBBBBBBBBBBBBBB,string +foo3,12,UPDATE,cccccccccccccccccc,string +foo4,12,UPDATE,DDDDDDDDDDDDDDDDDD,string diff --git a/testing/functionaltest/suts/baseline-dsp/docker-compose.yaml b/testing/functionaltest/suts/baseline-dsp/docker-compose.yaml index c18db471..264206c0 100644 --- a/testing/functionaltest/suts/baseline-dsp/docker-compose.yaml +++ b/testing/functionaltest/suts/baseline-dsp/docker-compose.yaml @@ -22,6 +22,7 @@ services: - --delta_directory=/srvdata/deltas - --realtime_directory=/srvdata/realtime_data - --mode=DSP + - --internal_lookup_deadline_duration=1s hostname: kv-server networks: - kvserver-net diff --git a/testing/functionaltest/suts/baseline-ssp/docker-compose.yaml b/testing/functionaltest/suts/baseline-ssp/docker-compose.yaml index b3df0ec6..5f4c6cb1 100644 --- a/testing/functionaltest/suts/baseline-ssp/docker-compose.yaml +++ b/testing/functionaltest/suts/baseline-ssp/docker-compose.yaml @@ -22,6 +22,7 @@ services: - --delta_directory=/srvdata/deltas - --realtime_directory=/srvdata/realtime_data - --mode=SSP + - --internal_lookup_deadline_duration=1s hostname: kv-server networks: - kvserver-net diff --git a/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-dsp1/DELTA_0000000000000001.csv b/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-dsp1/DELTA_0000000000000001.csv index 73cb9ba9..e9da496a 100644 --- a/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-dsp1/DELTA_0000000000000001.csv +++ b/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-dsp1/DELTA_0000000000000001.csv @@ -1,3 +1,3 @@ -key,logical_commit_time,mutation_type,value -abc1,1,UPDATE,AAAAAABBBBBBCCCCCC111111 -abc2,2,UPDATE,AAAAAABBBBBBCCCCCC222222 +key,logical_commit_time,mutation_type,value,value_type +abc1,1,UPDATE,AAAAAABBBBBBCCCCCC111111,string +abc2,2,UPDATE,AAAAAABBBBBBCCCCCC222222,string diff --git a/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-dsp2/DELTA_0000000000000002.csv b/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-dsp2/DELTA_0000000000000002.csv index e7bc2346..31c0ab6d 100644 --- a/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-dsp2/DELTA_0000000000000002.csv +++ b/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-dsp2/DELTA_0000000000000002.csv @@ -1,5 +1,5 @@ -key,logical_commit_time,mutation_type,value -foo1,12,UPDATE,AAAAAAAAAAAAAAAAAA -foo2,12,UPDATE,BBBBBBBBBBBBBBBBBB -foo3,12,UPDATE,cccccccccccccccccc -foo4,12,UPDATE,DDDDDDDDDDDDDDDDDD +key,logical_commit_time,mutation_type,value,value_type +foo1,12,UPDATE,AAAAAAAAAAAAAAAAAA,string +foo2,12,UPDATE,BBBBBBBBBBBBBBBBBB,string +foo3,12,UPDATE,cccccccccccccccccc,string +foo4,12,UPDATE,DDDDDDDDDDDDDDDDDD,string diff --git a/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-ssp/DELTA_0000000000000001.csv b/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-ssp/DELTA_0000000000000001.csv index 73cb9ba9..e9da496a 100644 --- a/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-ssp/DELTA_0000000000000001.csv +++ b/testing/functionaltest/suts/multiple-kv-servers/data/kvserver-ssp/DELTA_0000000000000001.csv @@ -1,3 +1,3 @@ -key,logical_commit_time,mutation_type,value -abc1,1,UPDATE,AAAAAABBBBBBCCCCCC111111 -abc2,2,UPDATE,AAAAAABBBBBBCCCCCC222222 +key,logical_commit_time,mutation_type,value,value_type +abc1,1,UPDATE,AAAAAABBBBBBCCCCCC111111,string +abc2,2,UPDATE,AAAAAABBBBBBCCCCCC222222,string diff --git a/testing/functionaltest/suts/multiple-kv-servers/docker-compose.yaml b/testing/functionaltest/suts/multiple-kv-servers/docker-compose.yaml index 7787c4e9..999044a9 100644 --- a/testing/functionaltest/suts/multiple-kv-servers/docker-compose.yaml +++ b/testing/functionaltest/suts/multiple-kv-servers/docker-compose.yaml @@ -24,6 +24,7 @@ services: - --delta_directory=/srvdata/deltas - --realtime_directory=/srvdata/realtime_data - --mode=SSP + - --internal_lookup_deadline_duration=1s hostname: kv-server networks: - kvserver-ssp-net @@ -66,6 +67,7 @@ services: - --delta_directory=/srvdata/deltas - --realtime_directory=/srvdata/realtime_data - --mode=DSP + - --internal_lookup_deadline_duration=1s hostname: kv-server networks: - kvserver-dsp1-net @@ -108,6 +110,7 @@ services: - --delta_directory=/srvdata/deltas - --realtime_directory=/srvdata/realtime_data - --mode=DSP + - --internal_lookup_deadline_duration=1s hostname: kv-server networks: - kvserver-dsp2-net diff --git a/third_party/aws_c_common.BUILD b/third_party/aws_c_common.BUILD deleted file mode 100644 index 6401df2d..00000000 --- a/third_party/aws_c_common.BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -# Description: -# AWS C Common -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "aws-c-common", - srcs = glob([ - "include/aws/common/*.h", - "include/aws/common/private/*.h", - "source/*.c", - "source/posix/*.c", - ]), - hdrs = [ - "include/aws/common/config.h", - ], - defines = [], - includes = [ - "include", - ], - textual_hdrs = glob([ - "include/**/*.inl", - ]), - deps = [], -) - -genrule( - name = "config_h", - srcs = [ - "include/aws/common/config.h.in", - ], - outs = [ - "include/aws/common/config.h", - ], - cmd = "sed 's/cmakedefine/undef/g' $< > $@", -) diff --git a/third_party/aws_c_event_stream.BUILD b/third_party/aws_c_event_stream.BUILD deleted file mode 100644 index 3d9c05e8..00000000 --- a/third_party/aws_c_event_stream.BUILD +++ /dev/null @@ -1,27 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -# Description: -# AWS C Event Stream -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "aws-c-event-stream", - srcs = glob([ - "include/**/*.h", - "source/**/*.c", - ]), - hdrs = [ - ], - defines = [], - includes = [ - "include", - ], - deps = [ - "@aws-c-common", - "@aws-checksums", - ], -) diff --git a/third_party/aws_checksums.BUILD b/third_party/aws_checksums.BUILD deleted file mode 100644 index 5132d648..00000000 --- a/third_party/aws_checksums.BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -# Description: -# AWS CheckSums -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "aws-checksums", - srcs = glob([ - "include/aws/checksums/*.h", - "include/aws/checksums/private/*.h", - "source/*.c", - ]) + [ - "crc_hw.c", - ], - hdrs = [], - defines = [], - includes = [ - "include", - ], - deps = [], -) - -genrule( - name = "crc_hw_c", - outs = ["crc_hw.c"], - cmd_bash = """cat <'$@' -#include -#include -int aws_checksums_do_cpu_id(int32_t* cpuid) { - return 0; -} -uint32_t aws_checksums_crc32c_hw(const uint8_t* input, int length, uint32_t previousCrc32) { - return aws_checksums_crc32c_sw(input, length, previousCrc32); -} -EOF""", -) diff --git a/third_party/aws_sdk_cpp.BUILD b/third_party/aws_sdk_cpp.BUILD deleted file mode 100644 index c13e3dae..00000000 --- a/third_party/aws_sdk_cpp.BUILD +++ /dev/null @@ -1,241 +0,0 @@ -# Description: -# AWS C++ SDK - -load("@rules_cc//cc:defs.bzl", "cc_library") - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "core", - srcs = glob([ - "aws-cpp-sdk-core/source/*.cpp", # AWS_SOURCE - "aws-cpp-sdk-core/source/external/tinyxml2/*.cpp", # AWS_TINYXML2_SOURCE - "aws-cpp-sdk-core/source/external/cjson/*.cpp", # CJSON_SOURCE - "aws-cpp-sdk-core/source/auth/*.cpp", # AWS_AUTH_SOURCE - "aws-cpp-sdk-core/source/client/*.cpp", # AWS_CLIENT_SOURCE - "aws-cpp-sdk-core/source/internal/*.cpp", # AWS_INTERNAL_SOURCE - "aws-cpp-sdk-core/source/aws/model/*.cpp", # AWS_MODEL_SOURCE - "aws-cpp-sdk-core/source/http/*.cpp", # HTTP_SOURCE - "aws-cpp-sdk-core/source/http/standard/*.cpp", # HTTP_STANDARD_SOURCE - "aws-cpp-sdk-core/source/config/*.cpp", # CONFIG_SOURCE - "aws-cpp-sdk-core/source/monitoring/*.cpp", # MONITORING_SOURCE - "aws-cpp-sdk-core/source/net/linux-shared/*.cpp", # NET_SOURCE - "aws-cpp-sdk-core/source/platform/linux-shared/*.cpp", # PLATFORM_LINUX_SHARED_SOURCE - "aws-cpp-sdk-core/source/utils/*.cpp", # UTILS_SOURCE - "aws-cpp-sdk-core/source/utils/event/*.cpp", # UTILS_EVENT_SOURCE - "aws-cpp-sdk-core/source/utils/base64/*.cpp", # UTILS_BASE64_SOURCE - "aws-cpp-sdk-core/source/utils/crypto/*.cpp", # UTILS_CRYPTO_SOURCE - "aws-cpp-sdk-core/source/utils/json/*.cpp", # UTILS_JSON_SOURCE - "aws-cpp-sdk-core/source/utils/threading/*.cpp", # UTILS_THREADING_SOURCE - "aws-cpp-sdk-core/source/utils/xml/*.cpp", # UTILS_XML_SOURCE - "aws-cpp-sdk-core/source/utils/logging/*.cpp", # UTILS_LOGGING_SOURCE - "aws-cpp-sdk-core/source/utils/memory/*.cpp", # UTILS_MEMORY_SOURCE - "aws-cpp-sdk-core/source/utils/memory/stl/*.cpp", # UTILS_MEMORY_STL_SOURCE - "aws-cpp-sdk-core/source/utils/stream/*.cpp", # UTILS_STREAM_SOURCE - "aws-cpp-sdk-core/source/utils/crypto/factory/*.cpp", # UTILS_CRYPTO_FACTORY_SOURCE - "aws-cpp-sdk-core/source/http/curl/*.cpp", # HTTP_CURL_CLIENT_SOURCE - "aws-cpp-sdk-core/source/utils/crypto/openssl/*.cpp", # UTILS_CRYPTO_OPENSSL_SOURCE - ]), - hdrs = [ - "aws-cpp-sdk-core/include/aws/core/SDKConfig.h", - ] + glob([ - "aws-cpp-sdk-core/include/aws/core/*.h", # AWS_HEADERS - "aws-cpp-sdk-core/include/aws/core/auth/*.h", # AWS_AUTH_HEADERS - "aws-cpp-sdk-core/include/aws/core/client/*.h", # AWS_CLIENT_HEADERS - "aws-cpp-sdk-core/include/aws/core/internal/*.h", # AWS_INTERNAL_HEADERS - "aws-cpp-sdk-core/include/aws/core/net/*.h", # NET_HEADERS - "aws-cpp-sdk-core/include/aws/core/http/*.h", # HTTP_HEADERS - "aws-cpp-sdk-core/include/aws/core/http/standard/*.h", # HTTP_STANDARD_HEADERS - "aws-cpp-sdk-core/include/aws/core/config/*.h", # CONFIG_HEADERS - "aws-cpp-sdk-core/include/aws/core/monitoring/*.h", # MONITORING_HEADERS - "aws-cpp-sdk-core/include/aws/core/platform/*.h", # PLATFORM_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/*.h", # UTILS_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/event/*.h", # UTILS_EVENT_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/base64/*.h", # UTILS_BASE64_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/crypto/*.h", # UTILS_CRYPTO_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/json/*.h", # UTILS_JSON_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/threading/*.h", # UTILS_THREADING_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/xml/*.h", # UTILS_XML_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/memory/*.h", # UTILS_MEMORY_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/memory/stl/*.h", # UTILS_STL_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/logging/*.h", # UTILS_LOGGING_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/ratelimiter/*.h", # UTILS_RATE_LIMITER_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/stream/*.h", # UTILS_STREAM_HEADERS - "aws-cpp-sdk-core/include/aws/core/external/cjson/*.h", # CJSON_HEADERS - "aws-cpp-sdk-core/include/aws/core/external/tinyxml2/*.h", # TINYXML2_HEADERS - "aws-cpp-sdk-core/include/aws/core/http/curl/*.h", # HTTP_CURL_CLIENT_HEADERS - "aws-cpp-sdk-core/include/aws/core/utils/crypto/openssl/*.h", # UTILS_CRYPTO_OPENSSL_HEADERS - ]), - defines = [ - 'AWS_SDK_VERSION_STRING=\\"1.8.186\\"', - "AWS_SDK_VERSION_MAJOR=1", - "AWS_SDK_VERSION_MINOR=8", - "AWS_SDK_VERSION_PATCH=186", - "ENABLE_OPENSSL_ENCRYPTION=1", - "ENABLE_CURL_CLIENT=1", - "OPENSSL_IS_BORINGSSL=1", - "PLATFORM_LINUX", - ], - includes = [ - "aws-cpp-sdk-core/include", - ], - deps = [ - "@aws-c-event-stream", - "@boringssl//:crypto", - "@boringssl//:ssl", - "@curl", - ], -) - -cc_library( - name = "ec2", - srcs = glob([ - "aws-cpp-sdk-ec2/source/*.cpp", # AWS_EC2_SOURCE - "aws-cpp-sdk-ec2/source/model/*.cpp", # AWS_EC2_MODEL_SOURCE - ]), - hdrs = glob([ - "aws-cpp-sdk-ec2/include/aws/ec2/*.h", # AWS_EC2_HEADERS - "aws-cpp-sdk-ec2/include/aws/ec2/model/*.h", # AWS_EC2_MODEL_HEADERS - ]), - includes = [ - "aws-cpp-sdk-ec2/include", - ], - deps = [ - ":core", - ], -) - -cc_library( - name = "s3", - srcs = glob([ - "aws-cpp-sdk-s3/source/*.cpp", # AWS_S3_SOURCE - "aws-cpp-sdk-s3/source/model/*.cpp", # AWS_S3_MODEL_SOURCE - ]), - hdrs = glob([ - "aws-cpp-sdk-s3/include/aws/s3/*.h", # AWS_S3_HEADERS - "aws-cpp-sdk-s3/include/aws/s3/model/*.h", # AWS_S3_MODEL_HEADERS - ]), - includes = [ - "aws-cpp-sdk-s3/include", - ], - deps = [ - ":core", - ], -) - -cc_library( - name = "sns", - srcs = glob([ - "aws-cpp-sdk-sns/source/*.cpp", # AWS_SNS_SOURCE - "aws-cpp-sdk-sns/source/model/*.cpp", # AWS_SNS_MODEL_SOURCE - ]), - hdrs = glob([ - "aws-cpp-sdk-sns/include/aws/sns/*.h", # AWS_SNS_HEADERS - "aws-cpp-sdk-sns/include/aws/sns/model/*.h", # AWS_SNS_MODEL_HEADERS - ]), - includes = [ - "aws-cpp-sdk-sns/include", - ], - deps = [ - ":core", - ], -) - -cc_library( - name = "sqs", - srcs = glob([ - "aws-cpp-sdk-sqs/source/*.cpp", # AWS_SQS_SOURCE - "aws-cpp-sdk-sqs/source/model/*.cpp", # AWS_SQS_MODEL_SOURCE - ]), - hdrs = glob([ - "aws-cpp-sdk-sqs/include/aws/sqs/*.h", # AWS_SQS_HEADERS - "aws-cpp-sdk-sqs/include/aws/sqs/model/*.h", # AWS_SQS_MODEL_HEADERS - ]), - includes = [ - "aws-cpp-sdk-sqs/include", - ], - deps = [ - ":core", - ], -) - -cc_library( - name = "ssm", - srcs = glob([ - "aws-cpp-sdk-ssm/source/*.cpp", # AWS_SSM_SOURCE - "aws-cpp-sdk-ssm/source/model/*.cpp", # AWS_SSM_MODEL_SOURCE - ]), - hdrs = glob([ - "aws-cpp-sdk-ssm/include/aws/ssm/*.h", # AWS_SSM_HEADERS - "aws-cpp-sdk-ssm/include/aws/ssm/model/*.h", # AWS_SSM_MODEL_HEADERS - ]), - includes = [ - "aws-cpp-sdk-ssm/include", - ], - deps = [ - ":core", - ], -) - -cc_library( - name = "transfer", - srcs = glob([ - "aws-cpp-sdk-transfer/source/transfer/*.cpp", # TRANSFER_SOURCE - ]), - hdrs = glob([ - "aws-cpp-sdk-transfer/include/aws/transfer/*.h", # TRANSFER_HEADERS - ]), - includes = [ - "aws-cpp-sdk-transfer/include", - ], - deps = [ - ":core", - ":s3", - ], -) - -cc_library( - name = "kinesis", - srcs = glob([ - "aws-cpp-sdk-kinesis/source/*.cpp", # AWS_KINESIS_SOURCE - "aws-cpp-sdk-kinesis/source/model/*.cpp", # AWS_KINESIS_MODEL_SOURCE - ]), - hdrs = glob([ - "aws-cpp-sdk-kinesis/include/aws/kinesis/*.h", # AWS_KINESIS_HEADERS - "aws-cpp-sdk-kinesis/include/aws/kinesis/model/*.h", # AWS_KINESIS_MODEL_HEADERS - ]), - includes = [ - "aws-cpp-sdk-kinesis/include", - ], - deps = [ - ":core", - ], -) - -cc_library( - name = "autoscaling", - srcs = glob([ - "aws-cpp-sdk-autoscaling/source/*.cpp", # AWS_AUTOSCALING_SOURCE - "aws-cpp-sdk-autoscaling/source/model/*.cpp", # AWS_AUTOSCALING_MODEL_SOURCE - ]), - hdrs = glob([ - "aws-cpp-sdk-autoscaling/include/aws/autoscaling/*.h", # AWS_AUTOSCALING_HEADERS - "aws-cpp-sdk-autoscaling/include/aws/autoscaling/model/*.h", # AWS_AUTOSCALING_MODEL_HEADERS - ]), - includes = [ - "aws-cpp-sdk-autoscaling/include", - ], - deps = [ - ":core", - ], -) - -genrule( - name = "SDKConfig_h", - outs = ["aws-cpp-sdk-core/include/aws/core/SDKConfig.h"], - cmd_bash = "touch '$@'", -) diff --git a/third_party/cpp_repositories.bzl b/third_party/cpp_repositories.bzl index a08f8a6b..d5c32623 100644 --- a/third_party/cpp_repositories.bzl +++ b/third_party/cpp_repositories.bzl @@ -13,74 +13,10 @@ # limitations under the License. load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("@google_privacysandbox_servers_common//:cpp_deps.bzl", shared_cpp_dependencies = "cpp_dependencies") def cpp_repositories(): """Entry point for all external repositories used for C++/C dependencies.""" - shared_cpp_dependencies() - - http_archive( - name = "aws-checksums", - build_file = "//third_party:aws_checksums.BUILD", - sha256 = "6e6bed6f75cf54006b6bafb01b3b96df19605572131a2260fddaf0e87949ced0", - strip_prefix = "aws-checksums-0.1.5", - urls = [ - "https://github.com/awslabs/aws-checksums/archive/v0.1.5.tar.gz", - ], - ) - - http_archive( - name = "aws-c-common", - build_file = "//third_party:aws_c_common.BUILD", - sha256 = "01c2a58553a37b3aa5914d9e0bf7bf14507ff4937bc5872a678892ca20fcae1f", - strip_prefix = "aws-c-common-0.4.29", - urls = [ - "https://github.com/awslabs/aws-c-common/archive/v0.4.29.tar.gz", - ], - ) - - http_archive( - name = "aws-c-event-stream", - build_file = "//third_party:aws_c_event_stream.BUILD", - sha256 = "31d880d1c868d3f3df1e1f4b45e56ac73724a4dc3449d04d47fc0746f6f077b6", - strip_prefix = "aws-c-event-stream-0.1.4", - urls = [ - "https://github.com/awslabs/aws-c-event-stream/archive/v0.1.4.tar.gz", - ], - ) - - http_archive( - name = "aws_sdk_cpp", - build_file = "//third_party:aws_sdk_cpp.BUILD", - patch_cmds = [ - """sed -i.bak 's/UUID::RandomUUID/Aws::Utils::UUID::RandomUUID/g' aws-cpp-sdk-core/source/client/AWSClient.cpp""", - # Apply fix in https://github.com/aws/aws-sdk-cpp/commit/9669a1c1d9a96621cd0846679cbe973c648a64b3 - """sed -i.bak 's/Tags\\.entry/Tag/g' aws-cpp-sdk-sqs/source/model/TagQueueRequest.cpp""", - ], - sha256 = "749322a8be4594472512df8a21d9338d7181c643a00e08a0ff12f07e831e3346", - strip_prefix = "aws-sdk-cpp-1.8.186", - urls = [ - "https://github.com/aws/aws-sdk-cpp/archive/1.8.186.tar.gz", - ], - ) - - http_archive( - name = "curl", - build_file = "//third_party:curl.BUILD", - sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6", - strip_prefix = "curl-7.49.1", - urls = ["https://mirror.bazel.build/curl.haxx.se/download/curl-7.49.1.tar.gz"], - ) - - http_archive( - name = "zlib_archive", - build_file = "//third_party:zlib.BUILD", - sha256 = "91844808532e5ce316b3c010929493c0244f3d37593afd6de04f71821d5136d9", - strip_prefix = "zlib-1.2.12", - urls = ["https://mirror.bazel.build/zlib.net/zlib-1.2.12.tar.gz"], - ) - #riegeli http_archive( name = "com_google_riegeli", diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD deleted file mode 100644 index 8f11146b..00000000 --- a/third_party/curl.BUILD +++ /dev/null @@ -1,266 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") - -# Description: -# curl is a tool for talking to web servers. -licenses(["notice"]) # MIT/X derivative license - -cc_library( - name = "curl", - srcs = ["include/curl_config.h"] + glob([ - "lib/**/*.h", - "lib/**/*.c", - ]), - hdrs = glob(["include/curl/*.h"]) + [":configure"], - copts = [ - "-Iexternal/curl/lib", - "-D_GNU_SOURCE", - "-DHAVE_CONFIG_H", - "-DCURL_DISABLE_FTP", - "-DCURL_DISABLE_NTLM", # turning it off in configure is not enough - "-DHAVE_LIBZ", - "-DHAVE_ZLIB_H", - "-Wno-string-plus-int", - "-DCURL_MAX_WRITE_SIZE=65536", - ], - defines = ["CURL_STATICLIB"], - includes = ["include"], - linkopts = ["-lrt"], - visibility = ["//visibility:public"], - deps = [ - "@boringssl//:ssl", - "@zlib", - ], -) - -cc_binary( - name = "curl_bin", - srcs = glob([ - "src/*.h", - "src/*.c", - ]) + ["lib/config-win32.h"], - copts = [ - "-Iexternal/curl/lib", - "-D_GNU_SOURCE", - "-DHAVE_CONFIG_H", - "-DCURL_DISABLE_LIBCURL_OPTION", - "-Wno-string-plus-int", - ], - deps = [":curl"], -) - -genrule( - name = "configure", - outs = ["include/curl_config.h"], - cmd_bash = """cat <'$@' -#ifndef EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_ -#define EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_ -#include -#if defined(OPENSSL_IS_BORINGSSL) -# define HAVE_BORINGSSL 1 -#endif -#define CURL_CA_BUNDLE "/etc/ssl/certs/ca-certificates.crt" -#define GETSERVBYPORT_R_ARGS 6 -#define GETSERVBYPORT_R_BUFSIZE 4096 -#define HAVE_CLOCK_GETTIME_MONOTONIC 1 -#define HAVE_CRYPTO_CLEANUP_ALL_EX_DATA 1 -#define HAVE_FSETXATTR_5 1 -#define HAVE_GETHOSTBYADDR_R 1 -#define HAVE_GETHOSTBYADDR_R_8 1 -#define HAVE_GETHOSTBYNAME_R 1 -#define HAVE_GETHOSTBYNAME_R_6 1 -#define HAVE_GETSERVBYPORT_R 1 -#define HAVE_LIBSSL 1 -#define HAVE_MALLOC_H 1 -#define HAVE_MSG_NOSIGNAL 1 -#define HAVE_OPENSSL_CRYPTO_H 1 -#define HAVE_OPENSSL_ERR_H 1 -#define HAVE_OPENSSL_PEM_H 1 -#define HAVE_OPENSSL_PKCS12_H 1 -#define HAVE_OPENSSL_RSA_H 1 -#define HAVE_OPENSSL_SSL_H 1 -#define HAVE_OPENSSL_X509_H 1 -#define HAVE_RAND_EGD 1 -#define HAVE_RAND_STATUS 1 -#define HAVE_SSL_GET_SHUTDOWN 1 -#define HAVE_TERMIOS_H 1 -#define OS "x86_64-pc-linux-gnu" -#define RANDOM_FILE "/dev/urandom" -#define USE_OPENSSL 1 -#define CURL_DISABLE_DICT 1 -#define CURL_DISABLE_FILE 1 -#define CURL_DISABLE_GOPHER 1 -#define CURL_DISABLE_IMAP 1 -#define CURL_DISABLE_LDAP 1 -#define CURL_DISABLE_LDAPS 1 -#define CURL_DISABLE_POP3 1 -#define CURL_DISABLE_SMTP 1 -#define CURL_DISABLE_TELNET 1 -#define CURL_DISABLE_TFTP 1 -#define CURL_EXTERN_SYMBOL __attribute__ ((__visibility__ ("default"))) -#define ENABLE_IPV6 1 -#define GETHOSTNAME_TYPE_ARG2 size_t -#define GETNAMEINFO_QUAL_ARG1 const -#define GETNAMEINFO_TYPE_ARG1 struct sockaddr * -#define GETNAMEINFO_TYPE_ARG2 socklen_t -#define GETNAMEINFO_TYPE_ARG46 socklen_t -#define GETNAMEINFO_TYPE_ARG7 int -#define HAVE_ALARM 1 -#define HAVE_ALLOCA_H 1 -#define HAVE_ARPA_INET_H 1 -#define HAVE_ARPA_TFTP_H 1 -#define HAVE_ASSERT_H 1 -#define HAVE_BASENAME 1 -#define HAVE_BOOL_T 1 -#define HAVE_CONNECT 1 -#define HAVE_DLFCN_H 1 -#define HAVE_ERRNO_H 1 -#define HAVE_FCNTL 1 -#define HAVE_FCNTL_H 1 -#define HAVE_FCNTL_O_NONBLOCK 1 -#define HAVE_FDOPEN 1 -#define HAVE_FORK 1 -#define HAVE_FREEADDRINFO 1 -#define HAVE_FREEIFADDRS 1 -#define HAVE_FSETXATTR 1 -#define HAVE_FTRUNCATE 1 -#define HAVE_GAI_STRERROR 1 -#define HAVE_GETADDRINFO 1 -#define HAVE_GETADDRINFO_THREADSAFE 1 -#define HAVE_GETEUID 1 -#define HAVE_GETHOSTBYADDR 1 -#define HAVE_GETHOSTBYNAME 1 -#define HAVE_GETHOSTNAME 1 -#define HAVE_GETIFADDRS 1 -#define HAVE_GETNAMEINFO 1 -#define HAVE_GETPPID 1 -#define HAVE_GETPROTOBYNAME 1 -#define HAVE_GETPWUID 1 -#define HAVE_GETPWUID_R 1 -#define HAVE_GETRLIMIT 1 -#define HAVE_GETTIMEOFDAY 1 -#define HAVE_GMTIME_R 1 -#define HAVE_IFADDRS_H 1 -#define HAVE_IF_NAMETOINDEX 1 -#define HAVE_INET_ADDR 1 -#define HAVE_INET_NTOP 1 -#define HAVE_INET_PTON 1 -#define HAVE_INTTYPES_H 1 -#define HAVE_IOCTL 1 -#define HAVE_IOCTL_FIONBIO 1 -#define HAVE_IOCTL_SIOCGIFADDR 1 -#define HAVE_LIBGEN_H 1 -#define HAVE_LIBZ 1 -#define HAVE_LIMITS_H 1 -#define HAVE_LL 1 -#define HAVE_LOCALE_H 1 -#define HAVE_LOCALTIME_R 1 -#define HAVE_LONGLONG 1 -#define HAVE_MEMORY_H 1 -#define HAVE_NETDB_H 1 -#define HAVE_NETINET_IN_H 1 -#define HAVE_NETINET_TCP_H 1 -#define HAVE_NET_IF_H 1 -#define HAVE_PERROR 1 -#define HAVE_PIPE 1 -#define HAVE_POLL 1 -#define HAVE_POLL_FINE 1 -#define HAVE_POLL_H 1 -#define HAVE_POSIX_STRERROR_R 1 -#define HAVE_PWD_H 1 -#define HAVE_RECV 1 -#define HAVE_SELECT 1 -#define HAVE_SEND 1 -#define HAVE_SETJMP_H 1 -#define HAVE_SETLOCALE 1 -#define HAVE_SETRLIMIT 1 -#define HAVE_SETSOCKOPT 1 -#define HAVE_SGTTY_H 1 -#define HAVE_SIGACTION 1 -#define HAVE_SIGINTERRUPT 1 -#define HAVE_SIGNAL 1 -#define HAVE_SIGNAL_H 1 -#define HAVE_SIGSETJMP 1 -#define HAVE_SIG_ATOMIC_T 1 -#define HAVE_SOCKADDR_IN6_SIN6_SCOPE_ID 1 -#define HAVE_SOCKET 1 -#define HAVE_SOCKETPAIR 1 -#define HAVE_STDBOOL_H 1 -#define HAVE_STDINT_H 1 -#define HAVE_STDIO_H 1 -#define HAVE_STDLIB_H 1 -#define HAVE_STRCASECMP 1 -#define HAVE_STRDUP 1 -#define HAVE_STRERROR_R 1 -#define HAVE_STRINGS_H 1 -#define HAVE_STRING_H 1 -#define HAVE_STRNCASECMP 1 -#define HAVE_STRSTR 1 -#define HAVE_STRTOK_R 1 -#define HAVE_STRTOLL 1 -#define HAVE_STRUCT_SOCKADDR_STORAGE 1 -#define HAVE_STRUCT_TIMEVAL 1 -#define HAVE_SYS_IOCTL_H 1 -#define HAVE_SYS_PARAM_H 1 -#define HAVE_SYS_POLL_H 1 -#define HAVE_SYS_RESOURCE_H 1 -#define HAVE_SYS_SELECT_H 1 -#define HAVE_SYS_SOCKET_H 1 -#define HAVE_SYS_STAT_H 1 -#define HAVE_SYS_TIME_H 1 -#define HAVE_SYS_TYPES_H 1 -#define HAVE_SYS_UIO_H 1 -#define HAVE_SYS_UN_H 1 -#define HAVE_SYS_WAIT_H 1 -#define HAVE_SYS_XATTR_H 1 -#define HAVE_TIME_H 1 -#define HAVE_UNAME 1 -#define HAVE_UNISTD_H 1 -#define HAVE_UTIME 1 -#define HAVE_UTIME_H 1 -#define HAVE_VARIADIC_MACROS_C99 1 -#define HAVE_VARIADIC_MACROS_GCC 1 -#define HAVE_WRITABLE_ARGV 1 -#define HAVE_WRITEV 1 -#define HAVE_ZLIB_H 1 -#define LT_OBJDIR ".libs/" -#define PACKAGE "curl" -#define PACKAGE_BUGREPORT "a suitable curl mailing list: https://curl.haxx.se/mail/" -#define PACKAGE_NAME "curl" -#define PACKAGE_STRING "curl -" -#define PACKAGE_TARNAME "curl" -#define PACKAGE_URL "" -#define PACKAGE_VERSION "-" -#define RECV_TYPE_ARG1 int -#define RECV_TYPE_ARG2 void * -#define RECV_TYPE_ARG3 size_t -#define RECV_TYPE_ARG4 int -#define RECV_TYPE_RETV ssize_t -#define RETSIGTYPE void -#define SELECT_QUAL_ARG5 -#define SELECT_TYPE_ARG1 int -#define SELECT_TYPE_ARG234 fd_set * -#define SELECT_TYPE_ARG5 struct timeval * -#define SELECT_TYPE_RETV int -#define SEND_QUAL_ARG2 const -#define SEND_TYPE_ARG1 int -#define SEND_TYPE_ARG2 void * -#define SEND_TYPE_ARG3 size_t -#define SEND_TYPE_ARG4 int -#define SEND_TYPE_RETV ssize_t -#define SIZEOF_INT 4 -#define SIZEOF_LONG 8 -#define SIZEOF_OFF_T 8 -#define SIZEOF_SHORT 2 -#define SIZEOF_SIZE_T 8 -#define SIZEOF_TIME_T 8 -#define SIZEOF_VOIDP 8 -#define STDC_HEADERS 1 -#define STRERROR_R_TYPE_ARG3 size_t -#define TIME_WITH_SYS_TIME 1 -#define VERSION "-" -#ifndef _DARWIN_USE_64_BIT_INODE -# define _DARWIN_USE_64_BIT_INODE 1 -#endif -#endif // EXTERNAL_CURL_INCLUDE_CURL_CONFIG_H_ -EOF""", -) diff --git a/third_party/highwayhash.BUILD b/third_party/highwayhash.BUILD index ac0c6350..568b8c2b 100644 --- a/third_party/highwayhash.BUILD +++ b/third_party/highwayhash.BUILD @@ -259,8 +259,8 @@ cc_library( deps = [ ":arch_specific", ":compiler_specific", - ":hh_types", ":hh_portable", + ":hh_types", ] + select({ ":cpu_ppc": [":hh_vsx"], "//conditions:default": [ diff --git a/third_party/quiche.bzl b/third_party/quiche.bzl deleted file mode 100644 index 2038d322..00000000 --- a/third_party/quiche.bzl +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - -def quiche_dependencies(): - http_archive( - name = "com_github_google_quiche", - urls = ["https://github.com/google/quiche/archive/c06013fca03cc95f662cb3b09ad582b0336258aa.tar.gz"], - strip_prefix = "quiche-c06013fca03cc95f662cb3b09ad582b0336258aa", - ) - - http_archive( - name = "com_google_quic_trace", - sha256 = "079331de8c3cbf145a3b57adb3ad4e73d733ecfa84d3486e1c5a9eaeef286549", # Last updated 2022-05-18 - strip_prefix = "quic-trace-c7b993eb750e60c307e82f75763600d9c06a6de1", - urls = ["https://github.com/google/quic-trace/archive/c7b993eb750e60c307e82f75763600d9c06a6de1.tar.gz"], - ) - - http_archive( - name = "com_google_googleurl", - sha256 = "a1bc96169d34dcc1406ffb750deef3bc8718bd1f9069a2878838e1bd905de989", - urls = ["https://storage.googleapis.com/quiche-envoy-integration/googleurl_9cdb1f4d1a365ebdbcbf179dadf7f8aa5ee802e7.tar.gz"], - ) diff --git a/third_party/zlib.BUILD b/third_party/zlib.BUILD deleted file mode 100644 index 6ac0234f..00000000 --- a/third_party/zlib.BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -package(default_visibility = ["//visibility:public"]) - -cc_library( - name = "zlib", - srcs = glob([ - "*.c", - "*.h", - ]), - hdrs = ["zlib.h"], - copts = [ - "-Wno-shift-negative-value", - "-DZ_HAVE_UNISTD_H", - ], - includes = ["."], -) diff --git a/tools/bidding_auction_data_generator/bidding_auction_data_cli.cc b/tools/bidding_auction_data_generator/bidding_auction_data_cli.cc index 8badf45a..420f9eba 100644 --- a/tools/bidding_auction_data_generator/bidding_auction_data_cli.cc +++ b/tools/bidding_auction_data_generator/bidding_auction_data_cli.cc @@ -49,8 +49,8 @@ ABSL_FLAG(int64_t, logical_commit_time, absl::ToUnixMicros(absl::Now()), ABSL_FLAG(int, num_keys_per_batch, 50, "The number of keys in one batch of http request"); -constexpr kv_server::DeltaMutationType kMutationType = - kv_server::DeltaMutationType::Update; +constexpr kv_server::KeyValueMutationType kMutationType = + kv_server::KeyValueMutationType::Update; using kv_server::HttpValueRetriever; using Output = std::vector>; diff --git a/tools/bidding_auction_data_generator/delta_key_value_writer.cc b/tools/bidding_auction_data_generator/delta_key_value_writer.cc index 7daaceed..22970d3f 100644 --- a/tools/bidding_auction_data_generator/delta_key_value_writer.cc +++ b/tools/bidding_auction_data_generator/delta_key_value_writer.cc @@ -34,15 +34,19 @@ DeltaKeyValueWriter::Create(std::ostream& output_stream) { absl::Status DeltaKeyValueWriter::Write( const absl::flat_hash_map& key_value_map, - int64_t logical_commit_time, DeltaMutationType mutation_type) { + int64_t logical_commit_time, KeyValueMutationType mutation_type) { for (const auto& [k, v] : key_value_map) { - DeltaFileRecordStruct record_struct; - record_struct.key = k; - record_struct.value = v; - record_struct.logical_commit_time = logical_commit_time; - record_struct.mutation_type = mutation_type; + KeyValueMutationRecordStruct kv_mutation_struct; + kv_mutation_struct.key = k; + kv_mutation_struct.value = v; + kv_mutation_struct.logical_commit_time = logical_commit_time; + kv_mutation_struct.mutation_type = mutation_type; - if (const auto status = delta_record_writer_->WriteRecord(record_struct); + DataRecordStruct data_record_struct; + data_record_struct.record = kv_mutation_struct; + + if (const auto status = + delta_record_writer_->WriteRecord(data_record_struct); !status.ok()) { LOG(ERROR) << "Failed to write a delta record with key: " << k << " and value: " << v << status; diff --git a/tools/bidding_auction_data_generator/delta_key_value_writer.h b/tools/bidding_auction_data_generator/delta_key_value_writer.h index 83ba4e1a..2753fdc5 100644 --- a/tools/bidding_auction_data_generator/delta_key_value_writer.h +++ b/tools/bidding_auction_data_generator/delta_key_value_writer.h @@ -35,7 +35,7 @@ class DeltaKeyValueWriter { std::ostream& output_stream); absl::Status Write( const absl::flat_hash_map& key_value_map, - int64_t logical_commit_time, DeltaMutationType mutation_type); + int64_t logical_commit_time, KeyValueMutationType mutation_type); private: explicit DeltaKeyValueWriter( diff --git a/tools/bidding_auction_data_generator/delta_key_value_writer_test.cc b/tools/bidding_auction_data_generator/delta_key_value_writer_test.cc index b0cbaa21..88b47d72 100644 --- a/tools/bidding_auction_data_generator/delta_key_value_writer_test.cc +++ b/tools/bidding_auction_data_generator/delta_key_value_writer_test.cc @@ -24,10 +24,11 @@ namespace kv_server { namespace { constexpr int64_t kTestLogicalCommitTime = 1234567890; -constexpr DeltaMutationType kTestDeltaMutationType = DeltaMutationType::Update; +constexpr KeyValueMutationType kTestDeltaMutationType = + KeyValueMutationType::Update; -DeltaFileRecordStruct GetDeltaRecord() { - DeltaFileRecordStruct record; +KeyValueMutationRecordStruct GetKVMutationRecord() { + KeyValueMutationRecordStruct record; record.key = "key1"; record.value = R"({"field": "test"})"; record.logical_commit_time = kTestLogicalCommitTime; @@ -56,15 +57,15 @@ TEST(DeltaKeyValueWriterTest, ValidateDeltaDataTest) { EXPECT_TRUE(stream_reader ->ReadStreamRecords( [](std::string_view record_string) -> absl::Status { - DeltaFileRecordStruct record; - auto fbs_record = flatbuffers::GetRoot( - record_string.data()); - record.key = fbs_record->key()->string_view(); - record.value = fbs_record->value()->string_view(); - record.mutation_type = fbs_record->mutation_type(); - record.logical_commit_time = - fbs_record->logical_commit_time(); - EXPECT_EQ(record, GetDeltaRecord()); + const auto* data_record = + flatbuffers::GetRoot( + record_string.data()); + EXPECT_EQ(data_record->record_type(), + Record::KeyValueMutationRecord); + const auto kv_record = + GetTypedRecordStruct( + *data_record); + EXPECT_EQ(kv_record, GetKVMutationRecord()); return absl::OkStatus(); }) .ok()); diff --git a/tools/collect-logs b/tools/collect-logs index d8128a25..bda4a098 100755 --- a/tools/collect-logs +++ b/tools/collect-logs @@ -28,41 +28,60 @@ function _cleanup() { exit ${STATUS} } -ZIP_FILENAME="$1" +function copy_log_outputs() { + declare -r _rootdest="$1" + declare -r _prune_to_dir="$2" + declare -r _filepath="$3" + declare -r _logpath="${_filepath##*/${_prune_to_dir}/}" + declare -r _destdir="${_rootdest}/${_logpath%/*}" + declare -r _fname="${_filepath##*/}" + declare -r _destfname="${_fname/#test./sponge_log.}" + mkdir -p "${_destdir}" + cp "${_filepath}" "${_destdir}/${_destfname}" +} +export -f copy_log_outputs + +function extract_test_outputs() { + declare -r _rootdest="$1" + declare -r _filepath="$2" + declare -r _logpath="${_filepath##*bazel-testlogs/}" + declare -r _destdir="${_rootdest}/${_logpath%/*/*}" + mkdir -p "${_destdir}" + unzip -q -d "${_destdir}" "${_filepath}" +} +export -f extract_test_outputs + + +declare ZIP_FILENAME="$1" if [[ ${ZIP_FILENAME##*.} != zip ]]; then ZIP_FILENAME=logs.zip fi SCRIPT_DIR="$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" +readonly SCRIPT_DIR WORKSPACE="${WORKSPACE-"$(readlink -f "${SCRIPT_DIR}"/..)"}" -TMPDIR="$(mktemp --directory)" -readonly TMPDIR -mkdir -p "${TMPDIR}"/{test,other} -# copy all test.log and test.xml files to the testlogs dir +readonly WORKSPACE +OUTDIR="$(mktemp --directory)" +readonly OUTDIR +export OUTDIR +mkdir -p "${OUTDIR}"/{test,other} + if [[ -d "${WORKSPACE}"/bazel-testlogs ]]; then - find -L "${WORKSPACE}"/bazel-testlogs -name test.log -exec cp --parents {} "${TMPDIR}"/test ';' - find -L "${WORKSPACE}"/bazel-testlogs -name test.xml -exec cp --parents {} "${TMPDIR}"/test ';' + # copy all test.log and test.xml files + find -L "${WORKSPACE}"/bazel-testlogs -type f '(' -name test.log -o -name test.xml ')' -exec bash -c 'copy_log_outputs "${OUTDIR}"/test bazel-testlogs "$0"' {} ';' + # extract test outputs + find -L "${WORKSPACE}"/bazel-testlogs -type f -name outputs.zip -exec bash -c 'extract_test_outputs "${OUTDIR}"/test "$0"' {} ';' fi -# copy log files under bazel-out (other than test.log) to the buildlogs dir if [[ -d "${WORKSPACE}"/bazel-out ]]; then - find -L "${WORKSPACE}"/bazel-out -name "*.log" ! -name "test.log" -exec cp --parents {} "${TMPDIR}"/other ';' -fi -if compgen -G "${TMPDIR}/test/src/workspace/bazel-testlogs/*" &>/dev/null; then - mv --target-directory="${TMPDIR}"/test "${TMPDIR}"/test/src/workspace/bazel-testlogs/* - rmdir "${TMPDIR}"/test/{src/workspace/bazel-testlogs,src/workspace,src} -fi - -if command -v rename &>/dev/null; then - # Rename the copied test.log and test.xml files to sponge_log.log and sponge_log.xml - find -L "${TMPDIR}" -name test.log -exec rename 's/test.log/sponge_log.log/' {} ';' - find -L "${TMPDIR}" -name test.xml -exec rename 's/test.xml/sponge_log.xml/' {} ';' + # copy log files under bazel-out (except for test.log) + find -L "${WORKSPACE}"/bazel-out -type f -name "*.log" ! -name test.log -exec bash -c 'copy_log_outputs "${OUTDIR}"/other bazel-out "$0"' {} ';' fi declare -r DISTDIR="${WORKSPACE}"/dist mkdir -p "${DISTDIR}" ( - cd "${TMPDIR}" + cd "${OUTDIR}" zip -r -q "${DISTDIR}/${ZIP_FILENAME}" -- * ) printf "stored bazel logs to %s\n" "${DISTDIR}/${ZIP_FILENAME}" &>/dev/stderr unzip -Z -h "${DISTDIR}/${ZIP_FILENAME}" &>/dev/stderr -rm -rf "${TMPDIR}" +rm -rf "${OUTDIR}" diff --git a/tools/data_cli/commands/format_data_command.cc b/tools/data_cli/commands/format_data_command.cc index f38e6668..72ab8a99 100644 --- a/tools/data_cli/commands/format_data_command.cc +++ b/tools/data_cli/commands/format_data_command.cc @@ -34,6 +34,10 @@ namespace { constexpr std::string_view kDeltaFormat = "delta"; constexpr std::string_view kCsvFormat = "csv"; +constexpr std::string_view kKeyValueMutationRecord = + "key_value_mutation_record"; +constexpr std::string_view kUserDefinedFunctionsConfig = + "user_defined_functions_config"; absl::Status ValidateParams(const FormatDataCommand::Params& params) { if (params.input_format.empty()) { @@ -42,6 +46,9 @@ absl::Status ValidateParams(const FormatDataCommand::Params& params) { if (params.output_format.empty()) { return absl::InvalidArgumentError("Output format cannot be empty."); } + if (params.record_type.empty()) { + return absl::InvalidArgumentError("Record type cannot be empty."); + } std::string lw_output_format = absl::AsciiStrToLower(params.output_format); if (absl::AsciiStrToLower(params.input_format) == lw_output_format) { return absl::InvalidArgumentError(absl::StrCat( @@ -51,12 +58,32 @@ absl::Status ValidateParams(const FormatDataCommand::Params& params) { return absl::OkStatus(); } +absl::StatusOr GetRecordType(std::string_view record_type) { + std::string lw_record_type = absl::AsciiStrToLower(record_type); + if (lw_record_type == kKeyValueMutationRecord) { + return DataRecordType::kKeyValueMutationRecord; + } + if (lw_record_type == kUserDefinedFunctionsConfig) { + return DataRecordType::kUserDefinedFunctionsConfig; + } + return absl::InvalidArgumentError( + absl::StrCat("Record type ", record_type, " is not supported.")); +} + absl::StatusOr> CreateRecordReader( const FormatDataCommand::Params& params, std::istream& input_stream) { std::string lw_input_format = absl::AsciiStrToLower(params.input_format); if (lw_input_format == kCsvFormat) { + const auto record_type = GetRecordType(params.record_type); + if (!record_type.ok()) { + return record_type.status(); + } return std::make_unique>( - input_stream); + input_stream, CsvDeltaRecordStreamReader::Options{ + .field_separator = params.csv_column_delimiter, + .value_separator = params.csv_value_delimiter, + .record_type = std::move(record_type.value()), + }); } if (lw_input_format == kDeltaFormat) { return std::make_unique>( @@ -70,8 +97,16 @@ absl::StatusOr> CreateRecordWriter( const FormatDataCommand::Params& params, std::ostream& output_stream) { std::string lw_output_format = absl::AsciiStrToLower(params.output_format); if (lw_output_format == kCsvFormat) { + const auto record_type = GetRecordType(params.record_type); + if (!record_type.ok()) { + return record_type.status(); + } return std::make_unique>( - output_stream); + output_stream, CsvDeltaRecordStreamWriter::Options{ + .field_separator = params.csv_column_delimiter, + .value_separator = params.csv_value_delimiter, + .record_type = std::move(record_type.value()), + }); } if (lw_output_format == kDeltaFormat) { KVFileMetadata metadata; @@ -108,8 +143,8 @@ absl::StatusOr> FormatDataCommand::Create( absl::Status FormatDataCommand::Execute() { absl::Status status = record_reader_->ReadRecords( - [record_writer = record_writer_.get()](DeltaFileRecordStruct record) { - return record_writer->WriteRecord(record); + [record_writer = record_writer_.get()](DataRecordStruct data_record) { + return record_writer->WriteRecord(data_record); }); record_writer_->Close(); return status; diff --git a/tools/data_cli/commands/format_data_command.h b/tools/data_cli/commands/format_data_command.h index 915f7dae..74033388 100644 --- a/tools/data_cli/commands/format_data_command.h +++ b/tools/data_cli/commands/format_data_command.h @@ -57,6 +57,9 @@ class FormatDataCommand : public Command { struct Params { std::string_view input_format; std::string_view output_format; + char csv_column_delimiter; + char csv_value_delimiter; + std::string_view record_type; }; static absl::StatusOr> Create( diff --git a/tools/data_cli/commands/format_data_command_test.cc b/tools/data_cli/commands/format_data_command_test.cc index e278dd22..eb5e0887 100644 --- a/tools/data_cli/commands/format_data_command_test.cc +++ b/tools/data_cli/commands/format_data_command_test.cc @@ -26,32 +26,56 @@ namespace kv_server { namespace { -FormatDataCommand::Params GetParams() { - return FormatDataCommand::Params{.input_format = "CSV", - .output_format = "DELTA"}; +FormatDataCommand::Params GetParams( + std::string_view record_type = "KEY_VALUE_MUTATION_RECORD") { + return FormatDataCommand::Params{ + .input_format = "CSV", + .output_format = "DELTA", + .csv_column_delimiter = ',', + .csv_value_delimiter = '|', + .record_type = std::move(record_type), + }; } -DeltaFileRecordStruct GetDeltaRecord() { - DeltaFileRecordStruct record; +KeyValueMutationRecordStruct GetKVMutationRecord() { + KeyValueMutationRecordStruct record; record.key = "key"; record.value = "value"; record.logical_commit_time = 1234567890; - record.mutation_type = DeltaMutationType::Update; + record.mutation_type = KeyValueMutationType::Update; return record; } +UserDefinedFunctionsConfigStruct GetUdfConfig() { + UserDefinedFunctionsConfigStruct udf_config_record; + udf_config_record.language = UserDefinedFunctionsLanguage::Javascript; + udf_config_record.code_snippet = "function hello(){}"; + udf_config_record.handler_name = "hello"; + udf_config_record.logical_commit_time = 1234567890; + return udf_config_record; +} + +DataRecordStruct GetDataRecord(const RecordT& record) { + DataRecordStruct data_record; + data_record.record = record; + return data_record; +} + KVFileMetadata GetMetadata() { KVFileMetadata metadata; return metadata; } -TEST(FormatDataCommandTest, ValidateGeneratingCsvToDeltaData) { +TEST(FormatDataCommandTest, ValidateGeneratingCsvToDeltaData_KVMutations) { std::stringstream csv_stream; std::stringstream delta_stream; CsvDeltaRecordStreamWriter csv_writer(csv_stream); - EXPECT_TRUE(csv_writer.WriteRecord(GetDeltaRecord()).ok()); - EXPECT_TRUE(csv_writer.WriteRecord(GetDeltaRecord()).ok()); - EXPECT_TRUE(csv_writer.WriteRecord(GetDeltaRecord()).ok()); + EXPECT_TRUE( + csv_writer.WriteRecord(GetDataRecord(GetKVMutationRecord())).ok()); + EXPECT_TRUE( + csv_writer.WriteRecord(GetDataRecord(GetKVMutationRecord())).ok()); + EXPECT_TRUE( + csv_writer.WriteRecord(GetDataRecord(GetKVMutationRecord())).ok()); csv_writer.Close(); EXPECT_FALSE(csv_stream.str().empty()); auto command = @@ -59,40 +83,114 @@ TEST(FormatDataCommandTest, ValidateGeneratingCsvToDeltaData) { EXPECT_TRUE(command.ok()) << command.status(); EXPECT_TRUE((*command)->Execute().ok()); DeltaRecordStreamReader delta_reader(delta_stream); - testing::MockFunction record_callback; + testing::MockFunction record_callback; EXPECT_CALL(record_callback, Call) .Times(3) - .WillRepeatedly([](DeltaFileRecordStruct record) { - EXPECT_EQ(record, GetDeltaRecord()); + .WillRepeatedly([](DataRecordStruct record) { + EXPECT_EQ(record, GetDataRecord(GetKVMutationRecord())); return absl::OkStatus(); }); EXPECT_TRUE(delta_reader.ReadRecords(record_callback.AsStdFunction()).ok()); } -TEST(FormatDataCommandTest, ValidateGeneratingDeltaToCsvData) { +TEST(FormatDataCommandTest, ValidateGeneratingDeltaToCsvData_KvMutations) { std::stringstream delta_stream; std::stringstream csv_stream; auto delta_writer = DeltaRecordStreamWriter::Create( delta_stream, DeltaRecordWriter::Options{.metadata = GetMetadata()}); EXPECT_TRUE(delta_writer.ok()) << delta_writer.status(); - EXPECT_TRUE((*delta_writer)->WriteRecord(GetDeltaRecord()).ok()); - EXPECT_TRUE((*delta_writer)->WriteRecord(GetDeltaRecord()).ok()); - EXPECT_TRUE((*delta_writer)->WriteRecord(GetDeltaRecord()).ok()); - EXPECT_TRUE((*delta_writer)->WriteRecord(GetDeltaRecord()).ok()); - EXPECT_TRUE((*delta_writer)->WriteRecord(GetDeltaRecord()).ok()); + EXPECT_TRUE( + (*delta_writer)->WriteRecord(GetDataRecord(GetKVMutationRecord())).ok()); + EXPECT_TRUE( + (*delta_writer)->WriteRecord(GetDataRecord(GetKVMutationRecord())).ok()); + EXPECT_TRUE( + (*delta_writer)->WriteRecord(GetDataRecord(GetKVMutationRecord())).ok()); + EXPECT_TRUE( + (*delta_writer)->WriteRecord(GetDataRecord(GetKVMutationRecord())).ok()); + EXPECT_TRUE( + (*delta_writer)->WriteRecord(GetDataRecord(GetKVMutationRecord())).ok()); (*delta_writer)->Close(); auto command = FormatDataCommand::Create( - FormatDataCommand::Params{.input_format = "DELTA", - .output_format = "CSV"}, + FormatDataCommand::Params{ + .input_format = "DELTA", + .output_format = "CSV", + .csv_column_delimiter = ',', + .csv_value_delimiter = '|', + .record_type = "KEY_VALUE_MUTATION_RECORD", + }, delta_stream, csv_stream); EXPECT_TRUE(command.ok()) << command.status(); EXPECT_TRUE((*command)->Execute().ok()); CsvDeltaRecordStreamReader csv_reader(csv_stream); - testing::MockFunction record_callback; + testing::MockFunction record_callback; EXPECT_CALL(record_callback, Call) .Times(5) - .WillRepeatedly([](DeltaFileRecordStruct record) { - EXPECT_EQ(record, GetDeltaRecord()); + .WillRepeatedly([](DataRecordStruct record) { + EXPECT_EQ(record, GetDataRecord(GetKVMutationRecord())); + return absl::OkStatus(); + }); + EXPECT_TRUE(csv_reader.ReadRecords(record_callback.AsStdFunction()).ok()); +} + +TEST(FormatDataCommandTest, ValidateGeneratingCsvToDeltaData_UdfConfig) { + std::stringstream csv_stream; + std::stringstream delta_stream; + CsvDeltaRecordStreamWriter csv_writer( + csv_stream, + CsvDeltaRecordStreamWriter::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig}); + EXPECT_TRUE(csv_writer.WriteRecord(GetDataRecord(GetUdfConfig())).ok()); + EXPECT_TRUE(csv_writer.WriteRecord(GetDataRecord(GetUdfConfig())).ok()); + EXPECT_TRUE(csv_writer.WriteRecord(GetDataRecord(GetUdfConfig())).ok()); + csv_writer.Close(); + EXPECT_FALSE(csv_stream.str().empty()); + auto command = FormatDataCommand::Create( + GetParams(/*record_type=*/"USER_DEFINED_FUNCTIONS_CONFIG"), csv_stream, + delta_stream); + EXPECT_TRUE(command.ok()) << command.status(); + EXPECT_TRUE((*command)->Execute().ok()); + DeltaRecordStreamReader delta_reader(delta_stream); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .Times(3) + .WillRepeatedly([](DataRecordStruct record) { + EXPECT_EQ(record, GetDataRecord(GetUdfConfig())); + return absl::OkStatus(); + }); + EXPECT_TRUE(delta_reader.ReadRecords(record_callback.AsStdFunction()).ok()); +} + +TEST(FormatDataCommandTest, ValidateGeneratingDeltaToCsvData_UdfConfig) { + std::stringstream delta_stream; + std::stringstream csv_stream; + auto delta_writer = DeltaRecordStreamWriter::Create( + delta_stream, DeltaRecordWriter::Options{.metadata = GetMetadata()}); + EXPECT_TRUE(delta_writer.ok()) << delta_writer.status(); + EXPECT_TRUE((*delta_writer)->WriteRecord(GetDataRecord(GetUdfConfig())).ok()); + EXPECT_TRUE((*delta_writer)->WriteRecord(GetDataRecord(GetUdfConfig())).ok()); + EXPECT_TRUE((*delta_writer)->WriteRecord(GetDataRecord(GetUdfConfig())).ok()); + (*delta_writer)->Close(); + auto command = FormatDataCommand::Create( + FormatDataCommand::Params{ + .input_format = "DELTA", + .output_format = "CSV", + .csv_column_delimiter = ',', + .csv_value_delimiter = '|', + .record_type = "USER_DEFINED_FUNCTIONS_CONFIG", + }, + delta_stream, csv_stream); + EXPECT_TRUE(command.ok()) << command.status(); + EXPECT_TRUE((*command)->Execute().ok()); + CsvDeltaRecordStreamReader csv_reader( + csv_stream, + CsvDeltaRecordStreamReader::Options{ + .record_type = DataRecordType::kUserDefinedFunctionsConfig, + }); + testing::MockFunction record_callback; + EXPECT_CALL(record_callback, Call) + .Times(3) + .WillRepeatedly([](DataRecordStruct record) { + EXPECT_EQ(record, GetDataRecord(GetUdfConfig())); return absl::OkStatus(); }); EXPECT_TRUE(csv_reader.ReadRecords(record_callback.AsStdFunction()).ok()); @@ -116,6 +214,23 @@ TEST(FormatDataCommandTest, ValidateIncorrectInputParams) { << status; } +TEST(FormatDataCommandTest, ValidateIncorrectRecordTypeParams) { + std::stringstream unused_stream; + auto params = GetParams(""); + absl::Status status = + FormatDataCommand::Create(params, unused_stream, unused_stream).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument) << status; + EXPECT_STREQ(status.message().data(), "Record type cannot be empty.") + << status; + params.record_type = "invalid record type"; + status = + FormatDataCommand::Create(params, unused_stream, unused_stream).status(); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument) << status; + EXPECT_STREQ(status.message().data(), + "Record type invalid record type is not supported.") + << status; +} + TEST(FormatDataCommandTest, ValidateIncorrectOutputParams) { std::stringstream unused_stream; auto params = GetParams(); diff --git a/tools/data_cli/data_cli.cc b/tools/data_cli/data_cli.cc index b7cc21cd..507f7a1b 100644 --- a/tools/data_cli/data_cli.cc +++ b/tools/data_cli/data_cli.cc @@ -51,6 +51,13 @@ ABSL_FLAG( ABSL_FLAG( bool, in_memory_compaction, true, "If true, delta file compaction to generate snapshots is done in memory."); +ABSL_FLAG(std::string, csv_column_delimiter, ",", + "Column delimiter for csv files"); +ABSL_FLAG(std::string, csv_value_delimiter, "|", + "Value delimiter for csv files"); +ABSL_FLAG(std::string, record_type, "key_value_mutation_record", + "Data record type. Possible " + "options=(KEY_VALUE_MUTATION_RECORD|USER_DEFINED_FUNCTIONS_CONFIG)."); constexpr std::string_view kUsageMessage = R"( Usage: data_cli @@ -61,6 +68,9 @@ Usage: data_cli [--input_format] (Optional) Defaults to "CSV". Possible options=(CSV|DELTA) [--output_file] (Optional) Defaults to stdout. Output file to write converted records to. [--output_format] (Optional) Defaults to "DELTA". Possible options=(CSV|DELTA). + [--record_type] (Optional) Defaults to "KEY_VALUE_MUTATION_RECORD". Possible + options=(KEY_VALUE_MUTATION_RECORD|USER_DEFINED_FUNCTIONS_CONFIG). + If reading/writing a UDF config, use "USER_DEFINED_FUNCTIONS_CONFIG". Examples: (1) Generate a csv file to a delta file and write output records to std::cout. - data_cli format_data --input_file="$PWD/data.csv" @@ -71,6 +81,9 @@ Usage: data_cli (3) Pipe csv records and generate delta file records and write to std cout. - cat "$PWD/data.csv" | data_cli format_data --input_format=CSV + (4) Generate a delta file with UDF configs back to csv file and write output to a file. + - data_cli format_data --input_file="$PWD/delta" --input_format=DELTA --output_file="$PWD/delta.csv" --output_format=CSV --record_type=USER_DEFINED_FUNCTIONS_CONFIG + - generate_snapshot Compacts a range of delta files into a single snapshot file. [--starting_file] (Required) Oldest delta file or base snapshot to include in compaction. [--ending_delta_file] (Required) Most recent delta file to include compaction. @@ -134,7 +147,12 @@ int main(int argc, char** argv) { auto format_data_command = FormatDataCommand::Create( FormatDataCommand::Params{ .input_format = absl::GetFlag(FLAGS_input_format), - .output_format = absl::GetFlag(FLAGS_output_format)}, + .output_format = absl::GetFlag(FLAGS_output_format), + .csv_column_delimiter = + absl::GetFlag(FLAGS_csv_column_delimiter)[0], + .csv_value_delimiter = absl::GetFlag(FLAGS_csv_value_delimiter)[0], + .record_type = absl::GetFlag(FLAGS_record_type), + }, *i_stream, *o_stream); if (!format_data_command.ok()) { LOG(ERROR) << "Failed to create command to format data. " diff --git a/tools/request_simulation/BUILD b/tools/request_simulation/BUILD new file mode 100644 index 00000000..0c8084c5 --- /dev/null +++ b/tools/request_simulation/BUILD @@ -0,0 +1,80 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") + +cc_library( + name = "request_generation_util", + srcs = ["request_generation_util.cc"], + hdrs = ["request_generation_util.h"], + deps = [ + "//tools/request_simulation/request:raw_request_cc_proto", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "message_queue", + srcs = ["message_queue.cc"], + hdrs = ["message_queue.h"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "rate_limiter", + srcs = ["rate_limiter.cc"], + hdrs = ["rate_limiter.h"], + deps = [ + "//components/util:sleepfor", + "@com_google_absl//absl/status", + "@google_privacysandbox_servers_common//src/cpp/util:duration", + ], +) + +cc_test( + name = "request_generation_util_test", + size = "small", + srcs = ["request_generation_util_test.cc"], + deps = [ + ":request_generation_util", + "@com_github_grpc_grpc//:grpc++", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "message_queue_test", + size = "small", + srcs = ["message_queue_test.cc"], + deps = [ + ":message_queue", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "rate_limiter_test", + size = "small", + srcs = ["rate_limiter_test.cc"], + deps = [ + ":rate_limiter", + "//components/util:sleepfor_mock", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tools/request_simulation/message_queue.cc b/tools/request_simulation/message_queue.cc new file mode 100644 index 00000000..e178339d --- /dev/null +++ b/tools/request_simulation/message_queue.cc @@ -0,0 +1,53 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/request_simulation/message_queue.h" + +#include + +namespace kv_server { + +void MessageQueue::Push(std::string message) { + absl::MutexLock lock(&mutex_); + if (queue_.size() < capacity_) { + queue_.push_back(std::move(message)); + } +} + +absl::StatusOr MessageQueue::Pop() { + absl::MutexLock lock(&mutex_); + if (queue_.empty()) { + return absl::FailedPreconditionError("Queue is empty"); + } + auto front = queue_.front(); + queue_.pop_front(); + return front; +} + +bool MessageQueue::Empty() const { + absl::MutexLock lock(&mutex_); + return queue_.empty(); +} + +size_t MessageQueue::Size() const { + absl::MutexLock lock(&mutex_); + return queue_.size(); +} + +void MessageQueue::Clear() { + absl::MutexLock lock(&mutex_); + queue_.clear(); +} + +} // namespace kv_server diff --git a/tools/request_simulation/message_queue.h b/tools/request_simulation/message_queue.h new file mode 100644 index 00000000..9033163d --- /dev/null +++ b/tools/request_simulation/message_queue.h @@ -0,0 +1,56 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TOOLS_REQUEST_SIMULATION_MESSAGE_QUEUE_H_ +#define TOOLS_REQUEST_SIMULATION_MESSAGE_QUEUE_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" + +namespace kv_server { + +// Synchronized message queue to stage the request body +class MessageQueue { + public: + explicit MessageQueue(int64_t capacity) : capacity_(capacity) {} + // Pushes new message to the queue + void Push(std::string message); + // Pops off message from the queue + absl::StatusOr Pop(); + // Checks if the queue is empty + bool Empty() const; + // Returns the size of the queue + size_t Size() const; + // Clears the queue + void Clear(); + ~MessageQueue() = default; + + // MessageQueue is neither copyable nor movable. + MessageQueue(const MessageQueue&) = delete; + MessageQueue& operator=(const MessageQueue&) = delete; + + private: + mutable absl::Mutex mutex_; + int64_t capacity_ ABSL_GUARDED_BY(mutex_); + std::deque queue_ ABSL_GUARDED_BY(mutex_); +}; +} // namespace kv_server + +#endif // TOOLS_REQUEST_SIMULATION_MESSAGE_QUEUE_H_ diff --git a/tools/request_simulation/message_queue_test.cc b/tools/request_simulation/message_queue_test.cc new file mode 100644 index 00000000..27e588de --- /dev/null +++ b/tools/request_simulation/message_queue_test.cc @@ -0,0 +1,53 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/request_simulation/message_queue.h" + +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +TEST(TestMessageQueue, TestQueueOperation) { + MessageQueue queue(100); + // Push first element + queue.Push("first"); + // Push second element + queue.Push("second"); + EXPECT_EQ(queue.Size(), 2); + auto pop = queue.Pop(); + EXPECT_TRUE(pop.ok()); + EXPECT_EQ(pop.value(), "first"); + EXPECT_EQ(queue.Size(), 1); + EXPECT_FALSE(queue.Empty()); + pop = queue.Pop(); + EXPECT_TRUE(pop.ok()); + EXPECT_EQ(pop.value(), "second"); + EXPECT_TRUE(queue.Empty()); + pop = queue.Pop(); + EXPECT_FALSE(pop.ok()); +} + +TEST(TestMessageQueue, TestCapacityConstraint) { + MessageQueue queue(1); + queue.Push("first"); + queue.Push("second"); + EXPECT_EQ(queue.Size(), 1); + auto pop = queue.Pop(); + EXPECT_TRUE(pop.ok()); + EXPECT_EQ(pop.value(), "first"); +} + +} // namespace +} // namespace kv_server diff --git a/tools/request_simulation/rate_limiter.cc b/tools/request_simulation/rate_limiter.cc new file mode 100644 index 00000000..139d2996 --- /dev/null +++ b/tools/request_simulation/rate_limiter.cc @@ -0,0 +1,53 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/request_simulation/rate_limiter.h" + +#include + +namespace kv_server { + +absl::Duration RateLimiter::Acquire() { return Acquire(1); } + +absl::Duration RateLimiter::Acquire(int permits) { + absl::MutexLock lock(&mu_); + const auto start_time = clock_.Now(); + while (permits_.load(std::memory_order_relaxed) - permits < 0) { + sleep_for_.Duration( + absl::Milliseconds(1000 * permits / permits_fill_rate_)); + RefillPermits(); + } + permits_.fetch_sub(permits, std::memory_order_relaxed) - permits; + return clock_.Now() - start_time; +} + +void RateLimiter::RefillPermits() { + const auto elapsed_time_ns = + ToChronoNanoseconds(last_refill_time_.GetElapsedTime()); + if (elapsed_time_ns <= std::chrono::nanoseconds::zero()) { + return; + } + const int64_t permits_to_fill = + (permits_fill_rate_ / 1e9) * elapsed_time_ns.count(); + permits_.fetch_add(permits_to_fill, std::memory_order_relaxed) + + permits_to_fill; + last_refill_time_.Reset(); +} + +void RateLimiter::SetFillRate(int64_t permits_per_second) { + absl::MutexLock lock(&mu_); + permits_fill_rate_ = permits_per_second; +} + +} // namespace kv_server diff --git a/tools/request_simulation/rate_limiter.h b/tools/request_simulation/rate_limiter.h new file mode 100644 index 00000000..ea1f96e7 --- /dev/null +++ b/tools/request_simulation/rate_limiter.h @@ -0,0 +1,72 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TOOLS_REQUEST_SIMULATION_RATE_LIMITER_H_ +#define TOOLS_REQUEST_SIMULATION_RATE_LIMITER_H_ + +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "components/util/sleepfor.h" +#include "src/cpp/util/duration.h" +namespace kv_server { + +// A simple permit-based rate limiter. The permits are refilled at given rate +// passed in the constructor. The fill rate can also be updated during runtime +class RateLimiter { + public: + RateLimiter(int64_t initial_permits, int64_t permits_per_second, + privacy_sandbox::server_common::SteadyClock& clock, + SleepFor& sleep_for) + : permits_fill_rate_(permits_per_second), + last_refill_time_(clock), + clock_(clock), + sleep_for_(sleep_for) { + permits_.store(initial_permits, std::memory_order_relaxed); + } + ~RateLimiter() = default; + // Acquires a single permit, returns waiting duration + absl::Duration Acquire() ABSL_LOCKS_EXCLUDED(mu_); + // Acquires a number of permits, returns waiting duration + absl::Duration Acquire(int permits) ABSL_LOCKS_EXCLUDED(mu_); + // Sets the fill rate + void SetFillRate(int64_t permits_per_second) ABSL_LOCKS_EXCLUDED(mu_); + + // RateLimiter is neither copyable nor movable. + RateLimiter(const RateLimiter&) = delete; + RateLimiter& operator=(const RateLimiter&) = delete; + + private: + void RefillPermits() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + mutable absl::Mutex mu_; + // Permits fill rate in permits per second + mutable int64_t permits_fill_rate_ ABSL_GUARDED_BY(mu_); + // Last refill time in nanoseconds + mutable privacy_sandbox::server_common::Stopwatch last_refill_time_ + ABSL_GUARDED_BY(mu_); + // Number of permits available + mutable std::atomic permits_; + privacy_sandbox::server_common::SteadyClock& clock_; + SleepFor& sleep_for_; + friend class RateLimiterTestPeer; +}; + +} // namespace kv_server + +#endif // TOOLS_REQUEST_SIMULATION_RATE_LIMITER_H_ diff --git a/tools/request_simulation/rate_limiter_test.cc b/tools/request_simulation/rate_limiter_test.cc new file mode 100644 index 00000000..b7802d07 --- /dev/null +++ b/tools/request_simulation/rate_limiter_test.cc @@ -0,0 +1,113 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/request_simulation/rate_limiter.h" + +#include "components/util/sleepfor_mock.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace kv_server { + +using privacy_sandbox::server_common::SimulatedSteadyClock; +using privacy_sandbox::server_common::SteadyTime; +using testing::_; +using testing::Return; + +class RateLimiterTestPeer { + public: + RateLimiterTestPeer() = delete; + static int64_t ReadCurrentPermits(const RateLimiter& r) { + absl::MutexLock lock(&r.mu_); + return r.permits_.load(std::memory_order_relaxed); + } + static int64_t ReadRefillRate(const RateLimiter& r) { + absl::MutexLock lock(&r.mu_); + return r.permits_fill_rate_; + } + static SteadyTime ReadLastRefillTime(const RateLimiter& r) { + absl::MutexLock lock(&r.mu_); + return r.last_refill_time_.GetStartTime(); + } +}; +namespace { + +class RateLimiterTest : public ::testing::Test { + protected: + SimulatedSteadyClock sim_clock_; +}; + +TEST_F(RateLimiterTest, TestRefill) { + MockSleepFor sleep_for; + EXPECT_CALL(sleep_for, Duration(_)).WillRepeatedly(Return(true)); + RateLimiter rate_limiter(1, 1, sim_clock_, sleep_for); + rate_limiter.Acquire(); + sim_clock_.AdvanceTime(absl::Seconds(1)); + rate_limiter.Acquire(); + EXPECT_EQ(RateLimiterTestPeer::ReadCurrentPermits(rate_limiter), 0); + + rate_limiter.SetFillRate(5); + sim_clock_.AdvanceTime(absl::Seconds(1)); + rate_limiter.Acquire(); + EXPECT_EQ(RateLimiterTestPeer::ReadCurrentPermits(rate_limiter), 4); +} + +TEST_F(RateLimiterTest, TestAcquireMultiplePermits) { + MockSleepFor sleep_for; + EXPECT_CALL(sleep_for, Duration(_)).WillRepeatedly(Return(true)); + // No refill + int permits_to_acquire = 5; + RateLimiter rate_limiter(permits_to_acquire, 0, sim_clock_, sleep_for); + // Acquire all available permits + rate_limiter.Acquire(permits_to_acquire); + EXPECT_EQ(RateLimiterTestPeer::ReadCurrentPermits(rate_limiter), 0); +} + +TEST_F(RateLimiterTest, TestLastRefillTimeUpdate) { + MockSleepFor sleep_for; + EXPECT_CALL(sleep_for, Duration(_)).WillRepeatedly(Return(true)); + RateLimiter rate_limiter(1, 1, sim_clock_, sleep_for); + const auto initial_refill_time = + RateLimiterTestPeer::ReadLastRefillTime(rate_limiter); + // trigger refill + sim_clock_.AdvanceTime(absl::Seconds(1)); + rate_limiter.Acquire(2); + const auto last_refill_time = + RateLimiterTestPeer::ReadLastRefillTime(rate_limiter); + EXPECT_EQ(last_refill_time - initial_refill_time, absl::Seconds(1)); + sim_clock_.AdvanceTime(absl::Seconds(1)); + // trigger refill again + rate_limiter.Acquire(1); + const auto last_refill_time2 = + RateLimiterTestPeer::ReadLastRefillTime(rate_limiter); + EXPECT_EQ(last_refill_time2 - last_refill_time, absl::Seconds(1)); +} + +TEST_F(RateLimiterTest, TestPermitsFillRate) { + MockSleepFor sleep_for; + EXPECT_CALL(sleep_for, Duration(_)).WillRepeatedly(Return(true)); + + RateLimiter rate_limiter(0, 100, sim_clock_, sleep_for); + sim_clock_.AdvanceTime(absl::Seconds(2)); + rate_limiter.Acquire(); + EXPECT_EQ(RateLimiterTestPeer::ReadCurrentPermits(rate_limiter), 199); + + rate_limiter.SetFillRate(1000); + sim_clock_.AdvanceTime(absl::Seconds(1)); + rate_limiter.Acquire(200); + EXPECT_EQ(RateLimiterTestPeer::ReadCurrentPermits(rate_limiter), 999); +} + +} // namespace +} // namespace kv_server diff --git a/third_party/open_telemetry.bzl b/tools/request_simulation/request/BUILD similarity index 54% rename from third_party/open_telemetry.bzl rename to tools/request_simulation/request/BUILD index 54ee0372..efc18015 100644 --- a/third_party/open_telemetry.bzl +++ b/tools/request_simulation/request/BUILD @@ -1,4 +1,4 @@ -# Copyright 2022 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("@rules_proto//proto:defs.bzl", "proto_library") -def open_telemetry_dependencies(): - http_archive( - name = "io_opentelemetry_cpp", - sha256 = "20fa97e507d067e9e2ab0c1accfc334f5a4b10d01312e55455dc3733748585f4", - strip_prefix = "opentelemetry-cpp-1.8.2", - urls = [ - "https://github.com/open-telemetry/opentelemetry-cpp/archive/refs/tags/v1.8.2.tar.gz", - ], - ) +package(default_visibility = ["//tools:__subpackages__"]) + +proto_library( + name = "raw_request_proto", + srcs = ["raw_request.proto"], + deps = [ + "@com_google_googleapis//google/api:httpbody_proto", + ], +) + +cc_proto_library( + name = "raw_request_cc_proto", + deps = [ + ":raw_request_proto", + ], +) diff --git a/tools/request_simulation/request/raw_request.proto b/tools/request_simulation/request/raw_request.proto new file mode 100644 index 00000000..0df0acc1 --- /dev/null +++ b/tools/request_simulation/request/raw_request.proto @@ -0,0 +1,25 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package kv_server; + +import "google/api/httpbody.proto"; + +// Raw request to be sent to any service API that supports this request format +message RawRequest { + // The data in the raw_body can be plain text json string or encrypted blob + google.api.HttpBody raw_body = 1; +} diff --git a/tools/request_simulation/request_generation_util.cc b/tools/request_simulation/request_generation_util.cc new file mode 100644 index 00000000..59596315 --- /dev/null +++ b/tools/request_simulation/request_generation_util.cc @@ -0,0 +1,62 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/request_simulation/request_generation_util.h" + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tools/request_simulation/request/raw_request.pb.h" + +namespace kv_server { + +// TODO(b/289240702): Make the request json schema easily configurable from +// external configure file + +constexpr std::string_view kKVV2KeyValueDSPRequestBodyFormat = R"json( +{"context": {},"partitions": [{ "id": 0, "compressionGroup": 0,"keyGroups": [{ "tags": [ "custom", "keys" ],"keyList": [ %s ] }] }] })json"; + +std::vector GenerateRandomKeys(int number_of_keys, int key_size) { + std::vector result; + for (int i = 0; i < number_of_keys; ++i) { + result.push_back( + std::string(key_size, 'A' + (std::rand() % number_of_keys))); + } + return result; +} + +std::string CreateKVDSPRequestBodyInJson(const std::vector& keys) { + const std::string comma_seperated_keys = + absl::StrJoin(keys, ",", [](std::string* out, const std::string& key) { + absl::StrAppend(out, "\"", key, "\""); + }); + return absl::StrFormat(kKVV2KeyValueDSPRequestBodyFormat, + comma_seperated_keys); +} + +// TODO(b/289240702): Explore if there is a way to create dynamic Message +// directly from request body in json and protobuf descriptor. So that request +// schema does not have to be tied to a specific proto file, user can only need +// to pass the protoset binary file and the system automatically generate +// corresponding Message based on the message name and constructed descriptor +// pool. Currently it is not easy because +// google::protobuf::util::JsonStringToMessage does not work well with message +// dependent on google.api.HttpBody +kv_server::RawRequest CreatePlainTextRequest( + const std::string& request_in_json) { + kv_server::RawRequest request; + request.mutable_raw_body()->set_data(request_in_json); + return request; +} + +} // namespace kv_server diff --git a/tools/request_simulation/request_generation_util.h b/tools/request_simulation/request_generation_util.h new file mode 100644 index 00000000..2dbf0cae --- /dev/null +++ b/tools/request_simulation/request_generation_util.h @@ -0,0 +1,39 @@ +/* + * Copyright 2023 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef TOOLS_REQUEST_SIMULATION_REQUEST_GENERATION_UTIL_H_ +#define TOOLS_REQUEST_SIMULATION_REQUEST_GENERATION_UTIL_H_ + +#include +#include + +#include "tools/request_simulation/request/raw_request.pb.h" + +namespace kv_server { + +// Generates random keys based on the number of keys and size of each key +std::vector GenerateRandomKeys(int number_of_keys, int key_size); + +// Creates KV DSP request body in json +std::string CreateKVDSPRequestBodyInJson(const std::vector& keys); + +// Creates proto message from request body in json +kv_server::RawRequest CreatePlainTextRequest( + const std::string& request_in_json); + +} // namespace kv_server + +#endif // TOOLS_REQUEST_SIMULATION_REQUEST_GENERATION_UTIL_H_ diff --git a/tools/request_simulation/request_generation_util_test.cc b/tools/request_simulation/request_generation_util_test.cc new file mode 100644 index 00000000..208ced1b --- /dev/null +++ b/tools/request_simulation/request_generation_util_test.cc @@ -0,0 +1,39 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/request_simulation/request_generation_util.h" + +#include "absl/strings/escaping.h" +#include "absl/strings/str_format.h" +#include "google/protobuf/util/json_util.h" +#include "gtest/gtest.h" + +namespace kv_server { +namespace { + +TEST(TestCreateMessage, ProtoMessageMatchJson) { + const auto keys = kv_server::GenerateRandomKeys(10, 3); + const std::string request_in_json = + kv_server::CreateKVDSPRequestBodyInJson(keys); + const auto request = kv_server::CreatePlainTextRequest(request_in_json); + EXPECT_EQ(request_in_json, request.raw_body().data()); + std::string encoded_request_body; + google::protobuf::util::MessageToJsonString(request.raw_body(), + &encoded_request_body); + std::string expect_encoded_request_body = absl::StrCat( + "{\"data\":", "\"", absl::Base64Escape(request_in_json), "\"", "}"); + EXPECT_EQ(encoded_request_body, expect_encoded_request_body); +} +} // namespace +} // namespace kv_server diff --git a/tools/serving_data_generator/test_serving_data_generator.cc b/tools/serving_data_generator/test_serving_data_generator.cc index f161d520..5d6e076a 100644 --- a/tools/serving_data_generator/test_serving_data_generator.cc +++ b/tools/serving_data_generator/test_serving_data_generator.cc @@ -29,31 +29,70 @@ #include "riegeli/records/record_writer.h" ABSL_FLAG(std::string, key, "foo", "Specify the key for lookups"); -ABSL_FLAG(int, value_size, 100, "Specify the size of value for the key"); +ABSL_FLAG(int, value_size, 10, "Specify the size of value for the key"); ABSL_FLAG(std::string, output_dir, "", "Output file directory"); ABSL_FLAG(int, num_records, 5, "Number of records to generate"); -ABSL_FLAG(int64_t, timestamp, 123123123, "Record timestamp"); +ABSL_FLAG(int64_t, timestamp, absl::ToUnixMicros(absl::Now()), + "Record timestamp"); +ABSL_FLAG(bool, generate_set_record, false, + "Whether to generate set record or not"); +ABSL_FLAG(int, num_values_in_set, 10, + "Number of values in the set to generate"); -using kv_server::DeltaFileRecordStruct; -using kv_server::DeltaMutationType; +using kv_server::DataRecordStruct; +using kv_server::KeyValueMutationRecordStruct; +using kv_server::KeyValueMutationType; using kv_server::KVFileMetadata; using kv_server::ToDeltaFileName; +using kv_server::ToFlatBufferBuilder; using kv_server::ToStringView; -void WriteRecords(std::string_view key, int value_size, - riegeli::RecordWriterBase& writer) { +void WriteKeyValueRecords(std::string_view key, int value_size, + riegeli::RecordWriterBase& writer) { const int repetition = absl::GetFlag(FLAGS_num_records); int64_t timestamp = absl::GetFlag(FLAGS_timestamp); - + std::string query(" "); for (int i = 0; i < repetition; ++i) { const std::string value(value_size, 'A' + (i % 50)); - writer.WriteRecord(ToStringView(DeltaFileRecordStruct{ - DeltaMutationType::Update, timestamp++, absl::StrCat(key, i), value} - .ToFlatBuffer())); + auto kv_record = KeyValueMutationRecordStruct{ + KeyValueMutationType::Update, timestamp++, absl::StrCat(key, i), value}; + writer.WriteRecord(ToStringView( + ToFlatBufferBuilder(DataRecordStruct{.record = std::move(kv_record)}))); + absl::StrAppend(&query, "\"", absl::StrCat(key, i), "\"", ", "); } + LOG(INFO) << "Print keys to query " << query; LOG(INFO) << "write done"; } +void WriteKeyValueSetRecords(std::string_view key, int value_size, + riegeli::RecordWriterBase& writer) { + const int repetition = absl::GetFlag(FLAGS_num_records); + int64_t timestamp = absl::GetFlag(FLAGS_timestamp); + const int num_values_in_set = absl::GetFlag(FLAGS_num_values_in_set); + std::string query(" "); + for (int i = 0; i < repetition; ++i) { + std::vector set_copy; + for (int j = 0; j < num_values_in_set; ++j) { + const std::string value(value_size, 'A' + (j % 50)); + set_copy.emplace_back( + absl::StrCat(value, std::to_string(std::rand() % num_values_in_set))); + } + std::vector set; + for (const auto& v : set_copy) { + set.emplace_back(v); + } + absl::StrAppend(&query, absl::StrCat(key, i), " | "); + KeyValueMutationRecordStruct record; + record.value = set; + record.mutation_type = KeyValueMutationType::Update; + record.logical_commit_time = timestamp++; + record.key = absl::StrCat(key, i); + writer.WriteRecord(ToStringView(ToFlatBufferBuilder(record))); + } + LOG(INFO) << "Example set query for all keys" << query; + LOG(INFO) << "write done for set records"; +} + int main(int argc, char** argv) { const std::vector commands = absl::ParseCommandLine(argc, argv); const std::string output_dir = absl::GetFlag(FLAGS_output_dir); @@ -71,7 +110,12 @@ int main(int argc, char** argv) { *metadata.MutableExtension(kv_server::kv_file_metadata) = file_metadata; options.set_metadata(std::move(metadata)); auto record_writer = riegeli::RecordWriter(std::move(os_writer), options); - WriteRecords(key, value_size, record_writer); + if (absl::GetFlag(FLAGS_generate_set_record)) { + WriteKeyValueSetRecords(key, value_size, record_writer); + } else { + WriteKeyValueRecords(key, value_size, record_writer); + } + record_writer.Close(); }; diff --git a/tools/udf/sample_udf/DELTA_1679074299455085 b/tools/udf/sample_udf/DELTA_1688672314376241 similarity index 52% rename from tools/udf/sample_udf/DELTA_1679074299455085 rename to tools/udf/sample_udf/DELTA_1688672314376241 index 45271766..9e5e169e 100644 Binary files a/tools/udf/sample_udf/DELTA_1679074299455085 and b/tools/udf/sample_udf/DELTA_1688672314376241 differ diff --git a/tools/udf/sample_udf/run_query_udf.js b/tools/udf/sample_udf/run_query_udf.js new file mode 100644 index 00000000..65215f1a --- /dev/null +++ b/tools/udf/sample_udf/run_query_udf.js @@ -0,0 +1,41 @@ +/** + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +function HandleRequest(input) { + const keyGroupOutputs = []; + for (const keyGroup of input.keyGroups) { + const keyGroupOutput = {}; + if (!keyGroup.tags.includes("custom") || !keyGroup.tags.includes("queries")) { + continue; + } + keyGroupOutput.tags = keyGroup.tags; + if (!Array.isArray(keyGroup.keyList) || !keyGroup.keyList.length) { + continue; + } + + // Get the first key in the keyList. + const runQueryArray = runQuery(keyGroup.keyList[0]); + // runQuery returns an array of strings when successful and "code" on failure. + // Ignore failures and only add successful runQuery results to output. + if (Array.isArray(runQueryArray) && runQueryArray.length) { + const keyValuesOutput = {}; + keyValuesOutput["result"] = { "value": runQueryArray }; + keyGroupOutput.keyValues = keyValuesOutput; + keyGroupOutputs.push(keyGroupOutput); + } + } + return { keyGroupOutputs, udfOutputApiVersion: 1 }; +} diff --git a/tools/udf/sample_udf/udf.js b/tools/udf/sample_udf/udf.js index fcc5a7e9..1d429614 100644 --- a/tools/udf/sample_udf/udf.js +++ b/tools/udf/sample_udf/udf.js @@ -20,15 +20,20 @@ function HandleRequest(input) { const keyGroupOutput = {}; keyGroupOutput.tags = keyGroup.tags; - const kvPairs = JSON.parse(getValues(keyGroup.keyList)).kvPairs; - const keyValuesOutput = {}; - for (const key in kvPairs) { - if (kvPairs[key].hasOwnProperty("value")) { - keyValuesOutput[key] = { "value": kvPairs[key].value }; + const getValuesResult = JSON.parse(getValues(keyGroup.keyList)); + // getValuesResult returns "kvPairs" when successful and "code" on failure. + // Ignore failures and only add successful getValuesResult lookups to output. + if (getValuesResult.hasOwnProperty("kvPairs")) { + const kvPairs = getValuesResult.kvPairs; + const keyValuesOutput = {}; + for (const key in kvPairs) { + if (kvPairs[key].hasOwnProperty("value")) { + keyValuesOutput[key] = { "value": kvPairs[key].value }; + } } + keyGroupOutput.keyValues = keyValuesOutput; + keyGroupOutputs.push(keyGroupOutput); } - keyGroupOutput.keyValues = keyValuesOutput; - keyGroupOutputs.push(keyGroupOutput); } return {keyGroupOutputs, udfOutputApiVersion: 1}; } diff --git a/tools/udf/udf_generator/BUILD b/tools/udf/udf_generator/BUILD index fb435465..c00c2ceb 100644 --- a/tools/udf/udf_generator/BUILD +++ b/tools/udf/udf_generator/BUILD @@ -28,6 +28,8 @@ cc_binary( "//public/data_loading:filename_utils", "//public/data_loading:records_utils", "//public/data_loading:riegeli_metadata_cc_proto", + "//public/data_loading/writers:delta_record_stream_writer", + "//public/data_loading/writers:delta_record_writer", "//public/udf:constants", "@com_github_google_glog//:glog", "@com_google_absl//absl/flags:flag", diff --git a/tools/udf/udf_generator/udf_delta_file_generator.cc b/tools/udf/udf_generator/udf_delta_file_generator.cc index 148be4f6..03b8352e 100644 --- a/tools/udf/udf_generator/udf_delta_file_generator.cc +++ b/tools/udf/udf_generator/udf_delta_file_generator.cc @@ -27,6 +27,8 @@ #include "public/data_loading/filename_utils.h" #include "public/data_loading/records_utils.h" #include "public/data_loading/riegeli_metadata.pb.h" +#include "public/data_loading/writers/delta_record_stream_writer.h" +#include "public/data_loading/writers/delta_record_writer.h" #include "public/udf/constants.h" #include "riegeli/bytes/ostream_writer.h" #include "riegeli/records/record_writer.h" @@ -42,13 +44,19 @@ ABSL_FLAG(std::string, output_path, "", ABSL_FLAG(int64_t, timestamp, 123123123, "Record timestamp. Default is 123123123."); -using kv_server::DeltaFileRecordStruct; -using kv_server::DeltaMutationType; +using kv_server::DataRecordStruct; +using kv_server::DeltaRecordStreamWriter; +using kv_server::DeltaRecordWriter; +using kv_server::KeyValueMutationRecordStruct; +using kv_server::KeyValueMutationType; using kv_server::kUdfCodeSnippetKey; using kv_server::kUdfHandlerNameKey; using kv_server::KVFileMetadata; using kv_server::ToDeltaFileName; +using kv_server::ToFlatBufferBuilder; using kv_server::ToStringView; +using kv_server::UserDefinedFunctionsConfigStruct; +using kv_server::UserDefinedFunctionsLanguage; absl::StatusOr ReadCodeSnippetAsString(std::string udf_file_path) { std::ifstream ifs(udf_file_path); @@ -60,58 +68,57 @@ absl::StatusOr ReadCodeSnippetAsString(std::string udf_file_path) { return udf; } -absl::Status WriteRecord(std::string udf_file_path, - std::string_view udf_handler_name, - riegeli::RecordWriterBase& writer) { - const int64_t timestamp = absl::GetFlag(FLAGS_timestamp); +absl::Status WriteUdfConfig(std::ostream* output_stream) { + if (!*output_stream) { + return absl::NotFoundError("Invalid output"); + } + const std::string udf_file_path = absl::GetFlag(FLAGS_udf_file_path); + const std::string udf_handler_name = absl::GetFlag(FLAGS_udf_handler_name); + int64_t logical_commit_time = absl::GetFlag(FLAGS_timestamp); absl::StatusOr code_snippet = ReadCodeSnippetAsString(std::move(udf_file_path)); if (!code_snippet.ok()) { return code_snippet.status(); } - writer.WriteRecord(ToStringView( - DeltaFileRecordStruct{DeltaMutationType::Update, timestamp, - kUdfCodeSnippetKey, std::move(code_snippet.value())} - .ToFlatBuffer())); - writer.WriteRecord( - ToStringView(DeltaFileRecordStruct{DeltaMutationType::Update, timestamp, - kUdfHandlerNameKey, udf_handler_name} - .ToFlatBuffer())); - LOG(INFO) << "write done"; + KVFileMetadata metadata; + auto delta_record_writer = DeltaRecordStreamWriter::Create( + *output_stream, DeltaRecordWriter::Options{.metadata = metadata}); + if (!delta_record_writer.ok()) { + return delta_record_writer.status(); + } + + UserDefinedFunctionsConfigStruct udf_config = { + .code_snippet = std::move(*code_snippet), + .handler_name = std::move(udf_handler_name), + .logical_commit_time = logical_commit_time, + .language = UserDefinedFunctionsLanguage::Javascript}; + if (absl::Status status = delta_record_writer.value()->WriteRecord( + DataRecordStruct{.record = std::move(udf_config)}); + !status.ok()) { + return status; + } + delta_record_writer.value()->Close(); return absl::OkStatus(); } +absl::StatusOr CreateDeltaFileName(std::string_view output_dir) { + absl::Time now = absl::Now(); + const auto maybe_name = ToDeltaFileName(absl::ToUnixMicros(now)); + if (!maybe_name.ok()) { + return maybe_name.status(); + } + return absl::StrCat(output_dir, "/", maybe_name.value()); +} + int main(int argc, char** argv) { absl::ParseCommandLine(argc, argv); const std::string output_path = absl::GetFlag(FLAGS_output_path); const std::string output_dir = absl::GetFlag(FLAGS_output_dir); - auto write_records = [](std::ostream* os) { - if (!*os) { - return absl::NotFoundError("Invalid output path"); - } - const std::string udf_file_path = absl::GetFlag(FLAGS_udf_file_path); - const std::string udf_handler_name = absl::GetFlag(FLAGS_udf_handler_name); - - auto os_writer = riegeli::OStreamWriter(os); - riegeli::RecordWriterBase::Options options; - options.set_uncompressed(); - riegeli::RecordsMetadata metadata; - KVFileMetadata file_metadata; - *metadata.MutableExtension(kv_server::kv_file_metadata) = file_metadata; - options.set_metadata(std::move(metadata)); - auto record_writer = riegeli::RecordWriter(std::move(os_writer), options); - const auto write_status = - WriteRecord(std::move(udf_file_path), udf_handler_name, record_writer); - record_writer.Close(); - return write_status; - }; - - absl::Status write_status; if (output_path == "-" || (output_path.empty() && output_dir.empty())) { LOG(INFO) << "Writing records to console"; - write_status = write_records(&std::cout); + const auto write_status = WriteUdfConfig(&std::cout); if (!write_status.ok()) { LOG(ERROR) << "Error writing records: " << write_status; return -1; @@ -132,9 +139,10 @@ int main(int argc, char** argv) { outfile = absl::StrCat(output_dir, "/", maybe_name.value()); } } + LOG(INFO) << "Writing records to " << outfile; std::ofstream ofs(outfile); - write_status = write_records(&ofs); + const auto write_status = WriteUdfConfig(&ofs); ofs.close(); if (!write_status.ok()) { LOG(ERROR) << "Error writing records: " << write_status; diff --git a/tools/udf/udf_tester/BUILD b/tools/udf/udf_tester/BUILD index cb3aae41..23baff89 100644 --- a/tools/udf/udf_tester/BUILD +++ b/tools/udf/udf_tester/BUILD @@ -27,6 +27,7 @@ cc_binary( "//components/data_server/cache:key_value_cache", "//components/udf:cache_get_values_hook", "//components/udf:udf_client", + "//components/udf:udf_config_builder", "//public/data_loading:data_loading_fbs", "//public/data_loading/readers:delta_record_stream_reader", "//public/udf:constants", diff --git a/tools/udf/udf_tester/udf_delta_file_tester.cc b/tools/udf/udf_tester/udf_delta_file_tester.cc index 1ea57929..9af4acdf 100644 --- a/tools/udf/udf_tester/udf_delta_file_tester.cc +++ b/tools/udf/udf_tester/udf_delta_file_tester.cc @@ -21,6 +21,7 @@ #include "components/data_server/cache/key_value_cache.h" #include "components/udf/cache_get_values_hook.h" #include "components/udf/udf_client.h" +#include "components/udf/udf_config_builder.h" #include "glog/logging.h" #include "nlohmann/json.hpp" #include "public/data_loading/data_loading_generated.h" @@ -40,64 +41,72 @@ ABSL_FLAG(std::string, namespace_tag, "keys", namespace kv_server { +absl::Status LoadCacheFromKVMutationRecord( + const KeyValueMutationRecordStruct& record, Cache& cache) { + switch (record.mutation_type) { + case KeyValueMutationType::Update: { + LOG(INFO) << "Updating cache with key " << record.key << ", value " + << std::get(record.value) + << ", logical commit time " << record.logical_commit_time; + cache.UpdateKeyValue(record.key, std::get(record.value), + record.logical_commit_time); + break; + } + case KeyValueMutationType::Delete: { + cache.DeleteKey(record.key, record.logical_commit_time); + break; + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Invalid mutation type: ", + EnumNameKeyValueMutationType(record.mutation_type))); + } + return absl::OkStatus(); +} + absl::Status LoadCacheFromFile(std::string file_path, Cache& cache) { std::ifstream delta_file(file_path); DeltaRecordStreamReader record_reader(delta_file); absl::Status status = - record_reader.ReadRecords([&cache](const DeltaFileRecordStruct& record) { - switch (record.mutation_type) { - case DeltaMutationType::Update: { - LOG(INFO) << "Updating cache with key " << record.key << ", value " - << record.value << ", logical commit time " - << record.logical_commit_time; - cache.UpdateKeyValue(record.key, record.value, - record.logical_commit_time); - break; - } - case DeltaMutationType::Delete: { - cache.DeleteKey(record.key, record.logical_commit_time); - break; - } - default: - return absl::InvalidArgumentError( - absl::StrCat("Invalid mutation type: ", - EnumNameDeltaMutationType(record.mutation_type))); + record_reader.ReadRecords([&cache](const DataRecordStruct& data_record) { + // Only load KVMutationRecords into cache. + if (std::holds_alternative( + data_record.record)) { + return LoadCacheFromKVMutationRecord( + std::get(data_record.record), + cache); } return absl::OkStatus(); }); return status; } +void ReadCodeConfigFromUdfConfig( + const UserDefinedFunctionsConfigStruct& udf_config, + CodeConfig& code_config) { + code_config.js = udf_config.code_snippet; + code_config.logical_commit_time = udf_config.logical_commit_time; + code_config.udf_handler_name = udf_config.handler_name; +} + absl::Status ReadCodeConfigFromFile(std::string file_path, CodeConfig& code_config) { std::ifstream delta_file(file_path); DeltaRecordStreamReader record_reader(delta_file); absl::Status status = record_reader.ReadRecords( - [&code_config](const DeltaFileRecordStruct& record) { - if (record.mutation_type != DeltaMutationType::Update) { - // Ignore non-updates + [&code_config](const DataRecordStruct& data_record) { + if (std::holds_alternative( + data_record.record)) { + ReadCodeConfigFromUdfConfig( + std::get(data_record.record), + code_config); return absl::OkStatus(); } - if (record.key == kUdfHandlerNameKey) { - code_config.udf_handler_name = record.value; - } - if (record.key == kUdfCodeSnippetKey) { - code_config.js = record.value; - } - return absl::OkStatus(); + return absl::InvalidArgumentError("Invalid record type."); }); if (!status.ok()) { return status; } - - if (code_config.udf_handler_name.empty()) { - return absl::InvalidArgumentError( - "Missing `udf_handler_name` key in delta file."); - } - if (code_config.js.empty()) { - return absl::InvalidArgumentError( - "Missing `udf_code_snippet` key in delta file."); - } return absl::OkStatus(); } @@ -155,8 +164,12 @@ absl::Status TestUdf(std::string kv_delta_file_path, } LOG(INFO) << "Starting UDF client"; - auto udf_client = UdfClient::Create( - UdfClient::ConfigWithGetValuesHook(*NewCacheGetValuesHook(*cache), 1)); + UdfConfigBuilder config_builder; + auto hook = NewCacheGetValuesHook(*cache); + auto udf_client = + UdfClient::Create(config_builder.RegisterGetValuesHook(*hook) + .SetNumberOfWorkers(1) + .Config()); if (!udf_client.ok()) { LOG(ERROR) << "Error starting UDF execution engine: " << udf_client.status(); diff --git a/version.txt b/version.txt index 2774f858..142464bf 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.10.0 \ No newline at end of file +0.11.0 \ No newline at end of file