Skip to content

Commit

Permalink
feat: add lock client
Browse files Browse the repository at this point in the history
Signed-off-by: LingKa <[email protected]>
  • Loading branch information
LingKa28 committed Dec 14, 2023
1 parent 55d290a commit 16db03b
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 4 deletions.
11 changes: 9 additions & 2 deletions client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from client.kv import KvClient
from client.lease import LeaseClient, LeaseIdGenerator
from client.watch import WatchClient
from client.lock import LockClient
from client.auth import AuthClient


Expand All @@ -17,18 +18,23 @@ class Client:
kv_client: Kv client
lease_client: Lease client
watch_client: Watch client
lock_client: Lock client
auth_client: Auth client
"""

kv_client: KvClient
lease_client: LeaseClient
watch_client: WatchClient
lock_client: LockClient
auth_client: AuthClient

def __init__(self, kv: KvClient, lease: LeaseClient, watch: WatchClient, auth: AuthClient) -> None:
def __init__(
self, kv: KvClient, lease: LeaseClient, watch: WatchClient, lock: LockClient, auth: AuthClient
) -> None:
self.kv_client = kv
self.lease_client = lease
self.watch_client = watch
self.lock_client = lock
self.auth_client = auth

@classmethod
Expand All @@ -45,6 +51,7 @@ async def connect(cls, addrs: list[str]) -> Client:
kv_client = KvClient("client", protocol_client, "")
lease_client = LeaseClient("client", protocol_client, channel, "", id_gen)
watch_client = WatchClient(channel)
lock_client = LockClient("client", protocol_client, channel, "", id_gen)
auth_client = AuthClient("client", protocol_client, channel, "")

return cls(kv_client, lease_client, watch_client, auth_client)
return cls(kv_client, lease_client, watch_client, lock_client, auth_client)
236 changes: 236 additions & 0 deletions client/lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""Lock Client"""

import uuid
from urllib import parse
from typing import Optional
from grpc import Channel
from client.protocol import ProtocolClient as CurpClient
from client.lease import LeaseClient, LeaseIdGenerator
from client.watch import WatchClient
from api.xline.xline_command_pb2 import Command, RequestWithToken, CommandResponse, SyncResponse
from api.xline.v3lock_pb2 import LockRequest as _LockRequest, LockResponse, UnlockResponse
from api.xline.kv_pb2 import Event
from api.xline.rpc_pb2 import (
PutRequest,
RangeRequest,
LeaseGrantRequest,
TxnRequest,
Compare,
RequestOp,
ResponseHeader,
WatchCreateRequest,
DeleteRangeRequest,
)


class LockRequest:
"""
Request for `Lock`
Attributes:
inner: The inner request.
ttl: The ttl of the lease that attached to the lock.
"""

inner: _LockRequest
ttl: Optional[int]

def __init__(self, req: _LockRequest, ttl: int = 60) -> None:
self.inner = req
self.ttl = ttl

Check warning on line 40 in client/lock.py

View check run for this annotation

Codecov / codecov/patch

client/lock.py#L39-L40

Added lines #L39 - L40 were not covered by tests


class LockClient:
"""
Client for Lock operations.
Attributes:
name: Name of the LockClient.
curp_client: The client running the CURP protocol, communicate with all servers.
lease_client: The lease client.
watch_client: The watch client.
token: Auth token
"""

name: str
curp_client: CurpClient
lease_client: LeaseClient
watch_client: WatchClient
token: Optional[str]

def __init__(
self, name: str, curp_client: CurpClient, channel: Channel, token: str, id_gen: LeaseIdGenerator
) -> None:
self.name = name
self.curp_client = curp_client
self.lease_client = LeaseClient(name=name, curp_client=curp_client, channel=channel, token=token, id_gen=id_gen)
self.watch_client = WatchClient(channel=channel)
self.token = token

async def lock(self, name: bytes, lease_id: int = 0, ttl: Optional[int] = None) -> LockResponse:
"""
Acquires a distributed shared lock on a given named lock.
On success, it will return a unique key that exists so long as the
lock is held by the caller. This key can be used in conjunction with
transactions to safely ensure updates to Xline only occur while holding
lock ownership. The lock is held until Unlock is called on the key or the
lease associate with the owner expires.
"""
if lease_id == 0:
lease_res = await self.lease_client.grant(LeaseGrantRequest(TTL=ttl))
lease_id = lease_res.ID

prefix = f"{parse.quote(name)}/"
key = f"{prefix}{lease_id:x}"
res = await self.lock_inner(prefix, key, lease_id)

return res

async def unlock(self, key: bytes) -> UnlockResponse:
"""
Takes a key returned by Lock and releases the hold on lock. The
next Lock caller waiting for the lock will then be woken up and given
ownership of the lock.
"""
header = await self.delete_key(key)
return UnlockResponse(header=header)

async def lock_inner(self, prefix: str, key: str, lease_id: int) -> LockResponse:
"""
The inner lock logic
"""
txn = self.create_acquire_txn(prefix, lease_id)
req = RequestWithToken(token=self.token, txn_request=txn)
er, asr = await self.propose(req, False)

txn_res = er.txn_response
if asr is None:
msg = "sync_res always has value when use slow path"
raise Exception(msg)

Check warning on line 109 in client/lock.py

View check run for this annotation

Codecov / codecov/patch

client/lock.py#L108-L109

Added lines #L108 - L109 were not covered by tests
my_rev = asr.revision
owner_res = txn_res.responses[1].response_range
owner_key = owner_res.kvs

header = ResponseHeader()
if len(owner_key) > 0 and owner_key[0].create_revision == my_rev:
header = owner_res.header

Check warning on line 116 in client/lock.py

View check run for this annotation

Codecov / codecov/patch

client/lock.py#L116

Added line #L116 was not covered by tests
else:
await self.wait_delete(prefix, my_rev)
range_req = RangeRequest(key=key.encode())
req = RequestWithToken(token=self.token, range_request=range_req)
try:
er, _ = await self.propose(req, True)
range_res = er.range_response
if len(range_res.kvs) == 0:
msg = "rpc error session expired"
raise Exception(msg)

Check warning on line 126 in client/lock.py

View check run for this annotation

Codecov / codecov/patch

client/lock.py#L125-L126

Added lines #L125 - L126 were not covered by tests
header = range_res.header
except Exception:
await self.delete_key(key.encode())

Check warning on line 129 in client/lock.py

View check run for this annotation

Codecov / codecov/patch

client/lock.py#L128-L129

Added lines #L128 - L129 were not covered by tests

return LockResponse(header=header, key=key.encode())

def create_acquire_txn(self, prefix: str, lease_id: int) -> TxnRequest:
"""
Create txn for try acquire lock
"""
key = f"{prefix}{lease_id:x}"
cmp = Compare(
result=Compare.CompareResult.EQUAL, target=Compare.CompareTarget.CREATE, key=key.encode(), range_end=b""
)
put = RequestOp(request_put=PutRequest(key=key.encode(), value=b"", lease=lease_id))
get = RequestOp(request_range=RangeRequest(key=key.encode()))
range_end = self.get_prefix(key.encode())
get_owner = RequestOp(
request_range=RangeRequest(
key=prefix.encode(),
range_end=range_end,
sort_order=RangeRequest.SortOrder.ASCEND,
sort_target=RangeRequest.SortTarget.CREATE,
limit=1,
)
)
return TxnRequest(compare=[cmp], success=[put, get_owner], failure=[get, get_owner])

def get_prefix(self, key: bytes) -> bytes:
"""Get prefix"""
MAX_VALUE = 255
end = list(key)
i = len(end) - 1
while i >= 0:
if end[i] < MAX_VALUE:
end[i] = (end[i] + 1) % 256
del end[i + 1 :]
return bytes(end)
i -= 1
return bytes([0])

Check warning on line 166 in client/lock.py

View check run for this annotation

Codecov / codecov/patch

client/lock.py#L165-L166

Added lines #L165 - L166 were not covered by tests

async def propose(self, req: RequestWithToken, use_fast_path: bool) -> tuple[CommandResponse, SyncResponse | None]:
"""
Send request using fast path.
"""
propose_id = self.generate_propose_id()
cmd = Command(request=req, propose_id=propose_id)

if use_fast_path:
res = await self.curp_client.propose(cmd, True)
return res
else:
res = await self.curp_client.propose(cmd, False)
if res[1] is None:
msg = "syncResp is always Some when useFastPath is false"
raise Exception(msg)

Check warning on line 182 in client/lock.py

View check run for this annotation

Codecov / codecov/patch

client/lock.py#L181-L182

Added lines #L181 - L182 were not covered by tests
return res

def generate_propose_id(self) -> str:
"""Generate propose id with the given prefix."""
propose_id = f"{self.name}-{uuid.uuid4()}"
return propose_id

async def wait_delete(self, pfx: str, my_rev: int) -> None:
"""
Wait until last key deleted.
"""
rev = my_rev - 1
while True:
range_end = self.get_prefix(pfx.encode())
get_req = RangeRequest(
key=pfx.encode(),
range_end=range_end,
sort_order=RangeRequest.SortOrder.DESCEND,
sort_target=RangeRequest.SortTarget.CREATE,
max_create_revision=rev,
)
req = RequestWithToken(token=self.token, range_request=get_req)

er, _ = await self.propose(req, False)
range_res = er.range_response

last_key: bytes = b""
if len(range_res.kvs) > 0:
last_key = range_res.kvs[0].key
else:
return

reps, watcher_id = self.watch_client.watch(WatchCreateRequest(key=last_key))
async for res in reps:
watch_id = res.watch_id
f = False
for e in res.events:
if e.type == Event.DELETE:
self.watch_client.cancel(watcher_id, watch_id)
f = True
break
if f:
break

async def delete_key(self, key: bytes) -> ResponseHeader:
"""
Delete key.
"""
del_req = DeleteRangeRequest(key=key, range_end=b"\0")
req = RequestWithToken(token=self.token, delete_range_request=del_req)

er, _ = await self.propose(req, True)
del_res = er.delete_range_response
return del_res.header
6 changes: 4 additions & 2 deletions client/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import uuid
import asyncio
from typing import AsyncIterable
from grpc import Channel
from grpc.aio import StreamStreamCall
from api.xline.rpc_pb2_grpc import WatchStub
from api.xline.rpc_pb2 import (
WatchRequest,
WatchCreateRequest,
WatchCancelRequest,
WatchResponse,
)


Expand Down Expand Up @@ -43,7 +45,7 @@ async def watch():
yield WatchRequest(create_request=req)

while not self.is_cancel:
await asyncio.sleep(0.5)
await asyncio.sleep(0.2)

yield WatchRequest(cancel_request=WatchCancelRequest(watch_id=self.watch_id))

Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(self, channel: Channel) -> None:
self.watch_client = WatchStub(channel=channel)
self.watchers = {}

def watch(self, req: WatchCreateRequest) -> tuple[StreamStreamCall, str]:
def watch(self, req: WatchCreateRequest) -> tuple[AsyncIterable[WatchResponse], str]:
"""
Create Watcher to watch
"""
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,6 @@ exclude_lines = [
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]

[tool.hatch.build.targets.wheel]
packages = ["client"]
Loading

0 comments on commit 16db03b

Please sign in to comment.