Skip to content

Commit

Permalink
training time
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 committed Oct 1, 2024
1 parent 9e84977 commit 6db64cb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
9 changes: 3 additions & 6 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,7 @@ def adjust_learning_rate(param_group, LR, epoch):
epoch_str = f'Epoch: {epoch + 1}|{num_epochs}'
loader = tqdm(train_loader, desc=epoch_str)
for step, batch in enumerate(loader):
if step > 50:
print("training on 50 random instances to start...")
if step > 1000:
break
optimizer.zero_grad()
loss = get_loss(model, batch, model_save_name=model_save_name)
Expand All @@ -212,8 +211,7 @@ def adjust_learning_rate(param_group, LR, epoch):
model.eval()
with torch.no_grad():
for step, batch in enumerate(val_loader):
if step > 50:
print("val'ing on 50 random instances to start...")
if step > 1000:
break
loss = get_loss(model, batch, model_save_name=model_save_name)
val_loss += loss.item()
Expand All @@ -239,8 +237,7 @@ def adjust_learning_rate(param_group, LR, epoch):
print("Final evaluation...")
progress_bar_test = tqdm(range(len(test_loader)))
for step, batch in enumerate(test_loader):
if step > 50:
print("testing on 50 random instances to start...")
if step > 1000:
break
with torch.no_grad():
eval_output.append(
Expand Down
2 changes: 1 addition & 1 deletion examples/llm/ogbg_code2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# hyperparams are hardcoded
import torch
from g_retriever import train

import gc
from torch_geometric.datasets import OGBG_Code2
from torch_geometric.nn.models import GAT, GRetriever
from torch_geometric.nn.nlp import LLM
Expand Down

0 comments on commit 6db64cb

Please sign in to comment.