diff --git a/enterprise_gateway/services/kernels/handlers.py b/enterprise_gateway/services/kernels/handlers.py index 35a07ec8a..73e280d0a 100644 --- a/enterprise_gateway/services/kernels/handlers.py +++ b/enterprise_gateway/services/kernels/handlers.py @@ -3,13 +3,20 @@ """Tornado handlers for kernel CRUD and communication.""" import json import os +from datetime import datetime, timezone from functools import partial import jupyter_server.services.kernels.handlers as jupyter_server_handlers import tornado from jupyter_client.jsonutil import date_default +from jupyter_server.base.handlers import APIHandler from tornado import web +try: + from jupyter_client.jsonutil import json_default +except ImportError: + from jupyter_client.jsonutil import date_default as json_default + from ...mixins import CORSMixin, JSONErrorsMixin, TokenAuthorizationMixin @@ -146,11 +153,179 @@ def get(self, kernel_id): self.finish(json.dumps(model, default=date_default)) -default_handlers = [] +class ConfigureMagicHandler(CORSMixin, JSONErrorsMixin, APIHandler): + @web.authenticated + async def post(self, kernel_id): + self.log.info(f"Update request received for kernel: {kernel_id}") + km = self.kernel_manager + km.check_kernel_id(kernel_id) + payload = self.get_json_body() + self.log.debug(f"Request payload: {payload}") + if payload is None: + self.finish( + json.dumps( + { + "message": f"Empty payload received. No operation performed on kernel: {kernel_id}" + }, + default=date_default, + ) + ) + return + if type(payload) != dict: + raise web.HTTPError(400, f"Invalid JSON payload received for kernel: {kernel_id}.") + if payload.get("env", None) is None: # We only allow env field for now. + raise web.HTTPError( + 400, "Missing required field `env` in payload for kernel: {kernel_id}." + ) + kernel = km.get_kernel(kernel_id) + if kernel.restarting: # handle duplicate request. + self.log.info( + "An existing restart request is still in progress. Skipping this request." + ) + raise web.HTTPError( + 400, f"Duplicate configure kernel request received for kernel: {kernel_id}." + ) + try: + # update Kernel metadata + kernel.set_user_extra_overrides(payload) + await km.restart_kernel(kernel_id) + kernel.fire_kernel_event_callbacks( + event="kernel_refresh", zmq_messages=payload.get("zmq_messages", {}) + ) + except web.HTTPError as he: + self.log.exception( + f"HTTPError exception occurred while re-configuring kernel: {kernel_id}: {he}" + ) + await km.shutdown_kernel(kernel_id) + kernel.fire_kernel_event_callbacks( + event="kernel_refresh_failure", zmq_messages=payload.get("zmq_messages", {}) + ) + raise he + except Exception as e: + self.log.exception( + f"An exception occurred while re-configuring kernel: {kernel_id}: {e}" + ) + await km.shutdown_kernel(kernel_id) + kernel.fire_kernel_event_callbacks( + event="kernel_refresh_failure", zmq_messages=payload.get("zmq_messages", {}) + ) + raise web.HTTPError( + 500, + f"Error occurred while re-configuring kernel: {kernel_id}", + reason=f"{e}", + ) + else: + response_body = {"message": f"Successfully re-configured kernel: {kernel_id}."} + self.finish(json.dumps(response_body, default=date_default)) + return + + +class RemoteZMQChannelsHandler( + TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, jupyter_server_handlers.ZMQChannelsHandler +): + def open(self, kernel_id): + self.log.debug(f"Websocket open request received for kernel: {kernel_id}") + super().open(kernel_id) + km = self.kernel_manager + km.add_kernel_event_callbacks(kernel_id, self.on_kernel_refresh, "kernel_refresh") + km.add_kernel_event_callbacks( + kernel_id, self.on_kernel_refresh_failure, "kernel_refresh_failure" + ) + + def on_kernel_refresh(self, **kwargs): + self.log.info("Refreshing the client websocket to kernel connection.") + self.refresh_zmq_sockets() + zmq_messages = kwargs.get("zmq_messages", {}) + if "stream_reply" in zmq_messages: + self.log.debug("Sending stream_reply success message.") + success_message = zmq_messages.get("stream_reply") + success_message["content"] = { + "name": "stdout", + "text": "The kernel is successfully refreshed.", + } + self._send_ws_message(success_message) + if "exec_reply" in zmq_messages: + self.log.debug("Sending exec_reply message.") + self._send_ws_message(zmq_messages.get("exec_reply")) + if "idle_reply" in zmq_messages: + self.log.debug("Sending idle_reply message.") + self._send_ws_message(zmq_messages.get("idle_reply")) + self._send_status_message( + "kernel_refreshed" + ) # In the future, UI clients might start to consume this. + + def on_kernel_refresh_failure(self, **kwargs): + self.log.error("kernel %s refresh failed!", self.kernel_id) + zmq_messages = kwargs.get("zmq_messages", {}) + if "error_reply" in zmq_messages: + self.log.debug("Sending stream_reply error message.") + error_message = zmq_messages.get("error_reply") + error_message["content"] = { + "ename": "KernelRefreshFailed", + "evalue": "The kernel refresh operation failed.", + "traceback": ["The kernel refresh operation failed."], + } + self._send_ws_message(error_message) + if "exec_reply" in zmq_messages: + self.log.debug("Sending exec_reply message.") + exec_reply = zmq_messages.get("exec_reply").copy() + if "metadata" in exec_reply: + exec_reply["metadata"]["status"] = "error" + exec_reply["content"]["status"] = "error" + exec_reply["content"]["ename"] = "KernelRefreshFailed." + exec_reply["content"]["evalue"] = "The kernel refresh operation failed." + exec_reply["content"]["traceback"] = ["The kernel refresh operation failed."] + self._send_ws_message(exec_reply) + if "idle_reply" in zmq_messages: + self.log.info("Sending idle reply message.") + self._send_ws_message(zmq_messages.get("idle_reply")) + self.log.debug("sending kernel dead message.") + self._send_status_message("dead") + + def refresh_zmq_sockets(self): + self.close_existing_streams() + kernel = self.kernel_manager.get_kernel(self.kernel_id) + self.session.key = kernel.session.key # refresh the session key + self.log.debug("Creating new ZMQ Socket streams.") + self.create_stream() + for channel, stream in self.channels.items(): + self.log.debug(f"Updating channel: {channel}") + stream.on_recv_stream(self._on_zmq_reply) + + def close_existing_streams(self): + self.log.debug(f"Closing existing channels for kernel: {self.kernel_id}") + for channel, stream in self.channels.items(): + if stream is not None and not stream.closed(): + self.log.debug(f"Close channel : {channel}") + stream.on_recv(None) + stream.close() + self.channels = {} + + def _send_ws_message(self, kernel_msg): + self.log.debug(f"Sending websocket message: {kernel_msg}") + if "header" in kernel_msg and type(kernel_msg["header"] == dict): + kernel_msg["header"]["date"] = datetime.utcnow().replace(tzinfo=timezone.utc) + self.write_message(json.dumps(kernel_msg, default=json_default)) + + def on_close(self): + self.log.info(f"Websocket close request received for kernel: {self.kernel_id}") + super().on_close() + self.kernel_manager.remove_kernel_event_callbacks( + self.kernel_id, self.on_kernel_refresh, "kernel_refresh" + ) + self.kernel_manager.remove_kernel_event_callbacks( + self.kernel_id, self.on_kernel_refresh_failure, "kernel_refresh_failure" + ) + + +_kernel_id_regex = r"(?P\w+-\w+-\w+-\w+-\w+)" +default_handlers = [(r"/api/kernels/configure/%s" % _kernel_id_regex, ConfigureMagicHandler)] for path, cls in jupyter_server_handlers.default_handlers: if cls.__name__ in globals(): # Use the same named class from here if it exists default_handlers.append((path, globals()[cls.__name__])) + elif cls.__name__ == jupyter_server_handlers.ZMQChannelsHandler.__name__: + default_handlers.append((path, RemoteZMQChannelsHandler)) else: # Gen a new type with CORS and token auth bases = (TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, cls) diff --git a/enterprise_gateway/services/kernels/remotemanager.py b/enterprise_gateway/services/kernels/remotemanager.py index 0971d3f2c..e6abeb822 100644 --- a/enterprise_gateway/services/kernels/remotemanager.py +++ b/enterprise_gateway/services/kernels/remotemanager.py @@ -9,9 +9,10 @@ import uuid from jupyter_client.ioloop.manager import AsyncIOLoopKernelManager +from jupyter_client.multikernelmanager import kernel_method from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager from tornado import web -from traitlets import directional_link +from traitlets import Dict, directional_link from traitlets import log as traitlets_log from enterprise_gateway.mixins import EnterpriseGatewayConfigMixin @@ -369,6 +370,14 @@ def new_kernel_id(self, **kwargs): return new_kernel_id(kernel_id_fn=super().new_kernel_id, log=self.log, **kwargs) + @kernel_method + def add_kernel_event_callbacks(self, kernel_id, callback, event="kernel_refresh"): + """Add kernel related events.""" + + @kernel_method + def remove_kernel_event_callbacks(self, kernel_id, callback, event="kernel_refresh"): + """Remove event..""" + class RemoteKernelManager(EnterpriseGatewayConfigMixin, AsyncIOLoopKernelManager): """ @@ -378,6 +387,13 @@ class RemoteKernelManager(EnterpriseGatewayConfigMixin, AsyncIOLoopKernelManager returned - upon which methods of poll(), wait(), send_signal(), and kill() can be called. """ + event_callbacks = Dict() + + def _event_callbacks_default(self): + return dict( + kernel_refresh=[], kernel_refresh_failure=[] + ) # define new default when adding new event. + def __init__(self, **kwargs): super().__init__(**kwargs) self.process_proxy = None @@ -385,7 +401,8 @@ def __init__(self, **kwargs): self.public_key = None self.sigint_value = None self.kernel_id = None - self.user_overrides = {} + self.user_overrides = {} # this is populated via create kernel request. + self.configure_kernel_overrides = {} # this is populated via configure kernel request. self.kernel_launch_timeout = default_kernel_launch_timeout self.restarting = False # need to track whether we're in a restart situation or not @@ -462,15 +479,23 @@ def _capture_user_overrides(self, **kwargs): self.kernel_launch_timeout = float( env.get("KERNEL_LAUNCH_TIMEOUT", default_kernel_launch_timeout) ) - self.user_overrides.update( - { - key: value - for key, value in env.items() - if key.startswith("KERNEL_") - or key in self.env_process_whitelist - or key in self.env_whitelist - } - ) + # kwargs['env'] gets updated with each kernel start / restart. + # user_overrides preserve the original envs with which the kernel is started. + if not self.user_overrides: + self.user_overrides.update( + { + key: value + for key, value in env.items() + if key.startswith("KERNEL_") + or key in self.env_process_whitelist + or key in self.env_whitelist + } + ) + extra_env = self._capture_user_update_overrides(**kwargs) + env.update( + self.user_overrides + ) # this is required to refresh the env variables present in kernel spec file. + env.update(extra_env) def format_kernel_cmd(self, extra_arguments=None): """ @@ -507,7 +532,6 @@ async def _launch_kernel(self, kernel_cmd, **kwargs): # Apply user_overrides to enable defaulting behavior from kernelspec.env stanza. Note that we do this # BEFORE setting KERNEL_GATEWAY and removing {EG,KG}_AUTH_TOKEN so those operations cannot be overridden. - env.update(self.user_overrides) # No longer using Kernel Gateway, but retain references of B/C purposes env["KERNEL_GATEWAY"] = "1" @@ -521,7 +545,6 @@ async def _launch_kernel(self, kernel_cmd, **kwargs): self.kernel_spec.display_name, kernel_cmd ) ) - proxy = await self.process_proxy.launch_process(kernel_cmd, **kwargs) return proxy @@ -725,3 +748,98 @@ def mapping_kernel_manager(self): return self.parent except AttributeError: return None + + def _capture_user_update_overrides(self, **kwargs): + allowed_override_keys = [ + "KERNEL_EXTRA_SPARK_OPTS", + "KERNEL_LAUNCH_TIMEOUT", + ] # TODO need to read this list from env variable + user_requested_env_overrides = self.configure_kernel_overrides.get("env", {}) + allowed_env_overrides = {} + for override_key in allowed_override_keys: + if override_key in user_requested_env_overrides: + self.log.info("Key exist in extra overrides..") + if override_key == "KERNEL_LAUNCH_TIMEOUT": + allowed_env_overrides[override_key] = str( + user_requested_env_overrides.get(override_key) + ) + else: + allowed_env_overrides[override_key] = self.user_overrides.get( + override_key, "" + ) + user_requested_env_overrides.get(override_key) + return allowed_env_overrides + + def set_user_extra_overrides(self, update_payload): + # TODO need to read this list from env variable + allowed_override_keys = ["KERNEL_EXTRA_SPARK_OPTS", "KERNEL_LAUNCH_TIMEOUT"] + env_overrides = update_payload.get("env", {}) + if type(env_overrides) != dict: + error_message = "Expected `env` be of type: {} but found: {}.".format( + dict.__name__, type(env_overrides).__name__ + ) + self.log.info(error_message) + raise web.HTTPError(400, error_message) + self.log.debug(f"Validating the user overrides: {env_overrides}") + for env_name in env_overrides: + if env_name not in allowed_override_keys: + raise web.HTTPError(400, f"Updating ENV: `{env_name}` is not supported currently.") + self.configure_kernel_overrides = update_payload + + def add_kernel_event_callbacks(self, callback_func, event="kernel_refresh"): + """register a callback to fire on a particular event + + :param callback_func: + :param event: + - 'kernel_refresh' (default): kernel has received an update request and has successfully restarted. + :return: + """ + try: + self.log.debug( + f"add_kernel_event_callbacks: called for event: {event}: callback: {callback_func.__name__}" + ) + self.event_callbacks[event].append(callback_func) + except Exception as e: + self.log.error( + "Failed to add callback for event: {}: callback: {}".format( + event, callback_func.__name__ + ), + exc_info=True, + ) + + def remove_kernel_event_callbacks(self, callback_func, event="kernel_refresh"): + """Deregister a callback from this kernel event. + + :param callback_func: the callback to be removed if exists. + :param event: 'kernel_refresh' + :return: nothing. + """ + + self.log.debug( + f"remove_kernel_event_callbacks: called for event: {event}: callback: {callback_func.__name__}" + ) + try: + self.event_callbacks[event].remove(callback_func) + except Exception as e: + self.log.error( + "Failed to remove callback for event: {}: callback: {}".format( + event, callback_func.__name__ + ), + exc_info=True, + ) + + def fire_kernel_event_callbacks(self, **kwargs): + """fire the callbacks for a particular kernel event""" + event = kwargs.get("event") + self.log.debug(f"fire_kernel_event_callbacks: called for event: {event}") + for callback in self.event_callbacks[event]: + try: + self.log.debug(f"triggering callback to {callback.__name__}") + callback(**kwargs) + except Exception as e: + # TODO handle exception here..what should we do in this case if we are not able to refresh. + self.log.exception( + "Exception while executing event: {} with callback {} failed".format( + event, callback.__name__ + ), + exc_info=True, + ) diff --git a/enterprise_gateway/services/processproxies/container.py b/enterprise_gateway/services/processproxies/container.py index 5ffa4c9b5..824aec454 100644 --- a/enterprise_gateway/services/processproxies/container.py +++ b/enterprise_gateway/services/processproxies/container.py @@ -175,6 +175,9 @@ def kill(self): result = None if self.container_name: # We only have something to terminate if we have a name + self.log.info( + f"Terminating kernel: {self.kernel_id} with container name: : {self.container_name}" + ) result = self.terminate_container_resources() return result diff --git a/etc/kernel-launchers/python/scripts/configure_magic.py b/etc/kernel-launchers/python/scripts/configure_magic.py new file mode 100644 index 000000000..21889e9b7 --- /dev/null +++ b/etc/kernel-launchers/python/scripts/configure_magic.py @@ -0,0 +1,293 @@ +import base64 +import json +import logging +import os +import sys +import time +from json import JSONDecodeError + +import requests +from IPython.core.magic import Magics, cell_magic, magics_class +from requests.packages.urllib3.exceptions import InsecureRequestWarning + +requests.packages.urllib3.disable_warnings(InsecureRequestWarning) + +logger = logging.getLogger(__name__) +logger.name = "configure_magic" +logger.setLevel(logging.DEBUG) +logger.propagate = True + +RESERVED_SPARK_CONFIGS = [ + "spark.kubernetes.container.image", + "spark.kubernetes.driver.container.image", + "spark.kubernetes.executor.container.image", + "spark.kubernetes.namespace", + "spark.kubernetes.driver.label.component", + "spark.kubernetes.executor.label.component", + "spark.kubernetes.driver.label.kernel_id", + "spark.kubernetes.executor.label.kernel_id", + "spark.kubernetes.driver.label.app", + "spark.kubernetes.executor.label.app", +] + + +class InvalidPayloadException(Exception): + pass + + +@magics_class +class ConfigureMagic(Magics): + SUPPORTED_MAGIC_FIELDS = { + "driverMemory": " --conf spark.driver.memory={} ", + "driverCores": " --conf spark.driver.cores={} ", + "executorMemory": " --conf spark.executor.memory={} ", + "executorCores": " --conf spark.executor.cores={} ", + "numExecutors": " --conf spark.executor.instances={} ", + "conf": "--conf {}={} ", + "launchTimeout": "{}", + } + MAX_LAUNCH_TIMEOUT = 500 + + KERNEL_ID_NOT_FOUND = ( + "We could not find any associated Kernel to apply the magic. Please restart Kernel." + ) + EMPTY_INVALID_MAGIC_PAYLOAD = ( + "The magic payload is either empty or not in the correct format." + " Please recheck and execute." + ) + INVALID_JSON_PAYLOAD = "The magic payload could not be parsed into a valid JSON object. Please recheck and execute." + SERVER_ERROR = "An error occurred while updating the kernel configuration: {}." + UNKNOWN_ERROR = "An error occurred while processing payload." + RESERVED_SPARK_CONFIGS_ERROR = ( + "You are not allowed to override {} spark config as its reserved." + ) + MAX_LAUNCH_TIMEOUT_ERROR = "The allowed range for Kernel launchTimeout is (0, {}).".format( + MAX_LAUNCH_TIMEOUT + ) + INVALID_PAYLOAD_ERROR = "{} with error: {}" + + def __init__(self, shell=None, **kwargs): + logger.info("New Initializing ConfigureMagic...") + super().__init__(shell=None, **kwargs) + self.shell = shell + self.kernel_id = os.environ.get("KERNEL_ID", None) + self.endpoint_ip = os.environ.get("ENDPOINT_IP", "") + self.endpoint_port = int(os.environ.get("ENDPOINT_PORT", 8888)) + self.protocol = "http" # TODO make this configurable. + if self.endpoint_ip == "" or self.endpoint_ip is None: + logger.info("Environment var: ENDPOINT_IP not set. Falling back to using localhost.") + self.endpoint_ip = "localhost" + self.endpoint_port = 18888 + self.protocol = "http" + self.update_kernel_url = "{}://{}:{}/api/kernels/configure/{}".format( + self.protocol, self.endpoint_ip, self.endpoint_port, self.kernel_id + ) + logger.debug(f"Kernel Update URL set to: {self.update_kernel_url}") + self.headers = { + "Content-Type": "application/json", + } + logger.info("successfully loaded configure magic.") + + @cell_magic + def configure(self, line, cell=""): + if self.kernel_id is None: + logger.error(ConfigureMagic.KERNEL_ID_NOT_FOUND) + return ConfigureMagic.KERNEL_ID_NOT_FOUND + logger.info(f"Magic cell payload received: {cell}") + magic_payload = None + try: + cell_payload = json.loads(cell) + magic_payload = self.convert_to_kernel_update_payload(cell_payload) + if not magic_payload: + logger.error(f"The payload is either empty or invalid. {magic_payload}") + return ConfigureMagic.EMPTY_INVALID_MAGIC_PAYLOAD + except ValueError as ve: + logger.exception(f"Could not parse JSON object from input {cell}: error: {ve}.") + return ConfigureMagic.INVALID_JSON_PAYLOAD + except JSONDecodeError as jde: + logger.exception(f"Could not parse JSON object from input: {cell}: error: {jde}.") + return ConfigureMagic.INVALID_JSON_PAYLOAD + except InvalidPayloadException as ipe: + logger.exception( + "InvalidPayloadException occurred while processing magic payload: {}: error: {}".format( + cell, ipe + ) + ) + return ConfigureMagic.INVALID_PAYLOAD_ERROR.format( + InvalidPayloadException.__name__, ipe + ) + except Exception as ex: + logger.exception( + f"Exception occurred while processing magic payload: {cell}: error: {ex}" + ) + return ConfigureMagic.UNKNOWN_ERROR + else: + magic_payload["zmq_messages"] = self.prepare_zmq_messages() + logger.debug(f"Payload to refresh: {magic_payload}") + result = self.update_kernel(magic_payload) + return result + return "Done" + + def prepare_zmq_messages(self): + messages = {} + messages["idle_reply"] = self.prepare_iopub_idle_reply_payload() + messages["exec_reply"] = self.prepare_shell_reply_payload() + messages["stream_reply"] = self.prepare_iopub_stream_reply_payload() + messages["error_reply"] = self.prepare_iopub_error_reply_payload() + return messages + + def prepare_iopub_error_reply_payload(self): + ipykernel = self.shell.kernel + reply_content = {"ename": "MagicExecutionError", "evalue": "UnknownError", "traceback": []} + parent_headers = self.shell.parent_header["header"].copy() + metadata = {} # ipykernel.init_metadata(parent_headers) + error_payload = ipykernel.session.msg( + msg_type="error", content=reply_content, parent=parent_headers, metadata=metadata + ) + error_payload["channel"] = "iopub" + error_payload["buffers"] = [] + return error_payload + + def prepare_iopub_idle_reply_payload(self): + ipykernel = self.shell.kernel + reply_content = {"execution_state": "idle"} + parent_headers = self.shell.parent_header["header"].copy() + metadata = {} # ipykernel.init_metadata(parent_headers) + idle_payload = ipykernel.session.msg( + msg_type="status", content=reply_content, parent=parent_headers, metadata=metadata + ) + idle_payload["channel"] = "iopub" + idle_payload["buffers"] = [] + return idle_payload + + def prepare_iopub_stream_reply_payload(self): + ipykernel = self.shell.kernel + reply_content = {"name": "stdout", "text": " "} + parent_headers = self.shell.parent_header["header"].copy() + metadata = ipykernel.init_metadata(parent_headers) + idle_payload = ipykernel.session.msg( + msg_type="stream", content=reply_content, parent=parent_headers, metadata=metadata + ) + idle_payload["channel"] = "iopub" + idle_payload["buffers"] = [] + return idle_payload + + def prepare_shell_reply_payload(self): + ipykernel = self.shell.kernel + reply_content = { + "status": "ok", + "execution_count": ipykernel.execution_count, + "user_expressions": {}, + "payload": [], + } + parent_headers = self.shell.parent_header["header"].copy() + metadata = ipykernel.init_metadata(parent_headers) + metadata = ipykernel.finish_metadata(parent_headers, metadata, reply_content) + shell_payload = ipykernel.session.msg( + msg_type="execute_reply", + content=reply_content, + parent=parent_headers, + metadata=metadata, + ) + shell_payload["channel"] = "shell" + shell_payload["buffers"] = [] + return shell_payload + + def convert_to_kernel_update_payload(self, cell_payload={}): + if not cell_payload or type(cell_payload) != dict: + return None + magic_payload = {} + spark_overrides = "" + for magic_key, spark_conf in ConfigureMagic.SUPPORTED_MAGIC_FIELDS.items(): + magic_prop = cell_payload.get(magic_key, None) + if magic_prop is not None: + if magic_key == "conf" and type(magic_prop) == dict: + conf_dict = magic_prop + conf = " " + for k, v in conf_dict.items(): + if k in RESERVED_SPARK_CONFIGS: + raise InvalidPayloadException( + ConfigureMagic.RESERVED_SPARK_CONFIGS_ERROR.format(k) + ) + conf += spark_conf.format(k, v) + spark_overrides += conf + elif magic_key == "launchTimeout": + if int(magic_prop) <= 0 or int(magic_prop) > ConfigureMagic.MAX_LAUNCH_TIMEOUT: + raise InvalidPayloadException(ConfigureMagic.MAX_LAUNCH_TIMEOUT_ERROR) + self.populate_env_in_payload( + magic_payload, "KERNEL_LAUNCH_TIMEOUT", str(magic_prop) + ) + else: + spark_overrides += spark_conf.format(magic_prop) + logger.debug(f"KERNEL_EXTRA_SPARK_OPTS: {spark_overrides}") + if len(spark_overrides.strip()) != 0: + # Do not strip spark_overrides while populating + self.populate_env_in_payload(magic_payload, "KERNEL_EXTRA_SPARK_OPTS", spark_overrides) + return magic_payload + + def populate_env_in_payload(self, payload, env_key, env_value): + if not payload.get("env", None): + payload["env"] = {} + if env_key and env_value: + payload["env"][env_key] = env_value + else: + logger.error(f"Either key or value is not defined. {env_key}, {env_value}") + + def update_kernel(self, payload_dict): + try: + logger.info( + f"Sending request to update kernel. Please wait while the kernel will be refreshed." + ) + # Flush output before sending the request. + sys.stdout.flush() + sys.stderr.flush() + time.sleep(0.005) # small delay + response = requests.post( + self.update_kernel_url, + data=json.dumps(payload_dict, default=str), + headers=self.headers, + verify=False, + ) + response_body = response.json() if response is not None else {} + # the below lines are executed only if the request was not successful or runaway kernel case. + logger.debug( + f"Response received for kernel update: {response.status_code}: body: {response_body}" + ) + if response.status_code != 200: + error_message = ( + response_body["message"] + if response_body.get("message", None) + else "Internal Error." + ) + logger.error( + "An error occurred while updating kernel: {}: {}".format( + response.status_code, error_message + ) + ) + return ConfigureMagic.SERVER_ERROR.format(error_message) + else: + # if we have hit this, we have a runaway kernel as this pod should have gone down. + logger.error( + "Successfully updated kernel but with possible duplicate / runaway kernel scenario." + ) + return f"Status: {response}. Possible kernel leak." + except Exception as ex: + logger.exception("Runtime exception occurred: {}", ex) + return ConfigureMagic.SERVER_ERROR.format(ex) + except KeyboardInterrupt: + logger.info( + "Received Interrupt to shutdown kernel. Please wait while the kernel will be refreshed." + ) + + +def load_ipython_extension(ipython): + # The `ipython` argument is the currently active `InteractiveShell` + # instance, which can be used in any way. This allows you to register + # new magics or aliases, for example. + logger.info("Loading ConfigureMagic ...") + ipython.register_magics(ConfigureMagic) + + +def unload_ipython_extension(ipython): + logger.info("Unloading ConfigureMagic is a NO-OP. You will need to restart kernel now.") + return "NO-OP" diff --git a/etc/kernel-launchers/python/scripts/launch_ipykernel.py b/etc/kernel-launchers/python/scripts/launch_ipykernel.py index 6b9d2094c..8b4851b04 100644 --- a/etc/kernel-launchers/python/scripts/launch_ipykernel.py +++ b/etc/kernel-launchers/python/scripts/launch_ipykernel.py @@ -543,7 +543,11 @@ def start_ipython( cluster_type = arguments["cluster_type"] or arguments["rpp_cluster_type"] kernel_class_name = arguments["kernel_class_name"] ip = "0.0.0.0" - + os.environ["KERNEL_ID"] = str(kernel_id) + os.environ["ENDPOINT_IP"] = str(response_addr).split(":")[0] + os.environ[ + "ENDPOINT_PORT" + ] = 8888 # TODO read this from the launcher by introducing a new argument. if connection_file is None and kernel_id is None: raise RuntimeError( "At least one of the parameters: 'connection_file' or '--kernel-id' must be provided!"