From e09779383c9ea225e85abb0ff7cfb5b05b2e02b1 Mon Sep 17 00:00:00 2001 From: pppppM Date: Wed, 20 Mar 2024 04:39:14 +0800 Subject: [PATCH 1/6] hybrid data pipeline --- .../hybrid/function_call.json | 54 ++ .../internlm2_chat_1_8b_function_call.py | 204 ++++++++ .../hybrid/internlm2_chat_1_8b_llava_sft.py | 224 +++++++++ .../hybrid/multi_modal.json | 41 ++ xtuner/dataset/hybrid/__init__.py | 14 + xtuner/dataset/hybrid/_pack.py | 131 +++++ xtuner/dataset/hybrid/collate.py | 74 +++ xtuner/dataset/hybrid/dataset.py | 465 ++++++++++++++++++ xtuner/dataset/hybrid/hybrid.py | 68 +++ xtuner/dataset/hybrid/mappings.py | 172 +++++++ xtuner/model/__init__.py | 3 +- xtuner/model/hybrid.py | 191 +++++++ xtuner/model/utils.py | 118 ++++- xtuner/types/__init__.py | 6 + xtuner/types/chat.py | 145 ++++++ xtuner/types/chat_template.py | 183 +++++++ xtuner/types/train.py | 297 +++++++++++ xtuner/utils/__init__.py | 3 +- xtuner/utils/config.py | 131 +++++ xtuner/utils/tokenizer.py | 46 ++ 20 files changed, 2566 insertions(+), 4 deletions(-) create mode 100644 xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json create mode 100644 xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py create mode 100644 xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py create mode 100644 xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json create mode 100644 xtuner/dataset/hybrid/__init__.py create mode 100644 xtuner/dataset/hybrid/_pack.py create mode 100644 xtuner/dataset/hybrid/collate.py create mode 100644 xtuner/dataset/hybrid/dataset.py create mode 100644 xtuner/dataset/hybrid/hybrid.py create mode 100644 xtuner/dataset/hybrid/mappings.py create mode 100644 xtuner/model/hybrid.py create mode 100644 xtuner/types/__init__.py create mode 100644 xtuner/types/chat.py create mode 100644 xtuner/types/chat_template.py create mode 100644 xtuner/types/train.py create mode 100644 xtuner/utils/config.py create mode 100644 xtuner/utils/tokenizer.py diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json new file mode 100644 index 000000000..719c37766 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/function_call.json @@ -0,0 +1,54 @@ +[ + { + "messages": [ + { + "role": "user", + "content": "I want to know today's weather in Shanghai" + }, + + { + "role": "assistant", + "content": "Sure, I will search for the weather of Shanghai.", + "function_call": { + "name": "get_current_weather", + "parameters": { + "location": "Shanghai" + } + } + }, + + { + "role": "function", + "name": "get_current_weather", + "content": "{'temperature': 22}" + }, + { + "role": "assistant", + "content": "The weather in Shanghai is 22 celsius" + } + + + ], + + "functions": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + "unit": {"type": "string"} + }, + "required": ["location"] + } + } + } + ] + } + +] + + \ No newline at end of file diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py new file mode 100644 index 000000000..a6d2a8049 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py @@ -0,0 +1,204 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR + +from torch.optim import AdamW +from transformers import AutoModelForCausalLM, AutoTokenizer + + +from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn +from xtuner.dataset.hybrid.mappings import openai_to_raw_training +from xtuner.engine.hooks import DatasetInfoHook +from xtuner.engine.runner import TrainLoop +from xtuner.model import HybridFinetune +from xtuner.types import HybridChatTemplate + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f/' +visual_encoder_name_or_path = 'openai/clip-vit-large-patch14-336' +# Specify the pretrained pth +pretrained_pth = None +# Data +data_dir = './' +data_files = ['function_call.json'] +max_length = 2048 + +# Chat Template +chat_template = dict( + type=HybridChatTemplate, + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n') + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 1 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = '' +evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + + +model = dict( + type=HybridFinetune, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16)) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=HybridDataset, + data_dir=data_dir, + data_files=data_files, + sample_ratio=1, + tokenizer=tokenizer, + chat_template=chat_template, + max_length=max_length, + pack_to_max_length=True, + num_workers = dataloader_num_workers, + mappings=[openai_to_raw_training]) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=hybrid_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + # dict( + # type=EvaluateChatHook, + # tokenizer=tokenizer, + # image_processor=image_processor, + # every_n_iters=evaluation_freq, + # evaluation_inputs=evaluation_inputs, + # evaluation_images=evaluation_images, + # system=SYSTEM, + # prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py new file mode 100644 index 000000000..97bae7ac3 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py @@ -0,0 +1,224 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + CLIPImageProcessor, CLIPVisionModel) + +from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn +from xtuner.dataset.hybrid.mappings import (insert_img_pad_tokens, + llava_to_openai, + openai_to_raw_training) +from xtuner.engine.hooks import DatasetInfoHook +from xtuner.engine.runner import TrainLoop +from xtuner.model import HybridFinetune +from xtuner.types import HybridChatTemplate + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +llm_name_or_path = '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f/' +visual_encoder_name_or_path = 'openai/clip-vit-large-patch14-336' +# Specify the pretrained pth +pretrained_pth = None +# Data +data_dir = './llava_data/' +data_files = ['LLaVA-Instruct-150K/llava_v1_5_mix665k.json'] +image_dir = data_dir + 'llava_images' +max_length = 1024 * 32 + +# Chat Template +chat_template = dict( + type=HybridChatTemplate, + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n') + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 1 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = '' +evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg' +evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture'] + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + padding_side='right') + +image_processor = dict( + type=CLIPImageProcessor.from_pretrained, + pretrained_model_name_or_path=visual_encoder_name_or_path, + trust_remote_code=True) + +model = dict( + type=HybridFinetune, + freeze_llm=False, + freeze_visual_encoder=True, + pretrained_pth=pretrained_pth, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=llm_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16), + visual_encoder=dict( + type=CLIPVisionModel.from_pretrained, + pretrained_model_name_or_path=visual_encoder_name_or_path)) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=HybridDataset, + data_dir=data_dir, + data_files=data_files, + data_cached='cached_llava', + image_dir=image_dir, + sample_ratio=1, + tokenizer=tokenizer, + chat_template=chat_template, + image_processor=image_processor, + pad_img_to_squared=True, + max_length=max_length, + pack_to_max_length=True, + num_workers=dataloader_num_workers, + mappings=[ + llava_to_openai, + openai_to_raw_training, + insert_img_pad_tokens, + ]) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=hybrid_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + # dict( + # type=EvaluateChatHook, + # tokenizer=tokenizer, + # image_processor=image_processor, + # every_n_iters=evaluation_freq, + # evaluation_inputs=evaluation_inputs, + # evaluation_images=evaluation_images, + # system=SYSTEM, + # prompt_template=prompt_template) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json new file mode 100644 index 000000000..0b1576131 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json @@ -0,0 +1,41 @@ +[ + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": "image1.jpg" + }, + { + "type": "image_url", + "image_url": "image2.jpg" + }, + { + "type": "text", + "text": "What are the colors of the bus in the first image?" + } + ] + }, + + { + "role": "assistant", + "content": "The bus in the image is white and red." + }, + + { + "role": "user", + "content": "Where is the cat positioned in the second image?" + }, + + { + "role": "assistant", + "content": "The cat is positioned on top of the back of the couch in the living room." + } + + ] + } +] + + \ No newline at end of file diff --git a/xtuner/dataset/hybrid/__init__.py b/xtuner/dataset/hybrid/__init__.py new file mode 100644 index 000000000..febf1a497 --- /dev/null +++ b/xtuner/dataset/hybrid/__init__.py @@ -0,0 +1,14 @@ +from .collate import hybrid_collate_fn +from .dataset import HybridDataset +from .mappings import (insert_img_pad_tokens, llava_to_openai, map_protocol, + map_sequential, openai_to_raw_training) + +__all__ = [ + 'hybrid_collate_fn', + 'HybridDataset', + 'insert_img_pad_tokens', + 'llava_to_openai', + 'map_protocol', + 'map_sequential', + 'openai_to_raw_training', +] diff --git a/xtuner/dataset/hybrid/_pack.py b/xtuner/dataset/hybrid/_pack.py new file mode 100644 index 000000000..946020909 --- /dev/null +++ b/xtuner/dataset/hybrid/_pack.py @@ -0,0 +1,131 @@ +import bisect +import itertools +import random + +import torch + + +class _PackDataset(torch.utils.data.Dataset): + + def __init__(self, dataset, max_length=2048): + super().__init__() + + self.max_length = max_length + + # unpack dataset + self.dataset = dataset + + self._ori_img_urls = dataset['image_urls'] + self._ori_img_rngs = dataset['image_ranges'] + self._ori_lens = dataset['tokens'] + + self._num_packed_samples = sum(self._ori_lens) // self.max_length + + inds = [i for i in range(len(self.dataset))] + random.shuffle(inds) + self.shfl_inds = inds + + shfl_lens = [self._ori_lens[i] for i in inds] + # shuffled cumulative lengths + shfl_acc_lens = list(itertools.accumulate(shfl_lens)) + + self._shfl_item_rngs_left = [0] + shfl_acc_lens[:-1] + self._shfl_item_rngs_right = shfl_acc_lens + + shfl_img_urls = [self._ori_img_urls[i] for i in inds] + self._flat_shfl_img_urls = list(itertools.chain(*shfl_img_urls)) + + flat_shfl_acc_img_rngs = [] + flat_shfl_acc_img_rngs_left = [] + flat_shfl_acc_img_rngs_right = [] + for i in range(len(self.dataset)): + shfl_i = self.shfl_inds[i] + img_rngs = self._ori_img_rngs[shfl_i] + for left, right in img_rngs: + acc_left = left + self._shfl_item_rngs_left[i] + acc_right = right + self._shfl_item_rngs_left[i] + + flat_shfl_acc_img_rngs_left.append(acc_left) + flat_shfl_acc_img_rngs_right.append(acc_right) + flat_shfl_acc_img_rngs.append([acc_left, acc_right]) + assert len(flat_shfl_acc_img_rngs) == len(self._flat_shfl_img_urls) + + self._flat_shfl_acc_img_rngs = flat_shfl_acc_img_rngs + self._flat_shfl_acc_img_rngs_left = flat_shfl_acc_img_rngs_left + self._flat_shfl_acc_img_rngs_right = flat_shfl_acc_img_rngs_right + + def _pack_img_urls_and_rngs_in_range(self, begin, end): + + left = bisect.bisect(self._flat_shfl_acc_img_rngs_right, begin) + right = bisect.bisect(self._flat_shfl_acc_img_rngs_left, end) + + filter_urls = self._flat_shfl_img_urls[left:right] + filter_rngs = self._flat_shfl_acc_img_rngs[left:right] + + inner_rngs = [] + for rng in filter_rngs: + inner_left = max(begin, rng[0]) - begin + inner_right = min(end, rng[1]) - begin + + if inner_right - inner_left > 0: + inner_rngs.append([inner_left, inner_right]) + + return filter_urls, inner_rngs + + def _pack_ids_and_labels_in_range(self, begin, end): + + left = bisect.bisect(self._shfl_item_rngs_right, begin) + right = bisect.bisect(self._shfl_item_rngs_left, end) + + trunc_ids = [] + trunc_labels = [] + cumulative_len = [] + position_ids = [] + for i in range(left, right): + cumulative_len.append(len(trunc_ids)) + + item_begin = self._shfl_item_rngs_left[i] + item_end = self._shfl_item_rngs_right[i] + + inner_l = max(begin, item_begin) - item_begin + inner_r = min(end, item_end) - item_begin + position_ids.extend([i for i in range(inner_r - inner_l)]) + + ori_idx = self.shfl_inds[i] + ori_input_ids = self.dataset[ori_idx]['input_ids'] + ori_labels = self.dataset[ori_idx]['labels'] + + trunc_ids.extend(ori_input_ids[inner_l:inner_r]) + trunc_labels.extend(ori_labels[inner_l:inner_r]) + + return trunc_ids, trunc_labels, cumulative_len, position_ids + + def __len__(self): + return self._num_packed_samples + + def __getitem__(self, item): + + begin = item * self.max_length + end = (item + 1) * self.max_length + + _res = self._pack_ids_and_labels_in_range(begin, end) + packed_ids, packed_labels, cumulative_len, position_ids = _res + assert self.max_length == len(packed_ids) == len(packed_labels) + + _res = self._pack_img_urls_and_rngs_in_range(begin, end) + packed_img_urls, packed_img_rngs = _res + + for left, right in packed_img_rngs: + assert len(set(packed_ids[left:right])) == 1 + + packed = { + 'input_ids': packed_ids, + 'labels': packed_labels, + 'tokens': self.max_length, + 'image_urls': packed_img_urls, + 'image_ranges': packed_img_rngs, + 'cumulative_len': cumulative_len, + 'position_ids': position_ids + } + + return packed diff --git a/xtuner/dataset/hybrid/collate.py b/xtuner/dataset/hybrid/collate.py new file mode 100644 index 000000000..925b9ac01 --- /dev/null +++ b/xtuner/dataset/hybrid/collate.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Sequence + +import torch +from torch.nn.utils.rnn import pad_sequence + +from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX + + +def hybrid_collate_fn(instances: Sequence[Dict], + pad_index: int = DEFAULT_PAD_TOKEN_INDEX, + return_hf_format: bool = False): + + input_ids = [] + labels = [] + pixel_values = [] + cumulative_len = [] + image_ranges = [] + image_belong = [] + position_ids = [] + + for i, data in enumerate(instances): + input_ids.append(torch.LongTensor(data['input_ids'])) + labels.append(torch.LongTensor(data['labels'])) + position_ids.append(torch.IntTensor(data['position_ids'])) + + if 'cumulative_len' in data: + cumulative_len.append(torch.IntTensor(data['cumulative_len'])) + + image_belong.append(i) + pixel_values.extend(data['pixel_values']) + image_ranges.extend(torch.IntTensor(data['image_ranges'])) + + if len(pixel_values) > 0: + assert len(image_ranges) > 0 + assert len(image_belong) > 0 + + pixel_values = torch.stack(pixel_values) + image_ranges = torch.stack(image_ranges) + image_belong = torch.IntTensor(image_belong) + else: + pixel_values = None + image_ranges = None + image_belong = None + + if len(instances) > 1: + input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=pad_index) + labels = pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX) + position_ids = pad_sequence(labels, batch_first=True, padding_value=0) + else: + input_ids = torch.stack(input_ids) + labels = torch.stack(labels) + position_ids = torch.stack(position_ids) + + if len(cumulative_len) == 0: + cumulative_len = None + + data_dict = { + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': input_ids.ne(pad_index), + 'labels': labels, + 'pixel_values': pixel_values, + 'cumulative_len': cumulative_len, + 'image_ranges': image_ranges, + 'image_belong': image_belong + } + + if return_hf_format: + return data_dict + else: + return {'data': data_dict, 'data_samples': None} diff --git a/xtuner/dataset/hybrid/dataset.py b/xtuner/dataset/hybrid/dataset.py new file mode 100644 index 000000000..69327a4e3 --- /dev/null +++ b/xtuner/dataset/hybrid/dataset.py @@ -0,0 +1,465 @@ +import json +import os +import random +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from functools import partial +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import torch +from datasets import Dataset, load_from_disk +from mmengine import print_log +from PIL import Image +from torch import distributed as dist +from torch import nn +from tqdm import tqdm + +from xtuner.dataset.hybrid._pack import _PackDataset +from xtuner.dataset.hybrid.mappings import map_protocol, map_sequential +from xtuner.dataset.utils import expand2square +from xtuner.registry import BUILDER +from xtuner.types import HybridChatTemplate +from xtuner.utils import build_tokenizer + +os.environ['TOKENIZERS_PARALLELISM'] = 'true' + + +@map_protocol( + input_keys=dict(input_ids=list), + added_keys=dict(tokens=int), +) +def _register_tokens(data, tokenizer=None, chat_template=None): + data['tokens'] = len(data['input_ids']) + return data + + +@map_protocol( + input_keys=dict(input_ids=list), + added_keys=dict(cumulative_len=list), +) +def _register_cumulative_len(data, tokenizer=None, chat_template=None): + data['cumulative_len'] = [0, len(data['input_ids'])] + return data + + +@map_protocol( + input_keys=dict(input_ids=list), + added_keys=dict(position_ids=list), +) +def _register_position_ids(data, tokenizer=None, chat_template=None): + data['position_ids'] = [i for i in range(len(data['input_ids']))] + return data + + +@map_protocol( + added_keys=dict(image_ranges=list), ) +def _register_empty_img_ranges(data, tokenizer=None, chat_template=None): + if 'image_ranges' not in data: + data['image_ranges'] = [] + return data + + +@map_protocol( + input_keys=dict( + input_ids=list, + labels=list, + tokens=int, + image_urls=list, + image_ranges=list, + position_ids=list, + cumulative_len=list), + output_keys=dict( + input_ids=list, + labels=list, + tokens=int, + image_urls=list, + image_ranges=list, + position_ids=list, + cumulative_len=list)) +def _check_mapped_data(item, tokenizer=None, chat_template=None): + assert isinstance(item['input_ids'][0], int) + assert isinstance(item['labels'][0], int) + + if len(item['image_urls']) > 0: + assert isinstance(item['image_urls'][0], str) + + if len(item['image_ranges']) > 0: + assert isinstance(item['image_ranges'][0], list) + assert isinstance(item['image_ranges'][0][0], int) + + return item + + +class HybridDataset(torch.utils.data.Dataset): + """ + Args: + tokenizer: The tokenizer processes some raw text as input and outputs + an Encoding. + max_length: Max length of the sequence. + pack_to_max_length: Whether to pack the dataset to the `max_length `. + This usually improves gpu utilization and therefore reduces + training time. + shuffle_before_pack: Whether to shuffle the dataset before + packing them. + use_varlen_attn: If use_varlen_attn is True, we calculate attention + the actual length of the sequence rather than the actual length + of the sequence + """ + + def __init__(self, + tokenizer, + chat_template: Union[Dict, HybridChatTemplate], + sample_ratio: int = 1.0, + max_length: int = 2048, + pack_to_max_length: bool = False, + num_workers: int = 8, + mappings: Union[Callable, List[Callable]] = [], + data_dir: Optional[str] = None, + data_files: Optional[Union[str, List[str]]] = None, + data_cached: Optional[str] = None, + image_dir: Optional[str] = None, + image_processor: Optional[nn.Module] = None, + pad_img_to_squared: bool = True): + super().__init__() + + assert data_dir or data_files or data_cached + + self.tokenizer = build_tokenizer(tokenizer) + + if isinstance(chat_template, HybridChatTemplate): + self.chat_template = chat_template + elif isinstance(chat_template, dict): + self.chat_template = BUILDER.build(chat_template) + else: + raise TypeError + + if isinstance(image_processor, dict): + image_processor = BUILDER.build(image_processor) + self.image_processor = image_processor + + if image_dir: + self.image_dir = Path(image_dir) + else: + self.image_dir = Path('') + + self.pad_img_to_squared = pad_img_to_squared + + self.sample_ratio = sample_ratio + self.max_length = max_length + self.pack_to_max_length = pack_to_max_length + + mappings.append(_register_cumulative_len) + mappings.append(_register_position_ids) + mappings.append(_register_tokens) + mappings.append(_register_empty_img_ranges) + mappings.append(_check_mapped_data) + map_fn = map_sequential(mappings) + self.map_fn = partial( + map_fn, tokenizer=self.tokenizer, chat_template=self.chat_template) + + self.num_workers = num_workers + if data_cached: + self.data_dir = data_dir + self.data_files = data_files + self.data_cached = data_cached + else: + data_dir = Path(data_dir) + if data_files is None: + data_files = [str(f) for f in data_dir.rglob('*.json')] + elif isinstance(data_files, list): + data_files = [str(data_dir / Path(f)) for f in data_files] + elif isinstance(data_files, str): + data_files = [str(data_dir / data_files)] + else: + raise TypeError + + self.data_dir = str(data_dir) + self.data_files = data_files + self.data_cached = data_cached + + self.dataset = self.build_dataset() + + def build_dataset(self): + + if not (dist.is_available() and dist.is_initialized()): + return self._build_dataset() + + timeout = timedelta( + minutes=int(os.getenv('XTUNER_DATASET_TIMEOUT', default=30))) + print_log(f'xtuner_dataset_timeout = {timeout}', logger='current') + + gloo_group = dist.new_group(backend='gloo', timeout=timeout) + + if dist.get_rank() == 0: + dataset = self._build_dataset() + objects = [dataset] + else: + objects = [None] + + dist.monitored_barrier(group=gloo_group, timeout=timeout) + dist.broadcast_object_list(objects, src=0) + + return objects[0] + + def _build_dataset(self): + + if self.data_cached: + dataset = load_from_disk(self.data_cached) + if self.pack_to_max_length: + dataset = self._pack_dataset(dataset) + return dataset + + dataset = [] + for file in self.data_files: + dataset.extend(json.load(open(file))) + print_log(f'Loaded json data from {file}', logger='current') + + if self.sample_ratio < 1: + num_samples = int(self.sample_ratio * len(dataset)) + dataset = random.sample(dataset, num_samples) + print_log( + f'Randomly selected {num_samples} samples', logger='current') + + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + dataset = list( + tqdm( + executor.map(self.map_fn, dataset), + desc='Map Dataset', + total=len(dataset))) + + dataset = self.filter_non_labels_data(dataset) + + self.analysis_tokens_labels(dataset) + self.analysis_image_samples(dataset) + + dataset = Dataset.from_list(dataset) + + if self.pack_to_max_length: + dataset = self._pack_dataset(dataset) + + return dataset + + def _pack_dataset(self, dataset): + + unpacked_samples = len(dataset) + dataset = _PackDataset(dataset, self.max_length) + packed_samples = len(dataset) + print_log( + 'Before pack multi samples to max length: ' + f'{unpacked_samples} samples', + logger='current') + print_log( + 'After pack multi samples to max length: ' + f'{packed_samples} samples', + logger='current') + return dataset + + def filter_non_labels_data(self, dataset): + + filter_fn = lambda item: any(item['labels'][i] >= 0 for i in range( + self.max_length)) # noqa: E501, E731 + + ori_samples = len(dataset) + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + results = list( + tqdm( + executor.map(filter_fn, dataset), + desc='Filter Dataset', + total=len(dataset))) + + new_dataset = [x for x, passed in zip(dataset, results) if passed] + + new_samples = len(new_dataset) + print_log(f'Before filter: {ori_samples} samples', logger='current') + print_log(f'After filter: {new_samples} samples', logger='current') + print_log( + f'Filtered {ori_samples - new_samples} samples ' + '(all labels are ignore)', + logger='current') + return new_dataset + + def analysis_image_samples(self, dataset): + + img_sample_counter = lambda item: len(item['image_urls'] + ) > 0 # noqa: E501, E731 + img_counter = lambda item: len(item['image_urls']) # noqa: E501, E731 + + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + images = list( + tqdm( + executor.map(img_counter, dataset), + desc='Count Images', + total=len(dataset))) + + samples = list( + tqdm( + executor.map(img_sample_counter, dataset), + desc='Count Contain Image Samples', + total=len(dataset))) + + num_images = sum(images) + num_samples = sum(samples) + print_log( + f'There are a total of {num_samples} samples with images, ' + f'amounting to {num_images} images.', + logger='current') + + def analysis_tokens_labels(self, dataset): + + label_counter = lambda item: sum([1 for i in item['labels'] + if i >= 0]) # noqa: E501, E731 + token_counter = lambda item: len(item['input_ids']) + + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + tokens = list( + tqdm( + executor.map(token_counter, dataset), + desc='Count Tokens', + total=len(dataset))) + + labels = list( + tqdm( + executor.map(label_counter, dataset), + desc='Count Labels', + total=len(dataset))) + + num_tokens = sum(tokens) + num_labels = sum(labels) + print_log( + f'There are a total of {num_tokens} tokens, ' + f'of which {num_labels} tokens need loss calculation.', + logger='current') + + def cache(self, cache_dir: str): + cache_dir = Path(cache_dir) + + if self.pack_to_max_length: + hf_dataset = Dataset.from_list(self.dataset.dataset) + else: + hf_dataset = Dataset.from_list(self.dataset) + + hf_dataset.save_to_disk(cache_dir) + + dset_conf = { + 'image_dir': str(self.image_dir), + 'data_files': self.data_files, + 'max_length': self.max_length, + 'chat_template': self.chat_template.model_dump(), + 'pack_to_max_length': self.pack_to_max_length, + 'tokenizer': type(self.tokenizer).__name__, + } + + with open(cache_dir / 'dataset_configuration.json', 'w') as f: + json.dump(dset_conf, f) + + self.tokenizer.save_pretrained(cache_dir / 'tokenizer') + self.image_processor.save_pretrained(cache_dir / 'image_processor') + + def load_image(self, url): + image_file = self.image_dir / url + image = Image.open(image_file).convert('RGB') + + if self.pad_img_to_squared: + background = tuple( + int(x * 255) for x in self.image_processor.image_mean) + image = expand2square(image, background) + + image = self.image_processor.preprocess( + image, return_tensors='pt')['pixel_values'][0] + + return image + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item: int) -> Dict[str, List]: + + data = self.dataset[item] + + pixel_values = [] + for url in data['image_urls']: + image = self.load_image(url) + + pixel_values.append(image) + + data['pixel_values'] = pixel_values + + return data + + +if __name__ == '__main__': + + from transformers import CLIPImageProcessor + + chat_template = HybridChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n' + ) + + processor = CLIPImageProcessor.from_pretrained( + 'openai/clip-vit-large-patch14-336', + trust_remote_code=True, + ) + + from xtuner.dataset.hybrid.mappings import ( + insert_img_pad_tokens, llava_to_openai, openai_to_raw_training) + + data_dir = './llava_data/LLaVA-Instruct-150K/' + image_dir = './llava_data/llava_images/' + data_files = 'llava_v1_5_mix665k.json' + + dataset = HybridDataset( + 'internlm/internlm2-chat-1_8b', + chat_template, + sample_ratio=1, + max_length=32*1024, + data_dir=data_dir, + data_files=data_files, + image_dir=image_dir, + image_processor=processor, + pack_to_max_length=True, + mappings=[ + llava_to_openai, openai_to_raw_training, insert_img_pad_tokens, + ], + num_workers=4) + + print(dataset[0]) + + dataset.cache('cached_llava') + dataset = HybridDataset( + 'internlm/internlm2-chat-1_8b', + chat_template, + sample_ratio=1, + max_length=32*1024, + data_cached='cached_llava', + image_dir=image_dir, + image_processor=processor, + pack_to_max_length=True, + mappings=[ + llava_to_openai, openai_to_raw_training, insert_img_pad_tokens, + ], + num_workers=4) + print(dataset[0]) + + from mmengine.dataset import DefaultSampler + from torch.utils.data import DataLoader + + from xtuner.dataset.hybrid.collate import hybrid_collate_fn + loader = DataLoader( + dataset, + 4, + num_workers=0, + collate_fn=hybrid_collate_fn, + sampler=DefaultSampler(dataset, shuffle=True)) + + for data in tqdm(loader): + continue diff --git a/xtuner/dataset/hybrid/hybrid.py b/xtuner/dataset/hybrid/hybrid.py new file mode 100644 index 000000000..289d21de2 --- /dev/null +++ b/xtuner/dataset/hybrid/hybrid.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Sequence + +import torch +from torch.nn.utils.rnn import pad_sequence + +from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX +from xtuner.types import RawTrainingData + + +def hybrid_collate_fn(instances: Sequence[Dict], + pad_index: int = DEFAULT_PAD_TOKEN_INDEX, + return_hf_format: bool = False): + + input_ids = [] + labels = [] + pixel_values = [] + cumulative_len = [] + image_ranges = [] + # indexes = [] + + + for item in instances: + input_ids.append(torch.LongTensor(item['input_ids'])) + labels.append(torch.LongTensor(item['labels'])) + + if 'cumulative_len' in item: + cumulative_len.append(torch.IntTensor(item['cumulative_len'])) + + pixel_values.extend(item['pixel_values']) + # image_ranges.extend(torch.IntTensor(item['image_ranges'])) + + if len(pixel_values) > 0: + pixel_values = torch.stack(pixel_values) + else: + pixel_values = None + + if len(instances) > 1: + input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=pad_index) + labels = pad_sequence( + labels, batch_first=True, padding_value=IGNORE_INDEX) + else: + input_ids = torch.stack(input_ids) + labels = torch.stack(labels) + + # if len(image_ranges) > 0: + # image_ranges = torch.stack(image_ranges) + # else: + # image_ranges = None + + if len(cumulative_len) == 0: + cumulative_len = None + + data_dict = { + 'input_ids': input_ids, + 'attention_mask': input_ids.ne(pad_index), + 'labels': labels, + 'pixel_values': pixel_values, + 'cumulative_len': cumulative_len, + # 'image_ranges': image_ranges, + } + + + if return_hf_format: + return data_dict + else: + return {'data': data_dict, 'data_samples': None} diff --git a/xtuner/dataset/hybrid/mappings.py b/xtuner/dataset/hybrid/mappings.py new file mode 100644 index 000000000..e104885ff --- /dev/null +++ b/xtuner/dataset/hybrid/mappings.py @@ -0,0 +1,172 @@ +import re +from typing import Callable, Dict, List, Type + +from mmengine.config.lazy import LazyObject + +from xtuner.types import TrainingHybridChatMessages + + +def map_protocol( + input_keys: Dict[str, Type] = {}, + output_keys: Dict[str, Type] = {}, + added_keys: Dict[str, Type] = {}, +) -> Callable: + + def decorator(func): + + def wrapper(data, *args, **kwargs): + + for key, _type in input_keys.items(): + assert key in data + if not isinstance(data[key], _type): + breakpoint() + + result = func(data, *args, **kwargs) + + for key, _type in output_keys.items(): + assert key in result + assert isinstance(result[key], _type) + + return result + + return wrapper + + setattr(decorator, 'input_keys', input_keys) + setattr(decorator, 'output_keys', output_keys) + setattr(decorator, 'added_keys', added_keys) + + return decorator + + +def map_sequential(mappings: List[Callable]): + + if not isinstance(mappings, List): + mappings = list(mappings) + + for i in range(len(mappings)): + if isinstance(mappings[i], LazyObject): + mappings[i] = mappings[i].build() + + def _sequential(item, tokenizer, chat_template): + + for func in mappings: + item = func(item, tokenizer, chat_template) + + return item + + return _sequential + + +@map_protocol( + input_keys=dict(input_ids=list, labels=list, image_urls=list), + output_keys=dict( + input_ids=list, labels=list, image_urls=list, image_ranges=list), +) +def insert_img_pad_tokens(data, tokenizer, chat_template) -> Dict: + + image_urls = data['image_urls'] + if len(image_urls) == 0: + data['image_ranges'] = [] + return data + + input_ids = data['input_ids'] + labels = data['labels'] + + img_token = chat_template.image_token_index + img_token_inds = [i for i, t in enumerate(input_ids) if t == img_token] + assert len(img_token_inds) == len( + image_urls), f'{img_token_inds} {image_urls}' + + for url, ind in zip(image_urls, img_token_inds): + # image = self.load_image(url) + h, w = 336 // 14, 336 // 14 + + pad_tokens = [tokenizer.pad_token_id] * (h * w) + pad_labels = [labels[ind]] * (h * w) + + input_ids[ind] = pad_tokens + labels[ind] = pad_labels + + new_ids = [] + new_labels = [] + assert len(input_ids) == len(labels) + + img_ranges = [] + for i, _ in enumerate(zip(input_ids, labels)): + if isinstance(input_ids[i], list): + assert isinstance(labels[i], list) + assert len(input_ids[i]) == len(labels[i]) + + img_begin = len(new_ids) + img_end = img_begin + len(input_ids[i]) + img_ranges.append([img_begin, img_end]) + + new_ids.extend(input_ids[i]) + new_labels.extend(labels[i]) + + else: + new_ids.append(input_ids[i]) + new_labels.append(labels[i]) + + data['input_ids'] = new_ids + data['labels'] = new_labels + data['image_ranges'] = img_ranges + + return data + + +@map_protocol( + input_keys=dict(messages=list), + output_keys=dict(input_ids=list, labels=list, image_urls=list), +) +def openai_to_raw_training(item: dict, tokenizer, chat_template) -> Dict: + + data = TrainingHybridChatMessages.from_dict(item) + data = data.tokenize(tokenizer, chat_template) + + return data + + +@map_protocol( + input_keys=dict(conversations=list), + output_keys=dict(messages=list), +) +def llava_to_openai(data, tokenizer=None, chat_template=None): + + image_token = '' + conversations = data['conversations'] + messages = [] + + if 'image' in data: + image_url = data['image'] + else: + image_url = None + + while conversations and conversations[0]['from'] == 'gpt': + # Skip the first one if it is from gpt + conversations = conversations[1:] + + for convs in conversations: + if convs['from'] == 'human': + pattern = f'({image_token})' + chunks = re.split(pattern, convs['value']) + + content = [] + for chunk in chunks: + if chunk == image_token: + assert isinstance(image_url, str), image_url + item = dict(type='image_url', image_url=image_url) + content.append(item) + elif len(chunk.strip()): + item = dict(type='text', text=chunk.strip()) + content.append(item) + + msg = {'role': 'user', 'content': content} + messages.append(msg) + + elif convs['from'] == 'gpt': + msg = {'role': 'assistant', 'content': convs['value']} + messages.append(msg) + else: + raise NotImplementedError + return {'messages': messages} diff --git a/xtuner/model/__init__.py b/xtuner/model/__init__.py index 39547b2d7..c2e45a89d 100644 --- a/xtuner/model/__init__.py +++ b/xtuner/model/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .hybrid import HybridFinetune from .llava import LLaVAModel from .sft import SupervisedFinetune -__all__ = ['SupervisedFinetune', 'LLaVAModel'] +__all__ = ['HybridFinetune', 'SupervisedFinetune', 'LLaVAModel'] diff --git a/xtuner/model/hybrid.py b/xtuner/model/hybrid.py new file mode 100644 index 000000000..40db93820 --- /dev/null +++ b/xtuner/model/hybrid.py @@ -0,0 +1,191 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import torch +from mmengine.model import BaseModel +from peft import LoraConfig +from torch import nn + +from xtuner.registry import BUILDER +from xtuner.utils.config import build_from_cfg_or_obj +from .modules import ProjectorConfig, ProjectorModel, dispatch_modules +from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing, + get_peft_model_state_dict, prepare_for_llm_lora, + prepare_for_vision_lora, + smart_tokenizer_and_embedding_resize) + + +class HybridFinetune(BaseModel): + + def __init__( + self, + llm, + visual_encoder=None, + visual_select_layer=-2, + projector_depth=2, + pretrained_pth=None, + tokenizer=None, + llm_lora=None, + visual_encoder_lora=None, + freeze_llm=False, + freeze_visual_encoder=False, + use_activation_checkpointing=True, + use_varlen_attn=False, + ): + super().__init__() + + # Build the base language model without initialization. + # This will greatly reduce the time to build the model. + with LoadWoInit(): + self.llm = build_from_cfg_or_obj(llm, nn.Module) + if visual_encoder: + visual_encoder = build_from_cfg_or_obj(visual_encoder, + nn.Module) + self.visual_encoder = visual_encoder + self.visual_select_layer = visual_select_layer + self.llm.config.use_cache = False + dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn) + + if tokenizer is not None: + if isinstance(tokenizer, dict): + tokenizer = BUILDER.build(tokenizer) + smart_tokenizer_and_embedding_resize(tokenizer, self.llm) + + projector_config = ProjectorConfig( + visual_hidden_size=self.visual_encoder.config.hidden_size, + llm_hidden_size=self.llm.config.hidden_size, + depth=projector_depth) + self.projector = ProjectorModel(projector_config).to( + self.visual_encoder.dtype) + + self.freeze_llm = freeze_llm + self.freeze_visual_encoder = freeze_visual_encoder + if self.freeze_llm: + self.llm.requires_grad_(False) + if self.freeze_visual_encoder: + self.visual_encoder.requires_grad_(False) + + if use_activation_checkpointing: + # For backward compatibility + enable_hf_model_gradient_checkpointing(self.llm) + enable_hf_model_gradient_checkpointing(self.visual_encoder) + + self.projector.enable_input_require_grads() + self.projector.gradient_checkpointing_enable() + + self.use_llm_lora = llm_lora is not None + self.use_visual_encoder_lora = visual_encoder_lora is not None + + # Prepare the model for LoRA if specified + if self.use_llm_lora: + lora_conf = build_from_cfg_or_obj(llm_lora, accept=LoraConfig) + self.llm = prepare_for_llm_lora(self.llm, lora_conf, + use_activation_checkpointing) + + if self.use_visual_encoder_lora: + lora_conf = build_from_cfg_or_obj( + visual_encoder_lora, accept=LoraConfig) + self.visual_encoder = prepare_for_vision_lora( + self.visual_encoder, lora_conf, use_activation_checkpointing) + self._is_init = True + + # Determines whether to calculate attention based on the + # seq_len dimension (use_varlen_attn = False) or the actual length of + # the sequence. + self.use_varlen_attn = use_varlen_attn + + def init_weights(self): + """Parent class method. + + To avoid overwriting the loaded weights, overload it to an empty + function. + """ + pass + + def forward(self, data, data_samples=None, mode='loss'): + """Overload parent class method, only support training.""" + + if mode == 'loss': + return self.compute_loss(data, data_samples) + else: + raise NotImplementedError( + f"{type(self)}'s forward is only supported for use during " + 'training. If you want to get predictions or chat, please ' + "directly use `llm`'s forward.") + + def compute_loss(self, data, data_samples=None): + + input_ids = data['input_ids'] + labels = data['labels'] + position_ids = data['position_ids'] + attention_mask = data['attention_mask'] + pixel_values = data['pixel_values'] + img_rngs = data['image_ranges'] + img_belong = data['image_belong'] + + input_embeds = self.llm.get_input_embeddings()(input_ids) + + if pixel_values is not None: + visual_outputs = self.visual_encoder( + pixel_values, output_hidden_states=True) + img_embeds = self.projector( + visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) + + empty_embs = torch.zeros_like(input_embeds) + for emb, rng, b_id in zip(img_embeds, img_rngs, img_belong): + left, right = rng + if emb.size(0) == right - left: + empty_embs[b_id, left:right, :] = emb + elif not emb.size(0) == right - left and left == 0: + empty_embs[b_id, left:right, :] = emb[-right:] + elif not emb.size( + 0) == right - left and right == empty_embs.size(1): + empty_embs[b_id, left:right, :] = emb[:right - left] + else: + breakpoint() + + non_img_mask = (empty_embs == 0) + input_embeds = input_embeds * non_img_mask + empty_embs + + outputs = self.llm( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=input_embeds, + labels=labels) + + loss_dict = {'loss': outputs.loss} + return loss_dict + + def state_dict(self, *args, **kwargs): + state_dict = super().state_dict(*args, **kwargs) + to_return = OrderedDict() + # Step 1. visual_encoder + if self.use_visual_encoder_lora: + to_return.update( + get_peft_model_state_dict( + self.visual_encoder, state_dict=state_dict)) + elif not self.freeze_visual_encoder: + to_return.update({ + k: v + for k, v in state_dict.items() if 'visual_encoder.' in k + }) + # Step 2. LLM + if self.use_llm_lora: + to_return.update( + get_peft_model_state_dict(self.llm, state_dict=state_dict)) + elif not self.freeze_llm: + to_return.update( + {k: v + for k, v in state_dict.items() if 'llm.' in k}) + # Step 3. Projector + to_return.update( + {k: v + for k, v in state_dict.items() if 'projector.' in k}) + return to_return + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.llm, name) diff --git a/xtuner/model/utils.py b/xtuner/model/utils.py index dce86315d..0e5dfb826 100644 --- a/xtuner/model/utils.py +++ b/xtuner/model/utils.py @@ -1,13 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +from contextlib import nullcontext from typing import List, Optional import torch from mmengine import print_log from mmengine.utils.misc import get_object_from_string -from peft import PeftType +from peft import (LoraConfig, PeftModel, PeftType, get_peft_model, + prepare_model_for_kbit_training) from torch import nn -from transformers import PreTrainedModel +from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers.integrations import is_deepspeed_zero3_enabled from xtuner.utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX @@ -50,6 +53,34 @@ def find_all_linear_names(model): return list(lora_module_names) +def collect_linear_suffix_names(model: torch.nn.Module, + exclude: list[str] = []) -> list[str]: + """Collect suffix names of nn.Linear modules from a PyTorch model. + + Args: + model: The PyTorch model. + exclude: A list of keys to be excluded from the collected + suffix names. Default: ['lm_head', 'output_layer']. + + Returns: + A list of collected suffix names after excluding specified keys. + """ + suffix_names = set() + + # Iterate through all named modules in the model + for name, module in model.named_modules(): + # Check if the module is an instance of nn.Linear + if isinstance(module, torch.nn.Linear): + names = name.split('.') + suffix_names.add(names[0] if len(names) == 1 else names[-1]) + + # Remove exclude_keys from the collected suffix_names + for key in exclude: + suffix_names.remove(key) + + return list(suffix_names) + + class LoadWoInit: """Context manager that disable parameter initialization.""" @@ -286,6 +317,73 @@ def make_inputs_require_grad(module, input, output): output.requires_grad_(True) +def prepare_for_llm_lora(model: PreTrainedModel, + lora_config: LoraConfig, + gradient_checkpointing: bool = True) -> PeftModel: + model = prepare_model_for_kbit_training(model, gradient_checkpointing) + if lora_config.target_modules is None: + modules = collect_linear_suffix_names(model, exclude=['output']) + lora_config.target_modules = modules + + model = get_peft_model(model, lora_config) + return model + + +def prepare_for_vision_lora(model: PreTrainedModel, + lora_config: LoraConfig, + gradient_checkpointing: bool = True) -> PeftModel: + + if lora_config.target_modules is None: + modules = collect_linear_suffix_names(model) + lora_config.target_modules = modules + + model = get_peft_model(model, lora_config) + return model + + +def smart_tokenizer_and_embedding_resize( + tokenizer: PreTrainedTokenizer, + model: PreTrainedModel, +): + """Resize embedding.""" + if is_deepspeed_zero3_enabled(): + import deepspeed + + params = [model.get_input_embeddings().weight] + if model.get_output_embeddings( + ) is not None and not model.config.tie_word_embeddings: + params.append(model.get_output_embeddings().weight) + + context_maybe_zero3 = deepspeed.zero.GatheredParameters( + params, modifier_rank=0) + else: + context_maybe_zero3 = nullcontext() + + with context_maybe_zero3: + current_embedding_size = model.get_input_embeddings().weight.size(0) + + if len(tokenizer) > current_embedding_size: + assert isinstance(model.get_output_embeddings(), nn.Linear) + + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) + with context_maybe_zero3: + num_new_tokens = len(tokenizer) - current_embedding_size + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + print_log( + f'Resized token embeddings from {current_embedding_size} to ' + f'{len(tokenizer)}.', 'current') + + def guess_load_checkpoint(pth_model): if osp.isfile(pth_model): state_dict = torch.load(pth_model, map_location='cpu') @@ -307,3 +405,19 @@ def guess_load_checkpoint(pth_model): else: raise FileNotFoundError(f'Cannot find {pth_model}') return state_dict + + +def enable_hf_model_gradient_checkpointing(model: PreTrainedModel) -> None: + # For backward compatibility + if hasattr(model, 'enable_input_require_grads'): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook( + make_inputs_require_grad) + + # enable gradient checkpointing for memory efficiency + model.gradient_checkpointing_enable() diff --git a/xtuner/types/__init__.py b/xtuner/types/__init__.py new file mode 100644 index 000000000..cc230e8f8 --- /dev/null +++ b/xtuner/types/__init__.py @@ -0,0 +1,6 @@ +from .chat_template import HybridChatTemplate +from .train import RawTrainingData, TrainingHybridChatMessages + +__all__ = [ + 'HybridChatTemplate', 'RawTrainingData', 'TrainingHybridChatMessages' +] diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py new file mode 100644 index 000000000..616d89b09 --- /dev/null +++ b/xtuner/types/chat.py @@ -0,0 +1,145 @@ +from typing import Dict, List, Literal, Union + +from pydantic import BaseModel + +from .chat_template import HybridChatTemplate + + +class TextContentItem(BaseModel): + type: Literal['text'] + text: str + + def format_content(self, chat_template: HybridChatTemplate) -> str: + return self.text + + +class ImageContentItem(BaseModel): + type: Literal['image_url'] + image_url: str + + def format_content(self, chat_template: HybridChatTemplate) -> str: + return chat_template.image_token + + +MultModalContentType = Union[TextContentItem, ImageContentItem] +ContentType = Union[str, List[MultModalContentType]] + + +class ChatMsg(BaseModel): + role: Literal['assistant', 'user', 'system'] + content: ContentType + + def collect_img_urls(self) -> List[str]: + img_urls = [] + if isinstance(self.content, list): + for item in self.content: + if isinstance(item, ImageContentItem): + img_urls.append(item.image_url) + return img_urls + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + if isinstance(self.content, str): + text = self.content + elif isinstance(self.content, list): + text = '' + for i, item in enumerate(self.content): + if i == 0: + text += item.format_content(chat_template) + else: + text += '\n' + item.format_content(chat_template) + else: + raise NotImplementedError + + if self.role == 'system': + prompt = chat_template.decorate_system(text) + elif self.role == 'user': + prompt = chat_template.decorate_user(text) + elif self.role == 'assistant': + prompt = chat_template.decorate_assistant(text) + else: + raise NotImplementedError + + return prompt + + +# Function Call + + +class FunctionCall(BaseModel): + name: str + arguments: Dict + + +class FunctionCallMsg(BaseModel): + + role: Literal['assistant'] + content: str + function_call: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + return chat_template.decorate_function_call(self.content, + self.function_call) + + +class FunctionResultMsg(BaseModel): + role: Literal['function'] + name: str + content: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return chat_template.decorate_function_result(self.content) + + +class Functions(BaseModel): + + # class Parameters(BaseModel): + + # class Property(BaseModel): + # type: str + # description: str + # enum: Optional[List] = None + + # type: Literal['object'] + # properties: Dict[str, Property] + # required: List[str] + + name: str + description: Union[str, Dict] + parameters: Union[str, Dict] + + +HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg] + + +class HybridChatMessages(BaseModel): + + messages: List[HybridChatMsgType] = [] + # images: List[Image.Image] = [] + functions: List[Functions] = [] + + # TODO (pppppM) add audio and video + + def collect_img_urls(self) -> List[str]: + img_urls = [] + for msg in self.messages: + img_urls.extend(msg.collect_img_urls()) + return img_urls + + def pop_latest_msg(self): + return self.messages.pop() + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + prompt = '' + if len(self.functions) > 0: + + functions = [func.model_dump() for func in self.functions] + + prompt += chat_template.decorate_functions(functions) + + for msg in self.messages: + prompt += msg.apply_chat_template(chat_template) + + return prompt diff --git a/xtuner/types/chat_template.py b/xtuner/types/chat_template.py new file mode 100644 index 000000000..847604bfe --- /dev/null +++ b/xtuner/types/chat_template.py @@ -0,0 +1,183 @@ +from typing import Dict, List, Optional + +from pydantic import BaseModel, field_validator + + +class HybridChatTemplate(BaseModel): + """Define a Pydantic data model for a hybrid chat with attributes for + system, user and assistant chat as well as function and interpreter calls + and results.""" + + # Normal Chat + system: str # System message format + user: str # User message format + assistant: str # Assistant message format + stop_words: List[str] # List of stop words + + # Multimodal Chat + # Predefined token and index for images + image_token: str = '' + image_token_index: int = -100 + + # Agent Chat + # Interpreter and function related strings + functions: Optional[str] = None # Function description format + function_call: Optional[str] = None # Function call format + function_result: Optional[str] = None # Function result format + + code_interpreter: Optional[str] = None + code_interpreter_call: Optional[str] = None # Interpreter call format + code_interpreter_result: Optional[str] = None # Interpreter result format + + function_token: Optional[str] = None + code_interpreter_token: Optional[str] = None + action_start_token: Optional[str] = None + action_end_token: Optional[str] = None + + @property + def mm_token_maps(self) -> Dict[str, int]: + """Return a dictionary that maps multimodal tokens to corresponding + token indexes.""" + return {self.image_token: self.image_token_index} + + def decorate_system(self, text: str) -> str: + """Decorate text with the `system` template.""" + return self.system.format(system=text) + + def decorate_assistant(self, text: str) -> str: + """Decorate text with the `assistant` template.""" + return self.assistant.format(assistant=text) + + def decorate_user(self, text: str) -> str: + """Decorate text with the `user` template.""" + return self.user.format(user=text) + + def decorate_functions(self, text: str) -> str: + """Decorate text with the `functions` template.""" + return self.functions.format(functions=text) + + def decorate_function_call(self, text: str, func: str) -> str: + """Decorate text with the `function_call` template.""" + return self.function_call.format(assistant=text, function_call=func) + + def decorate_function_result(self, text: str) -> str: + """Decorate text with the `function_result` template.""" + return self.function_result.format(function_result=text) + + def decorate_code_interpreter(self, text: str) -> str: + """Decorate text with the `code_interpreter` template.""" + return self.code_interpreter.format(code_interpreter=text) + + def decorate_code_interpreter_call(self, text: str) -> str: + """Decorate text with the `code_interpreter_call` template.""" + return self.code_interpreter_call.format(code_interpreter_call=text) + + def decorate_code_interpreter_result(self, text: str) -> str: + """Decorate text with the `code_interpreter_result` template.""" + return self.code_interpreter_result.format( + code_interpreter_result=text) + + @field_validator('system') + def check_system(cls, v: str) -> str: + """Validate that `system` contains '{system}'. + + If not, raises a ValueError. + """ + if v is not None and '{system}' not in v: + raise ValueError("system must contain the keyword '{system}'") + return v + + @field_validator('user') + def check_user(cls, v: str) -> str: + """Validate that `user` contains '{user}'. + + If not, raises a ValueError. + """ + if v is not None and '{user}' not in v: + raise ValueError("user must contain the keyword '{user}'") + return v + + @field_validator('assistant') + def check_assistant(cls, v: str) -> str: + """Validate that `assistant` contains '{assistant}'. + + If not, raises a ValueError. + """ + if v is not None and '{assistant}' not in v: + raise ValueError( + "assistant must contain the keyword '{assistant}'") + return v + + @field_validator('function_call') + def check_function_call(cls, v: str) -> str: + """Validate that `function_call` contains '{function_call}'. + + If not, raises a ValueError. + """ + if (v is not None and '{function_call}' not in v + and '{assistant}' not in v): + raise ValueError( + "function_call must contain the keywords '{function_call}'") + if v is not None and '{assistant}' not in v: + raise ValueError( + "function_call must contain the keyword '{assistant}' and " + "'{function_call}'") + return v + + @field_validator('function_result') + def check_function_result(cls, v: str) -> str: + """Validate that `function_result` contains '{function_result}'. + + If not, raises a ValueError. + """ + if v is not None and '{function_result}' not in v: + raise ValueError( + "function_result must contain the keyword '{function_result}'") + return v + + @field_validator('functions') + def check_functions(cls, v: str) -> str: + """Validate that `functions` contains '{functions}'. + + If not, raises a ValueError. + """ + if v is not None and '{functions}' not in v: + raise ValueError( + "functions must contain the keyword '{functions}'") + return v + + @field_validator('code_interpreter') + def check_code_interpreter(cls, v: str) -> str: + """Validate that `code_interpreter` contains '{code_interpreter}'. + + If not, raises a ValueError. + """ + if v is not None and '{code_interpreter}' not in v: + raise ValueError('code_interpreter must contain the keyword ' + "'{code_interpreter}'") + return v + + @field_validator('code_interpreter_call') + def check_code_interpreter_call(cls, v: str) -> str: + """Validate that `code_interpreter_call` contains + '{code_interpreter_call}'. + + If not, raises a ValueError. + """ + if v is not None and '{code_interpreter_call}' not in v: + raise ValueError('code_interpreter_call must contain the keyword ' + "'{code_interpreter_call}'") + return v + + @field_validator('code_interpreter_result') + def check_code_interpreter_result(cls, v: str) -> str: + """Validate that `code_interpreter_result` contains + '{code_interpreter_result}'. + + If not, raises a ValueError. + """ + if v is not None and '{code_interpreter_result}' not in v: + raise ValueError( + 'code_interpreter_result must contain the keyword ' + "'{code_interpreter_result}'") + return v diff --git a/xtuner/types/train.py b/xtuner/types/train.py new file mode 100644 index 000000000..a3775bffe --- /dev/null +++ b/xtuner/types/train.py @@ -0,0 +1,297 @@ +import copy +import re +from typing import Dict, List, Optional, Union + +import torch +from pydantic import BaseModel +from transformers.tokenization_utils import PreTrainedTokenizer + +from xtuner.utils import IGNORE_INDEX +from xtuner.utils.tokenizer import get_bos_token_ids +from .chat import (ChatMsg, FunctionCallMsg, FunctionResultMsg, Functions, + ImageContentItem, TextContentItem) +from .chat_template import HybridChatTemplate + + +class TrainingChatMsg(ChatMsg): + loss: Optional[bool] = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.loss is None: + if self.role == 'system': + self.loss = False + elif self.role == 'user': + self.loss = False + elif self.role == 'assistant': + self.loss = True + else: + raise NotImplementedError + + def _encode_mm_content(self, text: str, mm_token_maps: Dict[str, int], + tokenizer: PreTrainedTokenizer): + + mm_tokens = mm_token_maps.keys() + + pattern = r'(' + '|'.join(mm_tokens) + r')' + chunks = re.split(pattern, text) + + assert len(chunks) > 1 + + token_ids = [] + for c in chunks: + if c in mm_tokens: + token_ids.append(mm_token_maps[c]) + else: + token_ids.extend(tokenizer.encode(c, add_special_tokens=False)) + + return token_ids + + def _with_multi_modal_content(self): + flag = False + + if isinstance(self.content, list): + for item in self.content: + # TODO (pppppM) support video and audio + if isinstance(item, ImageContentItem): + flag = True + break + return flag + + def tokenize( + self, + tokenizer: PreTrainedTokenizer, + chat_template: HybridChatTemplate, + ): + + decorated = self.apply_chat_template(chat_template) + + if self._with_multi_modal_content(): + token_maps = chat_template.mm_token_maps + token_ids = self._encode_mm_content(decorated, token_maps, + tokenizer) + else: + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + image_urls = self.collect_img_urls() + + return { + 'input_ids': token_ids, + 'labels': label_ids, + 'image_urls': image_urls + } + + +class TrainingFunctionCallMsg(FunctionCallMsg): + loss: bool = True + + def tokenize( + self, + tokenizer: PreTrainedTokenizer, + chat_template: HybridChatTemplate, + ): + + decorated = self.apply_chat_template(chat_template) + + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + return {'input_ids': token_ids, 'labels': label_ids} + + +class TrainingFunctionResultMsg(FunctionResultMsg): + loss: bool = False + + def tokenize(self, tokenizer, chat_template: HybridChatTemplate): + + decorated = self.apply_chat_template(chat_template) + + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + return {'input_ids': token_ids, 'labels': label_ids} + + +class RawTrainingData(BaseModel): + + input_ids: List[int] + labels: List[int] + image_urls: List[str] = [] + + +class ProcessedTrainingData(BaseModel): + + input_ids: List[int] + labels: List[int] + length: int + cumulative_len: List[int] + position_ids: List[int] + image_urls: List[str] = [] + pixel_values: List[torch.Tensor] = [] + image_ranges: List[tuple] = [] + + class Config: + arbitrary_types_allowed = True + + +TraingHybridMessageType = Union[TrainingChatMsg, TrainingFunctionCallMsg, + TrainingFunctionResultMsg] + + +class TrainingHybridChatMessages(BaseModel): + messages: List[TraingHybridMessageType] + functions: Optional[List[Functions]] = None + + @classmethod + def from_dict(cls, item) -> 'TrainingHybridChatMessages': + ''' + item + { + 'messages':[ + {'role':'user', 'content':'hello'}, + {'role':'assistant', 'content':'hello!'}, + ], + 'funcitons': [], + } + + ''' + + assert 'messages' in item, item + + _messages = item['messages'] + messages = [] + functions = None + + for _msg in _messages: + assert 'role' in _msg and 'content' in _msg + _role = _msg['role'] + _content = _msg['content'] + if _role == 'function': + msg_factory = TrainingFunctionResultMsg + assert 'name' in _msg + name = _msg['name'] + msg = msg_factory(role=_role, name=name, content=_content) + messages.append(msg) + continue + + if isinstance(_content, list): + + content = [] + for c_item in _content: + assert 'type' in c_item + _type = c_item['type'] + if _type == 'text': + assert 'text' in c_item + _text = c_item['text'] + content.append(TextContentItem(type=_type, text=_text)) + elif _type == 'image_url': + assert 'image_url' in c_item + _url = c_item['image_url'] + content.append( + ImageContentItem(type=_type, image_url=_url)) + else: + raise NotImplementedError + + msg = TrainingChatMsg(role=_role, content=content) + messages.append(msg) + continue + + if isinstance(_content, str) and 'function_call' in _msg: + _call = _msg['function_call'] + msg = TrainingFunctionCallMsg( + role=_role, content=_content, function_call=_call) + messages.append(msg) + continue + + if isinstance(_content, str): + msg = TrainingChatMsg(role=_role, content=_content) + messages.append(msg) + + # TODO (pppppM) add format warning + + if 'functions' in item: + _functions = item['functions'] + assert isinstance(_functions, list) + functions = [] + + for _func in _functions: + assert 'name' in _func + assert 'description' in _func + assert 'parameters' in _func + + _name = _func['name'] + _desc = _func['description'] + _params = _func['parameters'] + + func = Functions( + name=_name, description=_desc, parameters=_params) + functions.append(func) + + return cls(messages=messages, functions=functions) + + def collect_img_urls(self) -> List[str]: + img_urls = [] + for msg in self.messages: + img_urls.extend(msg.collect_img_urls()) + return img_urls + + def pop_latest_msg(self): + return self.messages.pop() + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + prompt = '' + + if isinstance(self.functions, list) and len(self.functions) > 0: + + functions = [func.model_dump() for func in self.functions] + + prompt += chat_template.decorate_functions(functions) + + for msg in self.messages: + prompt += msg.apply_chat_template(chat_template) + + return prompt + + def tokenize(self, tokenizer: PreTrainedTokenizer, + chat_template: HybridChatTemplate) -> RawTrainingData: + + input_ids = [] + labels = [] + image_urls = [] + + bos_token_ids = get_bos_token_ids(tokenizer) + input_ids.extend(bos_token_ids) + labels.extend([IGNORE_INDEX] * len(bos_token_ids)) + + for msg in self.messages: + res = msg.tokenize(tokenizer, chat_template) + token_ids, label_ids = res['input_ids'], res['labels'] + + input_ids.extend(token_ids) + labels.extend(label_ids) + + if 'image_urls' in res: + image_urls.extend(res['image_urls']) + + # TODO (pppppM) Verify whether sep and suffix_as_eos are necessary + + training_data = { + 'input_ids': input_ids, + 'labels': labels, + 'image_urls': image_urls + } + return training_data diff --git a/xtuner/utils/__init__.py b/xtuner/utils/__init__.py index 6bc9a1173..75bcad2bf 100644 --- a/xtuner/utils/__init__.py +++ b/xtuner/utils/__init__.py @@ -3,9 +3,10 @@ IGNORE_INDEX, IMAGE_TOKEN_INDEX) from .stop_criteria import StopWordStoppingCriteria from .templates import PROMPT_TEMPLATE, SYSTEM_TEMPLATE +from .tokenizer import build_tokenizer __all__ = [ 'IGNORE_INDEX', 'DEFAULT_PAD_TOKEN_INDEX', 'PROMPT_TEMPLATE', 'DEFAULT_IMAGE_TOKEN', 'SYSTEM_TEMPLATE', 'StopWordStoppingCriteria', - 'IMAGE_TOKEN_INDEX' + 'IMAGE_TOKEN_INDEX', 'build_tokenizer' ] diff --git a/xtuner/utils/config.py b/xtuner/utils/config.py new file mode 100644 index 000000000..ecd165920 --- /dev/null +++ b/xtuner/utils/config.py @@ -0,0 +1,131 @@ +import dataclasses +from typing import TypeVar, Union + +import torch +from mmengine.config import Config +from mmengine.logging import print_log +from mmengine.utils import get_object_from_string + +from xtuner.registry import BUILDER + + +def convert_dtype_cfg_to_obj(config: Union[dict, list[dict]]) -> None: + """Convert dtype related config to python object. + + When MMEngine Runner is training, it will save the config file for + resuming training. + But in the saved config file, python objects of type torch.dtype are + converted to strings like 'torch.float16'. In order to accommodate this, + after loading the config, all dtype strings need to be converted into + python objects. + + Args: + config: A dict or list that potentially contains dtypes as strings. + + Returns: + None. The input 'config' is modified in-place. + """ + # If the config is a dictionary + if isinstance(config, dict): + for key, value in config.items(): + # Recursively call the function if the value is also a dict + if isinstance(value, dict): + convert_dtype_cfg_to_obj(value) + + # Replace the string with the corresponding dtype object + # if it's a recognized dtype string + elif value in ['torch.float16', 'torch.float32', 'torch.bfloat16']: + config[key] = getattr(torch, value.split('.')[-1]) + + # If the config is a list + elif isinstance(config, list): + for item in config: + convert_dtype_cfg_to_obj(item) + + +def convert_dataclass_cfg_to_obj(config: Union[dict, list[dict]]) -> None: + """Convert dataclass related config to python object. + + Huggingface's code uses dataclasses extensively. + In order to use Huggingface's interfaces in the MMEngine config, + we need to specifically handle these configurations. + + Note: + Before executing this function, you must first run + `convert_dtype_cfg_to_obj`, otherwise the dataclass config containing + dtype cannot be properly converted ! + + Args: + config: A dictionary or list that potentially contains configurations + as dataclasses. + + Returns: + None. The input 'config' is modified in-place. + """ + # If the config is a dictionary + if isinstance(config, dict): + for key, value in config.items(): + # Recursively call the function if the value is also a dict + if isinstance(value, dict): + convert_dataclass_cfg_to_obj(value) + + # Check if the type of value is a dataclass + if 'type' in value and dataclasses.is_dataclass(value['type']): + builder = value.pop( + 'type') # remove 'type' from value and get its content + + # Convert the builder to an object if it is a string + if isinstance(builder, str): + builder = get_object_from_string(builder) + + # Build a new_value using the remaining items in value + new_value = builder(**value) + # replace the original value with new_value + config[key] = new_value + print_log(f'{key} convert to {builder}') + + # If the config is a list + elif isinstance(config, list): + for item in config: + convert_dataclass_cfg_to_obj(item) + + +OBJ_T = TypeVar('OBJ_T') + + +def build_from_cfg_or_obj(cfg_or_obj: Union[dict, OBJ_T], + accept: OBJ_T) -> OBJ_T: + """Build a python object from a config or return an existed object. + + Args: + cfg_or_obj (dict, OBJ_T]): an object of a type specified in + `accept_obj_types`, or a dict. + accept_obj (OBJ_T): the type of object that return without any + modification. + + Returns: + If 'cfg_or_obj' is an object of `accept_obj` , it is returned directly. + If 'cfg_or_obj' is a dict, it is built into an object using + `build_from_cfg()`. + + Raises: + TypeError: If `cfg_or_obj` is not dict or `accept_obj`. + """ + + if isinstance(cfg_or_obj, accept): + return cfg_or_obj + + elif isinstance(cfg_or_obj, (dict, Config)): + convert_dtype_cfg_to_obj(cfg_or_obj) + convert_dataclass_cfg_to_obj(cfg_or_obj) + obj = BUILDER.build(cfg_or_obj) + + if not isinstance(obj, accept): + raise TypeError( + f'Expect an object of {accept}, but there is an object of ' + f'{type(obj)}.') + return BUILDER.build(cfg_or_obj) + + else: + raise TypeError(f'cfg_or_obj must be a dict, or {accept}, but got ' + f'{type(cfg_or_obj)}') diff --git a/xtuner/utils/tokenizer.py b/xtuner/utils/tokenizer.py new file mode 100644 index 000000000..7c79b9fba --- /dev/null +++ b/xtuner/utils/tokenizer.py @@ -0,0 +1,46 @@ +from typing import List, Union + +from transformers import AutoTokenizer + +from xtuner.registry import BUILDER + + +def build_tokenizer(tokenizer: Union[str, dict]): + + if isinstance(tokenizer, str): + return AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + elif isinstance(tokenizer, dict): + return BUILDER.build(tokenizer) + else: + raise TypeError + + +def get_bos_token_ids(tokenizer) -> List[int]: + + if tokenizer.__class__.__name__ == 'QWenTokenizer': + bos_token_ids = [] + elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': + bos_token_ids = [64790, 64792] + else: + bos_token_ids = tokenizer.bos_token_id + + if isinstance(bos_token_ids, int): + bos_token_ids = [bos_token_ids] + + return bos_token_ids + + +def get_eos_token_ids(tokenizer) -> List[int]: + if tokenizer.__class__.__name__ == 'QWenTokenizer': + eos_token_ids = tokenizer.eos_token_id + assert eos_token_ids is not None, \ + 'Please set eos_token for Qwen tokenizer!' + elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': + eos_token_ids = tokenizer.eos_token_id + else: + eos_token_ids = tokenizer.eos_token_id + + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] + + return eos_token_ids From 40c2fe6853b4279995ae41ba44463023d7222df5 Mon Sep 17 00:00:00 2001 From: pppppM Date: Fri, 22 Mar 2024 10:44:21 +0800 Subject: [PATCH 2/6] fix forward error --- xtuner/dataset/hybrid/collate.py | 31 +++++--- xtuner/dataset/hybrid/hybrid.py | 68 ---------------- xtuner/model/hybrid.py | 128 ++++++++++++++++++++++++------- xtuner/types/chat.py | 44 +++++++++++ xtuner/utils/config.py | 2 +- 5 files changed, 164 insertions(+), 109 deletions(-) delete mode 100644 xtuner/dataset/hybrid/hybrid.py diff --git a/xtuner/dataset/hybrid/collate.py b/xtuner/dataset/hybrid/collate.py index 925b9ac01..74d2218e7 100644 --- a/xtuner/dataset/hybrid/collate.py +++ b/xtuner/dataset/hybrid/collate.py @@ -16,9 +16,9 @@ def hybrid_collate_fn(instances: Sequence[Dict], pixel_values = [] cumulative_len = [] image_ranges = [] - image_belong = [] + image_belongs = [] position_ids = [] - + for i, data in enumerate(instances): input_ids.append(torch.LongTensor(data['input_ids'])) labels.append(torch.LongTensor(data['labels'])) @@ -27,28 +27,33 @@ def hybrid_collate_fn(instances: Sequence[Dict], if 'cumulative_len' in data: cumulative_len.append(torch.IntTensor(data['cumulative_len'])) - image_belong.append(i) - pixel_values.extend(data['pixel_values']) - image_ranges.extend(torch.IntTensor(data['image_ranges'])) - + + _values = data['pixel_values'] + _ranges = data['image_ranges'] + + assert len(_values) == len(_ranges) + for v, rng in zip(_values, _ranges): + pixel_values.append(v) + image_ranges.append(rng) + image_belongs.append(i) + if len(pixel_values) > 0: assert len(image_ranges) > 0 - assert len(image_belong) > 0 + assert len(image_belongs) > 0 pixel_values = torch.stack(pixel_values) - image_ranges = torch.stack(image_ranges) - image_belong = torch.IntTensor(image_belong) + # image_belongs = torch.IntTensor(image_belongs) else: pixel_values = None image_ranges = None - image_belong = None + image_belongs = None if len(instances) > 1: input_ids = pad_sequence( input_ids, batch_first=True, padding_value=pad_index) labels = pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX) - position_ids = pad_sequence(labels, batch_first=True, padding_value=0) + position_ids = pad_sequence(position_ids, batch_first=True, padding_value=0) else: input_ids = torch.stack(input_ids) labels = torch.stack(labels) @@ -57,6 +62,7 @@ def hybrid_collate_fn(instances: Sequence[Dict], if len(cumulative_len) == 0: cumulative_len = None + # breakpoint() data_dict = { 'input_ids': input_ids, 'position_ids': position_ids, @@ -65,8 +71,9 @@ def hybrid_collate_fn(instances: Sequence[Dict], 'pixel_values': pixel_values, 'cumulative_len': cumulative_len, 'image_ranges': image_ranges, - 'image_belong': image_belong + 'image_belongs': image_belongs } + if return_hf_format: return data_dict diff --git a/xtuner/dataset/hybrid/hybrid.py b/xtuner/dataset/hybrid/hybrid.py deleted file mode 100644 index 289d21de2..000000000 --- a/xtuner/dataset/hybrid/hybrid.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Sequence - -import torch -from torch.nn.utils.rnn import pad_sequence - -from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX -from xtuner.types import RawTrainingData - - -def hybrid_collate_fn(instances: Sequence[Dict], - pad_index: int = DEFAULT_PAD_TOKEN_INDEX, - return_hf_format: bool = False): - - input_ids = [] - labels = [] - pixel_values = [] - cumulative_len = [] - image_ranges = [] - # indexes = [] - - - for item in instances: - input_ids.append(torch.LongTensor(item['input_ids'])) - labels.append(torch.LongTensor(item['labels'])) - - if 'cumulative_len' in item: - cumulative_len.append(torch.IntTensor(item['cumulative_len'])) - - pixel_values.extend(item['pixel_values']) - # image_ranges.extend(torch.IntTensor(item['image_ranges'])) - - if len(pixel_values) > 0: - pixel_values = torch.stack(pixel_values) - else: - pixel_values = None - - if len(instances) > 1: - input_ids = pad_sequence( - input_ids, batch_first=True, padding_value=pad_index) - labels = pad_sequence( - labels, batch_first=True, padding_value=IGNORE_INDEX) - else: - input_ids = torch.stack(input_ids) - labels = torch.stack(labels) - - # if len(image_ranges) > 0: - # image_ranges = torch.stack(image_ranges) - # else: - # image_ranges = None - - if len(cumulative_len) == 0: - cumulative_len = None - - data_dict = { - 'input_ids': input_ids, - 'attention_mask': input_ids.ne(pad_index), - 'labels': labels, - 'pixel_values': pixel_values, - 'cumulative_len': cumulative_len, - # 'image_ranges': image_ranges, - } - - - if return_hf_format: - return data_dict - else: - return {'data': data_dict, 'data_samples': None} diff --git a/xtuner/model/hybrid.py b/xtuner/model/hybrid.py index 40db93820..24b165cd7 100644 --- a/xtuner/model/hybrid.py +++ b/xtuner/model/hybrid.py @@ -4,8 +4,9 @@ import torch from mmengine.model import BaseModel from peft import LoraConfig +from mmengine import print_log from torch import nn - +import math from xtuner.registry import BUILDER from xtuner.utils.config import build_from_cfg_or_obj from .modules import ProjectorConfig, ProjectorModel, dispatch_modules @@ -13,8 +14,8 @@ get_peft_model_state_dict, prepare_for_llm_lora, prepare_for_vision_lora, smart_tokenizer_and_embedding_resize) - - +import torch.distributed as dist +from mmengine import runner class HybridFinetune(BaseModel): def __init__( @@ -106,46 +107,117 @@ def forward(self, data, data_samples=None, mode='loss'): """Overload parent class method, only support training.""" if mode == 'loss': - return self.compute_loss(data, data_samples) + return self.compute_loss(data) else: raise NotImplementedError( f"{type(self)}'s forward is only supported for use during " 'training. If you want to get predictions or chat, please ' "directly use `llm`'s forward.") - def compute_loss(self, data, data_samples=None): - + + + def _get_vision_embeds_and_ranges(self, data): + input_ids = data['input_ids'] - labels = data['labels'] - position_ids = data['position_ids'] - attention_mask = data['attention_mask'] pixel_values = data['pixel_values'] img_rngs = data['image_ranges'] - img_belong = data['image_belong'] - - input_embeds = self.llm.get_input_embeddings()(input_ids) + img_belongs = data['image_belongs'] + + bs, tokens = input_ids.shape + + img_embeds = [] + ranges_in_flat_batch = [] if pixel_values is not None: + assert isinstance(pixel_values, torch.Tensor) + assert len(img_rngs) == len(img_belongs) == pixel_values.size(0) + + batch_total_imgs = len(img_rngs) + visual_outputs = self.visual_encoder( pixel_values, output_hidden_states=True) - img_embeds = self.projector( + features = self.projector( visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) - - empty_embs = torch.zeros_like(input_embeds) - for emb, rng, b_id in zip(img_embeds, img_rngs, img_belong): - left, right = rng - if emb.size(0) == right - left: - empty_embs[b_id, left:right, :] = emb - elif not emb.size(0) == right - left and left == 0: - empty_embs[b_id, left:right, :] = emb[-right:] - elif not emb.size( - 0) == right - left and right == empty_embs.size(1): - empty_embs[b_id, left:right, :] = emb[:right - left] + batch_total_imgs, actual_img_tokens, _ = features.shape + + + for i in range(batch_total_imgs): + img_start, img_end = img_rngs[i] + expect_img_tokens = img_end - img_start + img_emb = features[i] + img_bs_ind = img_belongs[i] + + if actual_img_tokens == expect_img_tokens: + img_embeds.append(img_emb) + elif not actual_img_tokens == expect_img_tokens and img_start == 0: + img_embeds.append(img_emb[actual_img_tokens-img_end:]) + elif not actual_img_tokens == expect_img_tokens and img_end == tokens: + img_embeds.append(img_emb[:expect_img_tokens]) else: - breakpoint() + raise RuntimeError + + flat_offset = tokens * img_bs_ind + + left = flat_offset + img_start + right = flat_offset + img_end + ranges_in_flat_batch.append((left, right)) + + return img_embeds, ranges_in_flat_batch + + + def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges): + + assert len(mm_embeds) == len(ranges) + if len(mm_embeds) == 0: + return flat_embeds + + chunk_embeds = [] + chunk_sizes = [] + mm_chunk_ids = [] + + cursor = 0 + _empty_embeds = torch.zeros_like(flat_embeds) + for (start, end), emb in zip(ranges, mm_embeds): + _empty_embeds[start: end] += emb + # if start - cursor > 0: + # chunk_sizes.append(start - cursor) + # cursor = start + + # mm_chunk_ids.append(len(chunk_sizes)) + + + # chunk_embeds.append(emb) + # chunk_sizes.append(end - start) + # cursor = end + + # tokens = flat_embeds.size(0) + # if sum(chunk_sizes) < tokens : + # chunk_sizes.append(tokens - sum(chunk_sizes)) + + # chunk_embs = list(torch.split(flat_embeds, chunk_sizes)) + # for ind, mm_emb in zip(mm_chunk_ids, mm_embeds) : + # chunk_embs[ind] = mm_emb + + # flat_embeds = torch.cat(chunk_embs, dim=0) + flat_embeds = flat_embeds * (_empty_embeds == 0) + + return flat_embeds + _empty_embeds + + def compute_loss(self, data): - non_img_mask = (empty_embs == 0) - input_embeds = input_embeds * non_img_mask + empty_embs + input_ids = data['input_ids'] + labels = data['labels'] + position_ids = data['position_ids'] + attention_mask = data['attention_mask'] + + input_embeds = self.llm.get_input_embeddings()(input_ids) + + bs, tokens, dim = input_embeds.shape + flat_embeds = input_embeds.flatten(0,1) + + img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) + flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, flat_bs_img_rngs) + input_embeds = flat_embeds.reshape(bs, tokens, dim) outputs = self.llm( input_ids=None, @@ -153,7 +225,7 @@ def compute_loss(self, data, data_samples=None): attention_mask=attention_mask, inputs_embeds=input_embeds, labels=labels) - + loss_dict = {'loss': outputs.loss} return loss_dict diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py index 616d89b09..125f789ee 100644 --- a/xtuner/types/chat.py +++ b/xtuner/types/chat.py @@ -92,6 +92,30 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: return chat_template.decorate_function_result(self.content) +class CodeInterpreterCallMsg(BaseModel): + + role: Literal['assistant'] + content: str + conde_interpreter_call: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + return chat_template.decorate_code_interpreter_call( + self.content, self.conde_interpreter_call) + + + +class CodeInterpreterResultMsg(BaseModel): + role: Literal['function'] + name: str + content: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return chat_template.decorate_code_internpreter_result(self.content) + + + + class Functions(BaseModel): # class Parameters(BaseModel): @@ -108,6 +132,26 @@ class Functions(BaseModel): name: str description: Union[str, Dict] parameters: Union[str, Dict] + + + +class CodeInterpreter(BaseModel): + + # class Parameters(BaseModel): + + # class Property(BaseModel): + # type: str + # description: str + # enum: Optional[List] = None + + # type: Literal['object'] + # properties: Dict[str, Property] + # required: List[str] + + name: str + description: Union[str, Dict] + + HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg] diff --git a/xtuner/utils/config.py b/xtuner/utils/config.py index ecd165920..0514dd8bf 100644 --- a/xtuner/utils/config.py +++ b/xtuner/utils/config.py @@ -124,7 +124,7 @@ def build_from_cfg_or_obj(cfg_or_obj: Union[dict, OBJ_T], raise TypeError( f'Expect an object of {accept}, but there is an object of ' f'{type(obj)}.') - return BUILDER.build(cfg_or_obj) + return obj else: raise TypeError(f'cfg_or_obj must be a dict, or {accept}, but got ' From 5c8c265f4d6b4dfab1c63d5155b922185ca7ac31 Mon Sep 17 00:00:00 2001 From: pppppM Date: Sun, 24 Mar 2024 18:07:35 +0800 Subject: [PATCH 3/6] support varlen attn --- xtuner/dataset/hybrid/_pack.py | 8 +- xtuner/dataset/hybrid/collate.py | 11 +- xtuner/dataset/hybrid/dataset.py | 44 +++++--- xtuner/model/hybrid.py | 123 ++++++++++----------- xtuner/model/modules/dispatch/internlm2.py | 2 +- xtuner/model/modules/dispatch/utils.py | 1 - xtuner/types/chat.py | 8 +- 7 files changed, 98 insertions(+), 99 deletions(-) diff --git a/xtuner/dataset/hybrid/_pack.py b/xtuner/dataset/hybrid/_pack.py index 946020909..12b29fbca 100644 --- a/xtuner/dataset/hybrid/_pack.py +++ b/xtuner/dataset/hybrid/_pack.py @@ -63,14 +63,15 @@ def _pack_img_urls_and_rngs_in_range(self, begin, end): filter_rngs = self._flat_shfl_acc_img_rngs[left:right] inner_rngs = [] - for rng in filter_rngs: + inner_urls = [] + for url, rng in zip(filter_urls, filter_rngs): inner_left = max(begin, rng[0]) - begin inner_right = min(end, rng[1]) - begin if inner_right - inner_left > 0: inner_rngs.append([inner_left, inner_right]) - - return filter_urls, inner_rngs + inner_urls.append(url) + return inner_urls, inner_rngs def _pack_ids_and_labels_in_range(self, begin, end): @@ -98,6 +99,7 @@ def _pack_ids_and_labels_in_range(self, begin, end): trunc_ids.extend(ori_input_ids[inner_l:inner_r]) trunc_labels.extend(ori_labels[inner_l:inner_r]) + cumulative_len.append(self.max_length) return trunc_ids, trunc_labels, cumulative_len, position_ids def __len__(self): diff --git a/xtuner/dataset/hybrid/collate.py b/xtuner/dataset/hybrid/collate.py index 74d2218e7..e6ed2288a 100644 --- a/xtuner/dataset/hybrid/collate.py +++ b/xtuner/dataset/hybrid/collate.py @@ -18,7 +18,7 @@ def hybrid_collate_fn(instances: Sequence[Dict], image_ranges = [] image_belongs = [] position_ids = [] - + for i, data in enumerate(instances): input_ids.append(torch.LongTensor(data['input_ids'])) labels.append(torch.LongTensor(data['labels'])) @@ -27,16 +27,15 @@ def hybrid_collate_fn(instances: Sequence[Dict], if 'cumulative_len' in data: cumulative_len.append(torch.IntTensor(data['cumulative_len'])) - _values = data['pixel_values'] _ranges = data['image_ranges'] - + assert len(_values) == len(_ranges) for v, rng in zip(_values, _ranges): pixel_values.append(v) image_ranges.append(rng) image_belongs.append(i) - + if len(pixel_values) > 0: assert len(image_ranges) > 0 assert len(image_belongs) > 0 @@ -53,7 +52,8 @@ def hybrid_collate_fn(instances: Sequence[Dict], input_ids, batch_first=True, padding_value=pad_index) labels = pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX) - position_ids = pad_sequence(position_ids, batch_first=True, padding_value=0) + position_ids = pad_sequence( + position_ids, batch_first=True, padding_value=0) else: input_ids = torch.stack(input_ids) labels = torch.stack(labels) @@ -73,7 +73,6 @@ def hybrid_collate_fn(instances: Sequence[Dict], 'image_ranges': image_ranges, 'image_belongs': image_belongs } - if return_hf_format: return data_dict diff --git a/xtuner/dataset/hybrid/dataset.py b/xtuner/dataset/hybrid/dataset.py index 69327a4e3..e8f127fc6 100644 --- a/xtuner/dataset/hybrid/dataset.py +++ b/xtuner/dataset/hybrid/dataset.py @@ -257,8 +257,8 @@ def _pack_dataset(self, dataset): def filter_non_labels_data(self, dataset): - filter_fn = lambda item: any(item['labels'][i] >= 0 for i in range( - self.max_length)) # noqa: E501, E731 + def filter_fn(item): + return any(item['labels'][i] >= 0 for i in range(self.max_length)) ori_samples = len(dataset) with ThreadPoolExecutor(max_workers=self.num_workers) as executor: @@ -281,9 +281,12 @@ def filter_non_labels_data(self, dataset): def analysis_image_samples(self, dataset): - img_sample_counter = lambda item: len(item['image_urls'] - ) > 0 # noqa: E501, E731 - img_counter = lambda item: len(item['image_urls']) # noqa: E501, E731 + def img_sample_counter(item): + return len(item['image_urls']) > 0 + + def img_counter(item): + return len(item['image_urls']) + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: images = list( @@ -307,9 +310,11 @@ def analysis_image_samples(self, dataset): def analysis_tokens_labels(self, dataset): - label_counter = lambda item: sum([1 for i in item['labels'] - if i >= 0]) # noqa: E501, E731 - token_counter = lambda item: len(item['input_ids']) + def label_counter(item): + return sum([1 for i in item['labels'] if i >= 0]) + + def token_counter(item): + return len(item['input_ids']) with ThreadPoolExecutor(max_workers=self.num_workers) as executor: tokens = list( @@ -398,10 +403,8 @@ def __getitem__(self, item: int) -> Dict[str, List]: assistant='{assistant}<|im_end|>\n', stop_words=['<|im_end|>'], image_token='', - function_call= - '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501 - function_result= - '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501 + function_call='{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result='<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n' ) @@ -410,8 +413,9 @@ def __getitem__(self, item: int) -> Dict[str, List]: trust_remote_code=True, ) - from xtuner.dataset.hybrid.mappings import ( - insert_img_pad_tokens, llava_to_openai, openai_to_raw_training) + from xtuner.dataset.hybrid.mappings import (insert_img_pad_tokens, + llava_to_openai, + openai_to_raw_training) data_dir = './llava_data/LLaVA-Instruct-150K/' image_dir = './llava_data/llava_images/' @@ -421,14 +425,16 @@ def __getitem__(self, item: int) -> Dict[str, List]: 'internlm/internlm2-chat-1_8b', chat_template, sample_ratio=1, - max_length=32*1024, + max_length=32 * 1024, data_dir=data_dir, data_files=data_files, image_dir=image_dir, image_processor=processor, pack_to_max_length=True, mappings=[ - llava_to_openai, openai_to_raw_training, insert_img_pad_tokens, + llava_to_openai, + openai_to_raw_training, + insert_img_pad_tokens, ], num_workers=4) @@ -439,13 +445,15 @@ def __getitem__(self, item: int) -> Dict[str, List]: 'internlm/internlm2-chat-1_8b', chat_template, sample_ratio=1, - max_length=32*1024, + max_length=32 * 1024, data_cached='cached_llava', image_dir=image_dir, image_processor=processor, pack_to_max_length=True, mappings=[ - llava_to_openai, openai_to_raw_training, insert_img_pad_tokens, + llava_to_openai, + openai_to_raw_training, + insert_img_pad_tokens, ], num_workers=4) print(dataset[0]) diff --git a/xtuner/model/hybrid.py b/xtuner/model/hybrid.py index 24b165cd7..0f0fc7e76 100644 --- a/xtuner/model/hybrid.py +++ b/xtuner/model/hybrid.py @@ -2,11 +2,11 @@ from collections import OrderedDict import torch +import torch.distributed as dist from mmengine.model import BaseModel from peft import LoraConfig -from mmengine import print_log from torch import nn -import math + from xtuner.registry import BUILDER from xtuner.utils.config import build_from_cfg_or_obj from .modules import ProjectorConfig, ProjectorModel, dispatch_modules @@ -14,8 +14,8 @@ get_peft_model_state_dict, prepare_for_llm_lora, prepare_for_vision_lora, smart_tokenizer_and_embedding_resize) -import torch.distributed as dist -from mmengine import runner + + class HybridFinetune(BaseModel): def __init__( @@ -114,109 +114,106 @@ def forward(self, data, data_samples=None, mode='loss'): 'training. If you want to get predictions or chat, please ' "directly use `llm`'s forward.") - - def _get_vision_embeds_and_ranges(self, data): - + input_ids = data['input_ids'] pixel_values = data['pixel_values'] img_rngs = data['image_ranges'] img_belongs = data['image_belongs'] - + bs, tokens = input_ids.shape - + img_embeds = [] ranges_in_flat_batch = [] if pixel_values is not None: assert isinstance(pixel_values, torch.Tensor) assert len(img_rngs) == len(img_belongs) == pixel_values.size(0) - + batch_total_imgs = len(img_rngs) - + visual_outputs = self.visual_encoder( pixel_values, output_hidden_states=True) features = self.projector( visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) - batch_total_imgs, actual_img_tokens, _ = features.shape - - + batch_total_imgs, real_img_tokens, _ = features.shape + for i in range(batch_total_imgs): img_start, img_end = img_rngs[i] - expect_img_tokens = img_end - img_start + exp_img_tokens = img_end - img_start img_emb = features[i] img_bs_ind = img_belongs[i] - - if actual_img_tokens == expect_img_tokens: + + if real_img_tokens == exp_img_tokens: img_embeds.append(img_emb) - elif not actual_img_tokens == expect_img_tokens and img_start == 0: - img_embeds.append(img_emb[actual_img_tokens-img_end:]) - elif not actual_img_tokens == expect_img_tokens and img_end == tokens: - img_embeds.append(img_emb[:expect_img_tokens]) + elif not real_img_tokens == exp_img_tokens and img_start == 0: + img_embeds.append(img_emb[real_img_tokens - img_end:]) + elif (not real_img_tokens == exp_img_tokens + and img_end == tokens): + img_embeds.append(img_emb[:exp_img_tokens]) else: raise RuntimeError - + flat_offset = tokens * img_bs_ind - + left = flat_offset + img_start right = flat_offset + img_end ranges_in_flat_batch.append((left, right)) - + return img_embeds, ranges_in_flat_batch - - + def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges): - + assert len(mm_embeds) == len(ranges) if len(mm_embeds) == 0: return flat_embeds - - chunk_embeds = [] - chunk_sizes = [] - mm_chunk_ids = [] - - cursor = 0 + _empty_embeds = torch.zeros_like(flat_embeds) for (start, end), emb in zip(ranges, mm_embeds): - _empty_embeds[start: end] += emb - # if start - cursor > 0: - # chunk_sizes.append(start - cursor) - # cursor = start - - # mm_chunk_ids.append(len(chunk_sizes)) - - - # chunk_embeds.append(emb) - # chunk_sizes.append(end - start) - # cursor = end - - # tokens = flat_embeds.size(0) - # if sum(chunk_sizes) < tokens : - # chunk_sizes.append(tokens - sum(chunk_sizes)) - - # chunk_embs = list(torch.split(flat_embeds, chunk_sizes)) - # for ind, mm_emb in zip(mm_chunk_ids, mm_embeds) : - # chunk_embs[ind] = mm_emb - - # flat_embeds = torch.cat(chunk_embs, dim=0) + _empty_embeds[start:end] += emb + flat_embeds = flat_embeds * (_empty_embeds == 0) - + return flat_embeds + _empty_embeds - + def compute_loss(self, data): input_ids = data['input_ids'] labels = data['labels'] - position_ids = data['position_ids'] + # position_ids = data['position_ids'] attention_mask = data['attention_mask'] - + # breakpoint() + bs, tokens = input_ids.shape + if self.use_varlen_attn: + assert bs == 1 + + cumulative_len = data['cumulative_len'][0] + max_seqlen = (cumulative_len[1:] - cumulative_len[:-1]).max() + + position_ids = [] + for i in range(1, len(cumulative_len)): + chunk_tokens = cumulative_len[i] - cumulative_len[i - 1] + position_ids.append(torch.arange(chunk_tokens)) + position_ids = torch.cat(position_ids, dim=0).unsqueeze(0) + + from mmengine import MessageHub + rank = dist.get_rank() + message_hub = MessageHub.get_instance('varlen_attn_args') + message_hub.update_info(f'cumulative_len_rank_{rank}', + cumulative_len) + message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen) + else: + + position_ids = torch.arange(0, tokens).unsqueeze(0).repeat(bs, 1) + input_embeds = self.llm.get_input_embeddings()(input_ids) - + bs, tokens, dim = input_embeds.shape - flat_embeds = input_embeds.flatten(0,1) - + flat_embeds = input_embeds.flatten(0, 1) + img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) - flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, flat_bs_img_rngs) + flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, + flat_bs_img_rngs) input_embeds = flat_embeds.reshape(bs, tokens, dim) outputs = self.llm( @@ -225,7 +222,7 @@ def compute_loss(self, data): attention_mask=attention_mask, inputs_embeds=input_embeds, labels=labels) - + loss_dict = {'loss': outputs.loss} return loss_dict diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py index a166e8bae..8aa664ad5 100644 --- a/xtuner/model/modules/dispatch/internlm2.py +++ b/xtuner/model/modules/dispatch/internlm2.py @@ -173,7 +173,7 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, max_seqlen): q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( 0, 1), value_states.flatten(0, 1) - cumulative_len = torch.cat(cumulative_len, dim=0) + attn_output = flash_attn_varlen_func( q_unpad, k_unpad, diff --git a/xtuner/model/modules/dispatch/utils.py b/xtuner/model/modules/dispatch/utils.py index 4cfa26cd1..5355bce74 100644 --- a/xtuner/model/modules/dispatch/utils.py +++ b/xtuner/model/modules/dispatch/utils.py @@ -25,7 +25,6 @@ def upad_qkv(query_layer, key_layer, value_layer, attention_mask, indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py index 125f789ee..74ac5e30e 100644 --- a/xtuner/types/chat.py +++ b/xtuner/types/chat.py @@ -101,8 +101,7 @@ class CodeInterpreterCallMsg(BaseModel): def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: return chat_template.decorate_code_interpreter_call( - self.content, self.conde_interpreter_call) - + self.content, self.conde_interpreter_call) class CodeInterpreterResultMsg(BaseModel): @@ -114,8 +113,6 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: return chat_template.decorate_code_internpreter_result(self.content) - - class Functions(BaseModel): # class Parameters(BaseModel): @@ -132,7 +129,6 @@ class Functions(BaseModel): name: str description: Union[str, Dict] parameters: Union[str, Dict] - class CodeInterpreter(BaseModel): @@ -150,8 +146,6 @@ class CodeInterpreter(BaseModel): name: str description: Union[str, Dict] - - HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg] From e5716883ce2915c6ecf2a34adf5e7aa6a212e150 Mon Sep 17 00:00:00 2001 From: pppppM Date: Wed, 27 Mar 2024 19:07:39 +0800 Subject: [PATCH 4/6] support code interpreter finetune --- .../internlm2_chat_1_8b/hybrid/agent.json | 62 +++++++++++++++ .../internlm2_chat_1_8b/hybrid/example.py | 29 +++++++ .../internlm2_chat_1_8b_function_call.py | 5 +- .../hybrid/internlm2_chat_1_8b_llava_sft.py | 50 ++++++++---- .../hybrid/multi_modal.json | 4 +- xtuner/dataset/hybrid/dataset.py | 7 +- xtuner/types/chat.py | 64 +++++++--------- xtuner/types/chat_template.py | 22 ++++-- xtuner/types/train.py | 76 +++++++++++++++++-- 9 files changed, 249 insertions(+), 70 deletions(-) create mode 100644 xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json create mode 100644 xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json new file mode 100644 index 000000000..89a82e4aa --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json @@ -0,0 +1,62 @@ +{ + "messages": [ + {"role": "system", "content": "You are InternLM2-Chat, a harmless AI assistant"}, + { + "role": "user", + "content": "Please help me process and visualize this dataset.", + "files": [{"path": "data.csv", "size": "10K"}] + }, + { + "role": "assistant", + "content": "I have processed the data and visualized it for you.", + "code_interpreter_call": "```python\nimport plotly.express as px\nimport pandas as pd\n\n# Load the data into a pandas dataframe\ndf = pd.read_csv('data.csv')\n\n# Create a scatter plot of rainfall vs wind direction\nfig = px.scatter(df, x='WindDir9am', y='Rainfall', color='WindDir3pm',\n labels={'WindDir9am': 'Wind Direction 9am', 'Rainfall': '\n\nRainfall', 'WindDir3pm': 'Wind Direction 3pm'},\n title='Rainfall vs Wind Direction',\n template='plotly_dark',\n width=600, height=500)\n\n# Add a hover effect to show the date\nfig.update_traces(hovertemplate='Date: %{text}
Wind Direction 9am: %{x}
Rainfall: %{y}
Wind Direction 3pm: %{marker.color}')\n\n# Show the plot\nfig.show()\n```" + }, + { + "role": "code_interpreter", + "content": "![image](xxx.png)" + }, + { + "role": "assistant", + "content": "Since the code output is not included here, I cannot provide specific chart content. However, if the code executed correctly, it should display a polar plot with two filled areas representing the relationship between wind direction at 9 am and rainfall, and between wind direction at 3 pm and rainfall, respectively. The values for each direction are based on the average rainfall calculated from the provided dataset. The chart should have a clear title, a legend, and be intuitive for comparing rainfall with different wind directions. Given the use of a dark theme, the overall appearance of the chart should be bright lines and filled areas on a dark background." + }, + { + "role": "user", + "content": "I want to know today's weather in Shanghai" + }, + { + "role": "assistant", + "content": "Sure, I will search for the weather of Shanghai.", + "function_call": { + "name": "get_current_weather", + "parameters": {"location": "Shanghai"} + } + }, + { + "role": "function", + "name": "get_current_weather", + "content": "{'temperature': 22}" + }, + { + "role": "assistant", + "content": "The weather in Shanghai is 22 celsius" + } + ], + + "functions": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + "unit": {"type": "string"}}, + "required": ["location"] + } + } + } + ], + + "code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)"} \ No newline at end of file diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py new file mode 100644 index 000000000..8e7a0d0dd --- /dev/null +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py @@ -0,0 +1,29 @@ +import json + +from xtuner.types import HybridChatTemplate, TrainingHybridChatMessages + + +chat_template = HybridChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + files='<|im_start|>user name=file\n{files}<|im_end|>\n', + function_call='{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result='<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n', + code_interpreter_call='{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + code_interpreter_result='<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + code_interpreter='<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n' + +) + +agent_data = json.load(open('agent.json')) + +msg = TrainingHybridChatMessages.from_dict(agent_data) +print(msg.apply_chat_template(chat_template)) + +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained('internlm/internlm2-chat-7b', trust_remote_code=True) +print(msg.tokenize(tokenizer, chat_template)) \ No newline at end of file diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py index a6d2a8049..2c1df507b 100644 --- a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_function_call.py @@ -4,10 +4,8 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR - from torch.optim import AdamW from transformers import AutoModelForCausalLM, AutoTokenizer - from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn from xtuner.dataset.hybrid.mappings import openai_to_raw_training @@ -74,7 +72,6 @@ trust_remote_code=True, padding_side='right') - model = dict( type=HybridFinetune, llm=dict( @@ -95,7 +92,7 @@ chat_template=chat_template, max_length=max_length, pack_to_max_length=True, - num_workers = dataloader_num_workers, + num_workers=dataloader_num_workers, mappings=[openai_to_raw_training]) train_dataloader = dict( diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py index 97bae7ac3..a010f44b5 100644 --- a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/internlm2_chat_1_8b_llava_sft.py @@ -4,9 +4,11 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig from torch.optim import AdamW from transformers import (AutoModelForCausalLM, AutoTokenizer, - CLIPImageProcessor, CLIPVisionModel) + BitsAndBytesConfig, CLIPImageProcessor, + CLIPVisionModel) from xtuner.dataset.hybrid import HybridDataset, hybrid_collate_fn from xtuner.dataset.hybrid.mappings import (insert_img_pad_tokens, @@ -21,15 +23,17 @@ # PART 1 Settings # ####################################################################### # Model +# llm_name_or_path = '/mnt/petrelfs/share_data/basemodel/checkpoints/llm/hf_hub/models--internlm--internlm2-chat-1_8b/snapshots/aa8a7450c2227a3b6733b3c6fe33fefbb2ca54f9/' llm_name_or_path = '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f/' visual_encoder_name_or_path = 'openai/clip-vit-large-patch14-336' +use_varlen_attn = False # Specify the pretrained pth pretrained_pth = None # Data data_dir = './llava_data/' data_files = ['LLaVA-Instruct-150K/llava_v1_5_mix665k.json'] image_dir = data_dir + 'llava_images' -max_length = 1024 * 32 +max_length = 1024 * 2 # Chat Template chat_template = dict( @@ -46,12 +50,12 @@ functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n') # Scheduler & Optimizer -batch_size = 1 # per_device +batch_size = 16 # per_device accumulative_counts = 1 -dataloader_num_workers = 4 +dataloader_num_workers = 0 max_epochs = 1 optim_type = AdamW -lr = 2e-4 +lr = 0 betas = (0.9, 0.999) weight_decay = 0 max_norm = 1 # grad clip @@ -86,14 +90,34 @@ freeze_llm=False, freeze_visual_encoder=True, pretrained_pth=pretrained_pth, + use_varlen_attn=use_varlen_attn, llm=dict( type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=llm_name_or_path, trust_remote_code=True, - torch_dtype=torch.float16), + torch_dtype=torch.bfloat16, + attn_implementation='flash_attention_2', + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + llm_lora=dict( + type=LoraConfig, + r=512, + lora_alpha=256, + lora_dropout=0.05, + bias='none', + task_type='CAUSAL_LM'), visual_encoder=dict( type=CLIPVisionModel.from_pretrained, - pretrained_model_name_or_path=visual_encoder_name_or_path)) + pretrained_model_name_or_path=visual_encoder_name_or_path), + visual_encoder_lora=dict( + type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none')) ####################################################################### # PART 3 Dataset & Dataloader # @@ -102,16 +126,16 @@ type=HybridDataset, data_dir=data_dir, data_files=data_files, - data_cached='cached_llava', + # data_cached='cached_llava', image_dir=image_dir, - sample_ratio=1, + sample_ratio=0.1, tokenizer=tokenizer, chat_template=chat_template, image_processor=image_processor, pad_img_to_squared=True, max_length=max_length, - pack_to_max_length=True, - num_workers=dataloader_num_workers, + pack_to_max_length=False, + num_workers=4, mappings=[ llava_to_openai, openai_to_raw_training, @@ -120,7 +144,7 @@ train_dataloader = dict( batch_size=batch_size, - num_workers=dataloader_num_workers, + num_workers=4, dataset=llava_dataset, sampler=dict(type=DefaultSampler, shuffle=True), collate_fn=dict(type=hybrid_collate_fn)) @@ -182,7 +206,7 @@ # record the time of every iteration. timer=dict(type=IterTimerHook), # print log every 10 iterations. - logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1), # enable the parameter scheduler. param_scheduler=dict(type=ParamSchedulerHook), # save checkpoint per `save_steps`. diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json index 0b1576131..ebe5cf457 100644 --- a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/multi_modal.json @@ -13,7 +13,7 @@ "image_url": "image2.jpg" }, { - "type": "text", + "type": "text", "text": "What are the colors of the bus in the first image?" } ] @@ -37,5 +37,3 @@ ] } ] - - \ No newline at end of file diff --git a/xtuner/dataset/hybrid/dataset.py b/xtuner/dataset/hybrid/dataset.py index e8f127fc6..b2699e048 100644 --- a/xtuner/dataset/hybrid/dataset.py +++ b/xtuner/dataset/hybrid/dataset.py @@ -287,7 +287,6 @@ def img_sample_counter(item): def img_counter(item): return len(item['image_urls']) - with ThreadPoolExecutor(max_workers=self.num_workers) as executor: images = list( tqdm( @@ -403,8 +402,10 @@ def __getitem__(self, item: int) -> Dict[str, List]: assistant='{assistant}<|im_end|>\n', stop_words=['<|im_end|>'], image_token='', - function_call='{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 - function_result='<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n' ) diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py index 74ac5e30e..cd0a4d4a7 100644 --- a/xtuner/types/chat.py +++ b/xtuner/types/chat.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Literal, Union +from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel @@ -9,7 +9,7 @@ class TextContentItem(BaseModel): type: Literal['text'] text: str - def format_content(self, chat_template: HybridChatTemplate) -> str: + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: return self.text @@ -17,10 +17,18 @@ class ImageContentItem(BaseModel): type: Literal['image_url'] image_url: str - def format_content(self, chat_template: HybridChatTemplate) -> str: + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: return chat_template.image_token +class FileContentItem(BaseModel): + type: Literal['file_url'] + file_url: str + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return self.file_url + + MultModalContentType = Union[TextContentItem, ImageContentItem] ContentType = Union[str, List[MultModalContentType]] @@ -28,6 +36,7 @@ def format_content(self, chat_template: HybridChatTemplate) -> str: class ChatMsg(BaseModel): role: Literal['assistant', 'user', 'system'] content: ContentType + files: List[Union[str, Dict]] = [] def collect_img_urls(self) -> List[str]: img_urls = [] @@ -45,16 +54,20 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: text = '' for i, item in enumerate(self.content): if i == 0: - text += item.format_content(chat_template) + text += item.apply_chat_template(chat_template) else: - text += '\n' + item.format_content(chat_template) + text += '\n' + item.apply_chat_template(chat_template) else: raise NotImplementedError if self.role == 'system': prompt = chat_template.decorate_system(text) elif self.role == 'user': + if len(self.files) > 0: + stop_word = chat_template.stop_words[0] + text += f'\n{stop_word}\n{chat_template.decorate_files(self.files)}' prompt = chat_template.decorate_user(text) + elif self.role == 'assistant': prompt = chat_template.decorate_assistant(text) else: @@ -105,50 +118,22 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: class CodeInterpreterResultMsg(BaseModel): - role: Literal['function'] - name: str + role: Literal['code_interpreter'] content: Union[str, Dict] def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: - return chat_template.decorate_code_internpreter_result(self.content) + return chat_template.decorate_code_interpreter_result(self.content) class Functions(BaseModel): - # class Parameters(BaseModel): - - # class Property(BaseModel): - # type: str - # description: str - # enum: Optional[List] = None - - # type: Literal['object'] - # properties: Dict[str, Property] - # required: List[str] - name: str description: Union[str, Dict] parameters: Union[str, Dict] -class CodeInterpreter(BaseModel): - - # class Parameters(BaseModel): - - # class Property(BaseModel): - # type: str - # description: str - # enum: Optional[List] = None - - # type: Literal['object'] - # properties: Dict[str, Property] - # required: List[str] - - name: str - description: Union[str, Dict] - - -HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg] +HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg, + CodeInterpreterCallMsg, CodeInterpreterResultMsg] class HybridChatMessages(BaseModel): @@ -156,6 +141,7 @@ class HybridChatMessages(BaseModel): messages: List[HybridChatMsgType] = [] # images: List[Image.Image] = [] functions: List[Functions] = [] + code_interpreter: Optional[str] = None # TODO (pppppM) add audio and video @@ -171,6 +157,10 @@ def pop_latest_msg(self): def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: prompt = '' + + if self.code_interpreter: + prompt += chat_template.decorate_functions(self.code_interpreter) + if len(self.functions) > 0: functions = [func.model_dump() for func in self.functions] diff --git a/xtuner/types/chat_template.py b/xtuner/types/chat_template.py index 847604bfe..4318c6104 100644 --- a/xtuner/types/chat_template.py +++ b/xtuner/types/chat_template.py @@ -20,7 +20,10 @@ class HybridChatTemplate(BaseModel): image_token_index: int = -100 # Agent Chat + # Interpreter and function related strings + files: Optional[str] = None + functions: Optional[str] = None # Function description format function_call: Optional[str] = None # Function call format function_result: Optional[str] = None # Function result format @@ -52,6 +55,10 @@ def decorate_user(self, text: str) -> str: """Decorate text with the `user` template.""" return self.user.format(user=text) + def decorate_files(self, text: str) -> str: + """Decorate text with the `functions` template.""" + return self.files.format(files=text) + def decorate_functions(self, text: str) -> str: """Decorate text with the `functions` template.""" return self.functions.format(functions=text) @@ -68,9 +75,10 @@ def decorate_code_interpreter(self, text: str) -> str: """Decorate text with the `code_interpreter` template.""" return self.code_interpreter.format(code_interpreter=text) - def decorate_code_interpreter_call(self, text: str) -> str: + def decorate_code_interpreter_call(self, text: str, func: str) -> str: """Decorate text with the `code_interpreter_call` template.""" - return self.code_interpreter_call.format(code_interpreter_call=text) + return self.code_interpreter_call.format( + assistant=text, code_interpreter_call=func) def decorate_code_interpreter_result(self, text: str) -> str: """Decorate text with the `code_interpreter_result` template.""" @@ -164,9 +172,13 @@ def check_code_interpreter_call(cls, v: str) -> str: If not, raises a ValueError. """ - if v is not None and '{code_interpreter_call}' not in v: - raise ValueError('code_interpreter_call must contain the keyword ' - "'{code_interpreter_call}'") + if (v is not None and '{code_interpreter_call}' not in v + and '{assistant}' not in v): + raise ValueError('code_interpreter_call must contain the keywords ' + "'{assistant}' and '{code_interpreter_call}'") + if v is not None and '{assistant}' not in v: + raise ValueError('code_interpreter_call must contain the keywords ' + "'{assistant}' and '{code_interpreter_call}'") return v @field_validator('code_interpreter_result') diff --git a/xtuner/types/train.py b/xtuner/types/train.py index a3775bffe..6235e1d41 100644 --- a/xtuner/types/train.py +++ b/xtuner/types/train.py @@ -8,8 +8,9 @@ from xtuner.utils import IGNORE_INDEX from xtuner.utils.tokenizer import get_bos_token_ids -from .chat import (ChatMsg, FunctionCallMsg, FunctionResultMsg, Functions, - ImageContentItem, TextContentItem) +from .chat import (ChatMsg, CodeInterpreterCallMsg, CodeInterpreterResultMsg, + FileContentItem, FunctionCallMsg, FunctionResultMsg, + Functions, ImageContentItem, TextContentItem) from .chat_template import HybridChatTemplate @@ -125,6 +126,40 @@ def tokenize(self, tokenizer, chat_template: HybridChatTemplate): return {'input_ids': token_ids, 'labels': label_ids} +class TrainingCodeInterpreterCallMsg(CodeInterpreterCallMsg): + loss: bool = True + + def tokenize(self, tokenizer, chat_template: HybridChatTemplate): + + decorated = self.apply_chat_template(chat_template) + + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + return {'input_ids': token_ids, 'labels': label_ids} + + +class TrainingCodeInterpreterResultMsg(CodeInterpreterResultMsg): + loss: bool = False + + def tokenize(self, tokenizer, chat_template: HybridChatTemplate): + + decorated = self.apply_chat_template(chat_template) + + token_ids = tokenizer.encode(decorated, add_special_tokens=False) + + if self.loss: + label_ids = copy.deepcopy(token_ids) + else: + label_ids = [IGNORE_INDEX] * len(token_ids) + + return {'input_ids': token_ids, 'labels': label_ids} + + class RawTrainingData(BaseModel): input_ids: List[int] @@ -148,12 +183,15 @@ class Config: TraingHybridMessageType = Union[TrainingChatMsg, TrainingFunctionCallMsg, - TrainingFunctionResultMsg] + TrainingFunctionResultMsg, + TrainingCodeInterpreterCallMsg, + TrainingCodeInterpreterResultMsg] class TrainingHybridChatMessages(BaseModel): messages: List[TraingHybridMessageType] functions: Optional[List[Functions]] = None + code_interpreter: Optional[str] = None @classmethod def from_dict(cls, item) -> 'TrainingHybridChatMessages': @@ -187,6 +225,12 @@ def from_dict(cls, item) -> 'TrainingHybridChatMessages': messages.append(msg) continue + if _role == 'code_interpreter': + msg_factory = TrainingCodeInterpreterResultMsg + msg = msg_factory(role=_role, content=_content) + messages.append(msg) + continue + if isinstance(_content, list): content = [] @@ -202,6 +246,10 @@ def from_dict(cls, item) -> 'TrainingHybridChatMessages': _url = c_item['image_url'] content.append( ImageContentItem(type=_type, image_url=_url)) + elif _type == 'file_url': + assert 'file_url' in c_item + _url = c_item['file_url'] + content.append(FileContentItem(file_url=_url)) else: raise NotImplementedError @@ -215,13 +263,25 @@ def from_dict(cls, item) -> 'TrainingHybridChatMessages': role=_role, content=_content, function_call=_call) messages.append(msg) continue + elif isinstance(_content, str) and 'code_interpreter' in _msg: + _call = _msg['function_call'] + msg = TrainingCodeInterpreterCallMsg( + role=_role, content=_content, code_interpreter_call=_call) + messages.append(msg) + continue if isinstance(_content, str): + # breakpoint() msg = TrainingChatMsg(role=_role, content=_content) messages.append(msg) # TODO (pppppM) add format warning + if 'code_interpreter' in item: + + assert isinstance(item['code_interpreter'], str) + code_interpreter = item['code_interpreter'] + if 'functions' in item: _functions = item['functions'] assert isinstance(_functions, list) @@ -240,7 +300,10 @@ def from_dict(cls, item) -> 'TrainingHybridChatMessages': name=_name, description=_desc, parameters=_params) functions.append(func) - return cls(messages=messages, functions=functions) + return cls( + messages=messages, + functions=functions, + code_interpreter=code_interpreter) def collect_img_urls(self) -> List[str]: img_urls = [] @@ -262,7 +325,10 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: prompt += chat_template.decorate_functions(functions) for msg in self.messages: - prompt += msg.apply_chat_template(chat_template) + if msg.role == 'system': + prompt = msg.apply_chat_template(chat_template) + prompt + else: + prompt += msg.apply_chat_template(chat_template) return prompt From 654f1b157e7e73bac772bb33250d3022f8d8d14c Mon Sep 17 00:00:00 2001 From: pppppM Date: Thu, 28 Mar 2024 13:51:49 +0800 Subject: [PATCH 5/6] BaseTune and BaseEncoder --- .../internlm2_chat_1_8b/hybrid/agent.json | 59 ++-- .../internlm2_chat_1_8b/hybrid/example.py | 26 +- xtuner/model/auto.py | 20 ++ xtuner/model/base.py | 71 +++++ xtuner/model/encoders/__init__.py | 2 + xtuner/model/encoders/base.py | 53 ++++ xtuner/model/encoders/llava.py | 284 ++++++++++++++++++ xtuner/model/hybrid.py | 228 ++++++++------ xtuner/types/__init__.py | 4 +- 9 files changed, 606 insertions(+), 141 deletions(-) create mode 100644 xtuner/model/auto.py create mode 100644 xtuner/model/base.py create mode 100644 xtuner/model/encoders/__init__.py create mode 100644 xtuner/model/encoders/base.py create mode 100644 xtuner/model/encoders/llava.py diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json index 89a82e4aa..667c82d21 100644 --- a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/agent.json @@ -2,61 +2,62 @@ "messages": [ {"role": "system", "content": "You are InternLM2-Chat, a harmless AI assistant"}, { - "role": "user", - "content": "Please help me process and visualize this dataset.", + "role": "user", + "content": "Please help me process and visualize this dataset.", "files": [{"path": "data.csv", "size": "10K"}] - }, + }, { - "role": "assistant", - "content": "I have processed the data and visualized it for you.", + "role": "assistant", + "content": "I have processed the data and visualized it for you.", "code_interpreter_call": "```python\nimport plotly.express as px\nimport pandas as pd\n\n# Load the data into a pandas dataframe\ndf = pd.read_csv('data.csv')\n\n# Create a scatter plot of rainfall vs wind direction\nfig = px.scatter(df, x='WindDir9am', y='Rainfall', color='WindDir3pm',\n labels={'WindDir9am': 'Wind Direction 9am', 'Rainfall': '\n\nRainfall', 'WindDir3pm': 'Wind Direction 3pm'},\n title='Rainfall vs Wind Direction',\n template='plotly_dark',\n width=600, height=500)\n\n# Add a hover effect to show the date\nfig.update_traces(hovertemplate='Date: %{text}
Wind Direction 9am: %{x}
Rainfall: %{y}
Wind Direction 3pm: %{marker.color}')\n\n# Show the plot\nfig.show()\n```" - }, + }, { - "role": "code_interpreter", + "role": "code_interpreter", "content": "![image](xxx.png)" - }, + }, { - "role": "assistant", + "role": "assistant", "content": "Since the code output is not included here, I cannot provide specific chart content. However, if the code executed correctly, it should display a polar plot with two filled areas representing the relationship between wind direction at 9 am and rainfall, and between wind direction at 3 pm and rainfall, respectively. The values for each direction are based on the average rainfall calculated from the provided dataset. The chart should have a clear title, a legend, and be intuitive for comparing rainfall with different wind directions. Given the use of a dark theme, the overall appearance of the chart should be bright lines and filled areas on a dark background." - }, + }, { - "role": "user", + "role": "user", "content": "I want to know today's weather in Shanghai" }, { - "role": "assistant", - "content": "Sure, I will search for the weather of Shanghai.", + "role": "assistant", + "content": "Sure, I will search for the weather of Shanghai.", "function_call": { - "name": "get_current_weather", + "name": "get_current_weather", "parameters": {"location": "Shanghai"} } - }, + }, { - "role": "function", - "name": "get_current_weather", + "role": "function", + "name": "get_current_weather", "content": "{'temperature': 22}" - }, + }, { - "role": "assistant", + "role": "assistant", "content": "The weather in Shanghai is 22 celsius" } - ], - + ], + "functions": [ { - "name": "get_current_weather", - "description": "Get the current weather in a given location", + "name": "get_current_weather", + "description": "Get the current weather in a given location", "parameters": { - "type": "object", + "type": "object", "properties": { "location": { - "type": "string", + "type": "string", "description": "The city and state, e.g. San Francisco, CA", - "unit": {"type": "string"}}, + "unit": {"type": "string"}}, "required": ["location"] } } } - ], - - "code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)"} \ No newline at end of file + ], + + "code_interpreter": "You now have access to a Jupyter notebook environment supporting Python code execution. Just send code to python to run in this stateful environment. This feature is suitable for:\n- Data analysis or processing (such as data manipulation and graphic creation)\n- Complex calculations (such as math and physics problems)\n- Programming examples (for understanding programming concepts or language features)\n- Text processing and analysis (including text analysis and natural language processing)\n- Machine learning and data science (model training and data visualization)\n- File operations and data import (handling CSV, JSON, etc. formats)" +} diff --git a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py index 8e7a0d0dd..e9d5796bc 100644 --- a/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py +++ b/xtuner/configs/internlm/internlm2_chat_1_8b/hybrid/example.py @@ -2,22 +2,24 @@ from xtuner.types import HybridChatTemplate, TrainingHybridChatMessages - chat_template = HybridChatTemplate( system='<|im_start|>system\n{system}<|im_end|>\n', user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', assistant='{assistant}<|im_end|>\n', stop_words=['<|im_end|>'], image_token='', - files='<|im_start|>user name=file\n{files}<|im_end|>\n', - function_call='{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 - function_result='<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + files='<|im_start|>user name=file\n{files}<|im_end|>\n', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n', - code_interpreter_call='{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 - code_interpreter_result='<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 - code_interpreter='<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n' - -) + code_interpreter_call= + '{assistant}<|action_start|><|interpreter|>\n{code_interpreter_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + code_interpreter_result= + '<|im_start|>environment name=<|interpreter|>\n{code_interpreter_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + code_interpreter= + '<|im_start|>system name=<|interpreter|>\n{code_interpreter}<|im_end|>\n') agent_data = json.load(open('agent.json')) @@ -25,5 +27,7 @@ print(msg.apply_chat_template(chat_template)) from transformers import AutoTokenizer -tokenizer = AutoTokenizer.from_pretrained('internlm/internlm2-chat-7b', trust_remote_code=True) -print(msg.tokenize(tokenizer, chat_template)) \ No newline at end of file + +tokenizer = AutoTokenizer.from_pretrained( + 'internlm/internlm2-chat-7b', trust_remote_code=True) +print(msg.tokenize(tokenizer, chat_template)) diff --git a/xtuner/model/auto.py b/xtuner/model/auto.py new file mode 100644 index 000000000..d525f80e4 --- /dev/null +++ b/xtuner/model/auto.py @@ -0,0 +1,20 @@ +from mmengine import Config + +from xtuner.model.base import BaseTune +from xtuner.registry import BUILDER + + +class AutoModel(): + + @classmethod + def from_config(cls, config: str): + config = Config.fromfile(config) + model: BaseTune = BUILDER.build(config.model) + return model + + @classmethod + def from_pretrained(cls, config: str, checkpoint: str): + config = Config.fromfile(config) + model: BaseTune = BUILDER.build(config.model) + model.load_checkpoint(checkpoint) + return model diff --git a/xtuner/model/base.py b/xtuner/model/base.py new file mode 100644 index 000000000..84c4c1879 --- /dev/null +++ b/xtuner/model/base.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractclassmethod, abstractmethod + +from mmengine.model import BaseModel + +from xtuner.types import HybridChatMessages, HybridChatTemplate + + +class BaseTune(BaseModel): + + def __init__(): + super().__init__() + + def init_weights(self): + """Parent class method. + + To avoid overwriting the loaded weights, overload it to an empty + function. + """ + pass + + def avoid_override_weights(self): + self._is_init = True + + @abstractmethod + @property + def chat_template(self) -> HybridChatTemplate: + pass + + @abstractmethod + @property + def llm(self): + pass + + @abstractmethod + @property + def tokenizer(self): + pass + + @abstractmethod + def gradient_checkpointing_enable(self): + pass + + def forward(self, data, data_samples=None, mode='loss'): + """Overload parent class method, only support training.""" + + if mode == 'loss': + return self.compute_loss(data) + else: + raise NotImplementedError( + f"{type(self)}'s forward is only supported for use during " + 'training. If you want to get predictions or chat, please ' + "directly use `llm`'s forward.") + + @abstractmethod + def chat(self, messages: HybridChatMessages, sample_params, streamer): + pass + + @abstractmethod + def save_checkpoint(self, *args, **kwargs): + pass + + @abstractmethod + def load_checkpoint(self, *args, **kwargs) -> 'BaseTune': + pass + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.llm, name) diff --git a/xtuner/model/encoders/__init__.py b/xtuner/model/encoders/__init__.py new file mode 100644 index 000000000..4a863bf0c --- /dev/null +++ b/xtuner/model/encoders/__init__.py @@ -0,0 +1,2 @@ +from .base import EncoderWrapper +from .llava import LlavaEncoderWrapper diff --git a/xtuner/model/encoders/base.py b/xtuner/model/encoders/base.py new file mode 100644 index 000000000..462c42091 --- /dev/null +++ b/xtuner/model/encoders/base.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractclassmethod, abstractmethod +from typing import List, Union + +import torch +from PIL import Image +from torch import nn + +_ImageType = Union[str, Image.Image] + + +class EncoderWrapper(nn.Module): + + def __init__(self): + super().__init__() + + @abstractmethod + @property + def encoder(self): + pass + + @abstractmethod + @property + def projector(self): + pass + + @abstractmethod + def post_init_proj(self, llm): + pass + + @abstractmethod + def preprocess(self, image: _ImageType) -> torch.Tensor: + pass + + @abstractmethod + def batch_infer(images: List[_ImageType]) -> List[torch.Tensor]: + pass + + @abstractmethod + def gradient_checkpointing_enable(self): + pass + + @abstractclassmethod + def save_checkpoint(self, *args, **kwargs): + pass + + @abstractclassmethod + def load_checkpoint(self, *args, **kwargs) -> 'EncoderWrapper': + pass + + @abstractclassmethod + def only_build_processor(self, *args, **kwargs): + pass diff --git a/xtuner/model/encoders/llava.py b/xtuner/model/encoders/llava.py new file mode 100644 index 000000000..1267fa8ba --- /dev/null +++ b/xtuner/model/encoders/llava.py @@ -0,0 +1,284 @@ +import base64 +import os +from collections import OrderedDict +from io import BytesIO +from typing import List, Literal, Optional, Union + +import requests +import torch +from accelerate import load_checkpoint_in_model +from peft import LoraConfig, PeftModel +from PIL import Image +from torch import nn +from transformers import AutoModel, CLIPImageProcessor, CLIPVisionModel + +from xtuner.dataset.utils import expand2square +from xtuner.utils.config import build_from_cfg_or_obj +from ..modules import ProjectorConfig, ProjectorModel +from ..utils import (LoadWoInit, get_peft_model_state_dict, + prepare_for_vision_lora) +from .base import BaseEncoder, _ImageType + + +def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: + """load image from base64 format.""" + return Image.open(BytesIO(base64.b64decode(image))) + + +def load_image(image_url: str) -> Image.Image: + """load image from url, local path or openai GPT4V.""" + + headers = { + 'User-Agent': + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' + } + if image_url.startswith('http'): + response = requests.get(image_url, headers=headers) + response.raise_for_status() + + # Open the image using PIL + img = Image.open(BytesIO(response.content)) + elif image_url.startswith('data:image'): + img = load_image_from_base64(image_url.split(',')[1]) + else: + img = Image.open(image_url) + + return img + + +class LlavaEncoderWrapper(BaseEncoder): + + def __init__(self, + model_name_or_path: str, + lora=None, + select_layer: int = -2, + freeze_clip: bool = True): + + super().__init__() + + assert not (lora is not None and freeze_clip) + self._projector = None + self.proj_inited = False + self.freeze_clip = freeze_clip + self.select_layer = select_layer + + _res = self.build_processor_and_encoder(model_name_or_path) + self._processor, self._encoder = _res + + if self.freeze_clip: + self._encoder.requires_grad_(False) + + if lora: + self.with_lora = True + lora_conf = build_from_cfg_or_obj(lora, accept=LoraConfig) + self._encoder = prepare_for_vision_lora(self._encoder, lora_conf) + else: + self.with_lora = False + + def post_init_proj(self, config: ProjectorConfig): + self._projector = ProjectorModel(config) + self.proj_inited = True + + def build_processor_and_encoder(self, model_name_or_path: str): + with LoadWoInit: + processor = CLIPImageProcessor.from_pretrained(model_name_or_path) + encoder = CLIPVisionModel.from_pretrained( + model_name_or_path, torch_dtype=torch.float16) + return processor, encoder + + @classmethod + def only_build_processor(self, model_name_or_path: str): + return CLIPImageProcessor.from_pretrained(model_name_or_path) + + @property + def encoder(self) -> CLIPVisionModel: + return self._encoder + + @property + def processor(self): + return self._processor + + @property + def projector(self) -> ProjectorModel: + if self._projector: + return self._projector + else: + raise RuntimeError('The projector has not been created yet, ' + 'please execute `post_init_proj` first.') + + def gradient_checkpointing_enable(self): + # For backward compatibility + if hasattr(self.encoder, 'enable_input_require_grads'): + self.encoder.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + self.encoder.get_input_embeddings().register_forward_hook( + make_inputs_require_grad) + + # enable gradient checkpointing for memory efficiency + self.encoder.gradient_checkpointing_enable() + + self.projector.enable_input_require_grads() + self.projector.gradient_checkpointing_enable() + + def preprocess(self, image: _ImageType) -> List[torch.Tensor]: + """Preprocess the input image, including expanding to square and + normalization. + + Args: + image (Image.Image): The input image need to be preprocessed. + Returns: + torch.Tensor: The preprocessed image tensor. + """ + + if isinstance(image, str): + image = load_image(image) + + if not isinstance(image, Image.Image): + raise TypeError(f"Don't support {type(image).__name__}, " + 'the image type must be `PIL.Image`.') + + processor = self.processor + image_mean = processor.image_mean + + background_color = tuple(int(x * 255) for x in image_mean) + squared_img = expand2square(image, background_color) + + processed = processor.preprocess(squared_img, return_tensors='pt') + img_tensor = processed['pixel_values'][0] # shape: 3, h, w + + # before this line, `img_tensor` is on cpu. + img_tensor = img_tensor.to(self.device).to(self.dtype) + return img_tensor + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + outputs = self.encoder(pixel_values, output_hidden_states=True) + embeddings = self.projector( + outputs.hidden_states[self.select_layer][:, 1:]) + return embeddings + + @torch.no_grad() + def batch_infer(self, images: List[_ImageType]) -> List[torch.Tensor]: + """Obtain the corresponding embeddings based on the images. + + Args: + images (List[Image.Image]): The input images. The data layout + for each image is (c, h, w). + Returns: + List[torch.Tensor]: The list of extracted features from images. + The data layout for each tensor should be (tokens, dims). + """ + + num_imgs = len(images) + + img_tensors = [self.process_img(img) for img in images] + + # Determine if all image sizes are consistent. + # TODO (pppppM): Confirm when the image size will be inconsistent + shape_consistant = all(x.shape == img_tensors[0].shape + for x in img_tensors) + + from transformers.modeling_outputs import BaseModelOutputWithPooling + + if shape_consistant: + # Batch inference when all image sizes are consistent. + # img_tensors[0] shape: (3, h, w) + # tensor shape: (num_imgs, 3, h, w) + tensor = torch.stack(img_tensors, dim=0) + + enc_out = self.visual_encoder(tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + + # feat shape: (num_imgs, tokens, dims) + feat = self.projector(enc_out.hidden_states[self.select_layer][:, + 1:]) + + # Split along the batch dimension + # The feature of each image corresponds to a tensor. + # len(features): num_imgs, features[0] shape:(1, tokens, dims) + features = torch.chunk(feat, num_imgs, dim=0) + + # per image feature's layout should be (tokens, dims) + features = [x.flatten(0, 1) for x in features] + + else: + features = [] + for tensor in img_tensors: + tensor: torch.Tensor + # The visual encoder requires a data layout of (bs, c, h, w). + # tensor shape: (3, h, w) batch_tensor shape: (1, 3, h, w) + batch_tensor = tensor.unsqueeze(0) + enc_out = self.visual_encoder( + batch_tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + # feat shape: (1, tokens, dims) + feat = self.projector( + enc_out.hidden_states[self.select_layer][:, 1:]) + features.append(feat) + + return features + + def save_checkpoint(self, dir: str): + + if self.with_lora: + _save_dir = os.path.join(dir, 'visual_encoder_adapter') + self.encoder.save_pretrained(_save_dir, safe_serialization=False) + + if not self.freeze_clip: + _save_dir = os.path.join(dir, 'visual_encoder') + self.encoder.save_pretrained(_save_dir, safe_serialization=False) + self.processor.save_pretrained(_save_dir) + + _save_dir = os.path.join(dir, 'projector') + self.projector.save_pretrained(_save_dir) + + def load_checkpoint(self, dir): + + if self.with_lora: + _ckpt_dir = os.path.join(dir, 'visual_encoder_adapter') + self.encoder.load_adapter(_ckpt_dir) + + if not self.freeze_clip: + _ckpt_dir = os.path.join(dir, 'visual_encoder') + load_checkpoint_in_model(self.encoder, _ckpt_dir) + load_checkpoint_in_model(self.processor, _ckpt_dir) + + if self.proj_inited: + _ckpt_dir = os.path.join(dir, 'projector') + load_checkpoint_in_model(self.projector, _ckpt_dir) + else: + ProjectorModel.from_pretrained(_ckpt_dir) + + def state_dict(self, *args, **kwargs): + + state_dict = super().state_dict(*args, **kwargs) + to_return = OrderedDict() + # Step 1. encoder + if self.with_lora: + to_return.update( + get_peft_model_state_dict(self.encoder, state_dict=state_dict)) + elif not self.freeze_clip: + to_return.update( + {k: v + for k, v in state_dict.items() if '_encoder.' in k}) + + # Step 2. Projector + to_return.update( + {k: v + for k, v in state_dict.items() if '_projector.' in k}) + + return to_return + + +# if __name__ == '__main__': +# img = load_image('llava.jpeg') +# model = VisionEncoderForDeploy('xtuner/llava-internlm-7b', +# 'openai/clip-vit-large-patch14-336') + +# model.cuda() +# model.eval() +# outputs = model([img]) diff --git a/xtuner/model/hybrid.py b/xtuner/model/hybrid.py index 0f0fc7e76..662db5e7e 100644 --- a/xtuner/model/hybrid.py +++ b/xtuner/model/hybrid.py @@ -1,107 +1,142 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import OrderedDict +from typing import Dict, Optional, Union import torch import torch.distributed as dist from mmengine.model import BaseModel from peft import LoraConfig from torch import nn +from transformers import PreTrainedModel, PreTrainedTokenizer from xtuner.registry import BUILDER +from xtuner.types import HybridChatMessages, HybridChatTemplate from xtuner.utils.config import build_from_cfg_or_obj -from .modules import ProjectorConfig, ProjectorModel, dispatch_modules -from .utils import (LoadWoInit, enable_hf_model_gradient_checkpointing, - get_peft_model_state_dict, prepare_for_llm_lora, - prepare_for_vision_lora, - smart_tokenizer_and_embedding_resize) +from .base import BaseTune +from .encoders import EncoderWrapper +from .modules import ProjectorConfig, dispatch_modules +from .utils import (LoadWoInit, get_peft_model_state_dict, + prepare_for_llm_lora, smart_tokenizer_and_embedding_resize) -class HybridFinetune(BaseModel): +class HybridFinetune(BaseTune): def __init__( self, - llm, - visual_encoder=None, - visual_select_layer=-2, - projector_depth=2, - pretrained_pth=None, - tokenizer=None, - llm_lora=None, - visual_encoder_lora=None, - freeze_llm=False, - freeze_visual_encoder=False, - use_activation_checkpointing=True, - use_varlen_attn=False, + llm: Union[PreTrainedModel, Dict], + tokenizer: Union[PreTrainedTokenizer, Dict], + chat_template: HybridChatTemplate, + visual_encoder: Optional[Union[EncoderWrapper, Dict]] = None, + audio_encoder: Optional[Union[EncoderWrapper, Dict]] = None, + video_encoder: Optional[Union[EncoderWrapper, Dict]] = None, + proj_depth: int = 2, + llm_lora: Optional[Union[LoraConfig, Dict]] = None, + freeze_llm: bool = False, + use_gradient_checkpointing: bool = True, + use_varlen_attn: bool = False, ): super().__init__() + tokenizer = build_from_cfg_or_obj( + tokenizer, accept=PreTrainedTokenizer) + smart_tokenizer_and_embedding_resize(tokenizer, self.llm) + self._tokenizer: PreTrainedModel = tokenizer + + self._chat_template = chat_template + # Build the base language model without initialization. # This will greatly reduce the time to build the model. with LoadWoInit(): - self.llm = build_from_cfg_or_obj(llm, nn.Module) - if visual_encoder: - visual_encoder = build_from_cfg_or_obj(visual_encoder, - nn.Module) - self.visual_encoder = visual_encoder - self.visual_select_layer = visual_select_layer - self.llm.config.use_cache = False - dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn) - - if tokenizer is not None: - if isinstance(tokenizer, dict): - tokenizer = BUILDER.build(tokenizer) - smart_tokenizer_and_embedding_resize(tokenizer, self.llm) - - projector_config = ProjectorConfig( - visual_hidden_size=self.visual_encoder.config.hidden_size, - llm_hidden_size=self.llm.config.hidden_size, - depth=projector_depth) - self.projector = ProjectorModel(projector_config).to( - self.visual_encoder.dtype) + self._llm: PreTrainedModel = build_from_cfg_or_obj(llm, nn.Module) + self._llm.config.use_cache = False self.freeze_llm = freeze_llm - self.freeze_visual_encoder = freeze_visual_encoder if self.freeze_llm: self.llm.requires_grad_(False) - if self.freeze_visual_encoder: - self.visual_encoder.requires_grad_(False) - - if use_activation_checkpointing: - # For backward compatibility - enable_hf_model_gradient_checkpointing(self.llm) - enable_hf_model_gradient_checkpointing(self.visual_encoder) - - self.projector.enable_input_require_grads() - self.projector.gradient_checkpointing_enable() - - self.use_llm_lora = llm_lora is not None - self.use_visual_encoder_lora = visual_encoder_lora is not None + self.with_lora = llm_lora is not None # Prepare the model for LoRA if specified - if self.use_llm_lora: + if self.with_lora: lora_conf = build_from_cfg_or_obj(llm_lora, accept=LoraConfig) - self.llm = prepare_for_llm_lora(self.llm, lora_conf, - use_activation_checkpointing) - - if self.use_visual_encoder_lora: - lora_conf = build_from_cfg_or_obj( - visual_encoder_lora, accept=LoraConfig) - self.visual_encoder = prepare_for_vision_lora( - self.visual_encoder, lora_conf, use_activation_checkpointing) - self._is_init = True + self.llm = prepare_for_llm_lora(self.llm, lora_conf) # Determines whether to calculate attention based on the # seq_len dimension (use_varlen_attn = False) or the actual length of # the sequence. self.use_varlen_attn = use_varlen_attn + dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn) + + if visual_encoder: + visual_encoder = build_from_cfg_or_obj(visual_encoder, + EncoderWrapper) + self.visual_encoder: EncoderWrapper = visual_encoder + _proj_config = ProjectorConfig( + visual_hidden_size=self.visual_encoder.hidden_size, + llm_hidden_size=self.llm.config.hidden_size, + depth=proj_depth) + + self.visual_encoder.post_init_proj(_proj_config) + else: + self.visual_encoder = None + + if audio_encoder: + audio_encoder = build_from_cfg_or_obj(audio_encoder, + EncoderWrapper) + self.audio_encoder: EncoderWrapper = audio_encoder + _proj_config = ProjectorConfig( + visual_hidden_size=self.audio_encoder.hidden_size, + llm_hidden_size=self.llm.config.hidden_size, + depth=proj_depth) + + self.audio_encoder.post_init_proj(_proj_config) + else: + self.audio_encoder = None + + if video_encoder: + video_encoder = build_from_cfg_or_obj(video_encoder, + EncoderWrapper) + self.video_encoder: EncoderWrapper = video_encoder + _proj_config = ProjectorConfig( + visual_hidden_size=self.video_encoder.hidden_size, + llm_hidden_size=self.llm.config.hidden_size, + depth=proj_depth) + + self.video_encoder.post_init_proj(_proj_config) + else: + self.video_encoder = None + + if use_gradient_checkpointing: + self.gradient_checkpointing_enable() + + self.avoid_override_weights() + + @property + def llm(self) -> PreTrainedModel: + return self._llm + + @property + def tokenizer(self) -> PreTrainedTokenizer: + return self._tokenizer - def init_weights(self): - """Parent class method. + @property + def chat_template(self) -> HybridChatTemplate: + return self._chat_template - To avoid overwriting the loaded weights, overload it to an empty - function. - """ - pass + def gradient_checkpointing_enable(self): + # For backward compatibility + if hasattr(self.llm, 'enable_input_require_grads'): + self.llm.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + self.llm.get_input_embeddings().register_forward_hook( + make_inputs_require_grad) + + # enable gradient checkpointing for memory efficiency + self.llm.gradient_checkpointing_enable() + self.visual_encoder.gradient_checkpointing_enable() def forward(self, data, data_samples=None, mode='loss'): """Overload parent class method, only support training.""" @@ -132,10 +167,7 @@ def _get_vision_embeds_and_ranges(self, data): batch_total_imgs = len(img_rngs) - visual_outputs = self.visual_encoder( - pixel_values, output_hidden_states=True) - features = self.projector( - visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) + features = self.visual_encoder(pixel_values) batch_total_imgs, real_img_tokens, _ = features.shape for i in range(batch_total_imgs): @@ -144,12 +176,12 @@ def _get_vision_embeds_and_ranges(self, data): img_emb = features[i] img_bs_ind = img_belongs[i] + # pack 导致的截断 if real_img_tokens == exp_img_tokens: img_embeds.append(img_emb) - elif not real_img_tokens == exp_img_tokens and img_start == 0: + elif real_img_tokens != exp_img_tokens and img_start == 0: img_embeds.append(img_emb[real_img_tokens - img_end:]) - elif (not real_img_tokens == exp_img_tokens - and img_end == tokens): + elif (real_img_tokens != exp_img_tokens and img_end == tokens): img_embeds.append(img_emb[:exp_img_tokens]) else: raise RuntimeError @@ -176,13 +208,8 @@ def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges): return flat_embeds + _empty_embeds - def compute_loss(self, data): - + def _compute_postion_ids(self, data): input_ids = data['input_ids'] - labels = data['labels'] - # position_ids = data['position_ids'] - attention_mask = data['attention_mask'] - # breakpoint() bs, tokens = input_ids.shape if self.use_varlen_attn: assert bs == 1 @@ -206,12 +233,23 @@ def compute_loss(self, data): position_ids = torch.arange(0, tokens).unsqueeze(0).repeat(bs, 1) + def compute_loss(self, data): + + input_ids = data['input_ids'] + labels = data['labels'] + attention_mask = data['attention_mask'] + + bs, tokens = input_ids.shape + position_ids = self._compute_postion_ids(data) + input_embeds = self.llm.get_input_embeddings()(input_ids) bs, tokens, dim = input_embeds.shape flat_embeds = input_embeds.flatten(0, 1) img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) + # audio_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) + # video_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, flat_bs_img_rngs) input_embeds = flat_embeds.reshape(bs, tokens, dim) @@ -229,32 +267,22 @@ def compute_loss(self, data): def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) to_return = OrderedDict() - # Step 1. visual_encoder - if self.use_visual_encoder_lora: - to_return.update( - get_peft_model_state_dict( - self.visual_encoder, state_dict=state_dict)) - elif not self.freeze_visual_encoder: - to_return.update({ - k: v - for k, v in state_dict.items() if 'visual_encoder.' in k - }) - # Step 2. LLM + + # Step 1. LLM if self.use_llm_lora: to_return.update( get_peft_model_state_dict(self.llm, state_dict=state_dict)) elif not self.freeze_llm: to_return.update( {k: v - for k, v in state_dict.items() if 'llm.' in k}) - # Step 3. Projector + for k, v in state_dict.items() if '_llm.' in k}) + + # Step 2. Visual Encoder to_return.update( {k: v - for k, v in state_dict.items() if 'projector.' in k}) + for k, v in state_dict.items() if 'visual_encoder.' in k}) return to_return - def __getattr__(self, name: str): - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.llm, name) + def chat(self, messages: HybridChatMessages, sample_params, streamer): + + prompt = messages.apply_chat_template(self.chat_template) diff --git a/xtuner/types/__init__.py b/xtuner/types/__init__.py index cc230e8f8..dd0a93ddf 100644 --- a/xtuner/types/__init__.py +++ b/xtuner/types/__init__.py @@ -1,6 +1,8 @@ +from .chat import HybridChatMessages from .chat_template import HybridChatTemplate from .train import RawTrainingData, TrainingHybridChatMessages __all__ = [ - 'HybridChatTemplate', 'RawTrainingData', 'TrainingHybridChatMessages' + 'HybridChatTemplate', 'RawTrainingData', 'TrainingHybridChatMessages', + 'HybridChatMessages' ] From cfb8fa4132635d6f3a438571699e6b30cb58b083 Mon Sep 17 00:00:00 2001 From: pppppM Date: Tue, 26 Mar 2024 21:26:44 +0800 Subject: [PATCH 6/6] support lmdeploy --- xtuner/chat/__init__.py | 0 xtuner/chat/backend/__init__.py | 5 + xtuner/chat/backend/base.py | 26 ++ xtuner/chat/backend/encoder.py | 308 +++++++++++++++++++++++ xtuner/chat/backend/huggingface.py | 224 +++++++++++++++++ xtuner/chat/backend/lmdeploy/__init__.py | 3 + xtuner/chat/backend/lmdeploy/_encoder.py | 122 +++++++++ xtuner/chat/backend/lmdeploy/_engine.py | 87 +++++++ xtuner/chat/backend/lmdeploy/backend.py | 107 ++++++++ xtuner/chat/conversation.py | 147 +++++++++++ xtuner/chat/streamer/__init__.py | 7 + xtuner/chat/streamer/huggingface.py | 37 +++ xtuner/chat/streamer/lmdeploy.py | 49 ++++ xtuner/types/__init__.py | 9 +- xtuner/types/chat.py | 4 +- xtuner/types/sample_params.py | 14 ++ 16 files changed, 1144 insertions(+), 5 deletions(-) create mode 100644 xtuner/chat/__init__.py create mode 100644 xtuner/chat/backend/__init__.py create mode 100644 xtuner/chat/backend/base.py create mode 100644 xtuner/chat/backend/encoder.py create mode 100644 xtuner/chat/backend/huggingface.py create mode 100644 xtuner/chat/backend/lmdeploy/__init__.py create mode 100644 xtuner/chat/backend/lmdeploy/_encoder.py create mode 100644 xtuner/chat/backend/lmdeploy/_engine.py create mode 100644 xtuner/chat/backend/lmdeploy/backend.py create mode 100644 xtuner/chat/conversation.py create mode 100644 xtuner/chat/streamer/__init__.py create mode 100644 xtuner/chat/streamer/huggingface.py create mode 100644 xtuner/chat/streamer/lmdeploy.py create mode 100644 xtuner/types/sample_params.py diff --git a/xtuner/chat/__init__.py b/xtuner/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/chat/backend/__init__.py b/xtuner/chat/backend/__init__.py new file mode 100644 index 000000000..54351fa29 --- /dev/null +++ b/xtuner/chat/backend/__init__.py @@ -0,0 +1,5 @@ +from .encoder import VisionEncoderForDeploy +from .huggingface import HFBackend +from .lmdeploy import LMDeployBackend + +__all__ = ['VisionEncoderForDeploy', 'HFBackend', 'LMDeployBackend'] diff --git a/xtuner/chat/backend/base.py b/xtuner/chat/backend/base.py new file mode 100644 index 000000000..0a0fd4bbe --- /dev/null +++ b/xtuner/chat/backend/base.py @@ -0,0 +1,26 @@ +from abc import abstractmethod + +from xtuner.types import HybridChatTemplate + + +class BaseBackend(): + + @property + def chat_template(self) -> HybridChatTemplate: + pass + + @abstractmethod + def create_streamer(self, iterable=False): + pass + + @abstractmethod + def chat(self, messages, streamer=None, generation_config=None): + pass + + # @abstractmethod + # def response_with_function_call(self, response: str): + # pass + + # @abstractmethod + # def response_with_code_interpreter(self, response: str): + # pass diff --git a/xtuner/chat/backend/encoder.py b/xtuner/chat/backend/encoder.py new file mode 100644 index 000000000..af05b78df --- /dev/null +++ b/xtuner/chat/backend/encoder.py @@ -0,0 +1,308 @@ +import base64 +import os +from io import BytesIO +from typing import List, Literal, Optional, Union + +import requests +import torch +from peft import PeftModel +from PIL import Image +from torch import nn +from transformers import AutoModel, CLIPImageProcessor, CLIPVisionModel + +from xtuner.dataset.utils import expand2square + + +def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: + """load image from base64 format.""" + return Image.open(BytesIO(base64.b64decode(image))) + + +def load_image(image_url: str) -> Image.Image: + """load image from url, local path or openai GPT4V.""" + + headers = { + 'User-Agent': + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' + } + if image_url.startswith('http'): + response = requests.get(image_url, headers=headers) + response.raise_for_status() + + # Open the image using PIL + img = Image.open(BytesIO(response.content)) + elif image_url.startswith('data:image'): + img = load_image_from_base64(image_url.split(',')[1]) + else: + img = Image.open(image_url) + + return img + + +ModelHub = Literal['huggingface', 'modelscope'] + + +class VisionEncoderForDeploy(nn.Module): + + def __init__(self, + model_name_or_path: str, + projector_name_or_path: str, + adapter_name_or_path: str = None, + select_layer: int = -2, + hub: ModelHub = 'huggingface', + device='cuda'): + + super().__init__() + + # model_path = self._parse_model_path(xtuner_model_name_or_path, hub) + # visual_encoder_path = self._parse_visual_encoder_path( + # model_path, visual_encoder_name_or_path, hub + # ) + # projector_path = self._parse_projector_path(model_path) + + # # parse visual encoder adapter path. + # vis_enc_adapter_path = self._parse_vis_enc_adapter_path(model_path) + + self.select_layer = select_layer + self.image_processor = CLIPImageProcessor.from_pretrained( + model_name_or_path) + print(f'Load Image Processor From {model_name_or_path}') + + visual_encoder = CLIPVisionModel.from_pretrained( + model_name_or_path, torch_dtype=torch.float16) + print(f'Load Visual Encoder From {model_name_or_path}') + + # when path is None, means without visual encoder adapter + if adapter_name_or_path: + self.visual_encoder = PeftModel.from_pretrained( + visual_encoder, adapter_name_or_path) + print(f'Load Visual Encoder Adapter From {adapter_name_or_path}') + else: + self.visual_encoder = visual_encoder + + self.projector = AutoModel.from_pretrained( + projector_name_or_path, + torch_dtype=torch.float16, + trust_remote_code=True) + print(f'Load Projector from {projector_name_or_path}') + + self.dtype = torch.float16 + self.device = device + self.to(self.device) + self.to(self.dtype) + + def process_img(self, image: Image.Image) -> List[torch.Tensor]: + """Preprocess the input image, including expanding to square and + normalization. + + Args: + image (Image.Image): The input image need to be preprocessed. + + Returns: + torch.Tensor: The preprocessed image tensor. + """ + + if isinstance(image, str): + image = load_image(image) + + if not isinstance(image, Image.Image): + raise TypeError(f"Don't support {type(image).__name__}, " + 'the image type must be `PIL.Image`.') + + processor = self.image_processor + image_mean = processor.image_mean + + background_color = tuple(int(x * 255) for x in image_mean) + squared_img = expand2square(image, background_color) + + processed = processor.preprocess(squared_img, return_tensors='pt') + img_tensor = processed['pixel_values'][0] # shape: 3, h, w + + # before this line, `img_tensor` is on cpu. + img_tensor = img_tensor.to(self.device).to(self.dtype) + return img_tensor + + @torch.no_grad() + def forward(self, images: List[Union[str, + Image.Image]]) -> List[torch.Tensor]: + """Obtain the corresponding embeddings based on the images. + + Args: + images (List[Image.Image]): The input images. The data layout + for each image is (c, h, w). + + Returns: + List[torch.Tensor]: The list of extracted features from images. + The data layout for each tensor should be (tokens, dims). + """ + + num_imgs = len(images) + + img_tensors = [self.process_img(img) for img in images] + + # Determine if all image sizes are consistent. + # TODO (pppppM): Confirm when the image size will be inconsistent + shape_consistant = all(x.shape == img_tensors[0].shape + for x in img_tensors) + + from transformers.modeling_outputs import BaseModelOutputWithPooling + + if shape_consistant: + # Batch inference when all image sizes are consistent. + # img_tensors[0] shape: (3, h, w) + # tensor shape: (num_imgs, 3, h, w) + tensor = torch.stack(img_tensors, dim=0) + + enc_out = self.visual_encoder(tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + + # feat shape: (num_imgs, tokens, dims) + feat = self.projector(enc_out.hidden_states[self.select_layer][:, + 1:]) + + # Split along the batch dimension + # The feature of each image corresponds to a tensor. + # len(features): num_imgs, features[0] shape:(1, tokens, dims) + features = torch.chunk(feat, num_imgs, dim=0) + + # per image feature's layout should be (tokens, dims) + features = [x.flatten(0, 1) for x in features] + + else: + features = [] + for tensor in img_tensors: + tensor: torch.Tensor + # The visual encoder requires a data layout of (bs, c, h, w). + # tensor shape: (3, h, w) batch_tensor shape: (1, 3, h, w) + batch_tensor = tensor.unsqueeze(0) + enc_out = self.visual_encoder( + batch_tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + # feat shape: (1, tokens, dims) + feat = self.projector( + enc_out.hidden_states[self.select_layer][:, 1:]) + features.append(feat) + + return features + + def _parse_model_path(self, name_or_path: str, hub: ModelHub) -> str: + """Parse and get the directory path of the model. It supports load + model from local directory or download from the hub. + + Args: + name_or_path (str): The directory path or name of the model. + hub (str): The hub to download models from. + + Returns: + str: The local directory path of the model. + + Raises: + NotImplementedError: If the input hub is not supported currently. + """ + + if os.path.isdir(name_or_path): + model_path = name_or_path + else: + if hub == 'huggingface': + from huggingface_hub import snapshot_download + model_path = snapshot_download(repo_id=name_or_path) + elif hub == 'modelscope': + from modelscope import snapshot_download + model_path = snapshot_download(model_id=name_or_path) + else: + raise NotImplementedError( + 'Only supports downloading models from `Huggingface` or ' + '`Modelscope`.') + + return model_path + + def _parse_visual_encoder_path(self, model_path: str, + visual_encoder_name_or_path: str, + hub: ModelHub) -> str: + """Parse and get the directory path of the visual encoder. It supports + load visual encoder from local directory, download from the hub, or + find it in the XTuner model directory. + + Args: + model_path (str): The directory path of the model. + visual_encoder_name_or_path (Optional[str]): The directory path or + name of the visual encoder. + hub (str): The hub to download models from. + + Returns: + str: The local directory path of the visual encoder. + + Raises: + NotImplementedError: If the input hub is not supported currently. + """ + + if 'visual_encoder' in os.listdir(model_path): + assert visual_encoder_name_or_path is None + visual_encoder_path = os.path.join(model_path, 'visual_encoder') + elif os.path.isdir(visual_encoder_name_or_path): + visual_encoder_path = visual_encoder_name_or_path + else: + if hub == 'huggingface': + from huggingface_hub import snapshot_download + visual_encoder_path = snapshot_download( + repo_id=visual_encoder_name_or_path) + elif hub == 'modelscope': + from modelscope import snapshot_download + visual_encoder_path = snapshot_download( + model_id=visual_encoder_name_or_path) + else: + raise NotImplementedError( + 'Only supports downloading models from `Huggingface` or ' + '`Modelscope`.') + + return visual_encoder_path + + def _parse_projector_path(self, model_path: str) -> Optional[str]: + """Parse the path of the `projector` model according to the model path. + + Args: + model_path (str): The path to the model directory. + + Raises: + ValueError: If the 'projector' directory is not found in the + `model_path`. + + Returns: + Optional[str]: The full path of 'projector' directory if exists, + else raises ValueError. + """ + if 'projector' in os.listdir(model_path): + projector_path = os.path.join(model_path, 'projector') + else: + # Raises exception if 'projector' directory/folder not found + raise ValueError('Projector directory not found in given path') + return projector_path + + def _parse_vis_enc_adapter_path(self, model_path: str) -> Optional[str]: + """Parses the model path and returns the path to + 'visual_encoder_adapter' directory. + + Args: + model_path (str): The path to the model directory. + + Returns: + Optional[str]: The full path of 'visual_encoder_adapter' directory if exists, + else returns None. + """ + if 'visual_encoder_adapter' in os.listdir(model_path): + adapter_path = os.path.join(model_path, 'visual_encoder_adapter') + else: + # Returns None if 'visual_encoder_adapter' directory/folder not found + adapter_path = None + return adapter_path + + +if __name__ == '__main__': + img = load_image('llava.jpeg') + model = VisionEncoderForDeploy('xtuner/llava-internlm-7b', + 'openai/clip-vit-large-patch14-336') + + model.cuda() + model.eval() + outputs = model([img]) diff --git a/xtuner/chat/backend/huggingface.py b/xtuner/chat/backend/huggingface.py new file mode 100644 index 000000000..51e742327 --- /dev/null +++ b/xtuner/chat/backend/huggingface.py @@ -0,0 +1,224 @@ +from typing import Optional + +import torch +from peft import PeftModel +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) +from transformers import GenerationConfig as HFGenerationConfig +from transformers import PreTrainedModel, PreTrainedTokenizer + +from xtuner.chat.streamer import HFTextIteratorStreamer, HFTextStreamer +from xtuner.model.utils import LoadWoInit +from xtuner.tools.utils import get_stop_criteria +from xtuner.types import HybridChatMessages, HybridChatTemplate, SampleParams +from .base import BaseBackend + + +class _HFBackend(BaseBackend): + + def __init__( + self, + chat_template: HybridChatTemplate, + llm: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + + self.llm = llm + self.llm.cuda() + self.tokenizer = tokenizer + + self._chat_template = chat_template + + @property + def chat_template(self) -> HybridChatTemplate: + return self._chat_template + + @property + def eos_token_id(self): + if self.tokenizer.pad_token_id: + return self.tokenizer.eos_token_id + else: + return self.tokenizer.eos_token_id + + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id + + def build_llm_and_tokenizer(self, + model_name_or_path, + adapter=None, + bits=None): + + if bits is None: + quantization_config = None + load_in_8bit = False + elif bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4') + load_in_8bit = False + elif bits == 8: + quantization_config = None + load_in_8bit = True + + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + trust_remote_code=True, + encode_special_tokens=True) + + with LoadWoInit(): + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map='auto', + load_in_8bit=load_in_8bit, + quantization_config=quantization_config, + trust_remote_code=True, + torch_dtype=torch.float16) + + if adapter is not None: + model = PeftModel.from_pretrained(model, adapter) + + model.eval() + return model, tokenizer + + def response_with_code_interpreter(self, response: str): + return False + + def response_with_function_call(self, response: str): + return False + + def create_streamer(self, chat_template=None, iterable=False): + if iterable: + return HFTextIteratorStreamer( + self.tokenizer, skip_prompt=True, chat_template=chat_template) + else: + return HFTextStreamer( + self.tokenizer, skip_prompt=True, chat_template=chat_template) + + def parse_sample_params(self, params: SampleParams) -> HFGenerationConfig: + + if params is None: + params = SampleParams() + + hf_gen_config = HFGenerationConfig( + max_new_tokens=params.max_new_tokens, + do_sample=params.temperature > 0, + temperature=params.temperature, + top_k=params.top_k, + top_p=params.top_p, + repetition_penalty=params.repetition_penalty, + seed=params.seed, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id) + + stop_words = params.stop_words + stop_words.extend(self.chat_template.stop_words) + + return hf_gen_config, stop_words + + def chat(self, + messages: HybridChatMessages, + streamer=None, + sample_params: Optional[SampleParams] = None): + + prompt = messages.apply_chat_template(self.chat_template) + ids = self.tokenizer.encode(prompt, return_tensors='pt') + + hf_gen_config, stop_words = self.parse_sample_params(sample_params) + + stop_criteria = get_stop_criteria( + tokenizer=self.tokenizer, stop_words=stop_words) + + generate_output = self.llm.generate( + inputs=ids.cuda(), + streamer=streamer, + generation_config=hf_gen_config, + stopping_criteria=stop_criteria) + + output = self.tokenizer.decode( + generate_output[0][len(ids[0]):], skip_special_tokens=True) + + for word in stop_words: + output = output.rstrip(word) + + return output + + +class HFBackend(_HFBackend): + + def __init__( + self, + chat_template: HybridChatTemplate, + llm: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + vision_tower: Optional[torch.nn.Module] = None, + ) -> None: + super().__init__(chat_template, llm, tokenizer) + + if vision_tower: + self.vision_tower = vision_tower + self.vision_tower.cuda() + self.vision_tower.eval() + else: + self.vision_tower = None + + def chat(self, + messages: HybridChatMessages, + streamer=None, + sample_params=None): + + img_urls = messages.collect_img_urls() + + if self.vision_tower is None or len(img_urls) == 0: + return super().chat(messages, streamer, sample_params) + + prompt = messages.apply_chat_template(self.chat_template) + + img_features = self.vision_tower(img_urls) + + # prompt, img_ranges = _insert_img_pad_tokens( + # prompt, self.chat_template.image_token, img_features, + # self.tokenizer.pad_token + # ) + + chunks = prompt.split(self.chat_template.image_token) + assert len(chunks) - 1 == len(img_urls) + chunk_embeddings = [] + for i in range(len(chunks)): + + chunk_ids = self.tokenizer.encode(chunks[i], return_tensors='pt') + chunk_ids = chunk_ids.to(self.llm.device) + chunk_emb = self.llm.get_input_embeddings()(chunk_ids) + chunk_embeddings.append(chunk_emb) + + if i < len(chunks) - 1: + chunk_embeddings.append(img_features[i].unsqueeze(0)) + + embeddings = torch.cat(chunk_embeddings, dim=1) + + hf_gen_config, stop_words = self.parse_sample_params(sample_params) + + stop_criteria = get_stop_criteria( + tokenizer=self.tokenizer, stop_words=stop_words) + + generate_output = self.llm.generate( + input_ids=None, + inputs_embeds=embeddings, + streamer=streamer, + generation_config=hf_gen_config, + bos_token_id=self.tokenizer.bos_token_id, + stopping_criteria=stop_criteria) + + output = self.tokenizer.decode( + generate_output[0], skip_special_tokens=True) + + for word in stop_words: + output = output.rstrip(word) + + return output diff --git a/xtuner/chat/backend/lmdeploy/__init__.py b/xtuner/chat/backend/lmdeploy/__init__.py new file mode 100644 index 000000000..139c066fb --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/__init__.py @@ -0,0 +1,3 @@ +from .backend import LMDeployBackend + +__all__ = ['LMDeployBackend'] diff --git a/xtuner/chat/backend/lmdeploy/_encoder.py b/xtuner/chat/backend/lmdeploy/_encoder.py new file mode 100644 index 000000000..3466eb30f --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/_encoder.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import queue +import time +from threading import Thread +from typing import List, Union + +import torch +from lmdeploy.utils import get_logger +from PIL.Image import Image + +logger = get_logger('lmdeploy') + + +class Record: + """Batching manager.""" + + def __init__(self): + self.number = [] + self.waiting = [] + self.done = [] + self.res_que = [] + self.total = 0 + + def enqueue(self, images: List[Image], que: Union[queue.Queue, + asyncio.Queue]): + """add ith request to manager.""" + self.number.append(len(images)) + self.waiting.extend(images) + self.res_que.append(que) + self.total += len(images) + self.log('received', len(images)) + + def dequeue(self, max_batch_size): + """try to dequeue max batch size images.""" + inputs = self.waiting[:max_batch_size] + self.waiting = self.waiting[max_batch_size:] + self.total -= len(inputs) + self.log('process', len(inputs)) + return inputs + + def nofify(self): + """set result if request i is finished.""" + if len(self.number) == 0 or self.number[0] > len(self.done): + return False + num_images = self.number.pop(0) + outputs = self.done[:num_images] + self.done = self.done[num_images:] + que = self.res_que.pop(0) + if isinstance(que, queue.Queue): + que.put(outputs) + else: + que._loop.call_soon_threadsafe(que.put_nowait, outputs) + self.log('done', num_images) + return True + + def log(self, task: str, num: int): + logger.info(f'ImageEncoder {task} {num} images, ' + f'left {self.total} images.') + + +class _AsyncEncoderWrapper: + """Image encoder.""" + + def __init__(self, model, max_batch_size: int = 16): + self.model = model + self.max_batch_size = max_batch_size + self.loop = asyncio.new_event_loop() + self.work_thread = self._start_work_thread() + torch.cuda.empty_cache() + + def _start_work_thread(self): + """internal thread.""" + + def _work_thread(): + asyncio.set_event_loop(self.loop) + self.que = asyncio.Queue() + self.loop.run_until_complete(self._forward_loop()) + + thread = Thread(target=_work_thread, daemon=True) + thread.start() + return thread + + async def _forward_loop(self): + """working loop to process images.""" + logger.info('start ImageEncoder._forward_loop') + record = Record() + while True: + while record.total == 0 or (self.que.qsize() and + record.total < self.max_batch_size): + item = await self.que.get() + record.enqueue(item[0], item[1]) + inputs = record.dequeue(self.max_batch_size) + outputs = self.forward(inputs) + record.done.extend(outputs) + while record.nofify(): + pass + + def forward(self, inputs: List[Image]): + """Model forward.""" + time_start = time.perf_counter() + outputs = self.model.forward(inputs) + time_end = time.perf_counter() + logger.info(f'ImageEncoder forward {len(inputs)} images, ' + f'cost {time_end - time_start:.3f}s') + return outputs + + def infer(self, inputs: List[Image]): + """infer.""" + outputs = queue.Queue() + item = (inputs, outputs) + self.loop.call_soon_threadsafe(self.que.put_nowait, item) + results = outputs.get() + return results + + async def async_infer(self, inputs: List[Image]): + """async infer.""" + outputs = asyncio.Queue() + item = (inputs, outputs) + self.loop.call_soon_threadsafe(self.que.put_nowait, item) + results = await outputs.get() + return results diff --git a/xtuner/chat/backend/lmdeploy/_engine.py b/xtuner/chat/backend/lmdeploy/_engine.py new file mode 100644 index 000000000..d81d30c6c --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/_engine.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from lmdeploy.serve.async_engine import AsyncEngine +from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX + +from xtuner.types import HybridChatMessages, HybridChatTemplate + + +class _MMAsyncEngine(AsyncEngine): + """Visual Language Async inference engine.""" + + def __init__(self, + chat_template: HybridChatTemplate, + *args, + encoder=None, + **kwargs) -> None: + super().__init__(*args, **kwargs) + assert self.model_name == 'base' + self.encoder = encoder + self.chat_template = chat_template + + async def _get_prompt_input(self, prompt: HybridChatMessages, + do_preprocess: bool, sequence_start: bool): + """get input_ids, embeddings and offsets.""" + + decorated = prompt.apply_chat_template(self.chat_template) + segs = decorated.split(self.chat_template.image_token) + + results = {} + input_ids = [] + if len(segs) > 1: + assert self.encoder is not None + img_urls = prompt.collect_img_urls() + features = await self.encoder.async_infer(img_urls) + features = [x.cpu().numpy() for x in features] + input_ids = [] + begins = [] + ends = [] + for i, seg in enumerate(segs): + if i > 0: + image_dim = features[i - 1].shape[0] + begins.append(len(input_ids)) + ends.append(begins[-1] + image_dim) + input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim) + seg_ids = self.tokenizer.encode( + seg, add_bos=((i == 0) and sequence_start)) + input_ids.extend(seg_ids) + ranges = np.stack([begins, ends], axis=1).tolist() + results['input_embeddings'] = features + results['input_embedding_ranges'] = ranges + else: + input_ids = self.tokenizer.encode( + decorated, add_bos=sequence_start) + + results['input_ids'] = input_ids + results['prompt'] = decorated + return results + + # def batch_infer(self, prompts: Union[VLPromptType, List[Dict], + # List[VLPromptType], List[List[Dict]]], + # **kwargs): + # """Inference a batch of prompts.""" + # # prompts = self._convert_prompts(prompts) + # return super().batch_infer(prompts, **kwargs) + + # def stream_infer(self, prompts: Union[VLPromptType, List[Dict], + # List[VLPromptType], + # List[List[Dict]]], **kwargs): + # """Inference a batch of prompts with stream mode.""" + # # prompts = self._convert_prompts(prompts) + # return super().stream_infer(prompts, **kwargs) + + # def __call__(self, prompts, **kwargs): + # """Inference a batch of prompts.""" + # # prompts = self._convert_prompts(prompts) + # return super().__call__(prompts, **kwargs) + + # def chat(self, prompts: VLPromptType, **kwargs): + # """chat.""" + # # _prompts = self._convert_prompts(prompts) + # sess = super().chat(_prompts, **kwargs) + + # # recover prompts & history + # sess._prompt = prompts + # last_round = sess.history[-1] + # sess.history[-1] = (prompts, last_round[-1]) + # return sess diff --git a/xtuner/chat/backend/lmdeploy/backend.py b/xtuner/chat/backend/lmdeploy/backend.py new file mode 100644 index 000000000..1df25fe81 --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/backend.py @@ -0,0 +1,107 @@ +import asyncio +import os +from typing import List, Optional, Union + +from lmdeploy.utils import get_logger + +from xtuner.types import HybridChatMessages, HybridChatTemplate, SampleParams +from ...streamer import LMDeployTextIteratorStreamer, LMDeployTextStreamer +from ..base import BaseBackend +from ._encoder import _AsyncEncoderWrapper +from ._engine import _MMAsyncEngine + +os.environ['TM_LOG_LEVEL'] = 'ERROR' +logger = get_logger('lmdeploy') +logger.setLevel('ERROR') + +_StreamerType = Union[LMDeployTextStreamer, LMDeployTextIteratorStreamer] + + +class LMDeployBackend(BaseBackend): + + def __init__(self, + chat_template, + llm_name_or_path, + vision_encoder=None) -> None: + super().__init__() + + if vision_encoder: + encoder = _AsyncEncoderWrapper(vision_encoder) + else: + encoder = None + + self._engine = _MMAsyncEngine( + chat_template, + encoder=encoder, + model_path=llm_name_or_path, + model_name='base') + + self._chat_template = chat_template + + @property + def chat_template(self) -> HybridChatTemplate: + return self._chat_template + + def create_streamer(self, iterable=False): + + if iterable: + return LMDeployTextIteratorStreamer() + else: + return LMDeployTextStreamer() + + def parse_sample_params(self, + params: SampleParams) -> 'LMGenerationConfig': + + if params is None: + params = SampleParams() + + stop_words = params.stop_words + stop_words.extend(self.chat_template.stop_words) + + from lmdeploy.messages import GenerationConfig as LMDGenerationConfig + lmd_gen_config = LMDGenerationConfig( + max_new_tokens=params.max_new_tokens, + temperature=params.temperature, + top_k=params.top_k, + top_p=params.top_p, + repetition_penalty=params.repetition_penalty, + random_seed=params.seed, + stop_words=stop_words) + + return lmd_gen_config + + def chat(self, + messages: HybridChatMessages, + streamer: Optional[_StreamerType] = None, + sample_params: Optional[SampleParams] = None): + + lmd_gen_config = self.parse_sample_params(sample_params) + self.session_id += 1 + import random + + generator = self._engine.generate( + messages, random.randint(1, 100000), gen_config=lmd_gen_config) + + async def get_response(): + out = '' + async for res in generator: + out += res.response + if streamer: + streamer.put(res.response) + if streamer: + streamer.end() + return out + + loop = asyncio.new_event_loop() + response = loop.run_until_complete(get_response()) + return response + + def batch_infer(self, + messages: List[HybridChatMessages], + sample_params: Optional[SampleParams] = None): + + lmd_gen_config = self.parse_sample_params(sample_params) + + results = self._engine.batch_infer(messages, gen_config=lmd_gen_config) + + return [r.text for r in results] diff --git a/xtuner/chat/conversation.py b/xtuner/chat/conversation.py new file mode 100644 index 000000000..a26616221 --- /dev/null +++ b/xtuner/chat/conversation.py @@ -0,0 +1,147 @@ +from xtuner.chat.backend import HFBackend +from xtuner.types.chat import (ChatMsg, HybridChatMessages, ImageContentItem, + TextContentItem) + + +class Conversation(): + + def __init__(self, + backend: HFBackend, + name=None, + system=None, + functions=None, + code_interpreter=None) -> None: + + self.name = name + self.backend = backend + self.system = system + self.functions = functions + self.code_interpreter = code_interpreter + self._messages = HybridChatMessages() + + if system: + msg = ChatMsg(role='system', content=system) + self._messages.messages.append(msg) + + @property + def messages(self): + return self._messages + + def add_message(self, role, content): + if role == 'system': + assert isinstance(content, str) + msg = ChatMsg(role='system', content=content) + self._messages.messages.append(msg) + elif role == 'user': + self._add_user(content) + elif role == 'assistant': + assert isinstance(content, str) + msg = ChatMsg(role='assistant', content=content) + self._messages.messages.append(msg) + + def _add_user(self, content): + + if isinstance(content, str): + msg = ChatMsg(role='user', content=content) + self._messages.messages.append(msg) + elif isinstance(content, list): + _content = [] + for item in content: + if isinstance(item, (ImageContentItem, TextContentItem)): + _content.append(item) + continue + + assert isinstance(item, dict) + assert 'type' in item + assert item['type'] in item + if item['type'] == 'image_url': + _item = ImageContentItem(image_url=item['image_url']) + _content.append(_item) + elif item['type'] == 'text': + _item = TextContentItem(text=item['text']) + _content.append(_item) + else: + raise NotImplementedError + + msg = ChatMsg(role='user', content=_content) + self._messages.messages.append(msg) + else: + raise TypeError + + def run(self, sample_params=None, streamer=None): + + self.add_message(role='user', content=content) + response = self.backend.chat(self.messages) + self.add_message(role='assistant', content=response) + return response + + def regenerate(self): + + assert self._messages.messages[-1].role == 'assistant' + self._messages.messages.pop() + return self.backend.chat(self.messages) + + def create_streamer(self, iterable=False): + return self.backend.create_streamer(iterable=iterable) + + +if __name__ == '__main__': + + from xtuner.types import HybridChatTemplate + chat_template = HybridChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n' + ) + + from transformers import AutoModelForCausalLM, AutoTokenizer + + from xtuner.chat.backend import HFBackend, VisionEncoderForDeploy + + llm = AutoModelForCausalLM.from_pretrained( + '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f', + trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f', + trust_remote_code=True) + vision_tower = VisionEncoderForDeploy( + model_name_or_path='openai/clip-vit-large-patch14-336', + adapter_name_or_path= + '/mnt/petrelfs/share_data/linzhihao/model/models--xtuner--llava-internlm2-7b/snapshots/f363b45ce4787bd0a21d43ed724a70ee40ce69b2/visual_encoder_adapter', + projector_name_or_path= + '/mnt/petrelfs/share_data/linzhihao/model/models--xtuner--llava-internlm2-7b/snapshots/f363b45ce4787bd0a21d43ed724a70ee40ce69b2/projector' + ) + + llm.cuda() + + backend = HFBackend( + chat_template, + llm, + tokenizer, + vision_tower, + ) + + conv = Conversation(backend) + print(conv.chat('who are you?')) + + from xtuner.chat.backend import LMDeployBackend + backend = LMDeployBackend( + chat_template, + '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f', + vision_tower) + conv = Conversation(backend) + print(conv.chat('who are you?')) + + content = [ + TextContentItem(text='Please describe this image'), + ImageContentItem(image_url='llava.jpeg') + ] + + print(conv.chat(content)) diff --git a/xtuner/chat/streamer/__init__.py b/xtuner/chat/streamer/__init__.py new file mode 100644 index 000000000..7f83155fc --- /dev/null +++ b/xtuner/chat/streamer/__init__.py @@ -0,0 +1,7 @@ +from .huggingface import HFTextIteratorStreamer, HFTextStreamer +from .lmdeploy import LMDeployTextIteratorStreamer, LMDeployTextStreamer + +__all__ = [ + 'HFTextIteratorStreamer', 'HFTextStreamer', 'LMDeployTextIteratorStreamer', + 'LMDeployTextStreamer' +] diff --git a/xtuner/chat/streamer/huggingface.py b/xtuner/chat/streamer/huggingface.py new file mode 100644 index 000000000..91b0f29aa --- /dev/null +++ b/xtuner/chat/streamer/huggingface.py @@ -0,0 +1,37 @@ +from transformers import TextIteratorStreamer, TextStreamer +from transformers.models.auto import AutoTokenizer + + +class HFTextIteratorStreamer(TextIteratorStreamer): + + def __init__(self, + tokenizer: AutoTokenizer, + skip_prompt: bool = False, + timeout=None, + chat_template=None, + **decode_kwargs): + super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) + self.chat_template = chat_template + + def on_finalized_text(self, text: str, stream_end: bool = False): + + for word in self.chat_template.stop_words: + text = text.rstrip(word) + super().on_finalized_text(text, stream_end) + + +class HFTextStreamer(TextStreamer): + + def __init__(self, + tokenizer: AutoTokenizer, + skip_prompt: bool = False, + chat_template=None, + **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.chat_template = chat_template + + def on_finalized_text(self, text: str, stream_end: bool = False): + + for word in self.chat_template.stop_words: + text = text.rstrip(word) + super().on_finalized_text(text, stream_end) diff --git a/xtuner/chat/streamer/lmdeploy.py b/xtuner/chat/streamer/lmdeploy.py new file mode 100644 index 000000000..2ec03e482 --- /dev/null +++ b/xtuner/chat/streamer/lmdeploy.py @@ -0,0 +1,49 @@ +from queue import Queue +from typing import Optional + +from transformers.generation.streamers import BaseStreamer + + +class LMDeployTextStreamer(BaseStreamer): + + def put(self, text): + self.on_finalized_text(text) + + def end(self): + """Flushes any remaining cache and prints a newline to stdout.""" + self.on_finalized_text('', stream_end=True) + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Prints the new text to stdout. + + If the stream is ending, also prints a newline. + """ + print(text, flush=True, end='' if not stream_end else None) + + +class LMDeployTextIteratorStreamer(LMDeployTextStreamer): + + def __init__(self, timeout: Optional[float] = None): + super().__init__() + self.text_queue = Queue() + self.stop_signal = None + self.timeout = timeout + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Put the new text in the queue. + + If the stream is ending, also put a stop signal in the queue. + """ + self.text_queue.put(text, timeout=self.timeout) + if stream_end: + self.text_queue.put(self.stop_signal, timeout=self.timeout) + + def __iter__(self): + return self + + def __next__(self): + value = self.text_queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value diff --git a/xtuner/types/__init__.py b/xtuner/types/__init__.py index dd0a93ddf..79ea745af 100644 --- a/xtuner/types/__init__.py +++ b/xtuner/types/__init__.py @@ -1,8 +1,11 @@ -from .chat import HybridChatMessages +from .chat import (ChatMsg, HybridChatMessages, ImageContentItem, + TextContentItem) from .chat_template import HybridChatTemplate +from .sample_params import SampleParams from .train import RawTrainingData, TrainingHybridChatMessages __all__ = [ - 'HybridChatTemplate', 'RawTrainingData', 'TrainingHybridChatMessages', - 'HybridChatMessages' + 'ChatMsg', 'HybridChatMessages', 'ImageContentItem', 'TextContentItem', + 'HybridChatTemplate', 'SampleParams', 'RawTrainingData', + 'TrainingHybridChatMessages' ] diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py index cd0a4d4a7..0e48391f6 100644 --- a/xtuner/types/chat.py +++ b/xtuner/types/chat.py @@ -6,7 +6,7 @@ class TextContentItem(BaseModel): - type: Literal['text'] + type: Literal['text'] = 'text' text: str def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: @@ -14,7 +14,7 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: class ImageContentItem(BaseModel): - type: Literal['image_url'] + type: Literal['image_url'] = 'image_url' image_url: str def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: diff --git a/xtuner/types/sample_params.py b/xtuner/types/sample_params.py new file mode 100644 index 000000000..137809648 --- /dev/null +++ b/xtuner/types/sample_params.py @@ -0,0 +1,14 @@ +from typing import Optional + +from pydantic import BaseModel + + +class SampleParams(BaseModel): + + max_new_tokens: int = 512 + temperature: float = 0.1 + top_k: int = 40 + top_p: float = 0.75 + repetition_penalty: float = 1.0 + stop_words: list = [] + seed: Optional[int] = None