From bba6509cb47665b2df54a76d5035a4921b1adcc8 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Mon, 16 Sep 2024 12:03:48 +0530 Subject: [PATCH] add ollama to flytekit-inference (#2677) * add ollama to flytekit-inference Signed-off-by: Samhita Alla * add ollama to setup.py Signed-off-by: Samhita Alla * add support for creating models Signed-off-by: Samhita Alla * escape quote Signed-off-by: Samhita Alla * fix type hint Signed-off-by: Samhita Alla * lint Signed-off-by: Samhita Alla * update readme Signed-off-by: Samhita Alla * add support for flytefile in init container Signed-off-by: Samhita Alla * debug Signed-off-by: Samhita Alla * encode the modelfile Signed-off-by: Samhita Alla * flytefile in init container Signed-off-by: Samhita Alla * add input to args Signed-off-by: Samhita Alla * update inputs code and readme Signed-off-by: Samhita Alla * clean up Signed-off-by: Samhita Alla * cleanup Signed-off-by: Samhita Alla * add comment Signed-off-by: Samhita Alla * move sleep to python code snippets Signed-off-by: Samhita Alla * move input download code to init container Signed-off-by: Samhita Alla * debug Signed-off-by: Samhita Alla * move base code and ollama service ready to outer condition Signed-off-by: Samhita Alla * fix tests Signed-off-by: Samhita Alla * swap images Signed-off-by: Samhita Alla * remove tmp and update readme Signed-off-by: Samhita Alla * download to tmp if the file isn't in tmp Signed-off-by: Samhita Alla --------- Signed-off-by: Samhita Alla --- .github/workflows/pythonbuild.yml | 16 +- plugins/flytekit-inference/README.md | 59 ++++++ .../flytekitplugins/inference/__init__.py | 3 + .../flytekitplugins/inference/nim/serve.py | 4 +- .../inference/ollama/__init__.py | 0 .../flytekitplugins/inference/ollama/serve.py | 180 ++++++++++++++++++ .../inference/sidecar_template.py | 87 ++++++++- plugins/flytekit-inference/setup.py | 6 +- .../flytekit-inference/tests/test_ollama.py | 109 +++++++++++ plugins/setup.py | 3 +- 10 files changed, 450 insertions(+), 17 deletions(-) create mode 100644 plugins/flytekit-inference/flytekitplugins/inference/ollama/__init__.py create mode 100644 plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py create mode 100644 plugins/flytekit-inference/tests/test_ollama.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index db1c462eab..41991b960f 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -42,7 +42,7 @@ jobs: python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}} steps: - uses: actions/checkout@v4 - - name: 'Clear action cache' + - name: "Clear action cache" uses: ./.github/actions/clear-action-cache - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 @@ -81,7 +81,7 @@ jobs: python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}} steps: - uses: actions/checkout@v4 - - name: 'Clear action cache' + - name: "Clear action cache" uses: ./.github/actions/clear-action-cache - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 @@ -133,7 +133,7 @@ jobs: steps: - uses: actions/checkout@v4 - - name: 'Clear action cache' + - name: "Clear action cache" uses: ./.github/actions/clear-action-cache - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 @@ -244,15 +244,16 @@ jobs: matrix: os: [ubuntu-latest] python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}} - makefile-cmd: [integration_test_codecov, integration_test_lftransfers_codecov] + makefile-cmd: + [integration_test_codecov, integration_test_lftransfers_codecov] steps: # As described in https://github.com/pypa/setuptools_scm/issues/414, SCM needs git history # and tags to work. - uses: actions/checkout@v4 with: fetch-depth: 0 - - name: 'Clear action cache' - uses: ./.github/actions/clear-action-cache # sandbox has disk pressure, so we need to clear the cache to get more disk space. + - name: "Clear action cache" + uses: ./.github/actions/clear-action-cache # sandbox has disk pressure, so we need to clear the cache to get more disk space. - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: @@ -335,6 +336,7 @@ jobs: - flytekit-hive - flytekit-huggingface - flytekit-identity-aware-proxy + - flytekit-inference - flytekit-k8s-pod - flytekit-kf-mpi - flytekit-kf-pytorch @@ -414,7 +416,7 @@ jobs: plugin-names: "flytekit-kf-pytorch" steps: - uses: actions/checkout@v4 - - name: 'Clear action cache' + - name: "Clear action cache" uses: ./.github/actions/clear-action-cache - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 diff --git a/plugins/flytekit-inference/README.md b/plugins/flytekit-inference/README.md index ab33f97441..1bc5c8475e 100644 --- a/plugins/flytekit-inference/README.md +++ b/plugins/flytekit-inference/README.md @@ -67,3 +67,62 @@ def model_serving() -> str: return completion.choices[0].message.content ``` + +## Ollama + +The Ollama plugin allows you to serve LLMs locally. +You can either pull an existing model or create a new one. + +```python +from textwrap import dedent + +from flytekit import ImageSpec, Resources, task, workflow +from flytekitplugins.inference import Ollama, Model +from flytekit.extras.accelerators import A10G +from openai import OpenAI + + +image = ImageSpec( + name="ollama_serve", + registry="...", + packages=["flytekitplugins-inference"], +) + +ollama_instance = Ollama( + model=Model( + name="llama3-mario", + modelfile=dedent("""\ + FROM llama3 + ADAPTER {inputs.gguf} + PARAMETER temperature 1 + PARAMETER num_ctx 4096 + SYSTEM You are Mario from super mario bros, acting as an assistant.\ + """), + ) +) + + +@task( + container_image=image, + pod_template=ollama_instance.pod_template, + accelerator=A10G, + requests=Resources(gpu="0"), +) +def model_serving(questions: list[str], gguf: FlyteFile) -> list[str]: + responses = [] + client = OpenAI( + base_url=f"{ollama_instance.base_url}/v1", api_key="ollama" + ) # api key required but ignored + + for question in questions: + completion = client.chat.completions.create( + model="llama3-mario", + messages=[ + {"role": "user", "content": question}, + ], + max_tokens=256, + ) + responses.append(completion.choices[0].message.content) + + return responses +``` diff --git a/plugins/flytekit-inference/flytekitplugins/inference/__init__.py b/plugins/flytekit-inference/flytekitplugins/inference/__init__.py index a96ce6fc80..cfd14b09a8 100644 --- a/plugins/flytekit-inference/flytekitplugins/inference/__init__.py +++ b/plugins/flytekit-inference/flytekitplugins/inference/__init__.py @@ -8,6 +8,9 @@ NIM NIMSecrets + Model + Ollama """ from .nim.serve import NIM, NIMSecrets +from .ollama.serve import Model, Ollama diff --git a/plugins/flytekit-inference/flytekitplugins/inference/nim/serve.py b/plugins/flytekit-inference/flytekitplugins/inference/nim/serve.py index 66149c299b..50d326a5f8 100644 --- a/plugins/flytekit-inference/flytekitplugins/inference/nim/serve.py +++ b/plugins/flytekit-inference/flytekitplugins/inference/nim/serve.py @@ -34,7 +34,9 @@ def __init__( gpu: int = 1, mem: str = "20Gi", shm_size: str = "16Gi", - env: Optional[dict[str, str]] = None, + env: Optional[ + dict[str, str] + ] = None, # https://docs.nvidia.com/nim/large-language-models/latest/configuration.html#environment-variables hf_repo_ids: Optional[list[str]] = None, lora_adapter_mem: Optional[str] = None, ): diff --git a/plugins/flytekit-inference/flytekitplugins/inference/ollama/__init__.py b/plugins/flytekit-inference/flytekitplugins/inference/ollama/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py b/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py new file mode 100644 index 0000000000..f13acc10c3 --- /dev/null +++ b/plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py @@ -0,0 +1,180 @@ +import base64 +from dataclasses import dataclass +from typing import Optional + +from ..sidecar_template import ModelInferenceTemplate + + +@dataclass +class Model: + """Represents the configuration for a model used in a Kubernetes pod template. + + :param name: The name of the model. + :param mem: The amount of memory allocated for the model, specified as a string. Default is "500Mi". + :param cpu: The number of CPU cores allocated for the model. Default is 1. + :param modelfile: The actual model file as a JSON-serializable string. This represents the file content. Default is `None` if not applicable. + """ + + name: str + mem: str = "500Mi" + cpu: int = 1 + modelfile: Optional[str] = None + + +class Ollama(ModelInferenceTemplate): + def __init__( + self, + *, + model: Model, + image: str = "ollama/ollama", + port: int = 11434, + cpu: int = 1, + gpu: int = 1, + mem: str = "15Gi", + ): + """Initialize Ollama class for managing a Kubernetes pod template. + + :param model: An instance of the Model class containing the model's configuration, including its name, memory, CPU, and file. + :param image: The Docker image to be used for the container. Default is "ollama/ollama". + :param port: The port number on which the container should expose its service. Default is 11434. + :param cpu: The number of CPU cores requested for the container. Default is 1. + :param gpu: The number of GPUs requested for the container. Default is 1. + :param mem: The amount of memory requested for the container, specified as a string. Default is "15Gi". + """ + self._model_name = model.name + self._model_mem = model.mem + self._model_cpu = model.cpu + self._model_modelfile = model.modelfile + + super().__init__( + image=image, + port=port, + cpu=cpu, + gpu=gpu, + mem=mem, + download_inputs=(True if self._model_modelfile and "{inputs" in self._model_modelfile else False), + ) + + self.setup_ollama_pod_template() + + def setup_ollama_pod_template(self): + from kubernetes.client.models import ( + V1Container, + V1ResourceRequirements, + V1SecurityContext, + V1VolumeMount, + ) + + container_name = "create-model" if self._model_modelfile else "pull-model" + + base_code = """ +import base64 +import time +import ollama +import requests +""" + + ollama_service_ready = f""" +# Wait for Ollama service to be ready +max_retries = 30 +retry_interval = 1 +for _ in range(max_retries): + try: + response = requests.get('{self.base_url}') + if response.status_code == 200: + print('Ollama service is ready') + break + except requests.RequestException: + pass + time.sleep(retry_interval) +else: + print('Ollama service did not become ready in time') + exit(1) +""" + if self._model_modelfile: + encoded_modelfile = base64.b64encode(self._model_modelfile.encode("utf-8")).decode("utf-8") + + if "{inputs" in self._model_modelfile: + python_code = f""" +{base_code} +import json + +with open('/shared/inputs.json', 'r') as f: + inputs = json.load(f) + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + +inputs = {{'inputs': AttrDict(inputs)}} + +encoded_model_file = '{encoded_modelfile}' + +modelfile = base64.b64decode(encoded_model_file).decode('utf-8').format(**inputs) +modelfile = modelfile.replace('{{', '{{{{').replace('}}', '}}}}') + +with open('Modelfile', 'w') as f: + f.write(modelfile) + +{ollama_service_ready} + +# Debugging: Shows the status of model creation. +for chunk in ollama.create(model='{self._model_name}', path='Modelfile', stream=True): + print(chunk) +""" + else: + python_code = f""" +{base_code} + +encoded_model_file = '{encoded_modelfile}' + +modelfile = base64.b64decode(encoded_model_file).decode('utf-8') + +with open('Modelfile', 'w') as f: + f.write(modelfile) + +{ollama_service_ready} + +# Debugging: Shows the status of model creation. +for chunk in ollama.create(model='{self._model_name}', path='Modelfile', stream=True): + print(chunk) +""" + else: + python_code = f""" +{base_code} + +{ollama_service_ready} + +# Debugging: Shows the status of model pull. +for chunk in ollama.pull('{self._model_name}', stream=True): + print(chunk) +""" + + command = f'python3 -c "{python_code}"' + + self.pod_template.pod_spec.init_containers.append( + V1Container( + name=container_name, + image="python:3.11-slim", + command=["/bin/sh", "-c"], + args=[f"pip install requests && pip install ollama && {command}"], + resources=V1ResourceRequirements( + requests={ + "cpu": self._model_cpu, + "memory": self._model_mem, + }, + limits={ + "cpu": self._model_cpu, + "memory": self._model_mem, + }, + ), + security_context=V1SecurityContext( + run_as_user=0, + ), + volume_mounts=[ + V1VolumeMount(name="shared-data", mount_path="/shared"), + V1VolumeMount(name="tmp", mount_path="/tmp"), + ], + ) + ) diff --git a/plugins/flytekit-inference/flytekitplugins/inference/sidecar_template.py b/plugins/flytekit-inference/flytekitplugins/inference/sidecar_template.py index 549b400895..28091d46d5 100644 --- a/plugins/flytekit-inference/flytekitplugins/inference/sidecar_template.py +++ b/plugins/flytekit-inference/flytekitplugins/inference/sidecar_template.py @@ -1,20 +1,20 @@ from typing import Optional from flytekit import PodTemplate +from flytekit.configuration.default_images import DefaultImages class ModelInferenceTemplate: def __init__( self, image: Optional[str] = None, - health_endpoint: str = "/", + health_endpoint: Optional[str] = None, port: int = 8000, cpu: int = 1, gpu: int = 1, mem: str = "1Gi", - env: Optional[ - dict[str, str] - ] = None, # https://docs.nvidia.com/nim/large-language-models/latest/configuration.html#environment-variables + env: Optional[dict[str, str]] = None, + download_inputs: bool = False, ): from kubernetes.client.models import ( V1Container, @@ -24,6 +24,8 @@ def __init__( V1PodSpec, V1Probe, V1ResourceRequirements, + V1Volume, + V1VolumeMount, ) self._image = image @@ -33,6 +35,7 @@ def __init__( self._gpu = gpu self._mem = mem self._env = env + self._download_inputs = download_inputs self._pod_template = PodTemplate() @@ -60,14 +63,84 @@ def __init__( ), restart_policy="Always", # treat this container as a sidecar env=([V1EnvVar(name=k, value=v) for k, v in self._env.items()] if self._env else None), - startup_probe=V1Probe( - http_get=V1HTTPGetAction(path=self._health_endpoint, port=self._port), - failure_threshold=100, # The model server initialization can take some time, so the failure threshold is increased to accommodate this delay. + startup_probe=( + V1Probe( + http_get=V1HTTPGetAction( + path=self._health_endpoint, + port=self._port, + ), + failure_threshold=100, # The model server initialization can take some time, so the failure threshold is increased to accommodate this delay. + ) + if self._health_endpoint + else None ), ), ], + volumes=[ + V1Volume(name="shared-data", empty_dir={}), + V1Volume(name="tmp", empty_dir={}), + ], ) + if self._download_inputs: + input_download_code = """ +import os +import json +import sys + +from flyteidl.core import literals_pb2 as _literals_pb2 +from flytekit.core import utils +from flytekit.core.context_manager import FlyteContextManager +from flytekit.interaction.string_literals import literal_map_string_repr +from flytekit.models import literals as _literal_models +from flytekit.models.core.types import BlobType +from flytekit.types.file import FlyteFile + +input_arg = sys.argv[-1] + +ctx = FlyteContextManager.current_context() +local_inputs_file = os.path.join(ctx.execution_state.working_dir, 'inputs.pb') +ctx.file_access.get_data( + input_arg, + local_inputs_file, +) +input_proto = utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) +idl_input_literals = _literal_models.LiteralMap.from_flyte_idl(input_proto) + +inputs = literal_map_string_repr(idl_input_literals) + +for var_name, literal in idl_input_literals.literals.items(): + if literal.scalar and literal.scalar.blob: + if ( + literal.scalar.blob.metadata.type.dimensionality + == BlobType.BlobDimensionality.SINGLE + ): + downloaded_file = FlyteFile.from_source(literal.scalar.blob.uri).download() + + tmp_destination = None + if not downloaded_file.startswith('/tmp'): + tmp_destination = '/tmp' + os.path.basename(downloaded_file) + shutil.copy(downloaded_file, tmp_destination) + + inputs[var_name] = tmp_destination or downloaded_file + +with open('/shared/inputs.json', 'w') as f: + json.dump(inputs, f) +""" + + self._pod_template.pod_spec.init_containers.append( + V1Container( + name="input-downloader", + image=DefaultImages.default_image(), + command=["/bin/sh", "-c"], + args=[f'python3 -c "{input_download_code}" {{{{.input}}}}'], + volume_mounts=[ + V1VolumeMount(name="shared-data", mount_path="/shared"), + V1VolumeMount(name="tmp", mount_path="/tmp"), + ], + ), + ) + @property def pod_template(self): return self._pod_template diff --git a/plugins/flytekit-inference/setup.py b/plugins/flytekit-inference/setup.py index a344b3857c..fbc00b43e4 100644 --- a/plugins/flytekit-inference/setup.py +++ b/plugins/flytekit-inference/setup.py @@ -15,7 +15,11 @@ author_email="admin@flyte.org", description="This package enables seamless use of model inference sidecar services within Flyte", namespace_packages=["flytekitplugins"], - packages=[f"flytekitplugins.{PLUGIN_NAME}", f"flytekitplugins.{PLUGIN_NAME}.nim"], + packages=[ + f"flytekitplugins.{PLUGIN_NAME}", + f"flytekitplugins.{PLUGIN_NAME}.nim", + f"flytekitplugins.{PLUGIN_NAME}.ollama", + ], install_requires=plugin_requires, license="apache2", python_requires=">=3.8", diff --git a/plugins/flytekit-inference/tests/test_ollama.py b/plugins/flytekit-inference/tests/test_ollama.py new file mode 100644 index 0000000000..0e8ced374c --- /dev/null +++ b/plugins/flytekit-inference/tests/test_ollama.py @@ -0,0 +1,109 @@ +from flytekitplugins.inference import Ollama, Model + + +def test_ollama_init_valid_params(): + ollama_instance = Ollama( + mem="30Gi", + port=11435, + model=Model(name="mistral-nemo"), + ) + + assert len(ollama_instance.pod_template.pod_spec.init_containers) == 2 + assert ( + ollama_instance.pod_template.pod_spec.init_containers[0].image + == "ollama/ollama" + ) + assert ( + ollama_instance.pod_template.pod_spec.init_containers[0].resources.requests[ + "memory" + ] + == "30Gi" + ) + assert ( + ollama_instance.pod_template.pod_spec.init_containers[0].ports[0].container_port + == 11435 + ) + assert ( + "mistral-nemo" + in ollama_instance.pod_template.pod_spec.init_containers[1].args[0] + ) + assert ( + "ollama.pull" + in ollama_instance.pod_template.pod_spec.init_containers[1].args[0] + ) + + +def test_ollama_default_params(): + ollama_instance = Ollama(model=Model(name="phi")) + + assert ollama_instance.base_url == "http://localhost:11434" + assert ollama_instance._cpu == 1 + assert ollama_instance._gpu == 1 + assert ollama_instance._health_endpoint == None + assert ollama_instance._mem == "15Gi" + assert ollama_instance._model_name == "phi" + assert ollama_instance._model_cpu == 1 + assert ollama_instance._model_mem == "500Mi" + + +def test_ollama_modelfile(): + ollama_instance = Ollama( + model=Model( + name="llama3-mario", + modelfile="FROM llama3\nPARAMETER temperature 1\nPARAMETER num_ctx 4096\nSYSTEM You are Mario from super mario bros, acting as an assistant.", + ) + ) + + assert len(ollama_instance.pod_template.pod_spec.init_containers) == 2 + assert ( + "ollama.create" + in ollama_instance.pod_template.pod_spec.init_containers[1].args[0] + ) + assert ( + "format(**inputs)" + not in ollama_instance.pod_template.pod_spec.init_containers[1].args[0] + ) + + +def test_ollama_modelfile_with_inputs(): + ollama_instance = Ollama( + model=Model( + name="tinyllama-finetuned", + modelfile='''FROM tinyllama:latest +ADAPTER {inputs.ggml} +TEMPLATE """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + +{{ if .System }}### Instruction: +{{ .System }}{{ end }} + +{{ if .Prompt }}### Input: +{{ .Prompt }}{{ end }} + +### Response: +""" +SYSTEM "You're a kitty. Answer using kitty sounds." +PARAMETER stop "### Response:" +PARAMETER stop "### Instruction:" +PARAMETER stop "### Input:" +PARAMETER stop "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request." +PARAMETER num_predict 200 +''', + ) + ) + + assert len(ollama_instance.pod_template.pod_spec.init_containers) == 3 + assert ( + "model-server" in ollama_instance.pod_template.pod_spec.init_containers[0].name + ) + assert ( + "input-downloader" + in ollama_instance.pod_template.pod_spec.init_containers[1].name + ) + assert ( + "ollama.create" + in ollama_instance.pod_template.pod_spec.init_containers[2].args[0] + ) + assert ( + "format(**inputs)" + in ollama_instance.pod_template.pod_spec.init_containers[2].args[0] + ) diff --git a/plugins/setup.py b/plugins/setup.py index ea35649ed7..8f042a9d3a 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -23,9 +23,11 @@ "flytekitplugins-duckdb": "flytekit-duckdb", "flytekitplugins-data-fsspec": "flytekit-data-fsspec", "flytekitplugins-envd": "flytekit-envd", + "flytekitplugins-flyteinteractive": "flytekit-flyteinteractive", "flytekitplugins-great_expectations": "flytekit-greatexpectations", "flytekitplugins-hive": "flytekit-hive", "flytekitplugins-huggingface": "flytekit-huggingface", + "flytekitplugins-inference": "flytekit-inference", "flytekitplugins-pod": "flytekit-k8s-pod", "flytekitplugins-kfmpi": "flytekit-kf-mpi", "flytekitplugins-kfpytorch": "flytekit-kf-pytorch", @@ -45,7 +47,6 @@ "flytekitplugins-sqlalchemy": "flytekit-sqlalchemy", "flytekitplugins-vaex": "flytekit-vaex", "flytekitplugins-whylogs": "flytekit-whylogs", - "flytekitplugins-flyteinteractive": "flytekit-flyteinteractive", }