Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support for configure magic on Spark Python Kubernetes Kernels (WIP) #1105

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
176 changes: 175 additions & 1 deletion enterprise_gateway/services/kernels/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@
"""Tornado handlers for kernel CRUD and communication."""
import json
import os
from datetime import datetime, timezone
from functools import partial

import jupyter_server.services.kernels.handlers as jupyter_server_handlers
import tornado
from jupyter_client.jsonutil import date_default
from jupyter_server.base.handlers import APIHandler
from tornado import web

try:
from jupyter_client.jsonutil import json_default
except ImportError:
from jupyter_client.jsonutil import date_default as json_default

from ...mixins import CORSMixin, JSONErrorsMixin, TokenAuthorizationMixin


Expand Down Expand Up @@ -146,11 +153,178 @@ def get(self, kernel_id):
self.finish(json.dumps(model, default=date_default))


default_handlers = []
class ConfigureMagicHandler(CORSMixin, JSONErrorsMixin, APIHandler):
@web.authenticated
async def post(self, kernel_id):
self.log.info(f"Update request received for kernel: {kernel_id}")
km = self.kernel_manager
km.check_kernel_id(kernel_id)
payload = self.get_json_body()
self.log.debug(f"Request payload: {payload}")
if payload is None:
self.log.info("Empty payload in the request body.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These info messages aren't necessary since the message returned to the client will indicate where it came from.

self.finish(
json.dumps(
{"message": "Empty payload received. No operation performed on the Kernel."},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{"message": "Empty payload received. No operation performed on the Kernel."},
{"message": f"Empty payload received. No operation performed on kernel: {kernel_id}"},

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{"message": "Empty payload received. No operation performed on the Kernel."},
{"message": f"Empty payload received. No operation performed on kernel: {kernel_id}"},

default=date_default,
)
)
return
if type(payload) != dict:
self.log.info("Payload is not in acceptable format.")
raise web.HTTPError(400, "Invalid JSON payload received.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same goes for these info messages that precede exceptions. I believe EG will be logging these exceptions.

Could you please append something like f" for kernel: {kernel_id}" to the end? As EG becomes more multi-tenant, its important we try to include some unique identifier in messages wherever possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same goes for these info messages that precede exceptions. I believe EG will be logging these exceptions.

Could you please append something like f" for kernel: {kernel_id}" to the end? As EG becomes more multi-tenant, its important we try to include some unique identifier in messages wherever possible.

if payload.get("env", None) is None: # We only allow env field for now.
self.log.info("Payload is missing the required env field.")
raise web.HTTPError(400, "Missing required field `env` in payload.")
kernel = km.get_kernel(kernel_id)
if kernel.restarting: # handle duplicate request.
self.log.info(
"An existing restart request is still in progress. Skipping this request."
)
raise web.HTTPError(
400, "Duplicate Kernel update request received for Id: %s." % kernel_id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
400, "Duplicate Kernel update request received for Id: %s." % kernel_id
400, "Duplicate configure kernel request received for kernel: %s." % kernel_id

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
400, "Duplicate Kernel update request received for Id: %s." % kernel_id
400, "Duplicate configure kernel request received for kernel: %s." % kernel_id

)
try:
# update Kernel metadata
kernel.set_user_extra_overrides(payload)
await km.restart_kernel(kernel_id)
kernel.fire_kernel_event_callbacks(
event="kernel_refresh", zmq_messages=payload.get("zmq_messages", {})
)
except web.HTTPError as he:
self.log.exception(
f"HTTPError exception occurred in refreshing kernel: {kernel_id}: {he}"
)
await km.shutdown_kernel(kernel_id)
kernel.fire_kernel_event_callbacks(
event="kernel_refresh_failure", zmq_messages=payload.get("zmq_messages", {})
)
raise he
except Exception as e:
self.log.exception(f"An exception occurred in updating kernel : {kernel_id}: {e}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.log.exception(f"An exception occurred in updating kernel : {kernel_id}: {e}")
self.log.exception(f"An exception occurred in re-configuring kernel: {kernel_id}: {e}")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.log.exception(f"An exception occurred in updating kernel : {kernel_id}: {e}")
self.log.exception(f"An exception occurred in re-configuring kernel: {kernel_id}: {e}")

await km.shutdown_kernel(kernel_id)
kernel.fire_kernel_event_callbacks(
event="kernel_refresh_failure", zmq_messages=payload.get("zmq_messages", {})
)
raise web.HTTPError(
500,
"Error occurred while refreshing Kernel: %s." % kernel_id,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Error occurred while refreshing Kernel: %s." % kernel_id,
"Error occurred while refreshing kernel: %s." % kernel_id,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Error occurred while refreshing Kernel: %s." % kernel_id,
"Error occurred while refreshing kernel: %s." % kernel_id,

reason=f"{e}",
)
else:
response_body = {"message": f"Successfully refreshed kernel with id: {kernel_id}"}
self.finish(json.dumps(response_body, default=date_default))
return


class RemoteZMQChannelsHandler(
TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, jupyter_server_handlers.ZMQChannelsHandler
):
def open(self, kernel_id):
self.log.info(f"Websocket open request received for kernel: {kernel_id}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

info -> debug

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

info -> debug

super().open(kernel_id)
km = self.kernel_manager
km.add_kernel_event_callbacks(kernel_id, self.on_kernel_refresh, "kernel_refresh")
km.add_kernel_event_callbacks(
kernel_id, self.on_kernel_refresh_failure, "kernel_refresh_failure"
)

def on_kernel_refresh(self, **kwargs):
self.log.info("Refreshing the client websocket to kernel connection.")
self.refresh_zmq_sockets()
zmq_messages = kwargs.get("zmq_messages", {})
if "stream_reply" in zmq_messages:
self.log.debug("Sending stream_reply success message.")
success_message = zmq_messages.get("stream_reply")
success_message["content"] = {
"name": "stdout",
"text": "The Kernel is successfully refreshed.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is throughout. Please lowercase kernel unless part of a name (e.g., KernelManager).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad. will change is all places.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is throughout. Please lowercase kernel unless part of a name (e.g., KernelManager).

}
self._send_ws_message(success_message)
if "exec_reply" in zmq_messages:
self.log.debug("Sending exec_reply message.")
self._send_ws_message(zmq_messages.get("exec_reply"))
if "idle_reply" in zmq_messages:
self.log.debug("Sending idle_reply message.")
self._send_ws_message(zmq_messages.get("idle_reply"))
self._send_status_message(
"kernel_refreshed"
) # In the future, UI clients might start to consume this.

def on_kernel_refresh_failure(self, **kwargs):
self.log.error("Kernel %s refresh failed!", self.kernel_id)
zmq_messages = kwargs.get("zmq_messages", {})
if "error_reply" in zmq_messages:
self.log.debug("Sending stream_reply error message.")
error_message = zmq_messages.get("error_reply")
error_message["content"] = {
"ename": "KernelRefreshFailed",
"evalue": "The Kernel refresh operation failed.",
"traceback": ["The Kernel refresh operation failed."],
}
self._send_ws_message(error_message)
if "exec_reply" in zmq_messages:
self.log.debug("Sending exec_reply message.")
exec_reply = zmq_messages.get("exec_reply").copy()
if "metadata" in exec_reply:
exec_reply["metadata"]["status"] = "error"
exec_reply["content"]["status"] = "error"
exec_reply["content"]["ename"] = "KernelRefreshFailed."
exec_reply["content"]["evalue"] = "The Kernel refresh operation failed."
exec_reply["content"]["traceback"] = ["The Kernel refresh operation failed."]
self._send_ws_message(exec_reply)
if "idle_reply" in zmq_messages:
self.log.info("Sending idle reply message.")
self._send_ws_message(zmq_messages.get("idle_reply"))
self.log.debug("sending kernel dead message.")
self._send_status_message("dead")

def refresh_zmq_sockets(self):
self.close_existing_streams()
kernel = self.kernel_manager.get_kernel(self.kernel_id)
self.session.key = kernel.session.key # refresh the session key
self.log.debug("Creating new ZMQ Socket streams.")
self.create_stream()
for channel, stream in self.channels.items():
self.log.debug(f"Updating channel: {channel}")
stream.on_recv_stream(self._on_zmq_reply)

def close_existing_streams(self):
self.log.debug(f"Closing existing channels for kernel: {self.kernel_id}")
for channel, stream in self.channels.items():
if stream is not None and not stream.closed():
self.log.debug(f"Close channel : {channel}")
stream.on_recv(None)
stream.close()
self.channels = {}

def _send_ws_message(self, kernel_msg):
self.log.debug(f"Sending websocket message: {kernel_msg}")
if "header" in kernel_msg and type(kernel_msg["header"] == dict):
kernel_msg["header"]["date"] = datetime.utcnow().replace(tzinfo=timezone.utc)
self.write_message(json.dumps(kernel_msg, default=json_default))

def on_close(self):
self.log.info(f"Websocket close request received for kernel: {self.kernel_id}")
super().on_close()
self.kernel_manager.remove_kernel_event_callbacks(
self.kernel_id, self.on_kernel_refresh, "kernel_refresh"
)
self.kernel_manager.remove_kernel_event_callbacks(
self.kernel_id, self.on_kernel_refresh_failure, "kernel_refresh_failure"
)


_kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
default_handlers = [(r"/api/configure/%s" % _kernel_id_regex, ConfigureMagicHandler)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like it should be r"/api/kernels/configure/%s". Is there an issue with placing it there? I don't see that conflicting with the patterns of the default handlers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure..I can make the change.
I am also having a second thought to call this kernel Refresh API:
r"/api/kernels/refresh/%s" ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like it should be r"/api/kernels/configure/%s". Is there an issue with placing it there? I don't see that conflicting with the patterns of the default handlers.

for path, cls in jupyter_server_handlers.default_handlers:
if cls.__name__ in globals():
# Use the same named class from here if it exists
default_handlers.append((path, globals()[cls.__name__]))
elif (
cls.__name__ == jupyter_server_handlers.ZMQChannelsHandler.__name__
): # TODO: check for override.
default_handlers.append((path, RemoteZMQChannelsHandler))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this meant to replace ZMQChannelsHandler? I guess I don't understand why ZMQChannelsHandler isn't satisfied by the first condition - but I'm not that familiar with globals() (sorry).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to discuss this further.
what I am trying do here is replace the ZMQChannelsHandler with RemoteZMQChannelsHandler for handling the channels requests.

I tried to re-use the same class name on EG but was facing some issue where websocket connection was failing.

else:
# Gen a new type with CORS and token auth
bases = (TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, cls)
Expand Down
Loading