Skip to content

Commit

Permalink
add validation set to fine tunning
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranfah committed May 20, 2024
1 parent bfa32d7 commit d51cbfb
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 44 deletions.
28 changes: 21 additions & 7 deletions chemlactica/mol_opt/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def optimize(
file = open(config["log_dir"], "w")
print("config", config)
# print("molecule generation arguments", config["generation_config"])
pool = Pool(config["pool_size"])
pool = Pool(config["pool_size"], validation_perc=config["validation_perc"])

max_score = 0
tol_level = 0
Expand Down Expand Up @@ -129,7 +129,7 @@ def optimize(
current_optim_entries[i].last_entry = entry
iter_optim_entries.append(current_optim_entries[i])
file.write(f"generated smiles: {entry.smiles}, score: {entry.score:.4f}\n")
if entry.score > max_score + 0.01:
if entry.score > max_score:
max_score = entry.score
tol_level = 0
if oracle.finish or len(iter_optim_entries) >= config["num_gens_per_iter"]:
Expand Down Expand Up @@ -158,21 +158,35 @@ def optimize(
# top_k = int(len(all_entries) * config["rej_sample_config"]["rej_perc"])
# if top_k >= config["rej_sample_config"]["num_samples_per_round"]:
if config["rej_sample_config"]["should_train"](num_iter, tol_level, prev_train_iter):
training_entries = pool.optim_entries
print(f"Num of train examples {len(training_entries)}.")
train_entries, validation_entries = pool.get_train_valid_entries()
print(f"Num of training examples: {len(train_entries)}, num of validation examples: {len(validation_entries)}.")
file.write("Training entries\n")
for i, optim_entry in enumerate(training_entries):
for i, optim_entry in enumerate(train_entries):
file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n")
file.write("Validation entries\n")
for i, optim_entry in enumerate(validation_entries):
file.write(f"\t{i} smiles: {optim_entry.last_entry.smiles}, score: {optim_entry.last_entry.score:.4f}\n")

train_dataset = Dataset.from_dict({
"sample": [
optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config)
for optim_entry in training_entries
for optim_entry in train_entries
]
})
validation_dataset = Dataset.from_dict({
"sample": [
optim_entry.to_prompt(is_generation=False, include_oracle_score=True, config=config)
for optim_entry in validation_entries
]
})
train_dataset.shuffle(seed=42)
validation_dataset.shuffle(seed=42)
config["rej_sample_config"]["formatting_func"] = lambda x: x["sample"]
supervised_fine_tune(model, tokenizer, train_dataset, config["rej_sample_config"])
supervised_fine_tune(
model, tokenizer,
train_dataset, validation_dataset,
config["rej_sample_config"]
)
gc.collect()
torch.cuda.empty_cache()
prev_train_iter = num_iter
88 changes: 59 additions & 29 deletions chemlactica/mol_opt/tunning.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,79 @@
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_callback import TrainerControl, TrainerState, TrainerCallback
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup, TrainerCallback
from transformers import TrainingArguments, get_polynomial_decay_schedule_with_warmup, EarlyStoppingCallback
from torch.optim.lr_scheduler import ConstantLR
import torch
import math


class CustomSFTTrainer(SFTTrainer):
class CustomEarlyStopCallback(TrainerCallback):

def __init__(self, *args, patience, toll, **kwargs):
super().__init__(*args, **kwargs)
self.patience = patience
self.initial_pat = patience
self.toll = toll
self.best_loss = math.inf
def __init__(self, early_stopping_patience: int, early_stopping_threshold: float) -> None:
super().__init__()
self.best_valid_loss = math.inf
self.early_stopping_patience = early_stopping_patience
self.current_patiance = 0
self.early_stopping_threshold = early_stopping_threshold

def log(self, logs) -> None:
if logs.get("loss"):
curr_loss = logs["loss"]
if curr_loss > self.best_loss - self.toll:
self.patience -= 1
print(f"loss did not improve, patience {self.patience}")
else:
print("loss improved")
self.best_loss = curr_loss
self.patience = self.initial_pat
if self.patience == 0:
print("The loss does not improve, stop training.")
self.control.should_training_stop = True
return super().log(logs)
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
self.best_valid_loss = math.inf
self.current_patiance = 0
return super().on_train_begin(args, state, control, **kwargs)

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs):
if metrics["eval_loss"] >= self.best_valid_loss - self.early_stopping_threshold:
self.current_patiance += 1
else:
self.current_patiance = 0
self.best_valid_loss = metrics["eval_loss"]
print(f"Early Stopping patiance: {self.current_patiance}/{self.early_stopping_patience}")
if self.current_patiance >= self.early_stopping_patience:
control.should_training_stop = True
return super().on_evaluate(args, state, control, **kwargs)


# class CustomSFTTrainer(SFTTrainer):
#
# def __init__(self, *args, patience, toll, **kwargs):
# super().__init__(*args, **kwargs)
# self.patience = patience
# self.initial_pat = patience
# self.toll = toll
# self.best_loss = math.inf

# def log(self, logs) -> None:
# if logs.get("loss"):
# curr_loss = logs["loss"]
# if curr_loss > self.best_loss - self.toll:
# self.patience -= 1
# print(f"loss did not improve, patience {self.patience}")
# else:
# print("loss improved")
# self.best_loss = curr_loss
# self.patience = self.initial_pat
# if self.patience == 0:
# print("The loss does not improve, stop training.")
# self.control.should_training_stop = True
# return super().log(logs)


def supervised_fine_tune(
model, tokenizer,
train_dataset, config
train_dataset, validation_dataset, config
):
model.train()
training_args = TrainingArguments(
output_dir=config["checkpoints_dir"],
per_device_train_batch_size=config["train_batch_size"],
max_grad_norm=config["global_gradient_norm"],
num_train_epochs=config["num_train_epochs"],
evaluation_strategy="no",
evaluation_strategy="epoch",
dataloader_drop_last=False,
dataloader_pin_memory=True,
dataloader_num_workers=config["dataloader_num_workers"],
gradient_accumulation_steps=config["gradient_accumulation_steps"],
logging_steps=1,
metric_for_best_model="loss",
metric_for_best_model="loss"
)
optimizer = torch.optim.AdamW(
model.parameters(),
Expand All @@ -65,17 +91,21 @@ def supervised_fine_tune(
collator = DataCollatorForCompletionOnlyLM(
config["response_template"], tokenizer=tokenizer
)
trainer = CustomSFTTrainer(
early_stopping_callback = CustomEarlyStopCallback(
early_stopping_patience=1,
early_stopping_threshold=0.001
)
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
formatting_func=config["formatting_func"],
args=training_args,
packing=config["packing"],
tokenizer=tokenizer,
max_seq_length=config["max_seq_length"],
# data_collator=collator,
optimizers=[optimizer, lr_scheduler],
patience=2,
toll=0.0001
callbacks=[early_stopping_callback],
)
trainer.train()
48 changes: 40 additions & 8 deletions chemlactica/mol_opt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,22 @@ def __repr__(self):


class Pool:
def __init__(self, size):
def __init__(self, size, validation_perc: float):
self.size = size
self.optim_entries: List[OptimEntry] = []
self.num_validation_entries = int(size * validation_perc)

# def random_dump(self, num):
# for _ in range(num):
# rand_ind = random.randint(0, num - 1)
# self.molecule_entries.pop(rand_ind)
# print(f"Dump {num} random elements from pool, num pool mols {len(self)}")

def add(self, entries: List[MoleculeEntry], diversity_score=1.0):
def add(self, entries: List, diversity_score=1.0):
assert type(entries) == list
self.optim_entries.extend(entries)
self.optim_entries.sort(key=lambda x: x.last_entry, reverse=True)

# print(f"Updating with div_score {diversity_score:.4f}")
# remove doublicates
new_optim_entries = []
for entry in self.optim_entries:
Expand All @@ -116,6 +116,34 @@ def add(self, entries: List[MoleculeEntry], diversity_score=1.0):
new_optim_entries.append(entry)

self.optim_entries = new_optim_entries[: min(len(new_optim_entries), self.size)]
curr_num_validation_entries = sum([entry.entry_status == EntryStatus.valid for entry in self.optim_entries])

i = 0
while curr_num_validation_entries < self.num_validation_entries:
if self.optim_entries[i].entry_status == EntryStatus.none:
self.optim_entries[i].entry_status = EntryStatus.valid
curr_num_validation_entries += 1
i += 1

for j in range(i, len(self.optim_entries)):
if self.optim_entries[j].entry_status == EntryStatus.none:
self.optim_entries[j].entry_status = EntryStatus.train

curr_num_validation_entries = sum([entry.entry_status == EntryStatus.valid for entry in self.optim_entries])
assert curr_num_validation_entries == min(len(self.optim_entries), self.num_validation_entries)

def get_train_valid_entries(self):
train_entries = []
valid_entries = []
for entry in self.optim_entries:
if entry.entry_status == EntryStatus.train:
train_entries.append(entry)
elif entry.entry_status == EntryStatus.valid:
valid_entries.append(entry)
else:
raise Exception(f"EntryStatus of an entry in pool cannot be {entry.entry_status}.")
assert min(len(self.optim_entries), self.num_validation_entries) == len(valid_entries)
return train_entries, valid_entries

def random_subset(self, subset_size):
rand_inds = np.random.permutation(min(len(self.optim_entries), subset_size))
Expand All @@ -124,9 +152,6 @@ def random_subset(self, subset_size):
def __len__(self):
return len(self.optim_entries)

def __hash__(self):
return hash(self.smiles)


def make_output_files_base(input_path, results_dir, run_name, config):
formatted_date_time = datetime.datetime.now().strftime("%Y-%m-%d")
Expand All @@ -151,10 +176,17 @@ def create_prompt_with_similars(mol_entry: MoleculeEntry, sim_range=None):
return prompt


class EntryStatus:
none = 0
train = 1
valid = 2


class OptimEntry:
def __init__(self, last_entry, mol_entries):
self.last_entry: MoleculeEntry = last_entry
self.mol_entries: List[MoleculeEntry] = mol_entries
self.entry_status: EntryStatus = EntryStatus.none

def to_prompt(self, is_generation: bool, include_oracle_score: bool, config):
prompt = ""
Expand Down Expand Up @@ -192,8 +224,8 @@ def to_prompt(self, is_generation: bool, include_oracle_score: bool, config):
else 0
)
desired_oracle_score = generate_random_number(
q_0_9, 1.0
) # TODO: change the hard coded 1.0
q_0_9, config["max_possible_oracle_score"]
)
oracle_score = desired_oracle_score
else:
oracle_score = self.last_entry.score
Expand Down

0 comments on commit d51cbfb

Please sign in to comment.