-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: LingKa <[email protected]>
- Loading branch information
Showing
7 changed files
with
327 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2023-present LingKa <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache 2.0 | ||
|
||
"""Xline clients""" | ||
|
||
__version__ = "0.0.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
""" | ||
Client Errors | ||
""" | ||
|
||
from api.curp.curp_error_pb2 import ( | ||
ProposeError as _ProposeError, | ||
CommandSyncError as _CommandSyncError, | ||
WaitSyncError as _WaitSyncError, | ||
) | ||
from api.xline.xline_error_pb2 import ExecuteError as _ExecuteError | ||
|
||
|
||
class ResDecodeError(Exception): | ||
"""Response decode error""" | ||
|
||
pass | ||
|
||
|
||
class ProposeError(BaseException): | ||
"""Propose error""" | ||
|
||
inner: _ProposeError | ||
|
||
def __init__(self, err: _ProposeError) -> None: | ||
self.inner = err | ||
|
||
|
||
class CommandSyncError(BaseException): | ||
"""Command sync error""" | ||
|
||
inner: _CommandSyncError | ||
|
||
def __init__(self, err: _CommandSyncError) -> None: | ||
self.inner = err | ||
|
||
|
||
class WaitSyncError(BaseException): | ||
"""Wait sync error""" | ||
|
||
inner: _WaitSyncError | ||
|
||
def __init__(self, err: _WaitSyncError) -> None: | ||
self.inner = err | ||
|
||
|
||
class ExecuteError(BaseException): | ||
"""Execute error""" | ||
|
||
inner: _ExecuteError | ||
|
||
def __init__(self, err: _ExecuteError) -> None: | ||
self.inner = err |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
""" | ||
Protocol Client | ||
""" | ||
|
||
from __future__ import annotations | ||
import asyncio | ||
import logging | ||
import grpc | ||
|
||
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer | ||
from api.curp.message_pb2_grpc import ProtocolStub | ||
from api.curp.message_pb2 import FetchClusterRequest, FetchClusterResponse, Member | ||
from api.curp.curp_command_pb2 import ProposeRequest, ProposeResponse, WaitSyncedRequest, WaitSyncedResponse | ||
from api.xline.xline_command_pb2 import Command, CommandResponse, SyncResponse | ||
from client.error import ResDecodeError, CommandSyncError, WaitSyncError, ExecuteError | ||
from api.curp.curp_error_pb2 import ( | ||
CommandSyncError as _CommandSyncError, | ||
WaitSyncError as _WaitSyncError, | ||
) | ||
from api.xline.xline_error_pb2 import ExecuteError as _ExecuteError | ||
|
||
|
||
class ProtocolClient: | ||
""" | ||
Protocol client | ||
Attributes: | ||
local_server_id: Local server id. Only use in an inner client. | ||
state: state of a client | ||
inner: inner protocol clients | ||
connects: all servers's `Connect` | ||
""" | ||
|
||
leader_id: int | ||
inner: list[ProtocolStub] | ||
connects: RepeatedCompositeFieldContainer[Member] | ||
|
||
def __init__( | ||
self, | ||
leader_id: int, | ||
stubs: list[ProtocolStub], | ||
connects: RepeatedCompositeFieldContainer[Member], | ||
) -> None: | ||
self.leader_id = leader_id | ||
self.inner = stubs | ||
self.connects = connects | ||
|
||
@classmethod | ||
def build_from_addrs(cls, addrs: list[str]) -> ProtocolClient: | ||
"""Build client from addresses, this method will fetch all members from servers""" | ||
|
||
stubs: list[ProtocolStub] = [] | ||
|
||
for addr in addrs: | ||
channel = grpc.insecure_channel(addr) | ||
stub = ProtocolStub(channel) | ||
stubs.append(stub) | ||
|
||
cluster = fetch_cluster(stubs) | ||
|
||
return cls( | ||
cluster.leader_id, | ||
stubs, | ||
cluster.members, | ||
) | ||
|
||
def propose(self, cmd: Command, use_fast_path: bool = False) -> tuple[CommandResponse, SyncResponse | None]: | ||
"""Propose the request to servers, if use_fast_path is false, it will wait for the synced index""" | ||
logging.info("propose start. propose id: %s", cmd.propose_id) | ||
|
||
if use_fast_path: | ||
return asyncio.run(self.fast_path(cmd)) | ||
else: | ||
return asyncio.run(self.slow_path(cmd)) | ||
|
||
async def fast_path(self, cmd: Command) -> tuple[CommandResponse, SyncResponse | None]: | ||
"""Fast path of propose""" | ||
logging.info("fast path start. propose id: %s", cmd.propose_id) | ||
|
||
fast_task = asyncio.create_task(self.fast_round(cmd)) | ||
slow_task = asyncio.create_task(self.slow_round(cmd)) | ||
|
||
done, pending = await asyncio.wait([fast_task, slow_task], return_when=asyncio.FIRST_COMPLETED) | ||
|
||
for task in pending: | ||
task.cancel() | ||
|
||
for task in done: | ||
first, second = await task | ||
if isinstance(first, CommandResponse) and isinstance(second, bool): | ||
return (first, None) | ||
if isinstance(second, CommandResponse) and isinstance(first, SyncResponse): | ||
return (second, first) | ||
|
||
msg = "fast path error" | ||
raise Exception(msg) | ||
|
||
async def slow_path(self, cmd: Command) -> tuple[CommandResponse, SyncResponse]: | ||
"""Slow path of propose""" | ||
|
||
results = await asyncio.gather(self.fast_round(cmd), self.slow_round(cmd)) | ||
|
||
for result in results: | ||
if isinstance(result[0], SyncResponse) and isinstance(result[1], CommandResponse): | ||
return (result[1], result[0]) | ||
|
||
msg = "slow path error" | ||
raise Exception(msg) | ||
|
||
async def fast_round(self, cmd: Command) -> tuple[CommandResponse | None, bool]: | ||
""" | ||
The fast round of Curp protocol | ||
It broadcast the requests to all the curp servers. | ||
""" | ||
|
||
logging.info("fast round start. propose id: %s", cmd.propose_id) | ||
|
||
ok_cnt = 0 | ||
is_received_leader_res = False | ||
cmd_res = CommandResponse() | ||
exe_err = ExecuteError(_ExecuteError()) | ||
propose_tasks = [] | ||
|
||
for stub in self.inner: | ||
propose_tasks.append(asyncio.create_task(propose_wrapper(stub, cmd))) | ||
|
||
done, pending = await asyncio.wait(propose_tasks, return_when=asyncio.FIRST_COMPLETED) | ||
|
||
for task in pending: | ||
task.cancel() | ||
|
||
for _res in done: | ||
res = await _res | ||
if res.HasField("result"): | ||
cmd_result = res.result | ||
ok_cnt += 1 | ||
is_received_leader_res = True | ||
if cmd_result.HasField("er"): | ||
cmd_res.ParseFromString(cmd_result.er) | ||
if cmd_result.HasField("error"): | ||
exe_err.inner.ParseFromString(cmd_result.error) | ||
raise exe_err | ||
elif res.HasField("error"): | ||
raise res.error | ||
else: | ||
ok_cnt += 1 | ||
|
||
if is_received_leader_res and ok_cnt >= super_quorum(len(self.connects)): | ||
logging.info("fast round succeed. propose id: %s", cmd.propose_id) | ||
return (cmd_res, True) | ||
|
||
logging.info("fast round failed. propose id: %s", cmd.propose_id) | ||
return (cmd_res, False) | ||
|
||
async def slow_round(self, cmd: Command) -> tuple[SyncResponse, CommandResponse]: | ||
"""The slow round of Curp protocol""" | ||
|
||
logging.info("slow round start. propose id: %s", cmd.propose_id) | ||
|
||
addr = "" | ||
sync_res = SyncResponse() | ||
cmd_res = CommandResponse() | ||
exe_err = CommandSyncError(_CommandSyncError()) | ||
after_sync_err = WaitSyncError(_WaitSyncError()) | ||
|
||
for member in self.connects: | ||
if member.id == self.leader_id: | ||
addr = member.name | ||
break | ||
|
||
channel = grpc.aio.insecure_channel(addr) | ||
stub = ProtocolStub(channel) | ||
res = await stub.WaitSynced(WaitSyncedRequest(propose_id=cmd.propose_id)) | ||
|
||
if res.HasField("success"): | ||
success = res.success | ||
sync_res.ParseFromString(success.after_sync_result) | ||
cmd_res.ParseFromString(success.exe_result) | ||
logging.info("slow round succeed. propose id: %s", cmd.propose_id) | ||
return (sync_res, cmd_res) | ||
if res.HasField("error"): | ||
cmd_sync_err = res.error | ||
if cmd_sync_err.HasField("execute"): | ||
exe_err.inner.ParseFromString(cmd_sync_err.execute) | ||
raise exe_err | ||
if cmd_sync_err.HasField("after_sync"): | ||
after_sync_err.inner.ParseFromString(cmd_sync_err.after_sync) | ||
raise after_sync_err | ||
|
||
err_msg = "Response decode error" | ||
raise ResDecodeError(err_msg) | ||
|
||
|
||
def fetch_cluster(stubs: list[ProtocolStub]) -> FetchClusterResponse: | ||
""" | ||
Fetch cluster from server | ||
TODO: fetch cluster | ||
""" | ||
for stub in stubs: | ||
res: FetchClusterResponse = stub.FetchCluster(FetchClusterRequest()) | ||
return res | ||
|
||
|
||
def super_quorum(nodes: int) -> int: | ||
""" | ||
Get the superquorum for curp protocol | ||
Although curp can proceed with f + 1 available replicas, it needs f + 1 + (f + 1)/2 replicas | ||
(for superquorum of witnesses) to use 1 RTT operations. With less than superquorum replicas, | ||
clients must ask masters to commit operations in f + 1 replicas before returning result.(2 RTTs). | ||
""" | ||
fault_tolerance = nodes // 2 | ||
quorum = fault_tolerance + 1 | ||
superquorum = fault_tolerance + (quorum // 2) + 1 | ||
return superquorum | ||
|
||
|
||
async def propose_wrapper(stub: ProtocolStub, req: Command) -> ProposeResponse: | ||
"""Wrapper of propose""" | ||
res: ProposeResponse = stub.Propose(ProposeRequest(command=req.SerializeToString())) | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
# SPDX-FileCopyrightText: 2023-present LingKa <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache 2.0 | ||
|
||
import sys | ||
|
||
sys.path.append("./api/curp") | ||
|
||
sys.path.append("./api/xline") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""Tests for the protocol client.""" | ||
|
||
import unittest | ||
import uuid | ||
|
||
from api.xline.xline_command_pb2 import Command, RequestWithToken | ||
from api.xline.rpc_pb2 import PutRequest | ||
from client.protocol import ProtocolClient | ||
|
||
|
||
class TestProtocolClient(unittest.TestCase): | ||
"""test protocol client""" | ||
|
||
def setUp(self) -> None: | ||
curp_members = ["172.20.0.3:2379", "172.20.0.4:2379", "172.20.0.5:2379"] | ||
|
||
cmd = Command( | ||
request=RequestWithToken( | ||
put_request=PutRequest( | ||
key=b"hello", | ||
value=b"py-xline", | ||
) | ||
), | ||
propose_id=f"client-{uuid.uuid4()}", | ||
) | ||
|
||
self.cmd = cmd | ||
self.client = ProtocolClient.build_from_addrs(curp_members) | ||
|
||
def test_fast_path(self): | ||
"""test fast path""" | ||
er, _ = self.client.propose(self.cmd, True) | ||
self.assertTrue(er.HasField("put_response")) | ||
|
||
def test_slow_path(self): | ||
"""test slow path""" | ||
er, asr = self.client.propose(self.cmd, False) | ||
self.assertIsNotNone(asr) | ||
self.assertTrue(er.HasField("put_response")) |