Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev llm enhance ctr gwn #293

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = '1.0.0'
125 changes: 125 additions & 0 deletions alignment/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os

Check notice on line 1 in alignment/arguments.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/arguments.py#L1

'os' imported but unused (F401)

Check warning on line 1 in alignment/arguments.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/arguments.py#L1

Unused import os
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments


@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
target_model_path: str = field(
default=None,
metadata={"help": "Path to pretrained reranker target model"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)

# modeling
untie_encoder: bool = field(
default=False,
metadata={"help": "no weight sharing between qry passage encoders"}
)
fine_tuning: bool = field(
default=False,
metadata={"help": "whether to fix plm parameters"}
)

# out projection
add_pooler: bool = field(default=False)
projection_in_dim: int = field(default=768)
projection_out_dim: int = field(default=1)

prefix_len: int = field(
default=32,
metadata={
"help": "The length of prompt"
}
)
prefix_projection: bool = field(
default=False,
metadata={
"help": "Apply a two-layer MLP head over the prefix embeddings"
}
)
prefix_hidden_size: int = field(
default=512,
metadata={
"help": "The hidden size of the MLP projection head in Prefix Encoder if prefix projection is used"
}
)
hidden_dropout_prob: float = field(
default=0.1,
metadata={
"help": "The dropout probability used in the models"
}
)

ctr_hidden_dim: int = field(
default=128,
metadata={
"help": "The hidden size of the CTR model last represente layer output"
}
)


@dataclass
class DataArguments:
train_dir: str = field(
default=None, metadata={"help": "Path to train directory"}
)

train_file: str = field(
default=None, metadata={"help": "the whole data for train & eval & test"}
)

feature_config_file: str = field(
default=None, metadata={"help": "feature config file path for ctr model"}
)

dataset_name: str = field(
default=None, metadata={"help": "huggingface dataset name"}
)
dataset_proc_num: int = field(
default=12, metadata={"help": "number of proc used in dataset preprocess"}
)
train_n_passages: int = field(default=8)
positive_passage_no_shuffle: bool = field(
default=False, metadata={"help": "always use the first positive passage"})
negative_passage_no_shuffle: bool = field(
default=False, metadata={"help": "always use the first negative passages"})


encode_in_path: str = field(default=None, metadata={"help": "Path to data to encode"})
encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"})
encode_is_qry: bool = field(default=False)
encode_num_shard: int = field(default=1)
encode_shard_index: int = field(default=0)

q_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for query. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
p_max_len: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)


@dataclass
class AlignmentTrainingArguments(TrainingArguments):
alignment_mode: str = field(default="contrastive", metadata={"help": "contrastive or mlm"})
172 changes: 172 additions & 0 deletions alignment/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import json
from dataclasses import dataclass
from typing import Optional, List, Tuple
import numpy as np

import torch
import torch.utils.data as Data
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, DataCollatorWithPadding
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import pandas as pd
import torch
from sklearn.preprocessing import LabelEncoder


from .arguments import DataArguments
from ..deepctr_torch.inputs import get_feature_names
from ..deepctr_torch.inputs import (DenseFeat, SparseFeat, VarLenSparseFeat)

Check notice on line 18 in alignment/data.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/data.py#L18

'..deepctr_torch.inputs.DenseFeat' imported but unused (F401)

Check warning on line 18 in alignment/data.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/data.py#L18

Unused DenseFeat imported from deepctr_torch.inputs

import logging
logger = logging.getLogger(__name__)


class AlignmentDataset(Dataset):
"""
对齐任务的数据加载器
"""
def __init__(
self,
data_args: DataArguments,
tokenizer: PreTrainedTokenizer,
):
self.tok = tokenizer
self.data_args = data_args


self.feature_config = json.load(data_args.feature_config_file)
self.df = pd.read_csv(data_args.train_file)

# construct nlp model input
self.text_feature_fields = self.feature_config["text_feature_list"] # ["gender", "age", "occupation"]
# list of string
self.train_data_text = []
for _, row in self.df.iterrows():
# same as paper setting https://arxiv.org/pdf/2310.09234.pdf
text = " ".join(["{} is {}".format(feat, row[feat]) for feat in self.text_feature_fields])
self.train_data_text.append(text)

# construct ctr model input
self.sparse_features = self.feature_config["sparse_feature_list"] # ["movie_id", "user_id", "gender", "age", "occupation", "zip"]
self.target = self.feature_config["label"] # rating
# feature value tranform
for feat in self.sparse_features:
lbe = LabelEncoder()
self.df[feat] = lbe.fit_transform(self.df[feat])
# feature columns construct
fixlen_feature_columns = [SparseFeat(feat, self.df[feat].nunique()) for feat in self.sparse_features]
# for DeepFM input
self.linear_feature_columns = fixlen_feature_columns
self.dnn_feature_columns = fixlen_feature_columns
feature_names = get_feature_names(self.linear_feature_columns + self.dnn_feature_columns)
# ctr_model input: list of list
self.train_data_ctr = np.array([self.df[name].values.tolist() for name in feature_names]).T.tolist()

# must correspond
assert len(self.train_data_text) == len(self.train_data_ctr)

self.total_len = len(self.train_data_ctr)

def create_one_example(self, text_encoding: List[int]):
item = self.tok.encode_plus(
text_encoding,
truncation='only_first',
max_length=self.data_args.prompt_max_len,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
)
return item

def __len__(self):
return self.total_len

def __getitem__(self, item):
# text model input
text_feature = self.train_data_text[item]
text_model_input = self.create_one_example(text_feature)
# ctr model input
ctr_model_input = self.train_data_ctr[item]
return {"text_model_input" : text_model_input,
"ctr_model_input" : ctr_model_input}


# data collector for contrastive alignment
@dataclass
class ContrastiveAlignmentCollator(DataCollatorWithPadding):

tokenizer: PreTrainedTokenizerBase
max_len: int = 64

def __call__(self, features):

# batch inputs for nlp model
text_input = [feat_map["text_model_input"] for feat_map in features]
text_input_batch = self.tokenizer.pad(
text_input,
padding='max_length',
max_length=self.max_len,
return_tensors="pt",
)
# batch inputs for ID Model
ctr_input_batch = [feat_map["ctr_model_input"] for feat_map in features]
ctr_input_batch = Data.TensorDataset(ctr_input_batch)
return ctr_input_batch, text_input_batch


# data collector for mask language modeling alignment
@dataclass
class MlmAlignmentCollator(DataCollatorWithPadding):

tokenizer: PreTrainedTokenizerBase
max_len: Optional[int] = 64
mlm_probability: float = 0.15

def __call__(self, features):

# batch inputs for nlp model
text_input = [feat_map["text_model_input"] for feat_map in features]
batch = self.tokenizer.pad(
text_input,
padding='max_length',
max_length=self.max_len,
return_tensors="pt",
)
# generate input & label for mlm train
batch["mlm_input_ids"], batch["mlm_labels"] = self.mask_tokens(batch["input_ids"])
# batch inputs for CTR Model
ctr_input_batch = [feat_map["ctr_model_input"] for feat_map in features]
batch["ctr_input_ids"] = Data.TensorDataset(ctr_input_batch)
return batch

Check notice on line 141 in alignment/data.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

alignment/data.py#L141

Trailing whitespace
def mask_tokens(
self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:

inputs = inputs.clone()
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()

probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
5 changes: 5 additions & 0 deletions alignment/data_example/ctr_model/feature_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"sparse_feature_list" : ["movie_id", "user_id", "gender", "age", "occupation", "zip"],
"label" : "rating",
"text_feature_list" : ["gender", "age", "occupation"]
}
Loading
Loading