Skip to content

Commit

Permalink
Mistral Family support and Logging
Browse files Browse the repository at this point in the history
Add support for mistral models. They use the same prompt template as Mixtral. This dramatically increases the performance
of the full pipeline on a laptop using LLamaCPP. The issue with Mixtral is that even in GGUF form the model is too big for most consumer hardware.
Completions and prompting using this model hang for long periods of time.

The logging I added is optional debug logs that print all of the prompts in each LLMBlock and the prompt currently being generated.

I also added a loading bar for each LLMBlock pass that increments each time we get a prompt response back.

Signed-off-by: Charlie Doern <[email protected]>
  • Loading branch information
cdoern committed Oct 10, 2024
1 parent a51b9c8 commit ff6a199
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 6 additions & 2 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# Third Party
from datasets import Dataset
from tqdm import tqdm
import openai

# Local
Expand All @@ -18,7 +19,7 @@
MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"

_MODEL_PROMPT_MIXTRAL = "<s> [INST] {prompt} [/INST]"
_MODEL_PROMPT_MIXTRAL = "<s> [INST] {prompt} [/INST] </s>"
_MODEL_PROMPT_MERLINITE = "'<|system|>\nYou are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.\n<|user|>\n{prompt}\n<|assistant|>\n'"

_MODEL_PROMPTS = {
Expand Down Expand Up @@ -157,20 +158,23 @@ def _gen_kwargs(self, gen_kwargs, **defaults):

def _generate(self, samples) -> list:
prompts = [self._format_prompt(sample) for sample in samples]

logger.debug(f"STARTING GENERATION FOR LLMBlock USING PROMPTS: {prompts}")
if self.server_supports_batched:
response = self.ctx.client.completions.create(
prompt=prompts, **self.gen_kwargs
)
return [choice.text.strip() for choice in response.choices]

results = []
progress_bar = tqdm(range(len(prompts)), desc=f"{self.block_name} Prompt Generation")
for prompt in prompts:
logger.debug(f"CREATING COMPLETION FOR PROMPT: {prompt}")
for _ in range(self.gen_kwargs.get("n", 1)):
response = self.ctx.client.completions.create(
prompt=prompt, **self.gen_kwargs
)
results.append(response.choices[0].text.strip())
progress_bar.update(1)
return results

def generate(self, samples: Dataset) -> Dataset:
Expand Down
4 changes: 1 addition & 3 deletions src/instructlab/sdg/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
MODEL_FAMILIES = set(("merlinite", "mixtral"))

# Map model names to their family
MODEL_FAMILY_MAPPINGS = {
"granite": "merlinite",
}
MODEL_FAMILY_MAPPINGS = {"granite": "merlinite", "mistral": "mixtral"}


def get_model_family(model_family, model_path):
Expand Down

0 comments on commit ff6a199

Please sign in to comment.