From a3358878b398d80e8041684eece07e31635ce24b Mon Sep 17 00:00:00 2001 From: liukaiwen Date: Tue, 8 Oct 2024 18:05:30 +0800 Subject: [PATCH 1/4] feat: merge formula update --- magic_pdf/model/mfr_cudagraph.py | 899 +++++++++++++++++++++++++++++ magic_pdf/model/pdf_extract_kit.py | 6 + 2 files changed, 905 insertions(+) create mode 100644 magic_pdf/model/mfr_cudagraph.py diff --git a/magic_pdf/model/mfr_cudagraph.py b/magic_pdf/model/mfr_cudagraph.py new file mode 100644 index 00000000..59b45c52 --- /dev/null +++ b/magic_pdf/model/mfr_cudagraph.py @@ -0,0 +1,899 @@ +from typing import Optional, Tuple, Union +import torch +from torch import nn +import os +from unimernet.common.config import Config +import unimernet.tasks as tasks +import argparse +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask + +class PatchedMBartLearnedPositionalEmbedding(nn.Module): + + def __init__(self, origin: nn.Module): + super().__init__() + self.offset = origin.offset + self.embedding = nn.Embedding(origin.num_embeddings, origin.embedding_dim) + self.embedding.weight.data = origin.weight.data + + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + bsz, seq_len = input_ids.shape[:2] + positions = torch.arange(0, seq_len, dtype=torch.long, device=self.embedding.weight.device + ) + positions += past_key_values_length + positions = positions.expand(bsz, -1) + + return self.embedding(positions + self.offset) + + +class PatchedMBartDecoder(nn.Module): + def __init__(self, origin: nn.Module, kvlen: torch.LongTensor): + super().__init__() + self.origin = origin + self.kvlen = kvlen + + self.config = origin.config + self.embed_tokens = origin.embed_tokens + self.embed_scale = origin.embed_scale + self._use_flash_attention_2 = origin._use_flash_attention_2 + self.embed_positions = origin.embed_positions + self.counting_context_weight = getattr(origin, 'counting_context_weight', None) + self.layernorm_embedding = origin.layernorm_embedding + self.layers = origin.layers + self.layer_norm = origin.layer_norm + + self.patched_embed_positions = PatchedMBartLearnedPositionalEmbedding(self.embed_positions) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + count_pred: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + run_origin = False + if past_key_values is None: + run_origin = True + elif past_key_values[0][0].size(-2) < attention_mask.size(-1): + run_origin = True + + if run_origin: + return self.origin( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.patched_embed_positions(input, self.kvlen) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + + # TODO: add counting context weight to hidden_states + if count_pred is not None: + count_context_weight = self.counting_context_weight(count_pred) + hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1) + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {attn_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class PatchedMBartAttention(nn.Module): + + def __init__(self, origin: nn.Module, kvlen: torch.LongTensor): + super().__init__() + self.embed_dim = origin.embed_dim + self.num_heads = origin.num_heads + self.dropout = origin.dropout + self.head_dim = origin.head_dim + self.config = origin.config + + self.scaling = origin.scaling + self.is_decoder = origin.is_decoder + self.is_causal = origin.is_causal + + self.k_proj = origin.k_proj + self.v_proj = origin.v_proj + self.q_proj = origin.q_proj + self.out_proj = origin.out_proj + self.kvlen = kvlen + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if past_key_value[0].size(-2) < attention_mask.size(-1): + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + past_key_value[0][:, :, self.kvlen[None]] = key_states + past_key_value[1][:, :, self.kvlen[None]] = value_states + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = attn_weights + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + # attn_output = self.out_proj(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class PatchedMBartSqueezeAttention(nn.Module): + + def __init__(self, origin: nn.Module, kvlen: torch.LongTensor): + super().__init__() + self.embed_dim = origin.embed_dim + self.num_heads = origin.num_heads + self.dropout = origin.dropout + self.head_dim = origin.head_dim + self.squeeze_head_dim=origin.squeeze_head_dim + self.config = origin.config + + self.scaling = origin.scaling + self.is_decoder = origin.is_decoder + self.scaling = origin.scaling + + self.q_proj = origin.q_proj + self.k_proj = origin.k_proj + self.v_proj = origin.v_proj + self.out_proj = origin.out_proj + self.kvlen = kvlen + + def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous() + + def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz) + + if past_key_value[0].size(-2) < attention_mask.size(-1): + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + past_key_value[0][:, :, self.kvlen[None]] = key_states + past_key_value[1][:, :, self.kvlen[None]] = value_states + key_states = past_key_value[0] + value_states = past_key_value[1] + else: + # self_attention + key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim) + value_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*value_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + +def patch_model(model: nn.Module, kvlen: torch.LongTensor): + for name, child in model.named_children(): + cls_name = type(child).__name__ + if cls_name == 'MBartAttention': + patched_child = PatchedMBartAttention(child, kvlen) + model.register_module(name, patched_child) + elif cls_name == 'MBartSqueezeAttention': + patched_child = PatchedMBartSqueezeAttention(child, kvlen) + model.register_module(name, patched_child) + else: + patch_model(child, kvlen) + + cls_name = type(model).__name__ + if cls_name == 'CustomMBartDecoder': + model = PatchedMBartDecoder(model, kvlen) + return model + + +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n.""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def get_graph_key(batch_size: int, kvlens: int): + batch_size = next_power_of_2(batch_size) + kvlens = next_power_of_2(kvlens) + + batch_size = max(8, batch_size) + kvlens = max(32, kvlens) + + return batch_size, kvlens + + +class GraphRunnerImpl: + + def __init__(self, model: nn.Module, graph: torch.cuda.CUDAGraph, input_buffers: dict, output_buffers: dict): + self.model = model + self.graph = graph + self.input_buffers = input_buffers + self.output_buffers = output_buffers + + @staticmethod + def extract_input_buffers(input_buffers: dict, batch_size: int, kvlens: int): + input_ids = input_buffers['input_ids'][:batch_size] + attention_mask = input_buffers['attention_mask'][:batch_size, :kvlens] + encoder_hidden_states = input_buffers['encoder_hidden_states'][:batch_size] + kvlen=input_buffers['kvlen'] + + past_key_values = [] + for past_key_value in input_buffers['past_key_values']: + k0 = past_key_value[0][:batch_size, :, :kvlens] + v0 = past_key_value[1][:batch_size, :, :kvlens] + k1 = past_key_value[2][:batch_size] + v1 = past_key_value[3][:batch_size] + past_key_values.append((k0, v0, k1, v1)) + + input_buffers = dict( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + kvlen=kvlen, + ) + return input_buffers + + @staticmethod + def fill_input_buffers( + input_buffer: dict, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + ): + batch_size = input_ids.size(0) + kvlens = attention_mask.size(1) + + input_buffer['input_ids'][:batch_size] = input_ids + + if input_buffer['attention_mask'].data_ptr() != attention_mask.data_ptr(): + input_buffer['attention_mask'].fill_(0) + input_buffer['attention_mask'][:batch_size, :kvlens] = attention_mask + input_buffer['encoder_hidden_states'][:batch_size] = encoder_hidden_states + + if past_key_values is not None: + for buf_kv, kv in zip(input_buffer['past_key_values'], past_key_values): + idx = 0 + if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): + buf_kv[idx].fill_(0) + buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx] + idx = 1 + if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): + buf_kv[idx].fill_(0) + buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx] + + idx = 2 + if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): + buf_kv[idx].fill_(0) + buf_kv[idx][:batch_size] = kv[idx] + idx = 3 + if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): + buf_kv[idx].fill_(0) + buf_kv[idx][:batch_size] = kv[idx] + + input_buffer['kvlen'].fill_(kvlens - 1) + + @classmethod + @torch.inference_mode() + def capture(cls, + model: nn.Module, + input_buffers: dict, + pool, + warmup: bool = False, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + count_pred: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None,): + batch_size = input_ids.size(0) + kvlens = attention_mask.size(1) + + graph_key = get_graph_key(batch_size, kvlens) + batch_size = graph_key[0] + kvlens = graph_key[1] + + input_buffers = cls.extract_input_buffers(input_buffers, + batch_size=batch_size, + kvlens=kvlens) + cls.fill_input_buffers(input_buffers, + input_ids, + attention_mask, + encoder_hidden_states, + past_key_values) + + input_ids = input_buffers['input_ids'] + attention_mask = input_buffers['attention_mask'] + encoder_hidden_states = input_buffers['encoder_hidden_states'] + past_key_values = input_buffers['past_key_values'] + + if warmup: + # warmup + model( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, + pool=pool): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict) + + output_buffers = dict( + last_hidden_state=outputs['last_hidden_state'], + past_key_values=outputs['past_key_values'], + ) + + return GraphRunnerImpl(model, graph, input_buffers, output_buffers) + + def __call__(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + count_pred: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + batch_size = input_ids.size(0) + kvlens = attention_mask.size(1) + self.fill_input_buffers(self.input_buffers, + input_ids, + attention_mask, + encoder_hidden_states, + past_key_values) + + self.graph.replay() + + last_hidden_state = self.output_buffers['last_hidden_state'][:batch_size] + + past_key_values = [] + for past_key_value in self.output_buffers['past_key_values']: + k0 = past_key_value[0][:batch_size, :, :kvlens] + v0 = past_key_value[1][:batch_size, :, :kvlens] + k1 = past_key_value[2][:batch_size] + v1 = past_key_value[3][:batch_size] + past_key_values.append((k0, v0, k1, v1)) + + outputs = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_state, + past_key_values=past_key_values, + ) + return outputs + +class GraphRunner(nn.Module): + + def __init__(self, model: nn.Module, max_batchs: int, max_kvlens: int, dtype:torch.dtype = torch.float16, device: torch.device = 'cuda'): + super().__init__() + + self.kvlen = torch.tensor(0, dtype=torch.long, device=device) + model = patch_model(model.to(dtype), self.kvlen) + self.model = model + self.max_batchs = max_batchs + self.max_kvlens = max_kvlens + self.device = device + + self.input_buffers = None + + self.impl_map = dict() + self.graph_pool_handle = torch.cuda.graph_pool_handle() + self.warmuped = False + + def create_buffers(self, encoder_kvlens: int, dtype: torch.dtype): + max_batchs = self.max_batchs + max_kvlens = self.max_kvlens + device = self.device + config = self.model.config + + d_model = config.d_model + decoder_layers = config.decoder_layers + num_heads = config.decoder_attention_heads + + head_dim = d_model // num_heads + self_attn = self.model.layers[0].self_attn + qk_head_dim = getattr(self_attn, 'squeeze_head_dim', head_dim) + + input_ids = torch.ones((max_batchs, 1), dtype=torch.int64, device=device) + attention_mask = torch.zeros((max_batchs, max_kvlens), dtype=torch.int64, device=device) + encoder_hidden_states = torch.zeros((max_batchs, encoder_kvlens, d_model), dtype=dtype, device=device) + + past_key_values = [] + for _ in range(decoder_layers): + k0 = torch.zeros((max_batchs, num_heads, max_kvlens, qk_head_dim), dtype=dtype, device=device) + v0 = torch.zeros((max_batchs, num_heads, max_kvlens, head_dim), dtype=dtype, device=device) + k1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, qk_head_dim), dtype=dtype, device=device) + v1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, head_dim), dtype=dtype, device=device) + + past_key_values.append((k0, v0, k1, v1)) + + self.input_buffers = dict( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + kvlen=self.kvlen + ) + + @torch.inference_mode() + def forward(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + count_pred: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + batch_size, qlens = input_ids.size() + kvlens = attention_mask.size(1) + + eager_mode = False + + if qlens != 1: + eager_mode = True + + if past_key_values is None: + eager_mode = True + else: + for past_key_value in past_key_values: + if past_key_value is None: + eager_mode = True + break + + if batch_size >= self.max_batchs or kvlens >= self.max_kvlens: + eager_mode = True + + if eager_mode: + return self.model( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict,) + + # create buffer if not exists. + if self.input_buffers is None: + encoder_kvlens = encoder_hidden_states.size(1) + self.create_buffers(encoder_kvlens=encoder_kvlens, dtype=encoder_hidden_states.dtype) + + graph_key = get_graph_key(batch_size, kvlens) + if graph_key not in self.impl_map: + warmup = False + if not self.warmuped: + warmup = True + self.warmuped = True + impl = GraphRunnerImpl.capture( + self.model, + self.input_buffers, + self.graph_pool_handle, + warmup=warmup, + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + self.impl_map[graph_key] = impl + impl = self.impl_map[graph_key] + + ret = impl( + input_ids=input_ids, + attention_mask=attention_mask, + count_pred=count_pred, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return ret \ No newline at end of file diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index 1235a0a8..109fd53c 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -5,6 +5,7 @@ from magic_pdf.libs.Constants import * from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.model.model_list import AtomicModel +from .mfr_cudagraph import GraphRunner os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger @@ -67,6 +68,11 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'): model = task.build_model(cfg) model.to(_device_) model.eval() + model = model.to(_device_) + if 'cuda' in _device_: + decoder_runner = GraphRunner(model.model.model.decoder.model.decoder, max_batchs=128, max_kvlens=256, + device=_device_) + model.model.model.decoder.model.decoder = decoder_runner vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) mfr_transform = transforms.Compose([vis_processor, ]) return [model, mfr_transform] From 51f56aa32f5e8c48f196fd24146ea364507326ff Mon Sep 17 00:00:00 2001 From: liukaiwen Date: Thu, 17 Oct 2024 17:16:28 +0800 Subject: [PATCH 2/4] feat: merge formula update --- magic_pdf/libs/Constants.py | 4 ++-- magic_pdf/model/pdf_extract_kit.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/magic_pdf/libs/Constants.py b/magic_pdf/libs/Constants.py index 4e132290..d9e379e9 100644 --- a/magic_pdf/libs/Constants.py +++ b/magic_pdf/libs/Constants.py @@ -29,10 +29,10 @@ TABLE_MASTER_DIR = "table_structure_tablemaster_infer/" # pp detect model dir -DETECT_MODEL_DIR = "ch_PP-OCRv3_det_infer" +DETECT_MODEL_DIR = "ch_PP-OCRv4_det_infer" # pp rec model dir -REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer" +REC_MODEL_DIR = "ch_PP-OCRv4_rec_infer" # pp rec char dict path REC_CHAR_DICT = "ppocr_keys_v1.txt" diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index bca9b987..1719264f 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -433,3 +433,7 @@ def __call__(self, image): logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----") return layout_res +if __name__ == '__main__': + print() + + From a0eff3be5c946e49ee726a6204b5821bda080645 Mon Sep 17 00:00:00 2001 From: liukaiwen Date: Mon, 28 Oct 2024 16:34:16 +0800 Subject: [PATCH 3/4] feat: table model update with paddle recognition v4 --- magic_pdf/libs/Constants.py | 6 ++++++ magic_pdf/model/pdf_extract_kit.py | 16 +++++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/magic_pdf/libs/Constants.py b/magic_pdf/libs/Constants.py index d9e379e9..e6fa4b78 100644 --- a/magic_pdf/libs/Constants.py +++ b/magic_pdf/libs/Constants.py @@ -37,4 +37,10 @@ # pp rec char dict path REC_CHAR_DICT = "ppocr_keys_v1.txt" +# pp rec copy rec directory +PP_REC_DIRECTORY = ".paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer" + +# pp rec copy det directory +PP_DET_DIRECTORY = ".paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer" + diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index 1719264f..0c296fba 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -1,7 +1,8 @@ from loguru import logger import os import time - +from pathlib import Path +import shutil from magic_pdf.libs.Constants import * from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.model.model_list import AtomicModel @@ -271,6 +272,17 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): device=self.device ) + home_directory = Path.home() + det_source = os.path.join(models_dir, table_model_dir, DETECT_MODEL_DIR) + rec_source = os.path.join(models_dir, table_model_dir, REC_MODEL_DIR) + det_dest_dir = os.path.join(home_directory, PP_DET_DIRECTORY) + rec_dest_dir = os.path.join(home_directory, PP_REC_DIRECTORY) + + if not os.path.exists(det_dest_dir): + shutil.copytree(det_source, det_dest_dir) + if not os.path.exists(rec_dest_dir): + shutil.copytree(rec_source, rec_dest_dir) + logger.info('DocAnalysis init done!') def __call__(self, image): @@ -433,7 +445,5 @@ def __call__(self, image): logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----") return layout_res -if __name__ == '__main__': - print() From 4949408c9d7d5c0a7a42991a46a169a814ad9d66 Mon Sep 17 00:00:00 2001 From: liukaiwen Date: Mon, 28 Oct 2024 17:09:46 +0800 Subject: [PATCH 4/4] perf: table model update with PP OCRv4 --- magic_pdf/model/mfr_cudagraph.py | 899 ----------------------------- magic_pdf/model/pdf_extract_kit.py | 6 - 2 files changed, 905 deletions(-) delete mode 100644 magic_pdf/model/mfr_cudagraph.py diff --git a/magic_pdf/model/mfr_cudagraph.py b/magic_pdf/model/mfr_cudagraph.py deleted file mode 100644 index 59b45c52..00000000 --- a/magic_pdf/model/mfr_cudagraph.py +++ /dev/null @@ -1,899 +0,0 @@ -from typing import Optional, Tuple, Union -import torch -from torch import nn -import os -from unimernet.common.config import Config -import unimernet.tasks as tasks -import argparse -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask - -class PatchedMBartLearnedPositionalEmbedding(nn.Module): - - def __init__(self, origin: nn.Module): - super().__init__() - self.offset = origin.offset - self.embedding = nn.Embedding(origin.num_embeddings, origin.embedding_dim) - self.embedding.weight.data = origin.weight.data - - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): - """`input_ids' shape is expected to be [bsz x seqlen].""" - - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange(0, seq_len, dtype=torch.long, device=self.embedding.weight.device - ) - positions += past_key_values_length - positions = positions.expand(bsz, -1) - - return self.embedding(positions + self.offset) - - -class PatchedMBartDecoder(nn.Module): - def __init__(self, origin: nn.Module, kvlen: torch.LongTensor): - super().__init__() - self.origin = origin - self.kvlen = kvlen - - self.config = origin.config - self.embed_tokens = origin.embed_tokens - self.embed_scale = origin.embed_scale - self._use_flash_attention_2 = origin._use_flash_attention_2 - self.embed_positions = origin.embed_positions - self.counting_context_weight = getattr(origin, 'counting_context_weight', None) - self.layernorm_embedding = origin.layernorm_embedding - self.layers = origin.layers - self.layer_norm = origin.layer_norm - - self.patched_embed_positions = PatchedMBartLearnedPositionalEmbedding(self.embed_positions) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - count_pred: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - run_origin = False - if past_key_values is None: - run_origin = True - elif past_key_values[0][0].size(-2) < attention_mask.size(-1): - run_origin = True - - if run_origin: - return self.origin( - input_ids=input_ids, - attention_mask=attention_mask, - count_pred=count_pred, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input = input_ids - input_shape = input.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - input = inputs_embeds[:, :, -1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - # embed positions - positions = self.patched_embed_positions(input, self.kvlen) - - hidden_states = inputs_embeds + positions.to(inputs_embeds.device) - - # TODO: add counting context weight to hidden_states - if count_pred is not None: - count_context_weight = self.counting_context_weight(count_pred) - hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1) - hidden_states = self.layernorm_embedding(hidden_states) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None - - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): - if attn_mask is not None: - if attn_mask.size()[0] != len(self.layers): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {attn_mask.size()[0]}." - ) - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - - hidden_states = self.layer_norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -class PatchedMBartAttention(nn.Module): - - def __init__(self, origin: nn.Module, kvlen: torch.LongTensor): - super().__init__() - self.embed_dim = origin.embed_dim - self.num_heads = origin.num_heads - self.dropout = origin.dropout - self.head_dim = origin.head_dim - self.config = origin.config - - self.scaling = origin.scaling - self.is_decoder = origin.is_decoder - self.is_causal = origin.is_causal - - self.k_proj = origin.k_proj - self.v_proj = origin.v_proj - self.q_proj = origin.q_proj - self.out_proj = origin.out_proj - self.kvlen = kvlen - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if past_key_value[0].size(-2) < attention_mask.size(-1): - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - past_key_value[0][:, :, self.kvlen[None]] = key_states - past_key_value[1][:, :, self.kvlen[None]] = value_states - key_states = past_key_value[0] - value_states = past_key_value[1] - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = attn_weights - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - # attn_output = self.out_proj(attn_output) - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class PatchedMBartSqueezeAttention(nn.Module): - - def __init__(self, origin: nn.Module, kvlen: torch.LongTensor): - super().__init__() - self.embed_dim = origin.embed_dim - self.num_heads = origin.num_heads - self.dropout = origin.dropout - self.head_dim = origin.head_dim - self.squeeze_head_dim=origin.squeeze_head_dim - self.config = origin.config - - self.scaling = origin.scaling - self.is_decoder = origin.is_decoder - self.scaling = origin.scaling - - self.q_proj = origin.q_proj - self.k_proj = origin.k_proj - self.v_proj = origin.v_proj - self.out_proj = origin.out_proj - self.kvlen = kvlen - - def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous() - - def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz) - - if past_key_value[0].size(-2) < attention_mask.size(-1): - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - past_key_value[0][:, :, self.kvlen[None]] = key_states - past_key_value[1][:, :, self.kvlen[None]] = value_states - key_states = past_key_value[0] - value_states = past_key_value[1] - else: - # self_attention - key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim) - value_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*value_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" - ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - -def patch_model(model: nn.Module, kvlen: torch.LongTensor): - for name, child in model.named_children(): - cls_name = type(child).__name__ - if cls_name == 'MBartAttention': - patched_child = PatchedMBartAttention(child, kvlen) - model.register_module(name, patched_child) - elif cls_name == 'MBartSqueezeAttention': - patched_child = PatchedMBartSqueezeAttention(child, kvlen) - model.register_module(name, patched_child) - else: - patch_model(child, kvlen) - - cls_name = type(model).__name__ - if cls_name == 'CustomMBartDecoder': - model = PatchedMBartDecoder(model, kvlen) - return model - - -def next_power_of_2(n: int): - """Return the smallest power of 2 greater than or equal to n.""" - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n |= n >> 32 - n += 1 - return n - - -def get_graph_key(batch_size: int, kvlens: int): - batch_size = next_power_of_2(batch_size) - kvlens = next_power_of_2(kvlens) - - batch_size = max(8, batch_size) - kvlens = max(32, kvlens) - - return batch_size, kvlens - - -class GraphRunnerImpl: - - def __init__(self, model: nn.Module, graph: torch.cuda.CUDAGraph, input_buffers: dict, output_buffers: dict): - self.model = model - self.graph = graph - self.input_buffers = input_buffers - self.output_buffers = output_buffers - - @staticmethod - def extract_input_buffers(input_buffers: dict, batch_size: int, kvlens: int): - input_ids = input_buffers['input_ids'][:batch_size] - attention_mask = input_buffers['attention_mask'][:batch_size, :kvlens] - encoder_hidden_states = input_buffers['encoder_hidden_states'][:batch_size] - kvlen=input_buffers['kvlen'] - - past_key_values = [] - for past_key_value in input_buffers['past_key_values']: - k0 = past_key_value[0][:batch_size, :, :kvlens] - v0 = past_key_value[1][:batch_size, :, :kvlens] - k1 = past_key_value[2][:batch_size] - v1 = past_key_value[3][:batch_size] - past_key_values.append((k0, v0, k1, v1)) - - input_buffers = dict( - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - past_key_values=past_key_values, - kvlen=kvlen, - ) - return input_buffers - - @staticmethod - def fill_input_buffers( - input_buffer: dict, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - ): - batch_size = input_ids.size(0) - kvlens = attention_mask.size(1) - - input_buffer['input_ids'][:batch_size] = input_ids - - if input_buffer['attention_mask'].data_ptr() != attention_mask.data_ptr(): - input_buffer['attention_mask'].fill_(0) - input_buffer['attention_mask'][:batch_size, :kvlens] = attention_mask - input_buffer['encoder_hidden_states'][:batch_size] = encoder_hidden_states - - if past_key_values is not None: - for buf_kv, kv in zip(input_buffer['past_key_values'], past_key_values): - idx = 0 - if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): - buf_kv[idx].fill_(0) - buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx] - idx = 1 - if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): - buf_kv[idx].fill_(0) - buf_kv[idx][:batch_size, :, :kvlens-1] = kv[idx] - - idx = 2 - if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): - buf_kv[idx].fill_(0) - buf_kv[idx][:batch_size] = kv[idx] - idx = 3 - if buf_kv[idx].data_ptr() != kv[idx].data_ptr(): - buf_kv[idx].fill_(0) - buf_kv[idx][:batch_size] = kv[idx] - - input_buffer['kvlen'].fill_(kvlens - 1) - - @classmethod - @torch.inference_mode() - def capture(cls, - model: nn.Module, - input_buffers: dict, - pool, - warmup: bool = False, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - count_pred: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None,): - batch_size = input_ids.size(0) - kvlens = attention_mask.size(1) - - graph_key = get_graph_key(batch_size, kvlens) - batch_size = graph_key[0] - kvlens = graph_key[1] - - input_buffers = cls.extract_input_buffers(input_buffers, - batch_size=batch_size, - kvlens=kvlens) - cls.fill_input_buffers(input_buffers, - input_ids, - attention_mask, - encoder_hidden_states, - past_key_values) - - input_ids = input_buffers['input_ids'] - attention_mask = input_buffers['attention_mask'] - encoder_hidden_states = input_buffers['encoder_hidden_states'] - past_key_values = input_buffers['past_key_values'] - - if warmup: - # warmup - model( - input_ids=input_ids, - attention_mask=attention_mask, - count_pred=count_pred, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict) - - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, - pool=pool): - outputs = model( - input_ids=input_ids, - attention_mask=attention_mask, - count_pred=count_pred, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict) - - output_buffers = dict( - last_hidden_state=outputs['last_hidden_state'], - past_key_values=outputs['past_key_values'], - ) - - return GraphRunnerImpl(model, graph, input_buffers, output_buffers) - - def __call__(self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - count_pred: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - batch_size = input_ids.size(0) - kvlens = attention_mask.size(1) - self.fill_input_buffers(self.input_buffers, - input_ids, - attention_mask, - encoder_hidden_states, - past_key_values) - - self.graph.replay() - - last_hidden_state = self.output_buffers['last_hidden_state'][:batch_size] - - past_key_values = [] - for past_key_value in self.output_buffers['past_key_values']: - k0 = past_key_value[0][:batch_size, :, :kvlens] - v0 = past_key_value[1][:batch_size, :, :kvlens] - k1 = past_key_value[2][:batch_size] - v1 = past_key_value[3][:batch_size] - past_key_values.append((k0, v0, k1, v1)) - - outputs = BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=last_hidden_state, - past_key_values=past_key_values, - ) - return outputs - -class GraphRunner(nn.Module): - - def __init__(self, model: nn.Module, max_batchs: int, max_kvlens: int, dtype:torch.dtype = torch.float16, device: torch.device = 'cuda'): - super().__init__() - - self.kvlen = torch.tensor(0, dtype=torch.long, device=device) - model = patch_model(model.to(dtype), self.kvlen) - self.model = model - self.max_batchs = max_batchs - self.max_kvlens = max_kvlens - self.device = device - - self.input_buffers = None - - self.impl_map = dict() - self.graph_pool_handle = torch.cuda.graph_pool_handle() - self.warmuped = False - - def create_buffers(self, encoder_kvlens: int, dtype: torch.dtype): - max_batchs = self.max_batchs - max_kvlens = self.max_kvlens - device = self.device - config = self.model.config - - d_model = config.d_model - decoder_layers = config.decoder_layers - num_heads = config.decoder_attention_heads - - head_dim = d_model // num_heads - self_attn = self.model.layers[0].self_attn - qk_head_dim = getattr(self_attn, 'squeeze_head_dim', head_dim) - - input_ids = torch.ones((max_batchs, 1), dtype=torch.int64, device=device) - attention_mask = torch.zeros((max_batchs, max_kvlens), dtype=torch.int64, device=device) - encoder_hidden_states = torch.zeros((max_batchs, encoder_kvlens, d_model), dtype=dtype, device=device) - - past_key_values = [] - for _ in range(decoder_layers): - k0 = torch.zeros((max_batchs, num_heads, max_kvlens, qk_head_dim), dtype=dtype, device=device) - v0 = torch.zeros((max_batchs, num_heads, max_kvlens, head_dim), dtype=dtype, device=device) - k1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, qk_head_dim), dtype=dtype, device=device) - v1 = torch.zeros((max_batchs, num_heads, encoder_kvlens, head_dim), dtype=dtype, device=device) - - past_key_values.append((k0, v0, k1, v1)) - - self.input_buffers = dict( - input_ids=input_ids, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - past_key_values=past_key_values, - kvlen=self.kvlen - ) - - @torch.inference_mode() - def forward(self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - count_pred: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - batch_size, qlens = input_ids.size() - kvlens = attention_mask.size(1) - - eager_mode = False - - if qlens != 1: - eager_mode = True - - if past_key_values is None: - eager_mode = True - else: - for past_key_value in past_key_values: - if past_key_value is None: - eager_mode = True - break - - if batch_size >= self.max_batchs or kvlens >= self.max_kvlens: - eager_mode = True - - if eager_mode: - return self.model( - input_ids=input_ids, - attention_mask=attention_mask, - count_pred=count_pred, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict,) - - # create buffer if not exists. - if self.input_buffers is None: - encoder_kvlens = encoder_hidden_states.size(1) - self.create_buffers(encoder_kvlens=encoder_kvlens, dtype=encoder_hidden_states.dtype) - - graph_key = get_graph_key(batch_size, kvlens) - if graph_key not in self.impl_map: - warmup = False - if not self.warmuped: - warmup = True - self.warmuped = True - impl = GraphRunnerImpl.capture( - self.model, - self.input_buffers, - self.graph_pool_handle, - warmup=warmup, - input_ids=input_ids, - attention_mask=attention_mask, - count_pred=count_pred, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - self.impl_map[graph_key] = impl - impl = self.impl_map[graph_key] - - ret = impl( - input_ids=input_ids, - attention_mask=attention_mask, - count_pred=count_pred, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - head_mask=head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - return ret \ No newline at end of file diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index fb3a5f79..f1478b10 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -6,7 +6,6 @@ from magic_pdf.libs.Constants import * from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.model.model_list import AtomicModel -from .mfr_cudagraph import GraphRunner os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger @@ -70,11 +69,6 @@ def mfr_model_init(weight_dir, cfg_path, _device_='cpu'): model = task.build_model(cfg) model.to(_device_) model.eval() - model = model.to(_device_) - if 'cuda' in _device_: - decoder_runner = GraphRunner(model.model.model.decoder.model.decoder, max_batchs=128, max_kvlens=256, - device=_device_) - model.model.model.decoder.model.decoder = decoder_runner vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) mfr_transform = transforms.Compose([vis_processor, ]) return [model, mfr_transform]