Skip to content
This repository has been archived by the owner on May 5, 2023. It is now read-only.

Commit

Permalink
refactor: use auto methods for loading models
Browse files Browse the repository at this point in the history
  • Loading branch information
onlurking committed Apr 20, 2021
1 parent a04cbd4 commit ba1f07d
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions vault-replica/.obsidian/plugins/Dual/skeleton/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import os
import glob
import torch
from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, GPT2Tokenizer, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from util import md_to_text
import json
import random
import re


Expand All @@ -17,7 +15,7 @@ def __init__(self, root_dir):
self.entry_regex = os.path.join(root_dir, '**/*md')
self.skeleton_ready = False
self.essence_ready = False
self.model_name = ''
self.config = {}

self.load_skeleton()
self.load_essence()
Expand Down Expand Up @@ -117,20 +115,12 @@ def load_skeleton(self):
def load_essence(self):
tentative_folder_path = os.path.join(self.root_dir, '.obsidian/plugins/Dual/essence')
tentative_file_path = os.path.join(tentative_folder_path, 'pytorch_model.bin')
tentative_config_path = os.path.join(tentative_folder_path, 'config.json')

with open(tentative_config_path) as file:
config_model = json.load(file)
self.model_name = config_model["_name_or_path"]

if self.essence_ready == False and os.path.isfile(tentative_file_path):
print('Loading essence...')
self.gen_tokenizer = GPT2Tokenizer.from_pretrained(self.model_name)

if "gpt-neo" in self.model_name:
self.gen_model = GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=tentative_folder_path, pad_token_id=self.gen_tokenizer.eos_token_id)
else:
self.gen_model = GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path=tentative_folder_path, pad_token_id=self.gen_tokenizer.eos_token_id)
self.config = AutoConfig.from_pretrained(tentative_folder_path)
self.gen_tokenizer = AutoTokenizer.from_pretrained(self.config._name_or_path)
self.gen_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=tentative_folder_path, pad_token_id=self.gen_tokenizer.eos_token_id)

self.essence_ready = True

Expand Down

0 comments on commit ba1f07d

Please sign in to comment.