diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c3b896672fd..3fa5805be5d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) +- Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) ### Changed diff --git a/examples/llm/README.md b/examples/llm/README.md index f1f01428d991..1f7a5cb09100 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,5 +1,10 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| Example | Description | +| ------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | + +## Run GLEM for getting SOTA result on ogbn-products dataset + +`python glem.py` diff --git a/examples/llm/glem.py b/examples/llm/glem.py new file mode 100644 index 000000000000..6dbb07a30a77 --- /dev/null +++ b/examples/llm/glem.py @@ -0,0 +1,440 @@ +"""This example run GLEM model using PyG. +Original Paper: https://arxiv.org/abs/2210.14709 +“Learning on Large-scale Text-attributed Graphs via Variational Inference“. +Requirements on top of basic PyG: +`pip install ogb transformers peft tqdm`. +GLEM is a data augmentation co-training strategy for LM and GNN, our +implementation extended original implementation from LM to LLM and opt for LoRA +from peft. + +``note:: + use addtional trick, please add your external prediction by assigning + `ext_pred_path` and combine it into pretraining phase and node features +""" + +import argparse +import os +import os.path as osp +import sys +import time + +import torch + +from torch_geometric import seed_everything +from torch_geometric.data import download_google_url +from torch_geometric.datasets import TAGDataset +from torch_geometric.nn.models import GAT, GCN, GLEM, GraphSAGE + +# Add the parent directory to sys.path +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(parent_dir) + + +def get_n_params(model): + pp = 0 + for p in list(model.parameters()): + nn = 1 + for s in list(p.size()): + nn = nn * s + pp += nn + return pp + + +def main(args): + gpu = args.gpu + dataset_name = args.dataset + root = osp.join(parent_dir, 'data', 'ogb') + hf_model = args.hf_model + pl_ratio = args.pl_ratio + gnn_lr = args.gnn_lr + lm_lr = args.lm_lr + em_order = args.em_order + gnn_epochs = args.gnn_epochs + lm_epochs = args.lm_epochs + patience = args.patience + verbose = args.verbose + out_dir = args.out_dir + lm_batch_size = args.lm_batch_size + gnn_batch_size = args.gnn_batch_size + lm_use_lora = args.lm_use_lora + token_on_disk = args.token_on_disk + num_em_iters = args.num_em_iters + start_time = time.time() + train_without_ext_pred = args.train_without_ext_pred + ext_pred = None + pretrain_augmented = False + ext_pseudo_labels = None + device = torch.device( + f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu') + print(f'Running on: {torch.cuda.get_device_name({gpu})}') + torch.cuda.empty_cache() + + if not train_without_ext_pred: + ext_pred_path = download_google_url( + id='15sO2m7BeW7C1Upmdw3Cx1JS__6nxTAzY', + folder='/work/users/junhaos/glem_data/ogbn_products/ext_preds', + filename='giant_sagn_scr.pt', log=True) + ext_pred = torch.load(ext_pred_path, map_location=device) + ext_pseudo_labels = ext_pred.argmax(dim=-1) + pretrain_augmented = True + + seed_everything(42) + from ogb.nodeproppred import PygNodePropPredDataset + dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) + split_idx = dataset.get_idx_split() + data = dataset.data + + tag_dataset = TAGDataset(root, dataset, hf_model, + token_on_disk=token_on_disk) + text_dataset = tag_dataset.to_text_dataset() + print(tag_dataset.num_classes, tag_dataset.raw_file_names) + + num_classes = tag_dataset.num_classes + num_features = data.num_features + # =========================== LM Data split =============================== + split_idx = tag_dataset.get_idx_split() + + # GLEM train with augmented data, mark original train data as gold data, + gold_idx = split_idx['train'] + test_idx = split_idx['test'] + + # randome sample pseudo labels nodes, generate their index + num_pseudo_labels = int(gold_idx.numel() * pl_ratio) + idx_to_select = torch.randperm(test_idx.numel())[:num_pseudo_labels] + pseudo_labels_idx = test_idx[idx_to_select] + train_idx = torch.cat( + (gold_idx, pseudo_labels_idx)) # augmented train_indx + + print(f'train_idx: {train_idx.size(0)}, ' + f'gold_idx: {gold_idx.size(0)}, ' + f'pseudo labels ratio: {pl_ratio}, ' + f'{train_idx.size(0)/gold_idx.size(0) - 1.0}') + gold_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=gold_idx) + train_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=train_idx) + # ========================== LM Data Loader =============================== + + print('Building language model dataloader...', end='-->') + from torch_geometric.loader import DataLoader + + # if set train_without_ext_pred == True, use this for pretrain + text_pretrain_loader = DataLoader(gold_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + # training with augmented data, + text_train_loader = DataLoader(train_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + text_data_loader = DataLoader(text_dataset, batch_size=lm_batch_size * 4, + drop_last=False, pin_memory=True, + shuffle=False) + print('done') + + # =========================== GNN Data Loader ============================= + initial_memory = torch.cuda.memory_allocated() + data = data.to(device) + if ext_pred is not None: + data.x = torch.cat((data.x, ext_pred), dim=1) + num_features += ext_pred.size(1) + current_memory_1 = torch.cuda.max_memory_allocated() + # 1 GB = 1073741824 Byte + gpu_usage = float(current_memory_1 - initial_memory) / 1073741824 + # Print the maximum memory usage after running the model + print(f'GPU memory usage -- data to gpu: {gpu_usage:.2f} GB') + + print('build GNN dataloader(GraphSAGE NeighborLoader)', end='-->') + from torch_geometric.loader import NeighborLoader + + # train on gold data w/o pseudo labels + graph_pretrain_loader = NeighborLoader( + data, + input_nodes=gold_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # graph data loader w/ pseudo labels in M-step + graph_train_loader = NeighborLoader( + data, + input_nodes=train_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # for gnn inference + subgraph_loader = NeighborLoader( + data, + input_nodes=None, + num_neighbors=[-1], + batch_size=gnn_batch_size * 4, + num_workers=12, + persistent_workers=True, + ) + # =========================== internal function =========================== + + from ogb.nodeproppred import Evaluator + evaluator = Evaluator(name=f'ogbn-{dataset_name}') + + def evaluate(out, split): + y_true = data.y.cpu() + y_pred = out.argmax(dim=-1, keepdim=True) + train_acc, val_acc, test_acc = None, None, None + if 'train' in split: + train_acc = evaluator.eval({ + 'y_true': y_true[split_idx['train']], + 'y_pred': y_pred[split_idx['train']], + })['acc'] + if 'valid' in split: + val_acc = evaluator.eval({ + 'y_true': y_true[split_idx['valid']], + 'y_pred': y_pred[split_idx['valid']], + })['acc'] + if 'test' in split: + test_acc = evaluator.eval({ + 'y_true': y_true[split_idx['test']], + 'y_pred': y_pred[split_idx['test']], + })['acc'] + + return train_acc, val_acc, test_acc + + # =========================== Build GNN Model ============================= + gnn = None + if args.gnn_model == 'SAGE': + gnn = GraphSAGE( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + elif args.gnn_model == 'GAT': + gnn = GAT(in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, heads=args.gat_heads) + else: + gnn = GCN( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + + print("# GNN Params:", get_n_params(gnn)) + # =========================== Build LM Model ============================== + + model = GLEM(lm_to_use=hf_model, gnn_to_use=gnn, out_channels=num_classes, + lm_use_lora=lm_use_lora, device=device) + lm = model.lm + print("# LM Params:", get_n_params(lm)) + gnn_opt = torch.optim.Adam(gnn.parameters(), lr=gnn_lr) + lm_opt = torch.optim.Adam(lm.parameters(), lr=lm_lr) + + def load_model(em_phase): + print(f'Move {em_phase} model from cpu memory') + if em_phase == 'lm': + model.lm = model.lm.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.lm.parameters(), lr=lm_lr) + if em_phase == 'gnn': + model.gnn = model.gnn.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.gnn.parameters(), lr=gnn_lr) + return optimizer + + # ================================= Run GLEM ============================== + preds_filename = 'lm_pretrain' + preds_dir = f'{out_dir}preds/{dataset_name}/' + gnn_test_acc = 0.0 + lm_test_acc = 0.0 + # =============================== GLEM pretraining ======================== + pretrain_phase = 'lm' + if em_order == 'lm': + pretrain_phase = 'gnn' + pretrain_start_time = time.time() + # pretraining + pretrain_loader = graph_pretrain_loader + test_loader = subgraph_loader + pretrain_num_epochs = gnn_epochs + pretrain_opt = gnn_opt + if pretrain_phase == 'gnn': + model.gnn = model.gnn.to(device) + print('pretraining gnn to generate pseudo labels') + if not train_without_ext_pred: + pretrain_loader = graph_train_loader + preds_filename = 'gnn_pretrain' + elif pretrain_phase == 'lm': + model.lm = model.lm.to(device) + print('pretraining lm to generate pseudo labels') + pretrain_num_epochs = lm_epochs + pretrain_loader = text_pretrain_loader + test_loader = text_data_loader + pretrain_opt = lm_opt + if not train_without_ext_pred: + pretrain_loader = text_train_loader + preds_filename = 'lm_pretrain' + + early_stopping = 0 + best_val_acc = final_test_acc = 0.0 + for epoch in range(1, pretrain_num_epochs + 1): + acc, loss = model.train(pretrain_phase, pretrain_loader, pretrain_opt, + ext_pseudo_labels, epoch, pretrain_augmented, + verbose) + if epoch >= 5 or epoch == pretrain_num_epochs: + pretrain_preds = model.inference(pretrain_phase, test_loader, + verbose=verbose) + train_acc, val_acc, test_acc = evaluate(pretrain_preds, + ['train', 'valid', 'test']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Pretrain Early stopped by Epoch: {epoch}') + break + else: + best_val_acc = val_acc + final_test_acc = test_acc + preds = pretrain_preds + + if pretrain_phase == 'gnn': + gnn_test_acc = max(gnn_test_acc, final_test_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + else: + lm_test_acc = max(lm_test_acc, final_test_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + torch.cuda.empty_cache() + + pretrain_phase_time = time.time() - pretrain_start_time + print(f'Pretrain {pretrain_phase} time: {pretrain_phase_time:.2f}s') + os.makedirs(osp.dirname(preds_dir), exist_ok=True) + torch.save(preds, osp.join(preds_dir, f'{preds_filename}.pt')) + print( + f'Saved predictions to {osp.join(preds_dir, f"{preds_filename}.pt")}') + train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) + print(f'Pretraining acc: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + # EM iterations + + em_phase = em_order + """ + We run E-step(LM training) and M-Step(GNN training) alternatively in each + em iterations, so the total number of iterations is num_em_iter * 2 and + we switch the em_phase at end of each iteration in following loop + """ + for em_it in range(1, num_em_iters * 2 + 1): + pseudo_labels = preds.argmax(dim=-1) + best_val_acc = final_test_acc = 0.0 + print(f'EM iteration: {em_it}, EM phase: {em_phase}') + optimizer = load_model(em_phase) + num_epochs = lm_epochs + train_loader = text_train_loader + test_loader = text_data_loader + early_stopping = 0 + if em_phase == 'gnn': + train_loader = graph_train_loader + num_epochs = gnn_epochs + test_loader = subgraph_loader + for epoch in range(1, num_epochs + 1): + acc, loss = model.train(em_phase, train_loader, optimizer, + pseudo_labels, epoch, True, verbose) + if epoch >= 5 or epoch == num_epochs: + cur_preds = model.inference(em_phase, test_loader, + verbose=verbose) + train_acc, val_acc, test_acc = evaluate( + cur_preds, ['train', 'valid', 'test']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'''Early stopped by Epoch: {epoch}, \ + Best acc: {final_test_acc}''') + break + else: + best_val_acc = val_acc + final_test_acc = test_acc + preds = cur_preds + + if em_phase == 'gnn': + gnn_test_acc = max(gnn_test_acc, final_test_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + em_phase = 'lm' + else: + lm_test_acc = max(lm_test_acc, final_test_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + em_phase = 'gnn' + torch.cuda.empty_cache() + print(f'Best GNN acc: {gnn_test_acc}, LM acc: {lm_test_acc}') + print('============================') + end_time = time.time() + running_time = (end_time - start_time) / 3600 + print(f'Total running time: {running_time:.2f} hours') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='GLEM Example:') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--num_runs', type=int, default=10, + help='number of runs') + parser.add_argument('--num_em_iters', type=int, default=1, + help='number of iterations') + parser.add_argument("--dataset", type=str, default='products', + help='arxiv or products') + parser.add_argument("--pl_ratio", type=float, default=0.5, + help="pseudo labels ratio") + parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', + help='huggingface model repo id') + parser.add_argument( + '--gnn_model', type=str, default='SAGE', + help='gnn model for node classification,' + 'options: SAGE, GAT, GCN') + parser.add_argument('--gnn_hidden_channels', type=int, default=256) + parser.add_argument('--gnn_num_layers', type=int, default=3) + parser.add_argument('--gat_heads', type=int, default=4, + help='Number of multi-head-attentions for GAT ') + parser.add_argument('--lm_batch_size', type=int, default=256) + parser.add_argument('--gnn_batch_size', type=int, default=1024) + parser.add_argument( + '--external_pred_path', type=str, default=None, + help="Other model's output logits during the " + "pretraining phase or simply concatenate it with" + "node features as augmented data for gnn") + parser.add_argument('--alpha', type=float, default=0.5, + help='pseudo label weight in E-step') + parser.add_argument('--beta', type=float, default=0.5, + help='pseudo label weight in M-step') + parser.add_argument('--lm_epochs', type=int, default=10) + parser.add_argument('--gnn_epochs', type=int, default=50) + parser.add_argument('--gnn_lr', type=float, default=0.002) + parser.add_argument('--lm_lr', type=float, default=0.001) + parser.add_argument('--patience', type=int, default=3, + help='Patience for early stopping') + parser.add_argument('--verbose', action='store_true', + help='show progress bar during training or not') + parser.add_argument('--em_order', type=str, default='lm', + help='decide train LM first or GNN first') + parser.add_argument('--lm_use_lora', action='store_true', + help='use Lora to fine-tune model or not') + parser.add_argument( + '--token_on_disk', action='store_true', + help='save token on disk and load token from disk' + 'for reducing duplicated tokenizing') + parser.add_argument('--out_dir', type=str, default='output/', + help='output directory') + parser.add_argument( + '--train_without_ext_pred', action='store_true', + help='train glem without using additional pseudo labels ' + 'for augmenting data only available for ogbn-products') + args = parser.parse_args() + print(args) + main(args) diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 96d51032d818..0b6569d3f92b 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -77,6 +77,7 @@ from .brca_tgca import BrcaTcga from .neurograph import NeuroGraphDataset from .web_qsp_dataset import WebQSPDataset +from .tag_dataset import TAGDataset from .dbp15k import DBP15K from .aminer import AMiner @@ -190,6 +191,7 @@ 'BrcaTcga', 'NeuroGraphDataset', 'WebQSPDataset', + 'TAGDataset', ] hetero_datasets = [ diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py new file mode 100644 index 000000000000..f25992ced989 --- /dev/null +++ b/torch_geometric/datasets/tag_dataset.py @@ -0,0 +1,350 @@ +import os +import os.path as osp +from collections.abc import Sequence +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +from torch import Tensor +from tqdm import tqdm + +from torch_geometric.data import InMemoryDataset, download_google_url +from torch_geometric.data.data import BaseData + +try: + from pandas import DataFrame, read_csv + WITH_PANDAS = True +except ImportError: + WITH_PANDAS = False + +IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +class TAGDataset(InMemoryDataset): + r"""The Text Attributed Graph datasets from the + `"Learning on Large-scale Text-attributed Graphs via Variational Inference + " `_ paper. + This dataset is aiming on transform `ogbn products`, `ogbn arxiv` + into Text Attributed Graph that each node in graph is associate with a + raw text, that dataset can be adapt to DataLoader (for LM training) and + NeighborLoader(for GNN training). In addition, this class can be use as a + wrapper class by convert a InMemoryDataset with Tokenizer and text into + Text Attributed Graph. + + Args: + root (str): Root directory where the dataset should be saved. + dataset (InMemoryDataset): The name of the dataset + (:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`). + tokenizer_name (str): The tokenizer name for language model, + Be sure to use same tokenizer name as your `model id` of model repo + on huggingface.co. + text (List[str]): list of raw text associate with node, the order of + list should be align with node list + split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary, + for saving split index, it is required that if your dataset doesn't + have get_split_idx function + tokenize_batch_size (int): batch size of tokenizing text, the + tokenizing process will run on cpu, default: 256 + token_on_disk (bool): save token as .pt file on disk or not, + default: False + text_on_disk (bool): save given text(list of str) as dataframe on disk + or not, default: False + force_reload (bool): default: False + .. note:: + See `example/llm_plus_gnn/glem.py` for example usage + """ + raw_text_id = { + 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3', + 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt' + } + + def __init__(self, root: str, dataset: InMemoryDataset, + tokenizer_name: str, text: Optional[List[str]] = None, + split_idx: Optional[Dict[str, Tensor]] = None, + tokenize_batch_size: int = 256, token_on_disk: bool = False, + text_on_disk: bool = False, + force_reload: bool = False) -> None: + # list the vars you want to pass in before run download & process + self.name = dataset.name + self.text = text + self.tokenizer_name = tokenizer_name + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.dir_name = '_'.join(dataset.name.split('-')) + self.root = osp.join(root, self.dir_name) + missing_str_list = [] + if not WITH_PANDAS: + missing_str_list.append('pandas') + if len(missing_str_list) > 0: + missing_str = ' '.join(missing_str_list) + error_out = f"`pip install {missing_str}` to use this dataset." + raise ImportError(error_out) + if hasattr(dataset, 'get_idx_split'): + self.split_idx = dataset.get_idx_split() + elif split_idx is not None: + self.split_idx = split_idx + else: + raise ValueError("TAGDataset need split idx for generating " + "is_gold mask, please pass splited index " + "in format of dictionaty with 'train', 'valid' " + "'test' index tensor to 'split_idx'") + if text is not None and text_on_disk: + self.save_node_text(text) + self.text_on_disk = text_on_disk + # init will call download and process + super().__init__(self.root, transform=None, pre_transform=None, + pre_filter=None, force_reload=force_reload) + # after processing and download + # Dataset has to have BaseData as _data + assert dataset._data is not None + self._data = dataset._data # reassign reference + assert self._data is not None + assert dataset._data.y is not None + assert isinstance(self._data, BaseData) + assert self._data.num_nodes is not None + assert isinstance(dataset._data.num_nodes, int) + assert isinstance(self._data.num_nodes, int) + self._n_id = torch.arange(self._data.num_nodes) + is_good_tensor = self.load_gold_mask() + self._is_gold = is_good_tensor.squeeze() + self._data['is_gold'] = is_good_tensor + if self.text is not None and len(self.text) != self._data.num_nodes: + raise ValueError("The number of text sequence in 'text' should be " + "equal to number of nodes!") + self.token_on_disk = token_on_disk + self.tokenize_batch_size = tokenize_batch_size + self._token = self.tokenize_graph(self.tokenize_batch_size) + self.__num_classes__ = dataset.num_classes + + @property + def num_classes(self) -> int: + return self.__num_classes__ + + @property + def raw_file_names(self) -> List[str]: + file_names = [] + for root, _, files in os.walk(osp.join(self.root, 'raw')): + for file in files: + file_names.append(file) + return file_names + + @property + def processed_file_names(self) -> List[str]: + return [ + 'geometric_data_processed.pt', 'pre_filter.pt', + 'pre_transformed.pt' + ] + + @property + def token(self) -> Dict[str, Tensor]: + if self._token is None: # lazy load + self._token = self.tokenize_graph() + return self._token + + # load is_gold after init + @property + def is_gold(self) -> Tensor: + if self._is_gold is None: + print('lazy load is_gold!!') + self._is_gold = self.load_gold_mask() + return self._is_gold + + def get_n_id(self, node_idx: IndexType) -> Tensor: + if self._n_id is None: + assert self._data is not None + assert self._data.num_nodes is not None + assert isinstance(self._data.num_nodes, int) + self._n_id = torch.arange(self._data.num_nodes) + return self._n_id[node_idx] + + def load_gold_mask(self) -> Tensor: + r"""Use original train split as gold split, generating is_gold mask + for picking ground truth labels and pseudo labels. + """ + train_split_idx = self.get_idx_split()['train'] + assert self._data is not None + assert self._data.num_nodes is not None + assert isinstance(self._data.num_nodes, int) + is_good_tensor = torch.zeros(self._data.num_nodes, + dtype=torch.bool).view(-1, 1) + is_good_tensor[train_split_idx] = True + return is_good_tensor + + def get_gold(self, node_idx: IndexType) -> Tensor: + r"""Get gold mask for given node_idx. + + Args: + node_idx (torch.tensor): a tensor contain node idx + """ + if self._is_gold is None: + self._is_gold = self.is_gold + return self._is_gold[node_idx] + + def get_idx_split(self) -> Dict[str, Tensor]: + return self.split_idx + + def download(self) -> None: + print('downloading raw text') + raw_text_path = download_google_url(id=self.raw_text_id[self.name], + folder=f'{self.root}/raw', + filename='node-text.csv.gz', + log=True) + text_df = read_csv(raw_text_path) + self.text = list(text_df['text']) + + def process(self) -> None: + if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')): + text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz')) + self.text = list(text_df['text']) + elif self.name in self.raw_text_id: + self.download() + else: + print('The dataset is not ogbn-products nor ogbn-arxiv,' + 'please pass in your raw text string list to `text`') + if self.text is None: + raise ValueError("The TAGDataset only have ogbn-products and " + "ogbn-arxiv raw text in default " + "The raw text of each node is not specified" + "Please pass in 'text' when convert your dataset " + "to Text Attribute Graph Dataset") + + def save_node_text(self, text: List[str]) -> None: + node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz') + if osp.exists(node_text_path): + print(f'The raw text is existed at {node_text_path}') + else: + print(f'Saving raw text file at {node_text_path}') + os.makedirs(f'{self.root}/raw', exist_ok=True) + text_df = DataFrame(text, columns=['text']) + text_df.to_csv(osp.join(node_text_path), compression='gzip', + index=False) + + def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: + r"""Tokenizing the text associate with each node, running in cpu. + + Args: + batch_size (Optional[int]): batch size of list of text for + generating emebdding + Returns: + Dict[str, torch.Tensor]: tokenized graph + """ + data_len = 0 + if self.text is not None: + data_len = len(self.text) + else: + raise ValueError("The TAGDataset need text for tokenization") + token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] + path = os.path.join(self.processed_dir, 'token', self.tokenizer_name) + # Check if the .pt files already exist + token_files_exist = any( + os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys) + + if token_files_exist and self.token_on_disk: + print('Found tokenized file, loading may take several minutes...') + all_encoded_token = { + k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True) + for k in token_keys + if os.path.exists(os.path.join(path, f'{k}.pt')) + } + return all_encoded_token + + all_encoded_token = {k: [] for k in token_keys} + pbar = tqdm(total=data_len) + + pbar.set_description('Tokenizing Text Attributed Graph') + for i in range(0, data_len, batch_size): + end_index = min(data_len, i + batch_size) + token = self.tokenizer(self.text[i:min(i + batch_size, data_len)], + padding='max_length', truncation=True, + max_length=512, return_tensors="pt") + for k in token.keys(): + all_encoded_token[k].append(token[k]) + pbar.update(end_index - i) + pbar.close() + + all_encoded_token = { + k: torch.cat(v) + for k, v in all_encoded_token.items() if len(v) > 0 + } + if self.token_on_disk: + os.makedirs(path, exist_ok=True) + print('Saving tokens on Disk') + for k, tensor in all_encoded_token.items(): + torch.save(tensor, os.path.join(path, f'{k}.pt')) + print('Token saved:', os.path.join(path, f'{k}.pt')) + os.environ["TOKENIZERS_PARALLELISM"] = 'true' # supressing warning + return all_encoded_token + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + class TextDataset(torch.utils.data.Dataset): + r"""This nested dataset provides textual data for each node in + the graph. Factory method to create TextDataset from TAGDataset. + + Args: + tag_dataset (TAGDataset): the parent dataset + """ + def __init__(self, tag_dataset: 'TAGDataset') -> None: + self.tag_dataset = tag_dataset + self.token = tag_dataset.token + assert tag_dataset._data is not None + self._data = tag_dataset._data + + assert tag_dataset._data.y is not None + self.labels = tag_dataset._data.y + + def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]: + r"""This function will be called in __getitem__(). + + Args: + node_idx (IndexType): selected node idx in each batch + Returns: + items (Dict[str, Tensor]): input for LM + """ + items = {k: v[node_idx] for k, v in self.token.items()} + return items + + # for LM training + def __getitem__( + self, node_id: IndexType + ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: + r"""This function will override the function in + torch.utils.data.Dataset, and will be called when you + iterate batch in the dataloader, make sure all following + key value pairs are present in the return dict. + + Args: + node_id (List[int]): list of node idx for selecting tokens, + labels etc. when iterating data loader for LM + Returns: + items (dict): input k,v pairs for Language model training and + inference + """ + item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {} + item['input'] = self.get_token(node_id) + item['labels'] = self.labels[node_id] + item['is_gold'] = self.tag_dataset.get_gold(node_id) + item['n_id'] = self.tag_dataset.get_n_id(node_id) + return item + + def __len__(self) -> int: + assert self._data.num_nodes is not None + return self._data.num_nodes + + def get(self, idx: int) -> BaseData: + return self._data + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + def to_text_dataset(self) -> TextDataset: + r"""Factory Build text dataset from Text Attributed Graph Dataset + each data point is node's associated text token. + """ + return TAGDataset.TextDataset(self) diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 7cfadf0143b2..5860db311ac3 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -29,7 +29,7 @@ from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet from .g_retriever import GRetriever - +from .glem import GLEM # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) @@ -77,4 +77,5 @@ 'NeuralFingerprint', 'ViSNet', 'GRetriever', + 'GLEM', ] diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py new file mode 100644 index 000000000000..afc8b09d77c7 --- /dev/null +++ b/torch_geometric/nn/models/glem.py @@ -0,0 +1,384 @@ +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from tqdm import tqdm + +from torch_geometric.loader import DataLoader, NeighborLoader +from torch_geometric.nn.models import GraphSAGE, basic_gnn + + +class GLEM(torch.nn.Module): + r"""This GNN+LM co-training model is based on GLEM from the `"Learning on + Large-scale Text-attributed Graphs via Variational Inference" + `_ paper. + + Args: + lm_to_use (str): A TextEncoder from huggingface model repo + with a classifier(default: TinyBERT) + gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE) + out_channels (int): output channels for LM and GNN, should be same + num_gnn_heads Optional[int]: Number of heads for attention, if needed + num_gnn_layers (int): number of gnn layers + gnn_loss: loss function for gnn, (default: CrossEntropyLoss) + lm_loss: loss function for Language Model, (default: CrossEntropyLoss) + alpha (float): pseudo label weight of E-step, LM optimization, + (default: 0.5) + beta (float): pseudo label weight of M-step, GNN optimization, + (default: 0.5) + lm_dtype (torch.dtype): the data type once you load LM into memory, + (default: torch.bfloat16) + lm_use_lora (bool): choose if LM use Lora peft for fine tune, + (default: True) + lora_target_modules: The names of the target modules to apply the lora + adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None) + + .. note:: + See `examples/llm_plus_gnn/glem.py` for example usage. + """ + def __init__( + self, + lm_to_use: str = 'prajjwal1/bert-tiny', + gnn_to_use: basic_gnn = GraphSAGE, + out_channels: int = 47, + gnn_loss=nn.CrossEntropyLoss(reduction='mean'), + lm_loss=nn.CrossEntropyLoss(reduction='mean'), + alpha: float = 0.5, + beta: float = 0.5, + lm_dtype: torch.dtype = torch.bfloat16, + lm_use_lora: bool = True, + lora_target_modules: Optional[Union[List[str], str]] = None, + device: Union[str, torch.device] = torch.device('cpu'), + ): + super().__init__() + self.device = device + self.lm_loss = lm_loss + self.gnn = gnn_to_use + self.gnn_loss = gnn_loss + self.alpha = alpha + self.beta = beta + self.gnn_loss = gnn_loss + self.lm = lm_to_use + from transformers import AutoModelForSequenceClassification + self.lm = AutoModelForSequenceClassification.from_pretrained( + lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype, + offload_folder="offload", trust_remote_code=True) + if lm_use_lora: + from peft import ( + LoraConfig, + TaskType, + get_peft_model, + prepare_model_for_kbit_training, + ) + print("Training LM with LORA!") + self.lm = prepare_model_for_kbit_training(self.lm) + config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16, + lora_alpha=16, lora_dropout=0.05, bias="none", + target_modules=lora_target_modules) + self.lm = get_peft_model(self.lm, config) + self.lm.print_trainable_parameters() + self.lm.config.pad_token_id = self.lm.config.eos_token_id + self.lm_device = self.lm.device + + if self.lm.num_labels != self.gnn.out_channels: + raise ValueError('''The output channel of language model \ + and gnn should be the same''') + + def pre_train_gnn(self, train_loader: NeighborLoader, + optimizer: torch.optim.Optimizer, num_epochs: int, + patience: int, ext_pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + # Pretrain GNN, optional steps if you do not have pseudo labels. + best_acc = 0 + early_stopping = 0 + # training only based on gold data + for epoch in range(0, num_epochs): + acc, loss = self.train_gnn(train_loader, optimizer, epoch, + ext_pseudo_labels, is_augmented, + verbose) + if acc < best_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Early stopped by Epoch: {epoch}, ' + f'Best acc: {best_acc}') + break + best_acc = max(best_acc, acc) + + def pre_train_lm(self, train_loader: DataLoader, + optimizer: torch.optim.Optimizer, num_epochs: int, + patience: int, ext_pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + # Pretrain language model + best_acc = 0 + early_stopping = 0 + for epoch in range(1, num_epochs + 1): + acc, loss = self.train_lm(train_loader, optimizer, epoch, + ext_pseudo_labels, is_augmented, verbose) + if acc < best_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Early stopped by Epoch: {epoch}, ' + f'Best acc: {best_acc}') + break + best_acc = max(best_acc, acc) + + def train(self, em_phase: str, train_loader: Union[DataLoader, + NeighborLoader], + optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor, + epoch: int, is_augmented: bool = False, verbose: bool = False): + r"""GLEM training step, EM steps. + + Args: + em_phase(str): 'gnn' or 'lm' choose which phase you are training on + train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for + lm training, include tokenized data, labels is_gold mask. + use NeighborLoader for gnn training, include x, edge_index. + optimizer (torch.optim.Optimizer): optimizer for training + pseudo_labels(torch.Tensor): the predicted labels used as pseudo + labels + epoch (int): current epoch + is_augmented (bool): will use pseudo_labels or not + verbose (bool): print training progress bar or not + + Returns: + acc (float): training accuracy + loss (float): loss value + """ + pseudo_labels = pseudo_labels.to(self.device) + if em_phase == 'gnn': + acc, loss = self.train_gnn(train_loader, optimizer, epoch, + pseudo_labels, is_augmented, verbose) + if em_phase == 'lm': + acc, loss = self.train_lm(train_loader, optimizer, epoch, + pseudo_labels, is_augmented, verbose) + return acc, loss + + def train_lm(self, train_loader: DataLoader, + optimizer: torch.optim.Optimizer, epoch: int, + pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + r"""Language model Training in every epoch. + + Args: + train_loader (loader.dataloader.DataLoader): text token dataloader + optimizer (torch.optim.Optimizer): model optimizer + epoch (int): current train epoch + pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn + is_augmented (bool): train with pseudo labels or not + verbose (bool): print training progress bar or not + + Returns: + approx_acc (torch.tensor): training accuracy + loss (torch.float): loss value + + """ + all_out = [] + total_loss = total_correct = 0 + num_nodes = train_loader.dataset.indices.size(0) + self.lm.train() + if verbose: + pbar = tqdm(total=num_nodes) + pbar.set_description(f'Epoch {epoch:02d}') + for batch in train_loader: + inputs = {k: v.to(self.device) for k, v in batch['input'].items()} + out = self.lm(**inputs).logits + labels = batch['labels'].to(self.device).squeeze() + # training with pseudo labels or not + if is_augmented: + pl_batch = pseudo_labels[batch['n_id']].to(self.device) + else: + pl_batch = None + loss = self.loss(out, labels, self.lm_loss, + batch['is_gold'].to(self.device), pl_batch, + self.alpha, is_augmented) + loss.backward() + optimizer.step() + optimizer.zero_grad() + all_out.append(out) + total_correct += int(out.argmax(dim=-1).eq(labels).sum()) + total_loss += float(loss) + if verbose: + pbar.update(batch['n_id'].size(0)) + + all_out = torch.cat(all_out, dim=0) + approx_acc = total_correct / num_nodes + loss = total_loss / len(train_loader) + if verbose: + pbar.close() + print(f'Epoch {epoch:02d} Loss: {loss:.4f} ' + f'Approx. Train: {approx_acc:.4f}') + return approx_acc, loss + + def train_gnn(self, train_loader: NeighborLoader, + optimizer: torch.optim.Optimizer, epoch: int, + pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + r"""GNN training step in every epoch. + + Args: + train_loader (loader.NeighborLoader): gnn Neighbor node loader + optimizer (torch.optim.Optimizer): model optimizer + epoch (int): current train epoch + pseudo_labels(torch.tensor): 1-D tensor, predictions from lm + is_augmented(bool): use pseudo labeled node or not + verbose (bool): print training progress or not + + Returns: + approx_acc (torch.tensor): training accuracy + loss (torch.float): loss value + """ + self.gnn.train() + num_nodes = train_loader.input_nodes.size(0) + if verbose: + pbar = tqdm(total=num_nodes) + pbar.set_description(f'Epoch {epoch:02d}') + total_loss = total_correct = 0 + all_out = [] + for batch in train_loader: + batch = batch.to(self.device) + out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] + all_out.append(out) + labels = batch.y[:batch.batch_size].squeeze() + is_gold_batch = batch.is_gold[:batch.batch_size].squeeze() + # training with pseudo labels or not + if is_augmented and pseudo_labels is not None: + pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]] + else: + pl_batch = None + loss = self.loss(out, labels, self.gnn_loss, is_gold_batch, + pl_batch, self.beta, is_augmented) + loss.backward() + optimizer.step() + optimizer.zero_grad() + total_loss += float(loss) + total_correct += int(out.argmax(dim=-1).eq(labels).sum()) + if verbose: + pbar.update(batch.batch_size) + + all_out = torch.cat(all_out, dim=0) + loss = total_loss / len(train_loader) + approx_acc = total_correct / num_nodes + if verbose: + pbar.close() + print(f'Epoch: {epoch:02d} Loss: {loss:.4f} ' + f'Approx. Train: {approx_acc:.4f}') + return approx_acc, loss + + @torch.no_grad() + def inference(self, em_phase: str, data_loader: Union[NeighborLoader, + DataLoader], + verbose: bool = False): + r"""GLEM inference step. + + Args: + em_phase(str): 'gnn' or 'lm' + data_loader(dataloader or Neighborloader): + dataloader: for lm training, include tokenized data + nodeloader: for gnn training, include x, edge_index + verbose(bool): print inference progress or not + + Returns: + out (torch.Tensor): n * m tensor, m is number of classes, + n is number of nodes + """ + out = None + if em_phase == 'gnn': + self.gnn.eval() + out = self.inference_gnn(data_loader, verbose) + elif em_phase == 'lm': + self.lm.eval() + out = self.inference_lm(data_loader, verbose) + return out + + @torch.no_grad() + def inference_lm(self, data_loader: DataLoader, verbose: bool = True): + r"""LM inference step. + + Args: + data_loader (Dataloader): include token, labels, and gold mask + verbose (bool): print progress bar or not + + Returns: + preds (tensor): prediction from GNN, convert to pseudo labels + by preds.argmax(dim=-1).unsqueeze(1) + """ + if verbose: + pbar = tqdm(total=data_loader.dataset._data.num_nodes) + pbar.set_description('LM inference stage') + self.lm.eval() + preds = [] + for batch in data_loader: + inputs = {k: v.to(self.device) for k, v in batch['input'].items()} + logits = self.lm(**inputs).logits + preds.append(logits) + if verbose: + pbar.update(batch['n_id'].size(0)) + if verbose: + pbar.close() + preds = torch.cat(preds) + return preds + + @torch.no_grad() + def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True): + r"""GNN inference step. + + Args: + data_loader(NeighborLoader): include x, edge_index, + verbose (bool): print progress bar or not + + Returns: + preds (tensor): prediction from GNN, + convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) + """ + if verbose: + pbar = tqdm(total=data_loader.data.num_nodes) + pbar.set_description('GNN inference stage') + preds = [] + self.gnn.eval() + for batch in data_loader: + batch = batch.to(self.device) + out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] + preds.append(out) + if verbose: + pbar.update(batch.batch_size) + if verbose: + pbar.close() + preds = torch.cat(preds, dim=0) + return preds + + def loss(self, logits: torch.Tensor, labels: torch.Tensor, + loss_func: torch.nn.functional, is_gold: torch.Tensor, + pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5, + is_augmented: bool = True): + r"""Core function of variational EM inference, this function is aming + on combining loss value on gold(original train) and loss value on + pseudo labels. + + Reference: + # noqa + + Args: + logits(torch.tensor): predict results from LM or GNN + labels(torch.tensor): combined node labels from ground truth and + pseudo labels(if provided) + loss_func(torch.nn.modules.loss): loss function for classification + is_gold(tensor): a tensor with bool value that mask ground truth + label and during training, thus ~is_gold mask pseudo labels + pseudo_labels(torch.tensor): predictions from other model + pl_weight: the pseudo labels used in E-step and M-step optimization + alpha in E-step, beta in M-step respectively + is_augmented: use EM or just train GNN and LM with gold data + + """ + def deal_nan(x): + return 0 if torch.isnan(x) else x + + if is_augmented and (sum(~is_gold) > 0): + mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold])) + # all other labels beside from ground truth(gold labels) + pseudo_label_loss = deal_nan( + loss_func(logits[~is_gold], pseudo_labels[~is_gold])) + loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss + else: + loss = loss_func(logits, labels) + return loss