Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC-166] Simple LLM inference example (WIP) #218

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions docs/examples/distributed/LLM_inference/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# FastAPI + HuggingFace + SLURM

Proof-of-concept for an API that performs inference with a Large Language Model (LLM) on the Mila cluster.

![LLM_api](https://user-images.githubusercontent.com/13387299/184188304-3ce82a7f-29a6-49ed-86ba-4842db4e207e.png)

## The goal:

- One ML researcher/student can submit this as a job on a SLURM cluster, and other users can use a single shared model instance via HTTP or a simple python client.

## Installation:

To run the server locally:

```console
> conda env create -n llm python=3.10
> conda activate llm
> pip install git+https://www.github.com/lebrice/LLM_api.git
```

(WIP) To connect to a running LLM server:

(Requires python >= 3.7)
```console
> pip install git+https://www.github.com/lebrice/LLM_api.git
```


## Usage:

Available options:
```console
$ python app/server.py --help
usage: server.py [-h] [--model str] [--hf_cache_dir Path] [--port int]
[--reload bool] [--offload_folder Path] [--use_public_ip bool]

API for querying a large language model.

options:
-h, --help show this help message and exit

Settings ['settings']:
Configuration settings for the API.

--model str HuggingFace model to use. Examples: facebook/opt-13b,
facebook/opt-30b, facebook/opt-66b, bigscience/bloom,
etc. (default: facebook/opt-13b)
--hf_cache_dir Path (default: $SCRATCH/cache/huggingface)
--port int The port to run the server on. (default: 12345)
--reload bool Whether to restart the server (and reload the model) when
the source code changes. (default: False)
--offload_folder Path
Folder where the model weights will be offloaded if the
entire model doesn't fit in memory. (default:
$SLURM_TMPDIR)
--use_public_ip bool Set to True to make the server available on the node's
public IP, rather than localhost. Setting this to False
is useful when using VSCode to debug the server, since
the port forwarding is done automatically for you.
Setting this to True makes it so many users on the
cluster can share the same server. However, at the
moment, you would still need to do the port forwarding
setup yourself, if you want to access the server from
outside the cluster. (default: False)
```

Spinning up the server:
```console
> python app/server.py
HF_HOME='/home/mila/n/normandf/scratch/cache/huggingface'
TRANSFORMERS_CACHE='/home/mila/n/normandf/scratch/cache/huggingface/transformers'
Running the server with the following settings: {"model_capacity": "13b", "hf_cache_dir": "~/scratch/cache/huggingface", "port": 12345, "reload": false, "offload_folder": "/Tmp/slurm.1968686.0"}
INFO: Started server process [25042]
INFO: Waiting for application startup.
Writing address_string='cn-b003:8000' to server.txt
INFO: Application startup complete.
INFO: Uvicorn running on http://127.0.0.1:12345 (Press CTRL+C to quit)
```

(WIP) Run as a slurm job:

```console
> sbatch run_server.sh
```

(WIP) Using the python client to Connect to a running server:

```python
import time
from app.client import server_is_up, get_completion_text
while not server_is_up():
print("Waiting for the server to be online...")
time.sleep(10)
print("server is up!")
rest_of_story = get_completion_text("Once upon a time, there lived a great wizard.")
```
125 changes: 125 additions & 0 deletions docs/examples/distributed/LLM_inference/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
""" TODO: Client-side code to communicate with the server that is running somewhere on the cluster.

IDEAS:
- Could look for slurm jobs that have a given name, like `deploy.sh` and extract the port from the
job's command-line ags!
"""
from pathlib import Path
import requests
import time


def _fetch_job_info(name):
# Mock this for testing
command = ["squeue", "-h", f"--name={name}", "--format=\"%A %j %T %P %U %k %N\""]
return subprocess.check_output(command, text=True)


def get_slurm_job_by_name(name):
"""Retrieve a list of jobs that match a given job name"""

output =_fetch_job_info(name)
jobs = []

def parse_meta(comment):
data = dict()
if comment != "(null)":
items = comment.split('|')
for kv in items:
try:
k, v = kv.split('=', maxsplit=1)
data[k] = v
except:
pass

return data

for line in output.splitlines():
job_id, job_name, status, partition, user, comment, nodes = line.split(' ')

jobs.append({
"job_id":job_id,
"job_name":job_name,
"status":status,
"partition":partition,
"user":user,
"comment": parse_meta(comment),
"nodes": nodes
})

return jobs


def find_suitable_inference_server(jobs, model):
"""Select suitable jobs from a list, looking for a specific model"""
selected = []

def is_shared(job):
return job["comment"].get("shared", 'y') == 'y'

def is_running(job):
return job['status'] == "RUNNING"

def has_model(job, model):
if model is None:
return True

# FIXME:
# /network/weights/llama.var/llama2/Llama-2-7b-hf != meta-llama/Llama-2-7b-hf
#
return job['comment']['model'] == model

def select(job):
selected.append({
"model": job['comment']["model"],
"host": job["comment"]["host"],
"port": job["comment"]["port"],
})

for job in jobs:
if is_shared(job) and is_running(job) and has_model(job, model):
select(job)

return selected


def get_inference_server(model=None):
"""Retrieve an inference server from slurm jobs"""
jobs = get_slurm_job_by_name('inference_server_SHARED.sh')

servers = find_suitable_inference_server(jobs, model)

try:
return random.choice(servers)
except IndexError:
return None


def get_server_url_and_port() -> tuple[str, int]:
server = get_inference_server(model)

if server is None:
return None

return server['host'], int(server['port'])


def debug():
# WIP: Not working yet.
while not Path("server.txt").exists():
time.sleep(1)
print(f"Waiting for server to start...")

server_url, port = get_server_url_and_port()
print(f"Found server at {server_url}:{port}")
response = requests.get(
f"http://{server_url}:{port}/complete/",
params={
"prompt": "Hello, my name is Bob. I love fishing, hunting, and my favorite food is",
},
)
print(response)


if __name__ == "__main__":
debug()
Loading
Loading