Skip to content

Commit

Permalink
Merge pull request #144 from esbmc/local-llms
Browse files Browse the repository at this point in the history
Add Support For Local LLMs
  • Loading branch information
Yiannis128 authored Sep 12, 2024
2 parents ee754d4 + 7c0195d commit 20097a9
Show file tree
Hide file tree
Showing 17 changed files with 104 additions and 224 deletions.
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ clang = "*"
langchain = "*"
langchain-openai = "*"
langchain-community = "*"
langchain-ollama = "*"
lizard = "*"

[dev-packages]
Expand Down
4 changes: 2 additions & 2 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@
"system": [
{
"role": "System",
"content": "From now on, act as an Automated Code Repair Tool that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with the type of error that occurred and its location. "
"content": "From now on, act as an Automated Code Repair Tool that repairs AI C code. You will be shown AI C code, along with ESBMC output. Pay close attention to the ESBMC output, which contains a stack trace along with the type of error that occurred and its location that you need to fix. Provide the repaired C code as output, as would an Automated Code Repair Tool. Aside from the corrected source code, do not output any other text."
}
],
"initial": "Provide the repaired C code as output, as would an Automated Code Repair Tool. Aside from the corrected source code, do not output any other text. The ESBMC output is {esbmc_output} The source code is {source_code}"
"initial": "The ESBMC output is:\n\n```\n{esbmc_output}\n```\n\nThe source code is:\n\n```c\n{source_code}\n```\n Using the ESBMC output, show the fixed text."
}
}
}
5 changes: 3 additions & 2 deletions esbmc_ai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import readline
from typing import Optional

from langchain_core.language_models import BaseChatModel

from esbmc_ai.commands.fix_code_command import FixCodeCommandResult

_ = readline

import argparse
from langchain.base_language import BaseLanguageModel


import esbmc_ai.config as config
Expand Down Expand Up @@ -365,7 +366,7 @@ def main() -> None:
del esbmc_output

printv(f"Initializing the LLM: {config.ai_model.name}\n")
chat_llm: BaseLanguageModel = config.ai_model.create_llm(
chat_llm: BaseChatModel = config.ai_model.create_llm(
api_keys=config.api_keys,
temperature=config.chat_prompt_user_mode.temperature,
requests_max_tries=config.requests_max_tries,
Expand Down
157 changes: 25 additions & 132 deletions esbmc_ai/ai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,14 @@
from abc import abstractmethod
from typing import Any, Iterable, Optional, Union
from enum import Enum
from langchain_core.language_models import BaseChatModel
from pydantic.v1.types import SecretStr
from typing_extensions import override

from langchain.prompts import PromptTemplate
from langchain.base_language import BaseLanguageModel

from langchain_openai import ChatOpenAI
from langchain_community.llms.huggingface_text_gen_inference import (
HuggingFaceTextGenInference,
)
from langchain_ollama import ChatOllama

from langchain.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.prompts.chat import ChatPromptTemplate
from langchain.schema import (
BaseMessage,
PromptValue,
Expand All @@ -30,6 +21,8 @@


class AIModel(object):
"""This base class represents an abstract AI model."""

name: str
tokens: int

Expand All @@ -48,7 +41,7 @@ def create_llm(
temperature: float = 1.0,
requests_max_tries: int = 5,
requests_timeout: float = 60,
) -> BaseLanguageModel:
) -> BaseChatModel:
"""Initializes a large language model model with the provided parameters."""
raise NotImplementedError()

Expand Down Expand Up @@ -132,7 +125,9 @@ def apply_chat_template(
messages: Iterable[BaseMessage],
**format_values: Any,
) -> PromptValue:
# Default one, identity function essentially.
"""Applies the formatted values onto the message chat template. For example,
if the message contains the token {source}, then format_values contains a
value for {source} then it will be substituted."""
escaped_messages = AIModel.escape_messages(messages, list(format_values.keys()))
message_tuples = AIModel.convert_messages_to_tuples(escaped_messages)
return ChatPromptTemplate.from_messages(messages=message_tuples).format_prompt(
Expand All @@ -151,7 +146,7 @@ def create_llm(
temperature: float = 1.0,
requests_max_tries: int = 5,
requests_timeout: float = 60,
) -> BaseLanguageModel:
) -> BaseChatModel:
assert api_keys.openai, "No OpenAI api key has been specified..."
return ChatOpenAI(
model=self.name,
Expand All @@ -163,134 +158,32 @@ def create_llm(
model_kwargs={},
)


class AIModelTextGen(AIModel):
"""Below are only used for models that need them, such as models that
are using the provider "text_inference_server"."""

def __init__(
self,
name: str,
tokens: int,
url: str,
config_message: str = "{history}\n\n{user_prompt}",
system_template: str = "{content}",
human_template: str = "{content}",
ai_template: str = "{content}",
stop_sequences: list[str] = [],
) -> None:
class OllamaAIModel(AIModel):
def __init__(self, name: str, tokens: int, url: str) -> None:
super().__init__(name, tokens)

self.url: str = url
self.chat_template: PromptTemplate = PromptTemplate.from_template(
template=config_message,
)
"""The chat template to place all messages in."""

self.system_template: SystemMessagePromptTemplate = (
SystemMessagePromptTemplate.from_template(
template=system_template,
)
)
"""Template for each system message."""

self.human_template: HumanMessagePromptTemplate = (
HumanMessagePromptTemplate.from_template(
template=human_template,
)
)
"""Template for each human message."""

self.ai_template: AIMessagePromptTemplate = (
AIMessagePromptTemplate.from_template(
template=ai_template,
)
)
"""Template for each AI message."""

self.stop_sequences: list[str] = stop_sequences


@override
def create_llm(
self,
api_keys: APIKeyCollection,
temperature: float = 1.0,
requests_max_tries: int = 5,
requests_timeout: float = 60,
) -> BaseLanguageModel:
return HuggingFaceTextGenInference(
client=None,
async_client=None,
inference_server_url=self.url,
server_kwargs={
"headers": {"Authorization": f"Bearer {api_keys.huggingface}"}
},
# FIXME Need to find a way to make output bigger. When token
# tracking for this LLM type is added.
max_new_tokens=5000,
def create_llm(self, api_keys: APIKeyCollection, temperature: float = 1, requests_max_tries: int = 5, requests_timeout: float = 60) -> BaseChatModel:
# Ollama does not use API keys
_ = api_keys
_ = requests_max_tries
return ChatOllama(
base_url=self.url,
model=self.name,
temperature=temperature,
stop_sequences=self.stop_sequences,
max_retries=requests_max_tries,
timeout=requests_timeout,
)

@override
def apply_chat_template(
self,
messages: Iterable[BaseMessage],
**format_values: Any,
) -> PromptValue:
"""Text generation LLMs take single string of text as input. So the conversation
is converted into a string and returned back in a single prompt value. The config
message is also applied to the conversation."""

escaped_messages = AIModel.escape_messages(messages, list(format_values.keys()))

formatted_messages: list[BaseMessage] = []
for msg in escaped_messages:
formatted_msg: BaseMessage
if msg.type == "ai":
formatted_msg = self.ai_template.format(content=msg.content)
elif msg.type == "system":
formatted_msg = self.system_template.format(content=msg.content)
elif msg.type == "human":
formatted_msg = self.human_template.format(content=msg.content)
else:
raise ValueError(
f"Got unsupported message type: {msg.type}: {msg.content}"
)
formatted_messages.append(formatted_msg)

return self.chat_template.format_prompt(
history="\n\n".join([str(msg.content) for msg in formatted_messages[:-1]]),
user_prompt=formatted_messages[-1].content,
**format_values,
client_kwargs={
"timeout":requests_timeout,
},
)


class _AIModels(Enum):
"""Private enum that contains predefined AI Models. OpenAI models are not
defined because they are fetched from the API."""

FALCON_7B = AIModelTextGen(
name="falcon-7b",
tokens=8192,
url="https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct",
config_message='>>DOMAIN<<You are a helpful assistant that answers any questions asked based on the previous messages in the conversation. The questions are asked by Human. The "AI" is the assistant. The AI shall not impersonate any other entity in the interaction including System and Human. The Human may refer to the AI directly, the AI should refer to the Human directly back, for example, when asked "How do you suggest a fix?", the AI shall respond "You can try...". The AI should use markdown formatting in its responses. The AI should follow the instructions given by System.\n\n>>SUMMARY<<{history}\n\n{user_prompt}\n\n',
ai_template=">>ANSWER<<{content}",
human_template=">>QUESTION<<Human:{content}>>ANSWER<<",
system_template="System: {content}",
)
STARCHAT_BETA = AIModelTextGen(
name="starchat-beta",
tokens=8192,
url="https://api-inference.huggingface.co/models/HuggingFaceH4/starchat-beta",
config_message="{history}\n{user_prompt}\n<|assistant|>\n",
system_template="<|system|>\n{content}\n<|end|>",
ai_template="<|assistant|>\n{content}\n<|end|>",
human_template="<|user|>\n{content}\n<|end|>",
stop_sequences=["<|end|>"],
)
# FALCON_7B = OllamaAIModel(...)
pass


_custom_ai_models: list[AIModel] = []
Expand Down
3 changes: 3 additions & 0 deletions esbmc_ai/chats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Author: Yiannis Charalambous

"""This module contains different chat interfaces. Along with `BaseChatInterface`
that provides necessary boilet-plate for implementing an LLM based chat."""

from .base_chat_interface import BaseChatInterface
from .latest_state_solution_generator import LatestStateSolutionGenerator
from .solution_generator import SolutionGenerator
Expand Down
29 changes: 7 additions & 22 deletions esbmc_ai/chats/base_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,26 @@
from abc import abstractmethod
from typing import Optional

from langchain.base_language import BaseLanguageModel
from langchain.schema import (
AIMessage,
BaseMessage,
HumanMessage,
LLMResult,
PromptValue,
)
from langchain_core.language_models import BaseChatModel

from esbmc_ai.config import ChatPromptSettings
from esbmc_ai.chat_response import ChatResponse, FinishReason
from esbmc_ai.ai_models import AIModel


class BaseChatInterface(object):
"""Base class for interacting with an LLM. It allows for interactions with
text generation LLMs and also chat LLMs."""

def __init__(
self,
ai_model_agent: ChatPromptSettings,
llm: BaseLanguageModel,
llm: BaseChatModel,
ai_model: AIModel,
) -> None:
super().__init__()
Expand All @@ -31,7 +32,7 @@ def __init__(
ai_model_agent.system_messages.messages
)
self.messages: list[BaseMessage] = []
self.llm: BaseLanguageModel = llm
self.llm: BaseChatModel = llm

@abstractmethod
def compress_message_stack(self) -> None:
Expand Down Expand Up @@ -84,25 +85,9 @@ def send_message(self, message: Optional[str] = None) -> ChatResponse:
all_messages = self._system_messages.copy()
all_messages.extend(self.messages.copy())

# Transform message stack to ChatPromptValue: If this is a ChatLLM then the
# function will simply be an identity function that does nothing and simply
# returns the messages as a ChatPromptValue. If this is a text generation
# LLM, then the function should inject the config message around the
# conversation to make the LLM behave like a ChatLLM.
# Do not replace any values.
message_prompts: PromptValue = self.ai_model.apply_chat_template(
messages=all_messages,
)

response: ChatResponse
try:
result: LLMResult = self.llm.generate_prompt(
prompts=[message_prompts],
)

response_message: BaseMessage = AIMessage(
content=result.generations[0][0].text
)
response_message: BaseMessage = self.llm.invoke(input=all_messages)

self.push_to_message_stack(message=response_message)

Expand Down
4 changes: 2 additions & 2 deletions esbmc_ai/chats/solution_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from re import S
from typing import Optional
from langchain_core.language_models import BaseChatModel
from typing_extensions import override
from langchain.base_language import BaseLanguageModel
from langchain.schema import BaseMessage, HumanMessage

from esbmc_ai.chat_response import ChatResponse, FinishReason
Expand Down Expand Up @@ -83,7 +83,7 @@ class SolutionGenerator(BaseChatInterface):
def __init__(
self,
ai_model_agent: DynamicAIModelAgent | ChatPromptSettings,
llm: BaseLanguageModel,
llm: BaseChatModel,
ai_model: AIModel,
scenario: str = "",
source_code_format: str = "full",
Expand Down
6 changes: 3 additions & 3 deletions esbmc_ai/chats/user_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from typing_extensions import override

from langchain.base_language import BaseLanguageModel
from langchain.memory import ConversationSummaryMemory
from langchain.schema import BaseMessage, SystemMessage
from langchain_core.language_models import BaseChatModel
from langchain_community.chat_message_histories import ChatMessageHistory

from langchain.schema import BaseMessage, SystemMessage

from esbmc_ai.config import AIAgentConversation, ChatPromptSettings
from esbmc_ai.ai_models import AIModel
Expand All @@ -21,7 +21,7 @@ def __init__(
self,
ai_model_agent: ChatPromptSettings,
ai_model: AIModel,
llm: BaseLanguageModel,
llm: BaseChatModel,
source_code: str,
esbmc_output: str,
set_solution_messages: AIAgentConversation,
Expand Down
Loading

0 comments on commit 20097a9

Please sign in to comment.