Skip to content

Commit

Permalink
feat: protocol client
Browse files Browse the repository at this point in the history
Signed-off-by: LingKa <[email protected]>
  • Loading branch information
LingKa28 committed Dec 2, 2023
1 parent 621d72a commit f1edd18
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 6 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/protobuf/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ name: Generate API
runs:
using: "composite"
steps:
- name: Install gRPC
run: python3 -m pip install grpcio
shell: bash
- name: Install gRPC tools
run: python3 -m pip install grpcio-tools
- name: Install gRPC & gRPC tools
run: python3 -m pip install grpcio grpcio-tools
shell: bash

- name: Initialize Git Submodules
Expand Down
7 changes: 7 additions & 0 deletions client/__about__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-FileCopyrightText: 2023-present LingKa <[email protected]>
#
# SPDX-License-Identifier: Apache 2.0

"""Xline clients"""

__version__ = "0.0.1"
52 changes: 52 additions & 0 deletions client/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Client Errors
"""

from api.curp.curp_error_pb2 import (
ProposeError as _ProposeError,
CommandSyncError as _CommandSyncError,
WaitSyncError as _WaitSyncError,
)
from api.xline.xline_error_pb2 import ExecuteError as _ExecuteError


class ResDecodeError(Exception):
"""Response decode error"""

pass


class ProposeError(BaseException):
"""Propose error"""

inner: _ProposeError

def __init__(self, err: _ProposeError) -> None:
self.inner = err


class CommandSyncError(BaseException):
"""Command sync error"""

inner: _CommandSyncError

def __init__(self, err: _CommandSyncError) -> None:
self.inner = err


class WaitSyncError(BaseException):
"""Wait sync error"""

inner: _WaitSyncError

def __init__(self, err: _WaitSyncError) -> None:
self.inner = err


class ExecuteError(BaseException):
"""Execute error"""

inner: _ExecuteError

def __init__(self, err: _ExecuteError) -> None:
self.inner = err
220 changes: 220 additions & 0 deletions client/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""
Protocol Client
"""

from __future__ import annotations
import asyncio
import logging
import grpc

from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
from api.curp.message_pb2_grpc import ProtocolStub
from api.curp.message_pb2 import FetchClusterRequest, FetchClusterResponse, Member
from api.curp.curp_command_pb2 import ProposeRequest, ProposeResponse, WaitSyncedRequest, WaitSyncedResponse
from api.xline.xline_command_pb2 import Command, CommandResponse, SyncResponse
from client.error import ResDecodeError, CommandSyncError, WaitSyncError, ExecuteError
from api.curp.curp_error_pb2 import (
CommandSyncError as _CommandSyncError,
WaitSyncError as _WaitSyncError,
)
from api.xline.xline_error_pb2 import ExecuteError as _ExecuteError


class ProtocolClient:
"""
Protocol client
Attributes:
local_server_id: Local server id. Only use in an inner client.
state: state of a client
inner: inner protocol clients
connects: all servers's `Connect`
"""

leader_id: int
inner: list[ProtocolStub]
connects: RepeatedCompositeFieldContainer[Member]

def __init__(
self,
leader_id: int,
stubs: list[ProtocolStub],
connects: RepeatedCompositeFieldContainer[Member],
) -> None:
self.leader_id = leader_id
self.inner = stubs
self.connects = connects

@classmethod
def build_from_addrs(cls, addrs: list[str]) -> ProtocolClient:
"""Build client from addresses, this method will fetch all members from servers"""

stubs: list[ProtocolStub] = []

for addr in addrs:
channel = grpc.insecure_channel(addr)
stub = ProtocolStub(channel)
stubs.append(stub)

cluster = fetch_cluster(stubs)

return cls(
cluster.leader_id,
stubs,
cluster.members,
)

def propose(self, cmd: Command, use_fast_path: bool = False) -> tuple[CommandResponse, SyncResponse | None]:
"""Propose the request to servers, if use_fast_path is false, it will wait for the synced index"""
logging.info("propose start. propose id: %s", cmd.propose_id)

if use_fast_path:
return asyncio.run(self.fast_path(cmd))
else:
return asyncio.run(self.slow_path(cmd))

async def fast_path(self, cmd: Command) -> tuple[CommandResponse, SyncResponse | None]:
"""Fast path of propose"""
logging.info("fast path start. propose id: %s", cmd.propose_id)

fast_task = asyncio.create_task(self.fast_round(cmd))
slow_task = asyncio.create_task(self.slow_round(cmd))

done, pending = await asyncio.wait([fast_task, slow_task], return_when=asyncio.FIRST_COMPLETED)

for task in pending:
task.cancel()

for task in done:
first, second = await task
if isinstance(first, CommandResponse) and isinstance(second, bool):
return (first, None)
if isinstance(second, CommandResponse) and isinstance(first, SyncResponse):
return (second, first)

msg = "fast path error"
raise Exception(msg)

async def slow_path(self, cmd: Command) -> tuple[CommandResponse, SyncResponse]:
"""Slow path of propose"""

results = await asyncio.gather(self.fast_round(cmd), self.slow_round(cmd))

for result in results:
if isinstance(result[0], SyncResponse) and isinstance(result[1], CommandResponse):
return (result[1], result[0])

msg = "slow path error"
raise Exception(msg)

async def fast_round(self, cmd: Command) -> tuple[CommandResponse | None, bool]:
"""
The fast round of Curp protocol
It broadcast the requests to all the curp servers.
"""

logging.info("fast round start. propose id: %s", cmd.propose_id)

ok_cnt = 0
is_received_leader_res = False
cmd_res = CommandResponse()
exe_err = ExecuteError(_ExecuteError())
propose_tasks = []

for stub in self.inner:
propose_tasks.append(asyncio.create_task(propose_wrapper(stub, cmd)))

done, pending = await asyncio.wait(propose_tasks, return_when=asyncio.FIRST_COMPLETED)

for task in pending:
task.cancel()

for _res in done:
res = await _res
if res.HasField("result"):
cmd_result = res.result
ok_cnt += 1
is_received_leader_res = True
if cmd_result.HasField("er"):
cmd_res.ParseFromString(cmd_result.er)
if cmd_result.HasField("error"):
exe_err.inner.ParseFromString(cmd_result.error)
raise exe_err
elif res.HasField("error"):
raise res.error
else:
ok_cnt += 1

if is_received_leader_res and ok_cnt >= super_quorum(len(self.connects)):
logging.info("fast round succeed. propose id: %s", cmd.propose_id)
return (cmd_res, True)

logging.info("fast round failed. propose id: %s", cmd.propose_id)
return (cmd_res, False)

async def slow_round(self, cmd: Command) -> tuple[SyncResponse, CommandResponse]:
"""The slow round of Curp protocol"""

logging.info("slow round start. propose id: %s", cmd.propose_id)

addr = ""
sync_res = SyncResponse()
cmd_res = CommandResponse()
exe_err = CommandSyncError(_CommandSyncError())
after_sync_err = WaitSyncError(_WaitSyncError())

for member in self.connects:
if member.id == self.leader_id:
addr = member.name
break

channel = grpc.aio.insecure_channel(addr)
stub = ProtocolStub(channel)
res = await stub.WaitSynced(WaitSyncedRequest(propose_id=cmd.propose_id))

if res.HasField("success"):
success = res.success
sync_res.ParseFromString(success.after_sync_result)
cmd_res.ParseFromString(success.exe_result)
logging.info("slow round succeed. propose id: %s", cmd.propose_id)
return (sync_res, cmd_res)
if res.HasField("error"):
cmd_sync_err = res.error
if cmd_sync_err.HasField("execute"):
exe_err.inner.ParseFromString(cmd_sync_err.execute)
raise exe_err
if cmd_sync_err.HasField("after_sync"):
after_sync_err.inner.ParseFromString(cmd_sync_err.after_sync)
raise after_sync_err

err_msg = "Response decode error"
raise ResDecodeError(err_msg)


def fetch_cluster(stubs: list[ProtocolStub]) -> FetchClusterResponse:
"""
Fetch cluster from server
TODO: fetch cluster
"""
for stub in stubs:
res: FetchClusterResponse = stub.FetchCluster(FetchClusterRequest())
return res


def super_quorum(nodes: int) -> int:
"""
Get the superquorum for curp protocol
Although curp can proceed with f + 1 available replicas, it needs f + 1 + (f + 1)/2 replicas
(for superquorum of witnesses) to use 1 RTT operations. With less than superquorum replicas,
clients must ask masters to commit operations in f + 1 replicas before returning result.(2 RTTs).
"""
fault_tolerance = nodes // 2
quorum = fault_tolerance + 1
superquorum = fault_tolerance + (quorum // 2) + 1
return superquorum


async def propose_wrapper(stub: ProtocolStub, req: Command) -> ProposeResponse:
"""Wrapper of propose"""
res: ProposeResponse = stub.Propose(ProposeRequest(command=req.SerializeToString()))
return res
2 changes: 1 addition & 1 deletion scripts/quick_start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ stop_all() {
run_container() {
echo container starting
size=${1}
image="ghcr.io/xline-kv/xline:latest"
image="ghcr.io/xline-kv/xline:b573f16"
for ((i = 1; i <= ${size}; i++)); do
docker run -d -it --rm --name=node${i} --net=xline_net --ip=${SERVERS[$i]} --cap-add=NET_ADMIN --cpu-shares=1024 -m=512M -v ${DIR}:/mnt ${image} bash &
done
Expand Down
6 changes: 6 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# SPDX-FileCopyrightText: 2023-present LingKa <[email protected]>
#
# SPDX-License-Identifier: Apache 2.0

import sys

sys.path.append("./api/curp")

sys.path.append("./api/xline")
39 changes: 39 additions & 0 deletions tests/protocol_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Tests for the protocol client."""

import unittest
import uuid

from api.xline.xline_command_pb2 import Command, RequestWithToken
from api.xline.rpc_pb2 import PutRequest
from client.protocol import ProtocolClient


class TestProtocolClient(unittest.TestCase):
"""test protocol client"""

def setUp(self) -> None:
curp_members = ["172.20.0.3:2379", "172.20.0.4:2379", "172.20.0.5:2379"]

cmd = Command(
request=RequestWithToken(
put_request=PutRequest(
key=b"hello",
value=b"py-xline",
)
),
propose_id=f"client-{uuid.uuid4()}",
)

self.cmd = cmd
self.client = ProtocolClient.build_from_addrs(curp_members)

def test_fast_path(self):
"""test fast path"""
er, _ = self.client.propose(self.cmd, True)
self.assertTrue(er.HasField("put_response"))

def test_slow_path(self):
"""test slow path"""
er, asr = self.client.propose(self.cmd, False)
self.assertIsNotNone(asr)
self.assertTrue(er.HasField("put_response"))

0 comments on commit f1edd18

Please sign in to comment.