Skip to content

Commit

Permalink
[usability] deps streamlining
Browse files Browse the repository at this point in the history
  • Loading branch information
Yizhen committed Sep 30, 2024
1 parent 94021d3 commit 60ee14b
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 113 deletions.
26 changes: 16 additions & 10 deletions contrib/long-context/sft_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,27 @@
from colorama import Fore,init
from typing import Optional, List

from trl.commands.cli_utils import TrlParser
import torch
from datasets import load_dataset
from dataclasses import dataclass, field
from tqdm.rich import tqdm
from transformers import AutoTokenizer, TrainingArguments, TrainerCallback
from trl import (
ModelConfig,
SFTTrainer,
DataCollatorForCompletionOnlyLM,
SFTConfig,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)

from lmflow.utils.versioning import is_trl_available

if is_trl_available():
from trl import (
ModelConfig,
SFTTrainer,
DataCollatorForCompletionOnlyLM,
SFTConfig,
get_peft_config,
get_quantization_config,
get_kbit_device_map,
)
from trl.commands.cli_utils import TrlParser
else:
raise ImportError("Please install trl package to use sft_summarizer.py")

@dataclass
class UserArguments:
Expand Down
14 changes: 10 additions & 4 deletions examples/chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,28 @@
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""A simple shell chatbot implemented with lmflow APIs.
"""
from dataclasses import dataclass, field
import logging
import json
import os
import sys
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
import torch
from typing import Optional
import warnings
import gradio as gr
from dataclasses import dataclass, field

import torch
from transformers import HfArgumentParser
from typing import Optional

from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.models.auto_model import AutoModel
from lmflow.args import ModelArguments, DatasetArguments, AutoArguments
from lmflow.utils.versioning import is_gradio_available

if is_gradio_available():
import gradio as gr
else:
raise ImportError("Gradio is not available. Please install it via `pip install gradio`.")

MAX_BOXES = 20

Expand Down
22 changes: 12 additions & 10 deletions examples/vis_chatbot_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,30 @@
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""A simple Multimodal chatbot implemented with lmflow APIs.
"""
import logging
from dataclasses import dataclass, field
import json
import logging
import time

from PIL import Image
from lmflow.pipeline.inferencer import Inferencer
import warnings
from typing import Optional

import numpy as np
import os
import sys
from PIL import Image
import torch
import warnings
import gradio as gr
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from typing import Optional

from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.models.auto_model import AutoModel
from lmflow.args import (VisModelArguments, DatasetArguments, \
InferencerArguments, AutoArguments)
from lmflow.utils.versioning import is_gradio_available

if is_gradio_available():
import gradio as gr
else:
raise ImportError("Gradio is not available. Please install it via `pip install gradio`.")


MAX_BOXES = 20

Expand Down
14 changes: 2 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,13 @@ datasets==2.14.6
tokenizers>=0.13.3
peft>=0.10.0
torch>=2.0.1
wandb==0.14.0
wandb
deepspeed>=0.14.4
trl==0.8.0
sentencepiece
transformers>=4.31.0
flask
flask_cors
icetk
cpm_kernels==1.0.11
evaluate==0.4.0
scikit-learn==1.2.2
lm-eval==0.3.0
dill<0.3.5
bitsandbytes>=0.40.0
pydantic
gradio
accelerate>=0.27.2
einops>=0.6.1
vllm>=0.4.3
ray>=2.22.0
einops>=0.6.1
21 changes: 12 additions & 9 deletions service/app.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from dataclasses import dataclass, field
import json
import torch
import os
from typing import Optional

from flask import Flask, request, stream_with_context
from flask import render_template
from flask_cors import CORS
from accelerate import Accelerator
from dataclasses import dataclass, field
import torch
from transformers import HfArgumentParser
from typing import Optional

from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.args import ModelArguments
from lmflow.models.auto_model import AutoModel
from lmflow.args import ModelArguments, DatasetArguments, AutoArguments
from lmflow.utils.versioning import is_flask_available

if is_flask_available():
from flask import Flask, request, stream_with_context
from flask import render_template
from flask_cors import CORS
else:
raise ImportError("Flask is not available. Please install flask and flask_cors.")

WINDOW_LENGTH = 512

Expand Down
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,11 @@
],
requires_python=">=3.9",
)

# optionals
# lm-eval==0.3.0
# vllm>=0.4.3
# ray>=2.22.0
# flask
# flask_cors
# trl==0.8.0
68 changes: 19 additions & 49 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,13 @@
import logging
import os, shutil
from typing import List, Union, Optional, Dict
from pathlib import Path

import ray
import ray.data
import torch
import transformers
import bitsandbytes
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig
from transformers import BitsAndBytesConfig
from transformers import (
CONFIG_MAPPING,
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
)
from peft import (
LoraConfig,
PeftModel,
TaskType,
get_peft_config,
get_peft_model,
prepare_model_for_kbit_training
)
from vllm import SamplingParams
from peft import PeftModel

from lmflow.datasets.dataset import Dataset
from lmflow.models.hf_model_mixin import HFModelMixin
Expand All @@ -63,39 +45,23 @@
tokenize_function,
conversation_tokenize_function
)
from lmflow.utils.versioning import is_ray_available, is_vllm_available, is_flash_attn_available


logger = logging.getLogger(__name__)


MODELS_SUPPORT_FLASH_ATTENTION = [
"LlamaForCausalLM",
"GPTNeoForCausalLM",
"GPT2ForCausalLM",
"BloomForCausalLM"
]

GPU_SUPPORT_FLASH_ATTENTION = {
"A100": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A40": ["GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A6000": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
}

try:
if is_flash_attn_available():
import flash_attn
if int(flash_attn.__version__.split(".")[0]) == 2:
GPU_SUPPORT_FLASH_ATTENTION = {
"A100": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A40": ["LlamaForCausalLM","GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"],
"A6000": ["LlamaForCausalLM", "GPTNeoForCausalLM", "GPT2ForCausalLM", "BloomForCausalLM"]
}
except Exception as e:
if e.__class__ == ModuleNotFoundError:
logger.warning(
"flash_attn is not installed. Install flash_attn for better performance."
)
else:
logger.warning(f'An error occurred when importing flash_attn, flash attention is disabled: {e}')
else:
logger.warning("Consider install flash_attn for better performance.")

if is_vllm_available():
from vllm import SamplingParams

if is_ray_available():
import ray
import ray.data


class HFDecoderModel(DecoderModel, HFModelMixin, Tunable):
Expand Down Expand Up @@ -380,6 +346,8 @@ def inference(
)

if use_vllm:
if not is_vllm_available():
raise ImportError("vllm is not installed. Please install vllm to use VLLM inference.")
res = self.__vllm_inference(inputs, **kwargs)
else:
res = self.__inference(inputs, **kwargs)
Expand Down Expand Up @@ -493,7 +461,7 @@ def prepare_inputs_for_inference(
enable_distributed_inference: bool = False,
use_vllm: bool = False,
**kwargs,
) -> Union[List[str], ray.data.Dataset, Dict[str, torch.Tensor]]:
) -> Union[List[str], "ray.data.Dataset", Dict[str, torch.Tensor]]:
"""
Prepare inputs for inference.
Expand All @@ -514,6 +482,8 @@ def prepare_inputs_for_inference(
The prepared inputs for inference.
"""
if use_vllm:
if not is_ray_available() and enable_distributed_inference:
raise ImportError("ray is not installed. Please install ray to use distributed vllm inference.")
inference_inputs = self.__prepare_inputs_for_vllm_inference(
dataset=dataset,
apply_chat_template=apply_chat_template,
Expand All @@ -534,7 +504,7 @@ def __prepare_inputs_for_vllm_inference(
dataset: Dataset,
apply_chat_template: bool = True,
enable_distributed_inference: bool = False,
) -> Union[List[str], ray.data.Dataset]:
) -> Union[List[str], "ray.data.Dataset"]:
if dataset.get_type() == 'text_only':
if apply_chat_template:
dataset = dataset.map(
Expand Down Expand Up @@ -606,7 +576,7 @@ def preprocess_conversation(sample):

inference_inputs = [sentence for sentence in inference_inputs if len(sentence) > 0]

if enable_distributed_inference:
if enable_distributed_inference:
inference_inputs = ray.data.from_items(inference_inputs) # -> Dict[str, np.ndarray], {"item": array(['...', '...', '...'])}

return inference_inputs
Expand Down
7 changes: 6 additions & 1 deletion src/lmflow/pipeline/dpo_aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
from pathlib import Path
from typing import Dict, Optional

from trl import DPOTrainer
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import TrainingArguments

from lmflow.pipeline.base_aligner import BaseAligner
from lmflow.utils.versioning import is_trl_available

if is_trl_available():
from trl import DPOTrainer
else:
raise ImportError("Please install trl package to use dpo_aligner.py")


def get_paired_dataset(
Expand Down
7 changes: 6 additions & 1 deletion src/lmflow/pipeline/utils/dpov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from trl import DPOTrainer

from lmflow.pipeline.utils.dpov2_dataprocessor import PreferenceDataCollatorWithPadding
from lmflow.utils.versioning import is_trl_available

if is_trl_available():
from trl import DPOTrainer
else:
raise ImportError("Please install trl package to use dpo_aligner.py")


logger = logging.getLogger(__name__)
Expand Down
20 changes: 15 additions & 5 deletions src/lmflow/pipeline/vllm_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
from typing import List, Union, Optional, Dict, Any

import numpy as np
import ray
import ray.data
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoTokenizer
from vllm import SamplingParams, LLM

from lmflow.datasets import Dataset
from lmflow.pipeline.base_pipeline import BasePipeline
Expand All @@ -30,11 +26,25 @@
from lmflow.utils.common import make_shell_args_from_dataclass
from lmflow.utils.constants import RETURN_CODE_ERROR_BUFFER, MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE
from lmflow.utils.data_utils import VLLMInferenceResultWithInput
from lmflow.utils.versioning import is_vllm_available, is_ray_available


logger = logging.getLogger(__name__)


if is_vllm_available():
from vllm import SamplingParams, LLM
else:
raise ImportError("VLLM is not available, please install vllm.")

if is_ray_available():
import ray
import ray.data
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
else:
logger.warning("Ray is not available, distributed vllm inference will not be supported.")


class InferencerWithOffloading(BasePipeline):
def __init__(
self,
Expand Down Expand Up @@ -343,7 +353,7 @@ def inference(self) -> List[VLLMInferenceResultWithInput]:
# > at interpreter shutdown, possibly due to daemon threads
logger.warning(
"^^^^^^^^^^ Please ignore the above error, as it comes from the subprocess. "
"This may due a kill signal with unfinished stdout/stderr writing in the subprocess. "
"This may due to a kill signal with unfinished stdout/stderr writing in the subprocess. "
)
else:
if cli_res.returncode != 0:
Expand Down
Loading

0 comments on commit 60ee14b

Please sign in to comment.