Skip to content

Commit

Permalink
sec-websocket-protocol, support base64.channel.k8s.io and v4.base64.c…
Browse files Browse the repository at this point in the history
…hannel.k8s.io

Signed-off-by: DrAuYueng <[email protected]>
  • Loading branch information
DrAuYueng committed Sep 4, 2023
1 parent 68d5a14 commit 7a85540
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 13 deletions.
10 changes: 6 additions & 4 deletions kubernetes/base/stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from . import ws_client


def _websocket_request(websocket_request, force_kwargs, api_method, *args, **kwargs):
def _websocket_request(websocket_request, force_kwargs, websocket_headers, api_method, *args, **kwargs):
"""Override the ApiClient.request method with an alternative websocket based
method and call the supplied Kubernetes API method with that in place."""
if force_kwargs:
Expand All @@ -31,11 +31,13 @@ def _websocket_request(websocket_request, force_kwargs, api_method, *args, **kwa
configuration = api_client.config
prev_request = api_client.request
try:
api_client.request = functools.partial(websocket_request, configuration)
api_client.request = functools.partial(websocket_request, configuration, websocket_headers)
return api_method(*args, **kwargs)
finally:
api_client.request = prev_request


stream = functools.partial(_websocket_request, ws_client.websocket_call, None)
portforward = functools.partial(_websocket_request, ws_client.portforward_call, {'_preload_content':False})
stream = functools.partial(_websocket_request, ws_client.websocket_call, None, None)
portforward = functools.partial(_websocket_request, ws_client.portforward_call, {'_preload_content': False}, None)

wsstream = functools.partial(_websocket_request, ws_client.websocket_call, None)
49 changes: 40 additions & 9 deletions kubernetes/base/stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import six
import yaml
import base64

from six.moves.urllib.parse import urlencode, urlparse, urlunparse
from six import StringIO
Expand All @@ -39,6 +40,20 @@
ERROR_CHANNEL = 3
RESIZE_CHANNEL = 4


class _Base64Codec:
def encode(self, data, charset="utf-8"):
if data is None:
return None
b = base64.b64encode(data.encode(charset))
return b.decode(charset)

def decode(self, data):
if data is None:
return None
return base64.b64decode(data).decode()


class _IgnoredIO:
def write(self, _x):
pass
Expand All @@ -65,6 +80,9 @@ def __init__(self, configuration, url, headers, capture_all):
self.sock = create_websocket(configuration, url, headers)
self._connected = True
self._returncode = None
self._base64_codec = None
if headers and headers.get("sec-websocket-protocol") in ["base64.channel.k8s.io", "v4.base64.channel.k8s.io"]:
self._base64_codec = _Base64Codec()

def peek_channel(self, channel, timeout=0):
"""Peek a channel and return part of the input,
Expand Down Expand Up @@ -109,10 +127,14 @@ def write_channel(self, channel, data):
binary = six.PY3 and type(data) == six.binary_type
opcode = ABNF.OPCODE_BINARY if binary else ABNF.OPCODE_TEXT

channel_prefix = chr(channel)
if binary:
channel_prefix = six.binary_type(channel_prefix, "ascii")

if self._base64_codec:
channel_prefix = str(channel)
data = self._base64_codec.encode(data)
else:
channel_prefix = chr(channel)
if binary:
channel_prefix = six.binary_type(channel_prefix, "ascii")

payload = channel_prefix + data
self.sock.send(payload, opcode=opcode)

Expand Down Expand Up @@ -200,8 +222,13 @@ def update(self, timeout=0):
if six.PY3:
data = data.decode("utf-8", "replace")
if len(data) > 1:
channel = ord(data[0])
channel = data[0]
data = data[1:]
if self._base64_codec:
channel = int(channel)
data = self._base64_codec.decode(data)
else:
channel = ord(channel)
if data:
if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]:
# keeping all messages in the order they received
Expand Down Expand Up @@ -508,13 +535,15 @@ def websocket_proxycare(connect_opt, configuration, url, headers):
return(connect_opt)


def websocket_call(configuration, _method, url, **kwargs):
def websocket_call(configuration, websocket_headers, _method, url, **kwargs):
"""An internal function to be called in api-client when a websocket
connection is required. method, url, and kwargs are the parameters of
apiClient.request method."""

url = get_websocket_url(url, kwargs.get("query_params"))
headers = kwargs.get("headers")
headers = kwargs.get("headers", {})
if websocket_headers:
headers.update(websocket_headers)
_request_timeout = kwargs.get("_request_timeout", 60)
_preload_content = kwargs.get("_preload_content", True)
capture_all = kwargs.get("capture_all", True)
Expand All @@ -529,7 +558,7 @@ def websocket_call(configuration, _method, url, **kwargs):
raise ApiException(status=0, reason=str(e))


def portforward_call(configuration, _method, url, **kwargs):
def portforward_call(configuration, websocket_headers, _method, url, **kwargs):
"""An internal function to be called in api-client when a websocket
connection is required for port forwarding. args and kwargs are the
parameters of apiClient.request method."""
Expand All @@ -553,7 +582,9 @@ def portforward_call(configuration, _method, url, **kwargs):
raise ApiValueError("Missing required parameter `ports`")

url = get_websocket_url(url, query_params)
headers = kwargs.get("headers")
headers = kwargs.get("headers", {})
if websocket_headers:
headers.update(websocket_headers)

try:
websocket = create_websocket(configuration, url, headers)
Expand Down
99 changes: 99 additions & 0 deletions kubernetes/e2e_test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from kubernetes.client.api import core_v1_api
from kubernetes.e2e_test import base
from kubernetes.stream import stream, portforward
from kubernetes.stream.stream import wsstream
from kubernetes.stream.ws_client import ERROR_CHANNEL
from kubernetes.client.rest import ApiException

Expand All @@ -51,6 +52,7 @@ def manifest_with_command(name, command):
'spec': {
'containers': [{
'image': 'busybox',
'imagePullPolicy': 'IfNotPresent',
'name': 'sleep',
"args": [
"/bin/sh",
Expand Down Expand Up @@ -160,6 +162,103 @@ def test_pod_apis(self):
resp = api.delete_namespaced_pod(name=name, body={},
namespace='default')

def test_pod_apis_with_selected_websocket_protocol(self):
client = api_client.ApiClient(configuration=self.config)
api = core_v1_api.CoreV1Api(client)

name = 'busybox-test-' + short_uuid()
pod_manifest = manifest_with_command(
name, "while true;do date;sleep 5; done")

# wait for the default service account to be created
timeout = time.time() + 30
while True:
if time.time() > timeout:
print('timeout waiting for default service account creation')
break
try:
resp = api.read_namespaced_service_account(name='default',
namespace='default')
except ApiException as e:
if (six.PY3 and e.status != HTTPStatus.NOT_FOUND) or (
six.PY3 is False and e.status != httplib.NOT_FOUND):
print('error: %s' % e)
self.fail(
msg="unexpected error getting default service account")
print('default service not found yet: %s' % e)
time.sleep(1)
continue
self.assertEqual('default', resp.metadata.name)
break

resp = api.create_namespaced_pod(body=pod_manifest,
namespace='default')
self.assertEqual(name, resp.metadata.name)
self.assertTrue(resp.status.phase)

while True:
resp = api.read_namespaced_pod(name=name,
namespace='default')
self.assertEqual(name, resp.metadata.name)
self.assertTrue(resp.status.phase)
if resp.status.phase != 'Pending':
break
time.sleep(1)

exec_command = ['/bin/sh',
'-c',
'for i in $(seq 1 3); do date; done']
ws_header = {'sec-websocket-protocol': 'v4.base64.channel.k8s.io'}
resp = wsstream(ws_header, api.connect_get_namespaced_pod_exec,
name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
self.assertEqual(3, len(resp.splitlines()))

exec_command = 'uptime'
resp = wsstream(ws_header, api.connect_post_namespaced_pod_exec,
name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
self.assertEqual(1, len(resp.splitlines()))

resp = wsstream(ws_header, api.connect_post_namespaced_pod_exec,
name, 'default',
command='/bin/sh',
stderr=True, stdin=True,
stdout=True, tty=False,
_preload_content=False)

resp.write_stdin("echo test string 1\n")
line = resp.readline_stdout(timeout=5)
self.assertFalse(resp.peek_stderr())
self.assertEqual("test string 1", line)
resp.write_stdin("echo test string 2 >&2\n")
line = resp.readline_stderr(timeout=5)
self.assertFalse(resp.peek_stdout())
self.assertEqual("test string 2", line)
resp.write_stdin("exit\n")
resp.update(timeout=5)
while True:
line = resp.read_channel(ERROR_CHANNEL)
if line != '':
break
time.sleep(1)
status = json.loads(line)
self.assertEqual(status['status'], 'Success')
resp.update(timeout=5)
self.assertFalse(resp.is_open())

number_of_pods = len(api.list_pod_for_all_namespaces().items)
self.assertTrue(number_of_pods > 0)

resp = api.delete_namespaced_pod(name=name, body={},
namespace='default')

def test_exit_code(self):
client = api_client.ApiClient(configuration=self.config)
api = core_v1_api.CoreV1Api(client)
Expand Down

0 comments on commit 7a85540

Please sign in to comment.