From 696fe7af33d6d93a8f2dc17ce13fed6c0038510f Mon Sep 17 00:00:00 2001 From: Junhao Shen Date: Sun, 15 Sep 2024 16:23:59 -0500 Subject: [PATCH 1/4] add GLEM model, TAGDataset and example of GLEM --- CHANGELOG.md | 3 +- examples/llm/README.md | 5 + examples/llm/glem.py | 879 ++++++++++++++++++++++++ torch_geometric/datasets/__init__.py | 2 + torch_geometric/datasets/tag_dataset.py | 344 ++++++++++ torch_geometric/nn/models/__init__.py | 3 +- torch_geometric/nn/models/glem.py | 366 ++++++++++ 7 files changed, 1600 insertions(+), 2 deletions(-) create mode 100644 examples/llm/glem.py create mode 100644 torch_geometric/datasets/tag_dataset.py create mode 100644 torch_geometric/nn/models/glem.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 84c50c4d411b..57b910d89e72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## \[2.7.0\] - 2024-MM-DD ### Added - +- Added `nn.Models.GLEM` ([#9661](https://github.com/pyg-team/pytorch_geometric/pull/9661)) +- Added `TAGDataset` ([#9661](https://github.com/pyg-team/pytorch_geometric/pull/9661)) ### Changed ### Deprecated diff --git a/examples/llm/README.md b/examples/llm/README.md index f1f01428d991..2f28dea9c66a 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -3,3 +3,8 @@ | Example | Description | | ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | | [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | + +## Run GLEM for getting SOTA result on ogbn-products dataset + +`python glem.py` \ No newline at end of file diff --git a/examples/llm/glem.py b/examples/llm/glem.py new file mode 100644 index 000000000000..46daeeddf5f4 --- /dev/null +++ b/examples/llm/glem.py @@ -0,0 +1,879 @@ +"""This example run GLEM model using PyG. +Original Paper: https://arxiv.org/abs/2210.14709 +“Learning on Large-scale Text-attributed Graphs via Variational Inference“. +Requirements on top of basic PyG: +`pip install ogb transformers peft tqdm`. +GLEM is a data augmentation co-training strategy for LM and GNN, our +implementation extended original implementation from LM to LLM and opt for LoRA +from peft. + +``note:: + use addtional trick, please add your external prediction by assigning + `ext_pred_path` and combine it into pretraining phase and node features +""" + +import argparse +import os +import os.path as osp +import sys +import time + +import torch + +from torch_geometric import seed_everything +from torch_geometric.data import download_google_url +from torch_geometric.datasets import TAGDataset +from torch_geometric.nn.models import GAT, GCN, GLEM, GraphSAGE + +# Add the parent directory to sys.path +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(parent_dir) + + +def get_n_params(model): + pp = 0 + for p in list(model.parameters()): + nn = 1 + for s in list(p.size()): + nn = nn * s + pp += nn + return pp + + +def main(args): + gpu = args.gpu + dataset_name = args.dataset + root = osp.join(parent_dir, 'data', 'ogb') + hf_model = args.hf_model + pl_ratio = args.pl_ratio + gnn_lr = args.gnn_lr + lm_lr = args.lm_lr + em_order = args.em_order + gnn_epochs = args.gnn_epochs + lm_epochs = args.lm_epochs + patience = args.patience + verbose = args.verbose + out_dir = args.out_dir + lm_batch_size = args.lm_batch_size + gnn_batch_size = args.gnn_batch_size + lm_use_lora = args.lm_use_lora + token_on_disk = args.token_on_disk + num_em_iters = args.num_em_iters + start_time = time.time() + train_without_ext_pred = args.train_without_ext_pred + ext_pred = None + pretrain_augmented = False + ext_pseudo_labels = None + device = torch.device( + f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu') + print(f'Running on: {torch.cuda.get_device_name({gpu})}') + torch.cuda.empty_cache() + + if not train_without_ext_pred: + ext_pred_path = download_google_url( + id='15sO2m7BeW7C1Upmdw3Cx1JS__6nxTAzY', + folder='/work/users/junhaos/glem_data/ogbn_products/ext_preds', + filename='giant_sagn_scr.pt', log=True) + ext_pred = torch.load(ext_pred_path, map_location=device) + ext_pseudo_labels = ext_pred.argmax(dim=-1) + pretrain_augmented = True + + seed_everything(42) + from ogb.nodeproppred import PygNodePropPredDataset + dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) + split_idx = dataset.get_idx_split() + data = dataset.data + + tag_dataset = TAGDataset(root, dataset, hf_model, + token_on_disk=token_on_disk) + text_dataset = tag_dataset.to_text_dataset() + print(tag_dataset.num_classes, tag_dataset.raw_file_names) + + num_classes = tag_dataset.num_classes + num_features = data.num_features + # =========================== LM Data split =============================== + split_idx = tag_dataset.get_idx_split() + + # GLEM train with augmented data, mark original train data as gold data, + gold_idx = split_idx['train'] + test_idx = split_idx['test'] + + # randome sample pseudo labels nodes, generate their index + num_pseudo_labels = int(gold_idx.numel() * pl_ratio) + idx_to_select = torch.randperm(test_idx.numel())[:num_pseudo_labels] + pseudo_labels_idx = test_idx[idx_to_select] + train_idx = torch.cat( + (gold_idx, pseudo_labels_idx)) # augmented train_indx + + print(f'train_idx: {train_idx.size(0)}, ' + f'gold_idx: {gold_idx.size(0)}, ' + f'pseudo labels ratio: {pl_ratio}, ' + f'{train_idx.size(0)/gold_idx.size(0) - 1.0}') + gold_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=gold_idx) + train_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=train_idx) + # ========================== LM Data Loader =============================== + + print('Building language model dataloader...', end='-->') + from torch_geometric.loader import DataLoader + + # if set train_without_ext_pred == True, use this for pretrain + text_pretrain_loader = DataLoader(gold_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + # training with augmented data, + text_train_loader = DataLoader(train_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + text_data_loader = DataLoader(text_dataset, batch_size=lm_batch_size * 4, + drop_last=False, pin_memory=True, + shuffle=False) + print('done') + + # =========================== GNN Data Loader ============================= + initial_memory = torch.cuda.memory_allocated() + data = data.to(device) + if ext_pred is not None: + data.x = torch.cat((data.x, ext_pred), dim=1) + num_features += ext_pred.size(1) + current_memory_1 = torch.cuda.max_memory_allocated() + # 1 GB = 1073741824 Byte + gpu_usage = float(current_memory_1 - initial_memory) / 1073741824 + # Print the maximum memory usage after running the model + print(f'GPU memory usage -- data to gpu: {gpu_usage:.2f} GB') + + print('build GNN dataloader(GraphSAGE NeighborLoader)', end='-->') + from torch_geometric.loader import NeighborLoader + + # train on gold data w/o pseudo labels + graph_pretrain_loader = NeighborLoader( + data, + input_nodes=gold_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # graph data loader w/ pseudo labels in M-step + graph_train_loader = NeighborLoader( + data, + input_nodes=train_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # for gnn inference + subgraph_loader = NeighborLoader( + data, + input_nodes=None, + num_neighbors=[-1], + batch_size=gnn_batch_size * 4, + num_workers=12, + persistent_workers=True, + ) + # =========================== internal function =========================== + + from ogb.nodeproppred import Evaluator + evaluator = Evaluator(name=f'ogbn-{dataset_name}') + + def evaluate(out, split): + y_true = data.y.cpu() + y_pred = out.argmax(dim=-1, keepdim=True) + train_acc, val_acc, test_acc = None, None, None + if 'train' in split: + train_acc = evaluator.eval({ + 'y_true': y_true[split_idx['train']], + 'y_pred': y_pred[split_idx['train']], + })['acc'] + if 'valid' in split: + val_acc = evaluator.eval({ + 'y_true': y_true[split_idx['valid']], + 'y_pred': y_pred[split_idx['valid']], + })['acc'] + if 'test' in split: + test_acc = evaluator.eval({ + 'y_true': y_true[split_idx['test']], + 'y_pred': y_pred[split_idx['test']], + })['acc'] + + return train_acc, val_acc, test_acc + + # =========================== Build GNN Model ============================= + gnn = None + if args.gnn_model == 'SAGE': + gnn = GraphSAGE( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + elif args.gnn_model == 'GAT': + gnn = GAT(in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, heads=args.gat_heads) + else: + gnn = GCN( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + + print("# GNN Params:", get_n_params(gnn)) + # =========================== Build LM Model ============================== + + model = GLEM(lm_to_use=hf_model, gnn_to_use=gnn, out_channels=num_classes, + lm_use_lora=lm_use_lora, device=device) + lm = model.lm + print("# LM Params:", get_n_params(lm)) + gnn_opt = torch.optim.Adam(gnn.parameters(), lr=gnn_lr) + lm_opt = torch.optim.Adam(lm.parameters(), lr=lm_lr) + + def load_model(em_phase): + print(f'Move {em_phase} model from cpu memory') + if em_phase == 'lm': + model.lm = model.lm.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.lm.parameters(), lr=lm_lr) + if em_phase == 'gnn': + model.gnn = model.gnn.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.gnn.parameters(), lr=gnn_lr) + return optimizer + + # ================================= Run GLEM ============================== + preds_filename = 'lm_pretrain' + preds_dir = f'{out_dir}preds/{dataset_name}/' + gnn_test_acc = 0.0 + lm_test_acc = 0.0 + # =============================== GLEM pretraining ======================== + pretrain_phase = 'lm' + if em_order == 'lm': + pretrain_phase = 'gnn' + pretrain_start_time = time.time() + # pretraining + pretrain_loader = graph_pretrain_loader + test_loader = subgraph_loader + pretrain_num_epochs = gnn_epochs + pretrain_opt = gnn_opt + if pretrain_phase == 'gnn': + model.gnn = model.gnn.to(device) + print('pretraining gnn to generate pseudo labels') + if not train_without_ext_pred: + pretrain_loader = graph_train_loader + preds_filename = 'gnn_pretrain' + elif pretrain_phase == 'lm': + model.lm = model.lm.to(device) + print('pretraining lm to generate pseudo labels') + pretrain_num_epochs = lm_epochs + pretrain_loader = text_pretrain_loader + test_loader = text_data_loader + pretrain_opt = lm_opt + if not train_without_ext_pred: + pretrain_loader = text_train_loader + preds_filename = 'lm_pretrain' + + early_stopping = 0 + best_val_acc = final_test_acc = 0.0 + for epoch in range(1, pretrain_num_epochs + 1): + acc, loss = model.train(pretrain_phase, pretrain_loader, pretrain_opt, + ext_pseudo_labels, epoch, pretrain_augmented, + verbose) + if epoch >= 5 or epoch == pretrain_num_epochs: + pretrain_preds = model.inference(pretrain_phase, test_loader, + verbose=verbose) + train_acc, val_acc, test_acc = evaluate(pretrain_preds, + ['train', 'valid', 'test']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Pretrain Early stopped by Epoch: {epoch}') + break + else: + best_val_acc = val_acc + final_test_acc = test_acc + preds = pretrain_preds + + if pretrain_phase == 'gnn': + gnn_test_acc = max(gnn_test_acc, final_test_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + else: + lm_test_acc = max(lm_test_acc, final_test_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + torch.cuda.empty_cache() + + pretrain_phase_time = time.time() - pretrain_start_time + print(f'Pretrain {pretrain_phase} time: {pretrain_phase_time:.2f}s') + os.makedirs(osp.dirname(preds_dir), exist_ok=True) + torch.save(preds, osp.join(preds_dir, f'{preds_filename}.pt')) + print( + f'Saved predictions to {osp.join(preds_dir, f"{preds_filename}.pt")}') + train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) + print(f'Pretraining acc: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + # EM iterations + + em_phase = em_order + """ + We run E-step(LM training) and M-Step(GNN training) alternatively in each + em iterations, so the total number of iterations is num_em_iter * 2 and + we switch the em_phase at end of each iteration in following loop + """ + for em_it in range(1, num_em_iters * 2 + 1): + pseudo_labels = preds.argmax(dim=-1) + best_val_acc = final_test_acc = 0.0 + print(f'EM iteration: {em_it}, EM phase: {em_phase}') + optimizer = load_model(em_phase) + num_epochs = lm_epochs + train_loader = text_train_loader + test_loader = text_data_loader + early_stopping = 0 + if em_phase == 'gnn': + train_loader = graph_train_loader + num_epochs = gnn_epochs + test_loader = subgraph_loader + for epoch in range(1, num_epochs + 1): + acc, loss = model.train(em_phase, train_loader, optimizer, + pseudo_labels, epoch, True, verbose) + if epoch >= 5 or epoch == num_epochs: + cur_preds = model.inference(em_phase, test_loader, + verbose=verbose) + train_acc, val_acc, test_acc = evaluate( + cur_preds, ['train', 'valid', 'test']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'''Early stopped by Epoch: {epoch}, \ + Best acc: {final_test_acc}''') + break + else: + best_val_acc = val_acc + final_test_acc = test_acc + preds = cur_preds +"""This example run GLEM model using PyG. +Original Paper: https://arxiv.org/abs/2210.14709 +“Learning on Large-scale Text-attributed Graphs via Variational Inference“. +Requirements on top of basic PyG: +`pip install ogb transformers peft tqdm`. +GLEM is a data augmentation co-training strategy for LM and GNN, our +implementation extended original implementation from LM to LLM and opt for LoRA +from peft. + +``note:: + use addtional trick, please add your external prediction by assigning + `ext_pred_path` and combine it into pretraining phase and node features +""" + +import argparse +import os +import os.path as osp +import sys +import time + +import torch + +from torch_geometric import seed_everything +from torch_geometric.data import download_google_url +from torch_geometric.datasets import TAGDataset +from torch_geometric.nn.models import GAT, GCN, GLEM, GraphSAGE + +# Add the parent directory to sys.path +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(parent_dir) + + +def get_n_params(model): + pp = 0 + for p in list(model.parameters()): + nn = 1 + for s in list(p.size()): + nn = nn * s + pp += nn + return pp + + +def main(args): + gpu = args.gpu + dataset_name = args.dataset + root = osp.join(parent_dir, 'data', 'ogb') + hf_model = args.hf_model + pl_ratio = args.pl_ratio + gnn_lr = args.gnn_lr + lm_lr = args.lm_lr + em_order = args.em_order + gnn_epochs = args.gnn_epochs + lm_epochs = args.lm_epochs + patience = args.patience + verbose = args.verbose + out_dir = args.out_dir + lm_batch_size = args.lm_batch_size + gnn_batch_size = args.gnn_batch_size + lm_use_lora = args.lm_use_lora + token_on_disk = args.token_on_disk + num_em_iters = args.num_em_iters + start_time = time.time() + train_without_ext_pred = args.train_without_ext_pred + ext_pred = None + pretrain_augmented = False + ext_pseudo_labels = None + device = torch.device( + f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu') + print(f'Running on: {torch.cuda.get_device_name({gpu})}') + torch.cuda.empty_cache() + + if not train_without_ext_pred: + ext_pred_path = download_google_url( + id='15sO2m7BeW7C1Upmdw3Cx1JS__6nxTAzY', + folder='/work/users/junhaos/glem_data/ogbn_products/ext_preds', + filename='giant_sagn_scr.pt', log=True) + ext_pred = torch.load(ext_pred_path, map_location=device) + ext_pseudo_labels = ext_pred.argmax(dim=-1) + pretrain_augmented = True + + seed_everything(42) + from ogb.nodeproppred import PygNodePropPredDataset + dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) + split_idx = dataset.get_idx_split() + data = dataset.data + + tag_dataset = TAGDataset(root, dataset, hf_model, + token_on_disk=token_on_disk) + text_dataset = tag_dataset.to_text_dataset() + print(tag_dataset.num_classes, tag_dataset.raw_file_names) + + num_classes = tag_dataset.num_classes + num_features = data.num_features + # =========================== LM Data split =============================== + split_idx = tag_dataset.get_idx_split() + + # GLEM train with augmented data, mark original train data as gold data, + gold_idx = split_idx['train'] + test_idx = split_idx['test'] + + # randome sample pseudo labels nodes, generate their index + num_pseudo_labels = int(gold_idx.numel() * pl_ratio) + idx_to_select = torch.randperm(test_idx.numel())[:num_pseudo_labels] + pseudo_labels_idx = test_idx[idx_to_select] + train_idx = torch.cat( + (gold_idx, pseudo_labels_idx)) # augmented train_indx + + print(f'train_idx: {train_idx.size(0)}, ' + f'gold_idx: {gold_idx.size(0)}, ' + f'pseudo labels ratio: {pl_ratio}, ' + f'{train_idx.size(0)/gold_idx.size(0) - 1.0}') + gold_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=gold_idx) + train_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=train_idx) + # ========================== LM Data Loader =============================== + + print('Building language model dataloader...', end='-->') + from torch_geometric.loader import DataLoader + + # if set train_without_ext_pred == True, use this for pretrain + text_pretrain_loader = DataLoader(gold_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + # training with augmented data, + text_train_loader = DataLoader(train_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + text_data_loader = DataLoader(text_dataset, batch_size=lm_batch_size * 4, + drop_last=False, pin_memory=True, + shuffle=False) + print('done') + + # =========================== GNN Data Loader ============================= + initial_memory = torch.cuda.memory_allocated() + data = data.to(device) + if ext_pred is not None: + data.x = torch.cat((data.x, ext_pred), dim=1) + num_features += ext_pred.size(1) + current_memory_1 = torch.cuda.max_memory_allocated() + # 1 GB = 1073741824 Byte + gpu_usage = float(current_memory_1 - initial_memory) / 1073741824 + # Print the maximum memory usage after running the model + print(f'GPU memory usage -- data to gpu: {gpu_usage:.2f} GB') + + print('build GNN dataloader(GraphSAGE NeighborLoader)', end='-->') + from torch_geometric.loader import NeighborLoader + + # train on gold data w/o pseudo labels + graph_pretrain_loader = NeighborLoader( + data, + input_nodes=gold_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # graph data loader w/ pseudo labels in M-step + graph_train_loader = NeighborLoader( + data, + input_nodes=train_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # for gnn inference + subgraph_loader = NeighborLoader( + data, + input_nodes=None, + num_neighbors=[-1], + batch_size=gnn_batch_size * 4, + num_workers=12, + persistent_workers=True, + ) + # =========================== internal function =========================== + + from ogb.nodeproppred import Evaluator + evaluator = Evaluator(name=f'ogbn-{dataset_name}') + + def evaluate(out, split): + y_true = data.y.cpu() + y_pred = out.argmax(dim=-1, keepdim=True) + train_acc, val_acc, test_acc = None, None, None + if 'train' in split: + train_acc = evaluator.eval({ + 'y_true': y_true[split_idx['train']], + 'y_pred': y_pred[split_idx['train']], + })['acc'] + if 'valid' in split: + val_acc = evaluator.eval({ + 'y_true': y_true[split_idx['valid']], + 'y_pred': y_pred[split_idx['valid']], + })['acc'] + if 'test' in split: + test_acc = evaluator.eval({ + 'y_true': y_true[split_idx['test']], + 'y_pred': y_pred[split_idx['test']], + })['acc'] + + return train_acc, val_acc, test_acc + + # =========================== Build GNN Model ============================= + gnn = None + if args.gnn_model == 'SAGE': + gnn = GraphSAGE( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + elif args.gnn_model == 'GAT': + gnn = GAT(in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, heads=args.gat_heads) + else: + gnn = GCN( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + + print("# GNN Params:", get_n_params(gnn)) + # =========================== Build LM Model ============================== + + model = GLEM(lm_to_use=hf_model, gnn_to_use=gnn, out_channels=num_classes, + lm_use_lora=lm_use_lora, device=device) + lm = model.lm + print("# LM Params:", get_n_params(lm)) + gnn_opt = torch.optim.Adam(gnn.parameters(), lr=gnn_lr) + lm_opt = torch.optim.Adam(lm.parameters(), lr=lm_lr) + + def load_model(em_phase): + print(f'Move {em_phase} model from cpu memory') + if em_phase == 'lm': + model.lm = model.lm.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.lm.parameters(), lr=lm_lr) + if em_phase == 'gnn': + model.gnn = model.gnn.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.gnn.parameters(), lr=gnn_lr) + return optimizer + + # ================================= Run GLEM ============================== + preds_filename = 'lm_pretrain' + preds_dir = f'{out_dir}preds/{dataset_name}/' + gnn_test_acc = 0.0 + lm_test_acc = 0.0 + # =============================== GLEM pretraining ======================== + pretrain_phase = 'lm' + if em_order == 'lm': + pretrain_phase = 'gnn' + pretrain_start_time = time.time() + # pretraining + pretrain_loader = graph_pretrain_loader + test_loader = subgraph_loader + pretrain_num_epochs = gnn_epochs + pretrain_opt = gnn_opt + if pretrain_phase == 'gnn': + model.gnn = model.gnn.to(device) + print('pretraining gnn to generate pseudo labels') + if not train_without_ext_pred: + pretrain_loader = graph_train_loader + preds_filename = 'gnn_pretrain' + elif pretrain_phase == 'lm': + model.lm = model.lm.to(device) + print('pretraining lm to generate pseudo labels') + pretrain_num_epochs = lm_epochs + pretrain_loader = text_pretrain_loader + test_loader = text_data_loader + pretrain_opt = lm_opt + if not train_without_ext_pred: + pretrain_loader = text_train_loader + preds_filename = 'lm_pretrain' + + early_stopping = 0 + best_val_acc = final_test_acc = 0.0 + for epoch in range(1, pretrain_num_epochs + 1): + acc, loss = model.train(pretrain_phase, pretrain_loader, pretrain_opt, + ext_pseudo_labels, epoch, pretrain_augmented, + verbose) + if epoch >= 5 or epoch == pretrain_num_epochs: + pretrain_preds = model.inference(pretrain_phase, test_loader, + verbose=verbose) + train_acc, val_acc, test_acc = evaluate(pretrain_preds, + ['train', 'valid', 'test']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Pretrain Early stopped by Epoch: {epoch}') + break + else: + best_val_acc = val_acc + final_test_acc = test_acc + preds = pretrain_preds + + if pretrain_phase == 'gnn': + gnn_test_acc = max(gnn_test_acc, final_test_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + else: + lm_test_acc = max(lm_test_acc, final_test_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + torch.cuda.empty_cache() + + pretrain_phase_time = time.time() - pretrain_start_time + print(f'Pretrain {pretrain_phase} time: {pretrain_phase_time:.2f}s') + os.makedirs(osp.dirname(preds_dir), exist_ok=True) + torch.save(preds, osp.join(preds_dir, f'{preds_filename}.pt')) + print( + f'Saved predictions to {osp.join(preds_dir, f"{preds_filename}.pt")}') + train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) + print(f'Pretraining acc: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + # EM iterations + + em_phase = em_order + """ + We run E-step(LM training) and M-Step(GNN training) alternatively in each + em iterations, so the total number of iterations is num_em_iter * 2 and + we switch the em_phase at end of each iteration in following loop + """ + for em_it in range(1, num_em_iters * 2 + 1): + pseudo_labels = preds.argmax(dim=-1) + best_val_acc = final_test_acc = 0.0 + print(f'EM iteration: {em_it}, EM phase: {em_phase}') + optimizer = load_model(em_phase) + num_epochs = lm_epochs + train_loader = text_train_loader + test_loader = text_data_loader + early_stopping = 0 + if em_phase == 'gnn': + train_loader = graph_train_loader + num_epochs = gnn_epochs + test_loader = subgraph_loader + for epoch in range(1, num_epochs + 1): + acc, loss = model.train(em_phase, train_loader, optimizer, + pseudo_labels, epoch, True, verbose) + if epoch >= 5 or epoch == num_epochs: + cur_preds = model.inference(em_phase, test_loader, + verbose=verbose) + train_acc, val_acc, test_acc = evaluate( + cur_preds, ['train', 'valid', 'test']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'''Early stopped by Epoch: {epoch}, \ + Best acc: {final_test_acc}''') + break + else: + best_val_acc = val_acc + final_test_acc = test_acc + preds = cur_preds + + if em_phase == 'gnn': + gnn_test_acc = max(gnn_test_acc, final_test_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + em_phase = 'lm' + else: + lm_test_acc = max(lm_test_acc, final_test_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + em_phase = 'gnn' + torch.cuda.empty_cache() + print(f'Best GNN acc: {gnn_test_acc}, LM acc: {lm_test_acc}') + print('============================') + end_time = time.time() + running_time = (end_time - start_time) / 3600 + print(f'Total running time: {running_time:.2f} hours') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='GLEM Example:') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--num_runs', type=int, default=10, + help='number of runs') + parser.add_argument('--num_em_iters', type=int, default=1, + help='number of iterations') + parser.add_argument("--dataset", type=str, default='products', + help='arxiv or products') + parser.add_argument("--pl_ratio", type=float, default=0.5, + help="pseudo labels ratio") + parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', + help='huggingface model repo id') + parser.add_argument( + '--gnn_model', type=str, default='SAGE', + help='gnn model for node classification,' + 'options: SAGE, GAT, GCN') + parser.add_argument('--gnn_hidden_channels', type=int, default=256) + parser.add_argument('--gnn_num_layers', type=int, default=3) + parser.add_argument('--gat_heads', type=int, default=4, + help='Number of multi-head-attentions for GAT ') + parser.add_argument('--lm_batch_size', type=int, default=256) + parser.add_argument('--gnn_batch_size', type=int, default=1024) + parser.add_argument( + '--external_pred_path', type=str, default=None, + help="Other model's output logits during the " + "pretraining phase or simply concatenate it with" + "node features as augmented data for gnn") + parser.add_argument('--alpha', type=float, default=0.5, + help='pseudo label weight in E-step') + parser.add_argument('--beta', type=float, default=0.5, + help='pseudo label weight in M-step') + parser.add_argument('--lm_epochs', type=int, default=10) + parser.add_argument('--gnn_epochs', type=int, default=50) + parser.add_argument('--gnn_lr', type=float, default=0.002) + parser.add_argument('--lm_lr', type=float, default=0.001) + parser.add_argument('--patience', type=int, default=3, + help='Patience for early stopping') + parser.add_argument('--verbose', action='store_true', + help='show progress bar during training or not') + parser.add_argument('--em_order', type=str, default='lm', + help='decide train LM first or GNN first') + parser.add_argument('--lm_use_lora', action='store_true', + help='use Lora to fine-tune model or not') + parser.add_argument( + '--token_on_disk', action='store_true', + help='save token on disk and load token from disk' + 'for reducing duplicated tokenizing') + parser.add_argument('--out_dir', type=str, default='output/', + help='output directory') + parser.add_argument( + '--train_without_ext_pred', action='store_true', + help='train glem without using additional pseudo labels ' + 'for augmenting data only available for ogbn-products') + args = parser.parse_args() + print(args) + main(args) + if em_phase == 'gnn': + gnn_test_acc = max(gnn_test_acc, final_test_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + em_phase = 'lm' + else: + lm_test_acc = max(lm_test_acc, final_test_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + em_phase = 'gnn' + torch.cuda.empty_cache() + print(f'Best GNN acc: {gnn_test_acc}, LM acc: {lm_test_acc}') + print('============================') + end_time = time.time() + running_time = (end_time - start_time) / 3600 + print(f'Total running time: {running_time:.2f} hours') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='GLEM Example:') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--num_runs', type=int, default=10, + help='number of runs') + parser.add_argument('--num_em_iters', type=int, default=1, + help='number of iterations') + parser.add_argument("--dataset", type=str, default='products', + help='arxiv or products') + parser.add_argument("--pl_ratio", type=float, default=0.5, + help="pseudo labels ratio") + parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', + help='huggingface model repo id') + parser.add_argument( + '--gnn_model', type=str, default='SAGE', + help='gnn model for node classification,' + 'options: SAGE, GAT, GCN') + parser.add_argument('--gnn_hidden_channels', type=int, default=256) + parser.add_argument('--gnn_num_layers', type=int, default=3) + parser.add_argument('--gat_heads', type=int, default=4, + help='Number of multi-head-attentions for GAT ') + parser.add_argument('--lm_batch_size', type=int, default=256) + parser.add_argument('--gnn_batch_size', type=int, default=1024) + parser.add_argument( + '--external_pred_path', type=str, default=None, + help="Other model's output logits during the " + "pretraining phase or simply concatenate it with" + "node features as augmented data for gnn") + parser.add_argument('--alpha', type=float, default=0.5, + help='pseudo label weight in E-step') + parser.add_argument('--beta', type=float, default=0.5, + help='pseudo label weight in M-step') + parser.add_argument('--lm_epochs', type=int, default=10) + parser.add_argument('--gnn_epochs', type=int, default=50) + parser.add_argument('--gnn_lr', type=float, default=0.002) + parser.add_argument('--lm_lr', type=float, default=0.001) + parser.add_argument('--patience', type=int, default=3, + help='Patience for early stopping') + parser.add_argument('--verbose', action='store_true', + help='show progress bar during training or not') + parser.add_argument('--em_order', type=str, default='lm', + help='decide train LM first or GNN first') + parser.add_argument('--lm_use_lora', action='store_true', + help='use Lora to fine-tune model or not') + parser.add_argument( + '--token_on_disk', action='store_true', + help='save token on disk and load token from disk' + 'for reducing duplicated tokenizing') + parser.add_argument('--out_dir', type=str, default='output/', + help='output directory') + parser.add_argument( + '--train_without_ext_pred', action='store_true', + help='train glem without using additional pseudo labels ' + 'for augmenting data only available for ogbn-products') + args = parser.parse_args() + print(args) + main(args) \ No newline at end of file diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 96d51032d818..0b6569d3f92b 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -77,6 +77,7 @@ from .brca_tgca import BrcaTcga from .neurograph import NeuroGraphDataset from .web_qsp_dataset import WebQSPDataset +from .tag_dataset import TAGDataset from .dbp15k import DBP15K from .aminer import AMiner @@ -190,6 +191,7 @@ 'BrcaTcga', 'NeuroGraphDataset', 'WebQSPDataset', + 'TAGDataset', ] hetero_datasets = [ diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py new file mode 100644 index 000000000000..2171e0f364c6 --- /dev/null +++ b/torch_geometric/datasets/tag_dataset.py @@ -0,0 +1,344 @@ +import os +import os.path as osp +from collections.abc import Sequence +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +from torch import Tensor +from tqdm import tqdm + +from torch_geometric.data import InMemoryDataset, download_google_url +from torch_geometric.data.data import BaseData + +try: + from pandas import DataFrame, read_csv + WITH_PANDAS = True +except ImportError: + WITH_PANDAS = False + +IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +class TAGDataset(InMemoryDataset): + r"""The Text Attributed Graph datasets from the + `"Learning on Large-scale Text-attributed Graphs via Variational Inference + " `_ paper. + This dataset is aiming on transform `ogbn products`, `ogbn arxiv` + into Text Attributed Graph that each node in graph is associate with a + raw text, that dataset can be adapt to DataLoader (for LM training) and + NeighborLoader(for GNN training). In addition, this class can be use as a + wrapper class by convert a InMemoryDataset with Tokenizer and text into + Text Attributed Graph. + Args: + root (str): Root directory where the dataset should be saved. + dataset (InMemoryDataset): The name of the dataset + (:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`). + tokenizer_name (str): The tokenizer name for language model, + Be sure to use same tokenizer name as your `model id` of model repo + on huggingface.co. + text (List[str]): list of raw text associate with node, the order of + list should be align with node list + split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary, + for saving split index, it is required that if your dataset doesn't + have get_split_idx function + tokenize_batch_size (int): batch size of tokenizing text, the + tokenizing process will run on cpu, default: 256 + token_on_disk (bool): save token as .pt file on disk or not, + default: False + text_on_disk (bool): save given text(list of str) as dataframe on disk + or not, default: False + force_reload (bool): default: False + .. note:: + See `example/llm_plus_gnn/glem.py` for example usage + """ + raw_text_id = { + 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3', + 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt' + } + + def __init__(self, root: str, dataset: InMemoryDataset, + tokenizer_name: str, text: Optional[List[str]] = None, + split_idx: Optional[Dict[str, Tensor]] = None, + tokenize_batch_size: int = 256, token_on_disk: bool = False, + text_on_disk: bool = False, + force_reload: bool = False) -> None: + # list the vars you want to pass in before run download & process + self.name = dataset.name + self.text = text + self.tokenizer_name = tokenizer_name + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.dir_name = '_'.join(dataset.name.split('-')) + self.root = osp.join(root, self.dir_name) + missing_str_list = [] + if not WITH_PANDAS: + missing_str_list.append('pandas') + if len(missing_str_list) > 0: + missing_str = ' '.join(missing_str_list) + error_out = f"`pip install {missing_str}` to use this dataset." + raise ImportError(error_out) + if hasattr(dataset, 'get_idx_split'): + self.split_idx = dataset.get_idx_split() + elif split_idx is not None: + self.split_idx = split_idx + else: + raise ValueError("TAGDataset need split idx for generating " + "is_gold mask, please pass splited index " + "in format of dictionaty with 'train', 'valid' " + "'test' index tensor to 'split_idx'") + if text is not None and text_on_disk: + self.save_node_text(text) + self.text_on_disk = text_on_disk + # init will call download and process + super().__init__(self.root, transform=None, pre_transform=None, + pre_filter=None, force_reload=force_reload) + # after processing and download + # Dataset has to have BaseData as _data + assert dataset._data is not None + self._data = dataset._data # reassign reference + assert self._data is not None + assert dataset._data.y is not None + assert isinstance(self._data, BaseData) + assert self._data.num_nodes is not None + assert isinstance(dataset._data.num_nodes, int) + assert isinstance(self._data.num_nodes, int) + self._n_id = torch.arange(self._data.num_nodes) + is_good_tensor = self.load_gold_mask() + self._is_gold = is_good_tensor.squeeze() + self._data['is_gold'] = is_good_tensor + if self.text is not None and len(self.text) != self._data.num_nodes: + raise ValueError("The number of text sequence in 'text' should be " + "equal to number of nodes!") + self.token_on_disk = token_on_disk + self.tokenize_batch_size = tokenize_batch_size + self._token = self.tokenize_graph(self.tokenize_batch_size) + self.__num_classes__ = dataset.num_classes + + @property + def num_classes(self) -> int: + return self.__num_classes__ + + @property + def raw_file_names(self) -> List[str]: + file_names = [] + for root, _, files in os.walk(osp.join(self.root, 'raw')): + for file in files: + file_names.append(file) + return file_names + + @property + def processed_file_names(self) -> List[str]: + return [ + 'geometric_data_processed.pt', 'pre_filter.pt', + 'pre_transformed.pt' + ] + + @property + def token(self) -> Dict[str, Tensor]: + if self._token is None: # lazy load + self._token = self.tokenize_graph() + return self._token + + # load is_gold after init + @property + def is_gold(self) -> Tensor: + if self._is_gold is None: + print('lazy load is_gold!!') + self._is_gold = self.load_gold_mask() + return self._is_gold + + def get_n_id(self, node_idx: IndexType) -> Tensor: + if self._n_id is None: + assert self._data is not None + assert self._data.num_nodes is not None + assert isinstance(self._data.num_nodes, int) + self._n_id = torch.arange(self._data.num_nodes) + return self._n_id[node_idx] + + def load_gold_mask(self) -> Tensor: + r"""Use original train split as gold split, generating is_gold mask + for picking ground truth labels and pseudo labels. + """ + train_split_idx = self.get_idx_split()['train'] + assert self._data is not None + assert self._data.num_nodes is not None + assert isinstance(self._data.num_nodes, int) + is_good_tensor = torch.zeros(self._data.num_nodes, + dtype=torch.bool).view(-1, 1) + is_good_tensor[train_split_idx] = True + return is_good_tensor + + def get_gold(self, node_idx: IndexType) -> Tensor: + r"""Get gold mask for given node_idx. + Args: + node_idx (torch.tensor): a tensor contain node idx + """ + if self._is_gold is None: + self._is_gold = self.is_gold + return self._is_gold[node_idx] + + def get_idx_split(self) -> Dict[str, Tensor]: + return self.split_idx + + def download(self) -> None: + print('downloading raw text') + raw_text_path = download_google_url(id=self.raw_text_id[self.name], + folder=f'{self.root}/raw', + filename='node-text.csv.gz', + log=True) + text_df = read_csv(raw_text_path) + self.text = list(text_df['text']) + + def process(self) -> None: + if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')): + text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz')) + self.text = list(text_df['text']) + elif self.name in self.raw_text_id: + self.download() + else: + print('The dataset is not ogbn-products nor ogbn-arxiv,' + 'please pass in your raw text string list to `text`') + if self.text is None: + raise ValueError("The TAGDataset only have ogbn-products and " + "ogbn-arxiv raw text in default " + "The raw text of each node is not specified" + "Please pass in 'text' when convert your dataset " + "to Text Attribute Graph Dataset") + + def save_node_text(self, text: List[str]) -> None: + node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz') + if osp.exists(node_text_path): + print(f'The raw text is existed at {node_text_path}') + else: + print(f'Saving raw text file at {node_text_path}') + os.makedirs(f'{self.root}/raw', exist_ok=True) + text_df = DataFrame(text, columns=['text']) + text_df.to_csv(osp.join(node_text_path), compression='gzip', + index=False) + + def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: + r"""Tokenizing the text associate with each node, running in cpu. + Args: + batch_size (Optional[int]): batch size of list of text for + generating emebdding + Returns: + Dict[str, torch.Tensor]: tokenized graph + """ + data_len = 0 + if self.text is not None: + data_len = len(self.text) + else: + raise ValueError("The TAGDataset need text for tokenization") + token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] + path = os.path.join(self.processed_dir, 'token', self.tokenizer_name) + # Check if the .pt files already exist + token_files_exist = any( + os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys) + + if token_files_exist and self.token_on_disk: + print('Found tokenized file, loading may take several minutes...') + all_encoded_token = { + k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True) + for k in token_keys + if os.path.exists(os.path.join(path, f'{k}.pt')) + } + return all_encoded_token + + all_encoded_token = {k: [] for k in token_keys} + pbar = tqdm(total=data_len) + + pbar.set_description('Tokenizing Text Attributed Graph') + for i in range(0, data_len, batch_size): + end_index = min(data_len, i + batch_size) + token = self.tokenizer(self.text[i:min(i + batch_size, data_len)], + padding='max_length', truncation=True, + max_length=512, return_tensors="pt") + for k in token.keys(): + all_encoded_token[k].append(token[k]) + pbar.update(end_index - i) + pbar.close() + + all_encoded_token = { + k: torch.cat(v) + for k, v in all_encoded_token.items() if len(v) > 0 + } + if self.token_on_disk: + os.makedirs(path, exist_ok=True) + print('Saving tokens on Disk') + for k, tensor in all_encoded_token.items(): + torch.save(tensor, os.path.join(path, f'{k}.pt')) + print('Token saved:', os.path.join(path, f'{k}.pt')) + os.environ["TOKENIZERS_PARALLELISM"] = 'true' # supressing warning + return all_encoded_token + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + class TextDataset(torch.utils.data.Dataset): + r"""This nested dataset provides textual data for each node in + the graph. Factory method to create TextDataset from TAGDataset. + Args: + tag_dataset (TAGDataset): the parent dataset + """ + def __init__(self, tag_dataset: 'TAGDataset') -> None: + self.tag_dataset = tag_dataset + self.token = tag_dataset.token + assert tag_dataset._data is not None + self._data = tag_dataset._data + + assert tag_dataset._data.y is not None + self.labels = tag_dataset._data.y + + def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]: + r"""This function will be called in __getitem__(). + Args: + node_idx (IndexType): selected node idx in each batch + Returns: + items (Dict[str, Tensor]): input for LM + """ + items = {k: v[node_idx] for k, v in self.token.items()} + return items + + # for LM training + def __getitem__( + self, node_id: IndexType + ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: + r"""This function will override the function in + torch.utils.data.Dataset, and will be called when you + iterate batch in the dataloader, make sure all following + key value pairs are present in the return dict. + Args: + node_id (List[int]): list of node idx for selecting tokens, + labels etc. when iterating data loader for LM + Returns: + items (dict): input k,v pairs for Language model training and + inference + """ + item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {} + item['input'] = self.get_token(node_id) + item['labels'] = self.labels[node_id] + item['is_gold'] = self.tag_dataset.get_gold(node_id) + item['n_id'] = self.tag_dataset.get_n_id(node_id) + return item + + def __len__(self) -> int: + assert self._data.num_nodes is not None + return self._data.num_nodes + + def get(self, idx: int) -> BaseData: + return self._data + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + def to_text_dataset(self) -> TextDataset: + r"""Factory Build text dataset from Text Attributed Graph Dataset + each data point is node's associated text token. + """ + return TAGDataset.TextDataset(self) \ No newline at end of file diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 7cfadf0143b2..5860db311ac3 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -29,7 +29,7 @@ from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet from .g_retriever import GRetriever - +from .glem import GLEM # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) @@ -77,4 +77,5 @@ 'NeuralFingerprint', 'ViSNet', 'GRetriever', + 'GLEM', ] diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py new file mode 100644 index 000000000000..1ad69b32e58b --- /dev/null +++ b/torch_geometric/nn/models/glem.py @@ -0,0 +1,366 @@ +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from tqdm import tqdm + +from torch_geometric.loader import DataLoader, NeighborLoader +from torch_geometric.nn.models import GraphSAGE, basic_gnn + + +class GLEM(torch.nn.Module): + r"""This GNN+LM co-training model is based on GLEM from the `"Learning on + Large-scale Text-attributed Graphs via Variational Inference" + `_ paper. + Args: + lm_to_use (str): A TextEncoder from huggingface model repo + with a classifier(default: TinyBERT) + gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE) + out_channels (int): output channels for LM and GNN, should be same + num_gnn_heads Optional[int]: Number of heads for attention, if needed + num_gnn_layers (int): number of gnn layers + gnn_loss: loss function for gnn, (default: CrossEntropyLoss) + lm_loss: loss function for Language Model, (default: CrossEntropyLoss) + alpha (float): pseudo label weight of E-step, LM optimization, + (default: 0.5) + beta (float): pseudo label weight of M-step, GNN optimization, + (default: 0.5) + lm_dtype (torch.dtype): the data type once you load LM into memory, + (default: torch.bfloat16) + lm_use_lora (bool): choose if LM use Lora peft for fine tune, + (default: True) + lora_target_modules: The names of the target modules to apply the lora + adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None) + .. note:: + See `examples/llm_plus_gnn/glem.py` for example usage. + """ + def __init__( + self, + lm_to_use: str = 'prajjwal1/bert-tiny', + gnn_to_use: basic_gnn = GraphSAGE, + out_channels: int = 47, + gnn_loss=nn.CrossEntropyLoss(reduction='mean'), + lm_loss=nn.CrossEntropyLoss(reduction='mean'), + alpha: float = 0.5, + beta: float = 0.5, + lm_dtype: torch.dtype = torch.bfloat16, + lm_use_lora: bool = True, + lora_target_modules: Optional[Union[List[str], str]] = None, + device: Union[str, torch.device] = torch.device('cpu'), + ): + super().__init__() + self.device = device + self.lm_loss = lm_loss + self.gnn = gnn_to_use + self.gnn_loss = gnn_loss + self.alpha = alpha + self.beta = beta + self.gnn_loss = gnn_loss + self.lm = lm_to_use + from transformers import AutoModelForSequenceClassification + self.lm = AutoModelForSequenceClassification.from_pretrained( + lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype, + offload_folder="offload", trust_remote_code=True) + if lm_use_lora: + from peft import ( + LoraConfig, + TaskType, + get_peft_model, + prepare_model_for_kbit_training, + ) + print("Training LM with LORA!") + self.lm = prepare_model_for_kbit_training(self.lm) + config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16, + lora_alpha=16, lora_dropout=0.05, bias="none", + target_modules=lora_target_modules) + self.lm = get_peft_model(self.lm, config) + self.lm.print_trainable_parameters() + self.lm.config.pad_token_id = self.lm.config.eos_token_id + self.lm_device = self.lm.device + + if self.lm.num_labels != self.gnn.out_channels: + raise ValueError('''The output channel of language model \ + and gnn should be the same''') + + def pre_train_gnn(self, train_loader: NeighborLoader, + optimizer: torch.optim.Optimizer, num_epochs: int, + patience: int, ext_pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + # Pretrain GNN, optional steps if you do not have pseudo labels. + best_acc = 0 + early_stopping = 0 + # training only based on gold data + for epoch in range(0, num_epochs): + acc, loss = self.train_gnn(train_loader, optimizer, epoch, + ext_pseudo_labels, is_augmented, + verbose) + if acc < best_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Early stopped by Epoch: {epoch}, ' + f'Best acc: {best_acc}') + break + best_acc = max(best_acc, acc) + + def pre_train_lm(self, train_loader: DataLoader, + optimizer: torch.optim.Optimizer, num_epochs: int, + patience: int, ext_pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + # Pretrain language model + best_acc = 0 + early_stopping = 0 + for epoch in range(1, num_epochs + 1): + acc, loss = self.train_lm(train_loader, optimizer, epoch, + ext_pseudo_labels, is_augmented, verbose) + if acc < best_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Early stopped by Epoch: {epoch}, ' + f'Best acc: {best_acc}') + break + best_acc = max(best_acc, acc) + + def train(self, em_phase: str, train_loader: Union[DataLoader, + NeighborLoader], + optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor, + epoch: int, is_augmented: bool = False, verbose: bool = False): + r"""GLEM training step, EM steps. + Args: + em_phase(str): 'gnn' or 'lm' choose which phase you are training on + train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for + lm training, include tokenized data, labels is_gold mask. + use NeighborLoader for gnn training, include x, edge_index. + optimizer (torch.optim.Optimizer): optimizer for training + pseudo_labels(torch.Tensor): the predicted labels used as pseudo + labels + epoch (int): current epoch + is_augmented (bool): will use pseudo_labels or not + verbose (bool): print training progress bar or not + Returns: + acc (float): training accuracy + loss (float): loss value + """ + pseudo_labels = pseudo_labels.to(self.device) + if em_phase == 'gnn': + acc, loss = self.train_gnn(train_loader, optimizer, epoch, + pseudo_labels, is_augmented, verbose) + if em_phase == 'lm': + acc, loss = self.train_lm(train_loader, optimizer, epoch, + pseudo_labels, is_augmented, verbose) + return acc, loss + + def train_lm(self, train_loader: DataLoader, + optimizer: torch.optim.Optimizer, epoch: int, + pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + r"""Language model Training in every epoch. + Args: + train_loader (loader.dataloader.DataLoader): text token dataloader + optimizer (torch.optim.Optimizer): model optimizer + epoch (int): current train epoch + pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn + is_augmented (bool): train with pseudo labels or not + verbose (bool): print training progress bar or not + Returns: + approx_acc (torch.tensor): training accuracy + loss (torch.float): loss value + """ + all_out = [] + total_loss = total_correct = 0 + num_nodes = train_loader.dataset.indices.size(0) + self.lm.train() + if verbose: + pbar = tqdm(total=num_nodes) + pbar.set_description(f'Epoch {epoch:02d}') + for batch in train_loader: + inputs = {k: v.to(self.device) for k, v in batch['input'].items()} + out = self.lm(**inputs).logits + labels = batch['labels'].to(self.device).squeeze() + # training with pseudo labels or not + if is_augmented: + pl_batch = pseudo_labels[batch['n_id']].to(self.device) + else: + pl_batch = None + loss = self.loss(out, labels, self.lm_loss, + batch['is_gold'].to(self.device), pl_batch, + self.alpha, is_augmented) + loss.backward() + optimizer.step() + optimizer.zero_grad() + all_out.append(out) + total_correct += int(out.argmax(dim=-1).eq(labels).sum()) + total_loss += float(loss) + if verbose: + pbar.update(batch['n_id'].size(0)) + + all_out = torch.cat(all_out, dim=0) + approx_acc = total_correct / num_nodes + loss = total_loss / len(train_loader) + if verbose: + pbar.close() + print(f'Epoch {epoch:02d} Loss: {loss:.4f} ' + f'Approx. Train: {approx_acc:.4f}') + return approx_acc, loss + + def train_gnn(self, train_loader: NeighborLoader, + optimizer: torch.optim.Optimizer, epoch: int, + pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + r"""GNN training step in every epoch. + Args: + train_loader (loader.NeighborLoader): gnn Neighbor node loader + optimizer (torch.optim.Optimizer): model optimizer + epoch (int): current train epoch + pseudo_labels(torch.tensor): 1-D tensor, predictions from lm + is_augmented(bool): use pseudo labeled node or not + verbose (bool): print training progress or not + Returns: + approx_acc (torch.tensor): training accuracy + loss (torch.float): loss value + """ + self.gnn.train() + num_nodes = train_loader.input_nodes.size(0) + if verbose: + pbar = tqdm(total=num_nodes) + pbar.set_description(f'Epoch {epoch:02d}') + total_loss = total_correct = 0 + all_out = [] + for batch in train_loader: + batch = batch.to(self.device) + out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] + all_out.append(out) + labels = batch.y[:batch.batch_size].squeeze() + is_gold_batch = batch.is_gold[:batch.batch_size].squeeze() + # training with pseudo labels or not + if is_augmented and pseudo_labels is not None: + pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]] + else: + pl_batch = None + loss = self.loss(out, labels, self.gnn_loss, is_gold_batch, + pl_batch, self.beta, is_augmented) + loss.backward() + optimizer.step() + optimizer.zero_grad() + total_loss += float(loss) + total_correct += int(out.argmax(dim=-1).eq(labels).sum()) + if verbose: + pbar.update(batch.batch_size) + + all_out = torch.cat(all_out, dim=0) + loss = total_loss / len(train_loader) + approx_acc = total_correct / num_nodes + if verbose: + pbar.close() + print(f'Epoch: {epoch:02d} Loss: {loss:.4f} ' + f'Approx. Train: {approx_acc:.4f}') + return approx_acc, loss + + @torch.no_grad() + def inference(self, em_phase: str, data_loader: Union[NeighborLoader, + DataLoader], + verbose: bool = False): + r"""GLEM inference step. + Args: + em_phase(str): 'gnn' or 'lm' + data_loader(dataloader or Neighborloader): + dataloader: for lm training, include tokenized data + nodeloader: for gnn training, include x, edge_index + verbose(bool): print inference progress or not + Returns: + out (torch.Tensor): n * m tensor, m is number of classes, + n is number of nodes + """ + out = None + if em_phase == 'gnn': + self.gnn.eval() + out = self.inference_gnn(data_loader, verbose) + elif em_phase == 'lm': + self.lm.eval() + out = self.inference_lm(data_loader, verbose) + return out + + @torch.no_grad() + def inference_lm(self, data_loader: DataLoader, verbose: bool = True): + r"""LM inference step. + Args: + data_loader (Dataloader): include token, labels, and gold mask + verbose (bool): print progress bar or not + Returns: + preds (tensor): prediction from GNN, convert to pseudo labels + by preds.argmax(dim=-1).unsqueeze(1) + """ + if verbose: + pbar = tqdm(total=data_loader.dataset._data.num_nodes) + pbar.set_description('LM inference stage') + self.lm.eval() + preds = [] + for batch in data_loader: + inputs = {k: v.to(self.device) for k, v in batch['input'].items()} + logits = self.lm(**inputs).logits + preds.append(logits) + if verbose: + pbar.update(batch['n_id'].size(0)) + if verbose: + pbar.close() + preds = torch.cat(preds) + return preds + + @torch.no_grad() + def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True): + r"""GNN inference step. + Args: + data_loader(NeighborLoader): include x, edge_index, + verbose (bool): print progress bar or not + Returns: + preds (tensor): prediction from GNN, + convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) + """ + if verbose: + pbar = tqdm(total=data_loader.data.num_nodes) + pbar.set_description('GNN inference stage') + preds = [] + self.gnn.eval() + for batch in data_loader: + batch = batch.to(self.device) + out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] + preds.append(out) + if verbose: + pbar.update(batch.batch_size) + if verbose: + pbar.close() + preds = torch.cat(preds, dim=0) + return preds + + def loss(self, logits: torch.Tensor, labels: torch.Tensor, + loss_func: torch.nn.functional, is_gold: torch.Tensor, + pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5, + is_augmented: bool = True): + r"""Core function of variational EM inference, this function is aming + on combining loss value on gold(original train) and loss value on + pseudo labels. + Reference: + # noqa + Args: + logits(torch.tensor): predict results from LM or GNN + labels(torch.tensor): combined node labels from ground truth and + pseudo labels(if provided) + loss_func(torch.nn.modules.loss): loss function for classification + is_gold(tensor): a tensor with bool value that mask ground truth + label and during training, thus ~is_gold mask pseudo labels + pseudo_labels(torch.tensor): predictions from other model + pl_weight: the pseudo labels used in E-step and M-step optimization + alpha in E-step, beta in M-step respectively + is_augmented: use EM or just train GNN and LM with gold data + """ + def deal_nan(x): + return 0 if torch.isnan(x) else x + + if is_augmented and (sum(~is_gold) > 0): + mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold])) + # all other labels beside from ground truth(gold labels) + pseudo_label_loss = deal_nan( + loss_func(logits[~is_gold], pseudo_labels[~is_gold])) + loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss + else: + loss = loss_func(logits, labels) + return loss \ No newline at end of file From 430c4fd605ff8151061904889c4aa9cd91deab16 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Sep 2024 21:28:30 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 ++ examples/llm/README.md | 10 +++++----- examples/llm/glem.py | 2 +- torch_geometric/datasets/tag_dataset.py | 8 +++++++- torch_geometric/nn/models/glem.py | 9 ++++++++- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57b910d89e72..3e1be94a8fe6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## \[2.7.0\] - 2024-MM-DD ### Added + - Added `nn.Models.GLEM` ([#9661](https://github.com/pyg-team/pytorch_geometric/pull/9661)) - Added `TAGDataset` ([#9661](https://github.com/pyg-team/pytorch_geometric/pull/9661)) + ### Changed ### Deprecated diff --git a/examples/llm/README.md b/examples/llm/README.md index 2f28dea9c66a..1f7a5cb09100 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,10 +1,10 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | -| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | +| Example | Description | +| ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | ## Run GLEM for getting SOTA result on ogbn-products dataset -`python glem.py` \ No newline at end of file +`python glem.py` diff --git a/examples/llm/glem.py b/examples/llm/glem.py index 46daeeddf5f4..3d21bef4c03c 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -876,4 +876,4 @@ def load_model(em_phase): 'for augmenting data only available for ogbn-products') args = parser.parse_args() print(args) - main(args) \ No newline at end of file + main(args) diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py index 2171e0f364c6..f25992ced989 100644 --- a/torch_geometric/datasets/tag_dataset.py +++ b/torch_geometric/datasets/tag_dataset.py @@ -30,6 +30,7 @@ class TAGDataset(InMemoryDataset): NeighborLoader(for GNN training). In addition, this class can be use as a wrapper class by convert a InMemoryDataset with Tokenizer and text into Text Attributed Graph. + Args: root (str): Root directory where the dataset should be saved. dataset (InMemoryDataset): The name of the dataset @@ -176,6 +177,7 @@ def load_gold_mask(self) -> Tensor: def get_gold(self, node_idx: IndexType) -> Tensor: r"""Get gold mask for given node_idx. + Args: node_idx (torch.tensor): a tensor contain node idx """ @@ -224,6 +226,7 @@ def save_node_text(self, text: List[str]) -> None: def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: r"""Tokenizing the text associate with each node, running in cpu. + Args: batch_size (Optional[int]): batch size of list of text for generating emebdding @@ -283,6 +286,7 @@ def __repr__(self) -> str: class TextDataset(torch.utils.data.Dataset): r"""This nested dataset provides textual data for each node in the graph. Factory method to create TextDataset from TAGDataset. + Args: tag_dataset (TAGDataset): the parent dataset """ @@ -297,6 +301,7 @@ def __init__(self, tag_dataset: 'TAGDataset') -> None: def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]: r"""This function will be called in __getitem__(). + Args: node_idx (IndexType): selected node idx in each batch Returns: @@ -313,6 +318,7 @@ def __getitem__( torch.utils.data.Dataset, and will be called when you iterate batch in the dataloader, make sure all following key value pairs are present in the return dict. + Args: node_id (List[int]): list of node idx for selecting tokens, labels etc. when iterating data loader for LM @@ -341,4 +347,4 @@ def to_text_dataset(self) -> TextDataset: r"""Factory Build text dataset from Text Attributed Graph Dataset each data point is node's associated text token. """ - return TAGDataset.TextDataset(self) \ No newline at end of file + return TAGDataset.TextDataset(self) diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py index 1ad69b32e58b..b8068cfb032f 100644 --- a/torch_geometric/nn/models/glem.py +++ b/torch_geometric/nn/models/glem.py @@ -12,6 +12,7 @@ class GLEM(torch.nn.Module): r"""This GNN+LM co-training model is based on GLEM from the `"Learning on Large-scale Text-attributed Graphs via Variational Inference" `_ paper. + Args: lm_to_use (str): A TextEncoder from huggingface model repo with a classifier(default: TinyBERT) @@ -125,6 +126,7 @@ def train(self, em_phase: str, train_loader: Union[DataLoader, optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor, epoch: int, is_augmented: bool = False, verbose: bool = False): r"""GLEM training step, EM steps. + Args: em_phase(str): 'gnn' or 'lm' choose which phase you are training on train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for @@ -154,6 +156,7 @@ def train_lm(self, train_loader: DataLoader, pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): r"""Language model Training in every epoch. + Args: train_loader (loader.dataloader.DataLoader): text token dataloader optimizer (torch.optim.Optimizer): model optimizer @@ -207,6 +210,7 @@ def train_gnn(self, train_loader: NeighborLoader, pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): r"""GNN training step in every epoch. + Args: train_loader (loader.NeighborLoader): gnn Neighbor node loader optimizer (torch.optim.Optimizer): model optimizer @@ -260,6 +264,7 @@ def inference(self, em_phase: str, data_loader: Union[NeighborLoader, DataLoader], verbose: bool = False): r"""GLEM inference step. + Args: em_phase(str): 'gnn' or 'lm' data_loader(dataloader or Neighborloader): @@ -282,6 +287,7 @@ def inference(self, em_phase: str, data_loader: Union[NeighborLoader, @torch.no_grad() def inference_lm(self, data_loader: DataLoader, verbose: bool = True): r"""LM inference step. + Args: data_loader (Dataloader): include token, labels, and gold mask verbose (bool): print progress bar or not @@ -308,6 +314,7 @@ def inference_lm(self, data_loader: DataLoader, verbose: bool = True): @torch.no_grad() def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True): r"""GNN inference step. + Args: data_loader(NeighborLoader): include x, edge_index, verbose (bool): print progress bar or not @@ -363,4 +370,4 @@ def deal_nan(x): loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss else: loss = loss_func(logits, labels) - return loss \ No newline at end of file + return loss From 78d9781cef04a5f92ce8ce727aaa9efbc1c1fe28 Mon Sep 17 00:00:00 2001 From: Junhao Shen Date: Sun, 15 Sep 2024 17:00:51 -0500 Subject: [PATCH 3/4] fix docstring unexpected intentation --- examples/llm/glem.py | 441 +----------------------------- torch_geometric/nn/models/glem.py | 13 +- 2 files changed, 13 insertions(+), 441 deletions(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index 3d21bef4c03c..e01985744f6c 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -364,446 +364,7 @@ def load_model(em_phase): best_val_acc = val_acc final_test_acc = test_acc preds = cur_preds -"""This example run GLEM model using PyG. -Original Paper: https://arxiv.org/abs/2210.14709 -“Learning on Large-scale Text-attributed Graphs via Variational Inference“. -Requirements on top of basic PyG: -`pip install ogb transformers peft tqdm`. -GLEM is a data augmentation co-training strategy for LM and GNN, our -implementation extended original implementation from LM to LLM and opt for LoRA -from peft. - -``note:: - use addtional trick, please add your external prediction by assigning - `ext_pred_path` and combine it into pretraining phase and node features -""" - -import argparse -import os -import os.path as osp -import sys -import time - -import torch - -from torch_geometric import seed_everything -from torch_geometric.data import download_google_url -from torch_geometric.datasets import TAGDataset -from torch_geometric.nn.models import GAT, GCN, GLEM, GraphSAGE - -# Add the parent directory to sys.path -parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(parent_dir) - - -def get_n_params(model): - pp = 0 - for p in list(model.parameters()): - nn = 1 - for s in list(p.size()): - nn = nn * s - pp += nn - return pp - - -def main(args): - gpu = args.gpu - dataset_name = args.dataset - root = osp.join(parent_dir, 'data', 'ogb') - hf_model = args.hf_model - pl_ratio = args.pl_ratio - gnn_lr = args.gnn_lr - lm_lr = args.lm_lr - em_order = args.em_order - gnn_epochs = args.gnn_epochs - lm_epochs = args.lm_epochs - patience = args.patience - verbose = args.verbose - out_dir = args.out_dir - lm_batch_size = args.lm_batch_size - gnn_batch_size = args.gnn_batch_size - lm_use_lora = args.lm_use_lora - token_on_disk = args.token_on_disk - num_em_iters = args.num_em_iters - start_time = time.time() - train_without_ext_pred = args.train_without_ext_pred - ext_pred = None - pretrain_augmented = False - ext_pseudo_labels = None - device = torch.device( - f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu') - print(f'Running on: {torch.cuda.get_device_name({gpu})}') - torch.cuda.empty_cache() - - if not train_without_ext_pred: - ext_pred_path = download_google_url( - id='15sO2m7BeW7C1Upmdw3Cx1JS__6nxTAzY', - folder='/work/users/junhaos/glem_data/ogbn_products/ext_preds', - filename='giant_sagn_scr.pt', log=True) - ext_pred = torch.load(ext_pred_path, map_location=device) - ext_pseudo_labels = ext_pred.argmax(dim=-1) - pretrain_augmented = True - - seed_everything(42) - from ogb.nodeproppred import PygNodePropPredDataset - dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) - split_idx = dataset.get_idx_split() - data = dataset.data - - tag_dataset = TAGDataset(root, dataset, hf_model, - token_on_disk=token_on_disk) - text_dataset = tag_dataset.to_text_dataset() - print(tag_dataset.num_classes, tag_dataset.raw_file_names) - - num_classes = tag_dataset.num_classes - num_features = data.num_features - # =========================== LM Data split =============================== - split_idx = tag_dataset.get_idx_split() - - # GLEM train with augmented data, mark original train data as gold data, - gold_idx = split_idx['train'] - test_idx = split_idx['test'] - - # randome sample pseudo labels nodes, generate their index - num_pseudo_labels = int(gold_idx.numel() * pl_ratio) - idx_to_select = torch.randperm(test_idx.numel())[:num_pseudo_labels] - pseudo_labels_idx = test_idx[idx_to_select] - train_idx = torch.cat( - (gold_idx, pseudo_labels_idx)) # augmented train_indx - - print(f'train_idx: {train_idx.size(0)}, ' - f'gold_idx: {gold_idx.size(0)}, ' - f'pseudo labels ratio: {pl_ratio}, ' - f'{train_idx.size(0)/gold_idx.size(0) - 1.0}') - gold_dataset = torch.utils.data.Subset(dataset=text_dataset, - indices=gold_idx) - train_dataset = torch.utils.data.Subset(dataset=text_dataset, - indices=train_idx) - # ========================== LM Data Loader =============================== - - print('Building language model dataloader...', end='-->') - from torch_geometric.loader import DataLoader - - # if set train_without_ext_pred == True, use this for pretrain - text_pretrain_loader = DataLoader(gold_dataset, batch_size=lm_batch_size, - drop_last=False, pin_memory=True, - shuffle=True) - # training with augmented data, - text_train_loader = DataLoader(train_dataset, batch_size=lm_batch_size, - drop_last=False, pin_memory=True, - shuffle=True) - text_data_loader = DataLoader(text_dataset, batch_size=lm_batch_size * 4, - drop_last=False, pin_memory=True, - shuffle=False) - print('done') - - # =========================== GNN Data Loader ============================= - initial_memory = torch.cuda.memory_allocated() - data = data.to(device) - if ext_pred is not None: - data.x = torch.cat((data.x, ext_pred), dim=1) - num_features += ext_pred.size(1) - current_memory_1 = torch.cuda.max_memory_allocated() - # 1 GB = 1073741824 Byte - gpu_usage = float(current_memory_1 - initial_memory) / 1073741824 - # Print the maximum memory usage after running the model - print(f'GPU memory usage -- data to gpu: {gpu_usage:.2f} GB') - - print('build GNN dataloader(GraphSAGE NeighborLoader)', end='-->') - from torch_geometric.loader import NeighborLoader - - # train on gold data w/o pseudo labels - graph_pretrain_loader = NeighborLoader( - data, - input_nodes=gold_idx, - num_neighbors=[15, 10, 5], - batch_size=gnn_batch_size, - shuffle=True, - num_workers=12, - persistent_workers=True, - ) - - # graph data loader w/ pseudo labels in M-step - graph_train_loader = NeighborLoader( - data, - input_nodes=train_idx, - num_neighbors=[15, 10, 5], - batch_size=gnn_batch_size, - shuffle=True, - num_workers=12, - persistent_workers=True, - ) - - # for gnn inference - subgraph_loader = NeighborLoader( - data, - input_nodes=None, - num_neighbors=[-1], - batch_size=gnn_batch_size * 4, - num_workers=12, - persistent_workers=True, - ) - # =========================== internal function =========================== - - from ogb.nodeproppred import Evaluator - evaluator = Evaluator(name=f'ogbn-{dataset_name}') - - def evaluate(out, split): - y_true = data.y.cpu() - y_pred = out.argmax(dim=-1, keepdim=True) - train_acc, val_acc, test_acc = None, None, None - if 'train' in split: - train_acc = evaluator.eval({ - 'y_true': y_true[split_idx['train']], - 'y_pred': y_pred[split_idx['train']], - })['acc'] - if 'valid' in split: - val_acc = evaluator.eval({ - 'y_true': y_true[split_idx['valid']], - 'y_pred': y_pred[split_idx['valid']], - })['acc'] - if 'test' in split: - test_acc = evaluator.eval({ - 'y_true': y_true[split_idx['test']], - 'y_pred': y_pred[split_idx['test']], - })['acc'] - - return train_acc, val_acc, test_acc - - # =========================== Build GNN Model ============================= - gnn = None - if args.gnn_model == 'SAGE': - gnn = GraphSAGE( - in_channels=num_features, - hidden_channels=args.gnn_hidden_channels, - num_layers=args.gnn_num_layers, - out_channels=dataset.num_classes, - ) - elif args.gnn_model == 'GAT': - gnn = GAT(in_channels=num_features, - hidden_channels=args.gnn_hidden_channels, - num_layers=args.gnn_num_layers, - out_channels=dataset.num_classes, heads=args.gat_heads) - else: - gnn = GCN( - in_channels=num_features, - hidden_channels=args.gnn_hidden_channels, - num_layers=args.gnn_num_layers, - out_channels=dataset.num_classes, - ) - - print("# GNN Params:", get_n_params(gnn)) - # =========================== Build LM Model ============================== - - model = GLEM(lm_to_use=hf_model, gnn_to_use=gnn, out_channels=num_classes, - lm_use_lora=lm_use_lora, device=device) - lm = model.lm - print("# LM Params:", get_n_params(lm)) - gnn_opt = torch.optim.Adam(gnn.parameters(), lr=gnn_lr) - lm_opt = torch.optim.Adam(lm.parameters(), lr=lm_lr) - - def load_model(em_phase): - print(f'Move {em_phase} model from cpu memory') - if em_phase == 'lm': - model.lm = model.lm.to(device, non_blocking=True) - optimizer = torch.optim.Adam(model.lm.parameters(), lr=lm_lr) - if em_phase == 'gnn': - model.gnn = model.gnn.to(device, non_blocking=True) - optimizer = torch.optim.Adam(model.gnn.parameters(), lr=gnn_lr) - return optimizer - - # ================================= Run GLEM ============================== - preds_filename = 'lm_pretrain' - preds_dir = f'{out_dir}preds/{dataset_name}/' - gnn_test_acc = 0.0 - lm_test_acc = 0.0 - # =============================== GLEM pretraining ======================== - pretrain_phase = 'lm' - if em_order == 'lm': - pretrain_phase = 'gnn' - pretrain_start_time = time.time() - # pretraining - pretrain_loader = graph_pretrain_loader - test_loader = subgraph_loader - pretrain_num_epochs = gnn_epochs - pretrain_opt = gnn_opt - if pretrain_phase == 'gnn': - model.gnn = model.gnn.to(device) - print('pretraining gnn to generate pseudo labels') - if not train_without_ext_pred: - pretrain_loader = graph_train_loader - preds_filename = 'gnn_pretrain' - elif pretrain_phase == 'lm': - model.lm = model.lm.to(device) - print('pretraining lm to generate pseudo labels') - pretrain_num_epochs = lm_epochs - pretrain_loader = text_pretrain_loader - test_loader = text_data_loader - pretrain_opt = lm_opt - if not train_without_ext_pred: - pretrain_loader = text_train_loader - preds_filename = 'lm_pretrain' - - early_stopping = 0 - best_val_acc = final_test_acc = 0.0 - for epoch in range(1, pretrain_num_epochs + 1): - acc, loss = model.train(pretrain_phase, pretrain_loader, pretrain_opt, - ext_pseudo_labels, epoch, pretrain_augmented, - verbose) - if epoch >= 5 or epoch == pretrain_num_epochs: - pretrain_preds = model.inference(pretrain_phase, test_loader, - verbose=verbose) - train_acc, val_acc, test_acc = evaluate(pretrain_preds, - ['train', 'valid', 'test']) - print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' - f'Test: {test_acc:.4f}') - - if val_acc <= best_val_acc: - early_stopping += 1 - if early_stopping > patience: - print(f'Pretrain Early stopped by Epoch: {epoch}') - break - else: - best_val_acc = val_acc - final_test_acc = test_acc - preds = pretrain_preds - - if pretrain_phase == 'gnn': - gnn_test_acc = max(gnn_test_acc, final_test_acc) - model.gnn = model.gnn.to('cpu', non_blocking=True) - else: - lm_test_acc = max(lm_test_acc, final_test_acc) - model.lm = model.lm.to('cpu', non_blocking=True) - torch.cuda.empty_cache() - - pretrain_phase_time = time.time() - pretrain_start_time - print(f'Pretrain {pretrain_phase} time: {pretrain_phase_time:.2f}s') - os.makedirs(osp.dirname(preds_dir), exist_ok=True) - torch.save(preds, osp.join(preds_dir, f'{preds_filename}.pt')) - print( - f'Saved predictions to {osp.join(preds_dir, f"{preds_filename}.pt")}') - train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) - print(f'Pretraining acc: {train_acc:.4f}, Val: {val_acc:.4f}, ' - f'Test: {test_acc:.4f}') - - # EM iterations - - em_phase = em_order - """ - We run E-step(LM training) and M-Step(GNN training) alternatively in each - em iterations, so the total number of iterations is num_em_iter * 2 and - we switch the em_phase at end of each iteration in following loop - """ - for em_it in range(1, num_em_iters * 2 + 1): - pseudo_labels = preds.argmax(dim=-1) - best_val_acc = final_test_acc = 0.0 - print(f'EM iteration: {em_it}, EM phase: {em_phase}') - optimizer = load_model(em_phase) - num_epochs = lm_epochs - train_loader = text_train_loader - test_loader = text_data_loader - early_stopping = 0 - if em_phase == 'gnn': - train_loader = graph_train_loader - num_epochs = gnn_epochs - test_loader = subgraph_loader - for epoch in range(1, num_epochs + 1): - acc, loss = model.train(em_phase, train_loader, optimizer, - pseudo_labels, epoch, True, verbose) - if epoch >= 5 or epoch == num_epochs: - cur_preds = model.inference(em_phase, test_loader, - verbose=verbose) - train_acc, val_acc, test_acc = evaluate( - cur_preds, ['train', 'valid', 'test']) - - print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' - f'Test: {test_acc:.4f}') - - if val_acc <= best_val_acc: - early_stopping += 1 - if early_stopping > patience: - print(f'''Early stopped by Epoch: {epoch}, \ - Best acc: {final_test_acc}''') - break - else: - best_val_acc = val_acc - final_test_acc = test_acc - preds = cur_preds - - if em_phase == 'gnn': - gnn_test_acc = max(gnn_test_acc, final_test_acc) - model.gnn = model.gnn.to('cpu', non_blocking=True) - em_phase = 'lm' - else: - lm_test_acc = max(lm_test_acc, final_test_acc) - model.lm = model.lm.to('cpu', non_blocking=True) - em_phase = 'gnn' - torch.cuda.empty_cache() - print(f'Best GNN acc: {gnn_test_acc}, LM acc: {lm_test_acc}') - print('============================') - end_time = time.time() - running_time = (end_time - start_time) / 3600 - print(f'Total running time: {running_time:.2f} hours') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='GLEM Example:') - parser.add_argument('--gpu', type=int, default=0) - parser.add_argument('--num_runs', type=int, default=10, - help='number of runs') - parser.add_argument('--num_em_iters', type=int, default=1, - help='number of iterations') - parser.add_argument("--dataset", type=str, default='products', - help='arxiv or products') - parser.add_argument("--pl_ratio", type=float, default=0.5, - help="pseudo labels ratio") - parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', - help='huggingface model repo id') - parser.add_argument( - '--gnn_model', type=str, default='SAGE', - help='gnn model for node classification,' - 'options: SAGE, GAT, GCN') - parser.add_argument('--gnn_hidden_channels', type=int, default=256) - parser.add_argument('--gnn_num_layers', type=int, default=3) - parser.add_argument('--gat_heads', type=int, default=4, - help='Number of multi-head-attentions for GAT ') - parser.add_argument('--lm_batch_size', type=int, default=256) - parser.add_argument('--gnn_batch_size', type=int, default=1024) - parser.add_argument( - '--external_pred_path', type=str, default=None, - help="Other model's output logits during the " - "pretraining phase or simply concatenate it with" - "node features as augmented data for gnn") - parser.add_argument('--alpha', type=float, default=0.5, - help='pseudo label weight in E-step') - parser.add_argument('--beta', type=float, default=0.5, - help='pseudo label weight in M-step') - parser.add_argument('--lm_epochs', type=int, default=10) - parser.add_argument('--gnn_epochs', type=int, default=50) - parser.add_argument('--gnn_lr', type=float, default=0.002) - parser.add_argument('--lm_lr', type=float, default=0.001) - parser.add_argument('--patience', type=int, default=3, - help='Patience for early stopping') - parser.add_argument('--verbose', action='store_true', - help='show progress bar during training or not') - parser.add_argument('--em_order', type=str, default='lm', - help='decide train LM first or GNN first') - parser.add_argument('--lm_use_lora', action='store_true', - help='use Lora to fine-tune model or not') - parser.add_argument( - '--token_on_disk', action='store_true', - help='save token on disk and load token from disk' - 'for reducing duplicated tokenizing') - parser.add_argument('--out_dir', type=str, default='output/', - help='output directory') - parser.add_argument( - '--train_without_ext_pred', action='store_true', - help='train glem without using additional pseudo labels ' - 'for augmenting data only available for ogbn-products') - args = parser.parse_args() - print(args) - main(args) if em_phase == 'gnn': gnn_test_acc = max(gnn_test_acc, final_test_acc) model.gnn = model.gnn.to('cpu', non_blocking=True) @@ -876,4 +437,4 @@ def load_model(em_phase): 'for augmenting data only available for ogbn-products') args = parser.parse_args() print(args) - main(args) + main(args) \ No newline at end of file diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py index b8068cfb032f..f3e1cb6ca404 100644 --- a/torch_geometric/nn/models/glem.py +++ b/torch_geometric/nn/models/glem.py @@ -32,6 +32,7 @@ class GLEM(torch.nn.Module): (default: True) lora_target_modules: The names of the target modules to apply the lora adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None) + .. note:: See `examples/llm_plus_gnn/glem.py` for example usage. """ @@ -138,6 +139,7 @@ def train(self, em_phase: str, train_loader: Union[DataLoader, epoch (int): current epoch is_augmented (bool): will use pseudo_labels or not verbose (bool): print training progress bar or not + Returns: acc (float): training accuracy loss (float): loss value @@ -164,9 +166,11 @@ def train_lm(self, train_loader: DataLoader, pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn is_augmented (bool): train with pseudo labels or not verbose (bool): print training progress bar or not + Returns: approx_acc (torch.tensor): training accuracy loss (torch.float): loss value + """ all_out = [] total_loss = total_correct = 0 @@ -218,6 +222,7 @@ def train_gnn(self, train_loader: NeighborLoader, pseudo_labels(torch.tensor): 1-D tensor, predictions from lm is_augmented(bool): use pseudo labeled node or not verbose (bool): print training progress or not + Returns: approx_acc (torch.tensor): training accuracy loss (torch.float): loss value @@ -271,6 +276,7 @@ def inference(self, em_phase: str, data_loader: Union[NeighborLoader, dataloader: for lm training, include tokenized data nodeloader: for gnn training, include x, edge_index verbose(bool): print inference progress or not + Returns: out (torch.Tensor): n * m tensor, m is number of classes, n is number of nodes @@ -291,6 +297,7 @@ def inference_lm(self, data_loader: DataLoader, verbose: bool = True): Args: data_loader (Dataloader): include token, labels, and gold mask verbose (bool): print progress bar or not + Returns: preds (tensor): prediction from GNN, convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) @@ -318,6 +325,7 @@ def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True): Args: data_loader(NeighborLoader): include x, edge_index, verbose (bool): print progress bar or not + Returns: preds (tensor): prediction from GNN, convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) @@ -345,8 +353,10 @@ def loss(self, logits: torch.Tensor, labels: torch.Tensor, r"""Core function of variational EM inference, this function is aming on combining loss value on gold(original train) and loss value on pseudo labels. + Reference: # noqa + Args: logits(torch.tensor): predict results from LM or GNN labels(torch.tensor): combined node labels from ground truth and @@ -358,6 +368,7 @@ def loss(self, logits: torch.Tensor, labels: torch.Tensor, pl_weight: the pseudo labels used in E-step and M-step optimization alpha in E-step, beta in M-step respectively is_augmented: use EM or just train GNN and LM with gold data + """ def deal_nan(x): return 0 if torch.isnan(x) else x @@ -370,4 +381,4 @@ def deal_nan(x): loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss else: loss = loss_func(logits, labels) - return loss + return loss \ No newline at end of file From a22742cf32c5504834ecfde1e9df4b38d96ca7d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Sep 2024 22:02:09 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/llm/glem.py | 2 +- torch_geometric/nn/models/glem.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index e01985744f6c..6dbb07a30a77 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -437,4 +437,4 @@ def load_model(em_phase): 'for augmenting data only available for ogbn-products') args = parser.parse_args() print(args) - main(args) \ No newline at end of file + main(args) diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py index f3e1cb6ca404..afc8b09d77c7 100644 --- a/torch_geometric/nn/models/glem.py +++ b/torch_geometric/nn/models/glem.py @@ -381,4 +381,4 @@ def deal_nan(x): loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss else: loss = loss_func(logits, labels) - return loss \ No newline at end of file + return loss