From 1e7aa67e307ce0859a936b2bc92443ae81d23606 Mon Sep 17 00:00:00 2001 From: LingKa Date: Fri, 1 Dec 2023 13:33:54 +0800 Subject: [PATCH] feat: add protocol client Signed-off-by: LingKa --- .github/workflows/protobuf/action.yaml | 7 +- client/__about__.py | 7 + client/error.py | 52 ++++++ client/protocol.py | 218 +++++++++++++++++++++++++ scripts/quick_start.sh | 2 +- tests/__init__.py | 6 + tests/protocol_test.py | 39 +++++ 7 files changed, 325 insertions(+), 6 deletions(-) create mode 100644 client/__about__.py create mode 100644 client/error.py create mode 100644 client/protocol.py create mode 100755 tests/protocol_test.py diff --git a/.github/workflows/protobuf/action.yaml b/.github/workflows/protobuf/action.yaml index 3a6caf0..f624b22 100644 --- a/.github/workflows/protobuf/action.yaml +++ b/.github/workflows/protobuf/action.yaml @@ -3,11 +3,8 @@ name: Generate API runs: using: "composite" steps: - - name: Install gRPC - run: python3 -m pip install grpcio - shell: bash - - name: Install gRPC tools - run: python3 -m pip install grpcio-tools + - name: Install gRPC & gRPC tools + run: python3 -m pip install grpcio grpcio-tools shell: bash - name: Initialize Git Submodules diff --git a/client/__about__.py b/client/__about__.py new file mode 100644 index 0000000..b950b91 --- /dev/null +++ b/client/__about__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present LingKa +# +# SPDX-License-Identifier: Apache 2.0 + +"""Xline clients""" + +__version__ = "0.0.1" diff --git a/client/error.py b/client/error.py new file mode 100644 index 0000000..c0ac471 --- /dev/null +++ b/client/error.py @@ -0,0 +1,52 @@ +""" +Client Errors +""" + +from api.curp.curp_error_pb2 import ( + ProposeError as _ProposeError, + CommandSyncError as _CommandSyncError, + WaitSyncError as _WaitSyncError, +) +from api.xline.xline_error_pb2 import ExecuteError as _ExecuteError + + +class ResDecodeError(Exception): + """Response decode error""" + + pass + + +class ProposeError(BaseException): + """Propose error""" + + inner: _ProposeError + + def __init__(self, err: _ProposeError) -> None: + self.inner = err + + +class CommandSyncError(BaseException): + """Command sync error""" + + inner: _CommandSyncError + + def __init__(self, err: _CommandSyncError) -> None: + self.inner = err + + +class WaitSyncError(BaseException): + """Wait sync error""" + + inner: _WaitSyncError + + def __init__(self, err: _WaitSyncError) -> None: + self.inner = err + + +class ExecuteError(BaseException): + """Execute error""" + + inner: _ExecuteError + + def __init__(self, err: _ExecuteError) -> None: + self.inner = err diff --git a/client/protocol.py b/client/protocol.py new file mode 100644 index 0000000..7ae5d19 --- /dev/null +++ b/client/protocol.py @@ -0,0 +1,218 @@ +""" +Protocol Client +""" + +from __future__ import annotations +import asyncio +import logging +import grpc + +from google.protobuf.internal.containers import RepeatedCompositeFieldContainer +from api.curp.message_pb2_grpc import ProtocolStub +from api.curp.message_pb2 import FetchClusterRequest, FetchClusterResponse, Member +from api.curp.curp_command_pb2 import ProposeRequest, ProposeResponse, WaitSyncedRequest, WaitSyncedResponse +from api.xline.xline_command_pb2 import Command, CommandResponse, SyncResponse +from client.error import ResDecodeError, CommandSyncError, WaitSyncError, ExecuteError +from api.curp.curp_error_pb2 import ( + CommandSyncError as _CommandSyncError, + WaitSyncError as _WaitSyncError, +) +from api.xline.xline_error_pb2 import ExecuteError as _ExecuteError + + +class ProtocolClient: + """ + Protocol client + + Attributes: + local_server_id: Local server id. Only use in an inner client. + state: state of a client + inner: inner protocol clients + connects: all servers's `Connect` + """ + + leader_id: int + inner: list[ProtocolStub] + connects: RepeatedCompositeFieldContainer[Member] + + def __init__( + self, + leader_id: int, + stubs: list[ProtocolStub], + connects: RepeatedCompositeFieldContainer[Member], + ) -> None: + self.leader_id = leader_id + self.inner = stubs + self.connects = connects + + @classmethod + def build_from_addrs(cls, addrs: list[str]) -> ProtocolClient: + """Build client from addresses, this method will fetch all members from servers""" + + stubs: list[ProtocolStub] = [] + + for addr in addrs: + channel = grpc.insecure_channel(addr) + stub = ProtocolStub(channel) + stubs.append(stub) + + cluster = fetch_cluster(stubs) + + return cls( + cluster.leader_id, + stubs, + cluster.members, + ) + + def propose(self, cmd: Command, use_fast_path: bool = False) -> tuple[CommandResponse, SyncResponse | None]: + """Propose the request to servers, if use_fast_path is false, it will wait for the synced index""" + if use_fast_path: + return asyncio.run(self.fast_path(cmd)) + else: + return asyncio.run(self.slow_path(cmd)) + + async def fast_path(self, cmd: Command) -> tuple[CommandResponse, SyncResponse | None]: + """Fast path of propose""" + + fast_task = asyncio.create_task(self.fast_round(cmd)) + slow_task = asyncio.create_task(self.slow_round(cmd)) + + done, pending = await asyncio.wait([fast_task, slow_task], return_when=asyncio.FIRST_COMPLETED) + + for task in pending: + task.cancel() + + for task in done: + first, second = await task + if isinstance(first, CommandResponse) and isinstance(second, bool): + return (first, None) + if isinstance(second, CommandResponse) and isinstance(first, SyncResponse): + return (second, first) + + msg = "fast path error" + raise Exception(msg) + + async def slow_path(self, cmd: Command) -> tuple[CommandResponse, SyncResponse]: + """Slow path of propose""" + + fast_task = asyncio.create_task(self.fast_round(cmd)) + slow_task = asyncio.create_task(self.slow_round(cmd)) + + results = await asyncio.gather(fast_task, slow_task) + + for result in results: + if isinstance(result[0], SyncResponse) and isinstance(result[1], CommandResponse): + return (result[1], result[0]) + + msg = "slow path error" + raise Exception(msg) + + async def fast_round(self, cmd: Command) -> tuple[CommandResponse | None, bool]: + """ + The fast round of Curp protocol + It broadcast the requests to all the curp servers. + """ + + logging.info("fast round start. propose id: %s", cmd.propose_id) + + ok_cnt = 0 + is_received_leader_res = False + cmd_res = CommandResponse() + exe_err = ExecuteError(_ExecuteError()) + + for stub in self.inner: + res = await propose_wrapper(stub, cmd) + + if res.HasField("result"): + cmd_result = res.result + ok_cnt += 1 + is_received_leader_res = True + if cmd_result.HasField("er"): + cmd_res.ParseFromString(cmd_result.er) + if cmd_result.HasField("error"): + exe_err.inner.ParseFromString(cmd_result.error) + raise exe_err + elif res.HasField("error"): + raise res.error + else: + ok_cnt += 1 + + if is_received_leader_res and ok_cnt >= super_quorum(len(self.connects)): + logging.info("fast round succeed. propose id: %s", cmd.propose_id) + return (cmd_res, True) + + logging.info("fast round failed. propose id: %s", cmd.propose_id) + return (cmd_res, False) + + async def slow_round(self, cmd: Command) -> tuple[SyncResponse, CommandResponse]: + """The slow round of Curp protocol""" + + logging.info("slow round start. propose id: %s", cmd.propose_id) + + addr = "" + sync_res = SyncResponse() + cmd_res = CommandResponse() + exe_err = CommandSyncError(_CommandSyncError()) + after_sync_err = WaitSyncError(_WaitSyncError()) + + for member in self.connects: + if member.id == self.leader_id: + addr = member.name + break + + channel = grpc.insecure_channel(addr) + stub = ProtocolStub(channel) + res = await wait_synced_wrapper(stub, cmd) + + if res.HasField("success"): + success = res.success + sync_res.ParseFromString(success.after_sync_result) + cmd_res.ParseFromString(success.exe_result) + logging.info("slow round succeed. propose id: %s", cmd.propose_id) + return (sync_res, cmd_res) + if res.HasField("error"): + cmd_sync_err = res.error + if cmd_sync_err.HasField("execute"): + exe_err.inner.ParseFromString(cmd_sync_err.execute) + raise exe_err + if cmd_sync_err.HasField("after_sync"): + after_sync_err.inner.ParseFromString(cmd_sync_err.after_sync) + raise after_sync_err + + err_msg = "Response decode error" + raise ResDecodeError(err_msg) + + +def fetch_cluster(stubs: list[ProtocolStub]) -> FetchClusterResponse: + """ + Fetch cluster from server + TODO: fetch cluster + """ + for stub in stubs: + res: FetchClusterResponse = stub.FetchCluster(FetchClusterRequest()) + return res + + +def super_quorum(nodes: int) -> int: + """ + Get the superquorum for curp protocol + Although curp can proceed with f + 1 available replicas, it needs f + 1 + (f + 1)/2 replicas + (for superquorum of witnesses) to use 1 RTT operations. With less than superquorum replicas, + clients must ask masters to commit operations in f + 1 replicas before returning result.(2 RTTs). + """ + fault_tolerance = nodes // 2 + quorum = fault_tolerance + 1 + superquorum = fault_tolerance + (quorum // 2) + 1 + return superquorum + + +async def propose_wrapper(stub: ProtocolStub, req: Command) -> ProposeResponse: + """Wrapper of propose""" + res: ProposeResponse = stub.Propose(ProposeRequest(command=req.SerializeToString())) + return res + + +async def wait_synced_wrapper(stub: ProtocolStub, req: Command) -> WaitSyncedResponse: + """Wrapper of wait sync""" + res: WaitSyncedResponse = stub.WaitSynced(WaitSyncedRequest(propose_id=req.propose_id)) + return res diff --git a/scripts/quick_start.sh b/scripts/quick_start.sh index f522f54..57adf88 100755 --- a/scripts/quick_start.sh +++ b/scripts/quick_start.sh @@ -55,7 +55,7 @@ stop_all() { run_container() { echo container starting size=${1} - image="ghcr.io/xline-kv/xline:latest" + image="ghcr.io/xline-kv/xline:b573f16" for ((i = 1; i <= ${size}; i++)); do docker run -d -it --rm --name=node${i} --net=xline_net --ip=${SERVERS[$i]} --cap-add=NET_ADMIN --cpu-shares=1024 -m=512M -v ${DIR}:/mnt ${image} bash & done diff --git a/tests/__init__.py b/tests/__init__.py index 91a12bb..69fc850 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,9 @@ # SPDX-FileCopyrightText: 2023-present LingKa # # SPDX-License-Identifier: Apache 2.0 + +import sys + +sys.path.append("./api/curp") + +sys.path.append("./api/xline") diff --git a/tests/protocol_test.py b/tests/protocol_test.py new file mode 100755 index 0000000..5fce5b4 --- /dev/null +++ b/tests/protocol_test.py @@ -0,0 +1,39 @@ +"""Tests for the protocol client.""" + +import unittest +import uuid + +from api.xline.xline_command_pb2 import Command, RequestWithToken +from api.xline.rpc_pb2 import PutRequest +from client.protocol import ProtocolClient + + +class TestProtocolClient(unittest.TestCase): + """test protocol client""" + + def setUp(self) -> None: + curp_members = ["172.20.0.3:2379", "172.20.0.4:2379", "172.20.0.5:2379"] + + cmd = Command( + request=RequestWithToken( + put_request=PutRequest( + key=b"hello", + value=b"py-xline", + ) + ), + propose_id=f"client-{uuid.uuid4()}", + ) + + self.cmd = cmd + self.client = ProtocolClient.build_from_addrs(curp_members) + + def test_fast_path(self): + """test fast path""" + er, _ = self.client.propose(self.cmd, True) + self.assertTrue(er.HasField("put_response")) + + def test_slow_path(self): + """test slow path""" + er, asr = self.client.propose(self.cmd, False) + self.assertIsNotNone(asr) + self.assertTrue(er.HasField("put_response"))