Skip to content

Commit

Permalink
training time
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 committed Oct 1, 2024
2 parents 197c4b9 + 127fbb2 commit b160381
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
9 changes: 5 additions & 4 deletions examples/llm/ogbg_code2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ def get_loss_ogbg(model, batch, **kwargs) -> torch.Tensor:
questions = [master_prompt for i in range(len(batch.y))]
labels = ['|'.join(label) for label in batch.y]
print("batch.desc=", batch.desc)
return model(questions, batch.x.to(torch.float), batch.edge_index, batch.batch, labels,
batch.edge_attr, batch.desc)
return model(questions, batch.x.to(torch.float), batch.edge_index,
batch.batch, labels, batch.edge_attr, batch.desc)


def inference_step_ogbg(model, batch, **kwargs):
questions = [master_prompt for i in range(len(batch.y))]
pred = model.inference(questions, batch.x.to(torch.float), batch.edge_index, batch.batch,
batch.edge_attr, batch.desc)
pred = model.inference(questions, batch.x.to(torch.float),
batch.edge_index, batch.batch, batch.edge_attr,
batch.desc)
labels = ['|'.join(label) for label in batch.y]
eval_data = {
"pred": pred,
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/datasets/ogbg_code2.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def process(self) -> None:
# combine all node information into a single feature tensor, let the GNN+LLM figure it out
new_obj.x = torch.cat(
(old_obj.x, old_obj.node_is_attributed,
old_obj.node_dfs_order, old_obj.node_depth), dim=1).to(torch.float)
old_obj.node_dfs_order, old_obj.node_depth),
dim=1).to(torch.float)
# extract raw python function for use by LLM
func_name_tokens = old_obj.y
new_obj.func_signature, new_obj.desc = self.get_raw_python_from_df(
Expand Down

0 comments on commit b160381

Please sign in to comment.