diff --git a/vault-replica/.obsidian/plugins/Dual/skeleton/core.py b/vault-replica/.obsidian/plugins/Dual/skeleton/core.py index 61069ce..0f79489 100644 --- a/vault-replica/.obsidian/plugins/Dual/skeleton/core.py +++ b/vault-replica/.obsidian/plugins/Dual/skeleton/core.py @@ -3,10 +3,8 @@ import os import glob import torch -from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, GPT2Tokenizer, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from util import md_to_text -import json -import random import re @@ -17,7 +15,7 @@ def __init__(self, root_dir): self.entry_regex = os.path.join(root_dir, '**/*md') self.skeleton_ready = False self.essence_ready = False - self.model_name = '' + self.config = {} self.load_skeleton() self.load_essence() @@ -117,20 +115,12 @@ def load_skeleton(self): def load_essence(self): tentative_folder_path = os.path.join(self.root_dir, '.obsidian/plugins/Dual/essence') tentative_file_path = os.path.join(tentative_folder_path, 'pytorch_model.bin') - tentative_config_path = os.path.join(tentative_folder_path, 'config.json') - - with open(tentative_config_path) as file: - config_model = json.load(file) - self.model_name = config_model["_name_or_path"] if self.essence_ready == False and os.path.isfile(tentative_file_path): print('Loading essence...') - self.gen_tokenizer = GPT2Tokenizer.from_pretrained(self.model_name) - - if "gpt-neo" in self.model_name: - self.gen_model = GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=tentative_folder_path, pad_token_id=self.gen_tokenizer.eos_token_id) - else: - self.gen_model = GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path=tentative_folder_path, pad_token_id=self.gen_tokenizer.eos_token_id) + self.config = AutoConfig.from_pretrained(tentative_folder_path) + self.gen_tokenizer = AutoTokenizer.from_pretrained(self.config._name_or_path) + self.gen_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=tentative_folder_path, pad_token_id=self.gen_tokenizer.eos_token_id) self.essence_ready = True