Skip to content

Commit

Permalink
[20240323]完成contrastive & mlm两种对齐模式的基础模块
Browse files Browse the repository at this point in the history
  • Loading branch information
HitAgain committed Mar 23, 2024
1 parent 2d47f6e commit b47b36c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 108 deletions.
58 changes: 5 additions & 53 deletions alignment/arguments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

Check notice on line 1 in alignment/arguments.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/arguments.py#L1

'os' imported but unused (F401)

Check warning on line 1 in alignment/arguments.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/arguments.py#L1

Unused import os
from dataclasses import dataclass, field
from typing import Optional, List
from typing import Optional
from transformers import TrainingArguments


Expand Down Expand Up @@ -38,34 +38,8 @@ class ModelArguments:
projection_in_dim: int = field(default=768)
projection_out_dim: int = field(default=1)

# p*-tuning
model_type: str = field(
default="bert",
metadata={
"help": "The type of model, where we currently support bert, roberta, deberta"
}
)
prefix: bool = field(
default=False,
metadata={
"help": "Will use P-tuning v2 during training"
}
)
prompt: bool = field(
default=False,
metadata={
"help": "Will use prompt tuning during training"
}
)
prompt_from_vocab: bool = field(
default=True,
metadata={
"help": "Will prompt embeddings initalized from plm's word embeddings"
}
)
prompt_encoder_type: str = field(default=None)
pre_seq_len: int = field(
default=100,
prefix_len: int = field(
default=32,
metadata={
"help": "The length of prompt"
}
Expand Down Expand Up @@ -145,29 +119,7 @@ class DataArguments:
},
)

def __post_init__(self):
if self.dataset_name is not None:
info = self.dataset_name.split('/')
self.dataset_split = info[-1] if len(info) == 3 else 'train'
self.dataset_name = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info)
self.dataset_language = 'default'
if ':' in self.dataset_name:
self.dataset_name, self.dataset_language = self.dataset_name.split(':')
if self.train_dir is not None:
files = os.listdir(self.train_dir)
self.train_path = [
os.path.join(self.train_dir, f)
for f in files
if f.endswith('tsv') or f.endswith('json')
]


@dataclass
class DenseTrainingArguments(TrainingArguments):
warmup_ratio: float = field(default=0.1)
negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})

grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"})
gc_q_chunk_size: int = field(default=4)
gc_p_chunk_size: int = field(default=32)
class AlignmentTrainingArguments(TrainingArguments):
alignment_mode: str = field(default="contrastive", metadata={"help": "contrastive or mlm"})
26 changes: 9 additions & 17 deletions alignment/data.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
import os
import random
import json
from dataclasses import dataclass
from typing import Optional, Union, List, Dict, Tuple, Any
import itertools
from typing import Optional, List, Tuple
import numpy as np

import datasets
import torch
import torch.utils.data as Data
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from transformers import PreTrainedTokenizer, DataCollatorWithPadding
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import pandas as pd
import torch
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder


from .arguments import DataArguments
from ..deepctr_torch.inputs import build_input_features, get_feature_names
from ..deepctr_torch.inputs import get_feature_names
from ..deepctr_torch.inputs import (DenseFeat, SparseFeat, VarLenSparseFeat)

Check notice on line 18 in alignment/data.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/data.py#L18

'..deepctr_torch.inputs.DenseFeat' imported but unused (F401)

Check warning on line 18 in alignment/data.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/data.py#L18

Unused DenseFeat imported from deepctr_torch.inputs

import logging
Expand Down Expand Up @@ -101,11 +95,12 @@ def __getitem__(self, item):
@dataclass
class ContrastiveAlignmentCollator(DataCollatorWithPadding):

tokenizer: PreTrainedTokenizerBase
max_len: int = 64

def __call__(self, features):

# batch inputs for PLM\LLM
# batch inputs for nlp model
text_input = [feat_map["text_model_input"] for feat_map in features]
text_input_batch = self.tokenizer.pad(
text_input,
Expand All @@ -124,20 +119,17 @@ def __call__(self, features):
class MlmAlignmentCollator(DataCollatorWithPadding):

tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = 64
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
max_len: Optional[int] = 64
mlm_probability: float = 0.15

def __call__(self, features):

# batch inputs for PLM\LLM
# batch inputs for nlp model
text_input = [feat_map["text_model_input"] for feat_map in features]
batch = self.tokenizer.pad(
text_input,
padding='max_length',
max_length=self.max_length,
max_length=self.max_len,
return_tensors="pt",
)
# generate input & label for mlm train
Expand Down
51 changes: 22 additions & 29 deletions alignment/run_train.py → alignment/main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import logging
import os
import sys
import json
sys.path.insert(0, '..')

import datasets
from transformers import AutoConfig, AutoTokenizer, AutoModel

Check notice on line 6 in alignment/main.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/main.py#L6

'transformers.AutoConfig' imported but unused (F401)

Check warning on line 6 in alignment/main.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/main.py#L6

Unused AutoConfig imported from transformers
from transformers import (
HfArgumentParser,
set_seed,
)

from .arguments import ModelArguments, DataArguments, \
DenseTrainingArguments as TrainingArguments
from .data import AlignmentDataset, ContrastiveAlignmentCollator
from .model import ContrastiveAlignmentModel
AlignmentTrainingArguments as TrainingArguments
from .data import AlignmentDataset, ContrastiveAlignmentCollator, MlmAlignmentCollator
from .model import ContrastiveAlignmentModel, MlmAlignmentModel
from .trainer import AlignmentTrainer as Trainer

from deepctr_torch.models.deepfm import DeepFM
Expand Down Expand Up @@ -62,21 +60,6 @@ def main():

set_seed(training_args.seed)

# config = AutoConfig.from_pretrained(
# model_args.config_name if model_args.config_name else model_args.model_name_or_path,
# cache_dir=model_args.cache_dir,
# num_labels=model_args.projection_out_dim
# )
# # p*-tuning
# config.fine_tuning = model_args.fine_tuning
# config.prefix = model_args.prefix
# config.prompt = model_args.prompt
# config.prompt_from_vocab = model_args.prompt_from_vocab
# config.prompt_encoder_type = model_args.prompt_encoder_type
# config.pre_seq_len = model_args.pre_seq_len
# config.prefix_projection = model_args.prefix_projection
# config.prefix_hidden_size = model_args.prefix_hidden_size
# config.hidden_dropout_prob = model_args.hidden_dropout_prob

tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
Expand All @@ -97,20 +80,30 @@ def main():
# build text model

Check notice on line 80 in alignment/main.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/main.py#L80

Trailing whitespace
text_model = AutoModel.from_pretrained(model_args.model_name_or_path, add_pooling_layer=False)

alignment_model = ContrastiveAlignmentModel(ctr_model = ctr_model,
text_model = text_model,
model_args = model_args,
data_args = data_args,
train_args = training_args)

# build alignment train model
if training_args.alignment_mode == "contrastive":
alignment_model = ContrastiveAlignmentModel(ctr_model = ctr_model,
text_model = text_model,
model_args = model_args,
data_args = data_args,
train_args = training_args)
data_collator=ContrastiveAlignmentCollator(tokenizer=tokenizer)

elif training_args.alignment_mode == "mlm":
alignment_model = MlmAlignmentModel(ctr_model = ctr_model,
text_model = text_model,
model_args = model_args,
data_args = data_args,
train_args = training_args)
data_collator=MlmAlignmentCollator(tokenizer=tokenizer)
else:
raise ValueError("Alignment mode must be in [contrastive, mlm]")

trainer = Trainer(
model=alignment_model,
args=training_args,
train_dataset=train_dataset,
data_collator=ContrastiveAlignmentCollator(
tokenizer
),
data_collator=data_collator,
)

trainer.train()
Expand Down
9 changes: 0 additions & 9 deletions alignment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,12 @@ def __init__(
self.train_args = train_args
self.data_args = data_args

# 判断是否需要维度对齐
self.text_model_config = AutoConfig.from_pretrained(
self.model_args.model_name_or_path,
cache_dir=self.model_args.cache_dir,
revision=self.model_args.model_revision,
use_auth_token=True if self.model_args.use_auth_token else None,
)
# if self.model_args.ctr_hidden_dim == text_model_config.hidden_size:
# logger.info("CTR hidden size equal to Text model hidden size")
# self.add_pooler = False
# else:
# logger.warning("CTR hidden size not equal to Text model hidden size, add pooler layer")
# self.add_pooler = True
# self.pooler = LinearPooler(input_dim=text_model_config.hidden_size,
# output_dim=self.model_args.ctr_hidden_dim)

self.prompt_layers = nn.Sequential(
nn.Linear(self.model_args.ctr_hidden_dim, self.text_model_config.hidden_size),
Expand Down

0 comments on commit b47b36c

Please sign in to comment.