Skip to content

Commit

Permalink
fix: load model in cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
alaeddine-13 committed Aug 14, 2023
1 parent e5cda88 commit e7a553b
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions textbook/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
PreTrainedModel,
AutoModelForCausalLM,
GPTBigCodeConfig,
GPTBigCodeForCausalLM,
)


Expand Down Expand Up @@ -68,10 +67,15 @@ class StarCoder:

def __init__(self, debug: bool = False):
self._init_tokenizer()
self.model = GPTBigCodeForCausalLM.from_pretrained(
self.base_model,
config=self.config if not debug else self.debug_config,
)
if debug:
self.model = AutoModelForCausalLM.from_pretrained(
self.base_model,
config=self.debug_config,
)
else:
self.model = AutoModelForCausalLM.from_pretrained(self.base_model).to(
"cuda"
)

def _init_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(
Expand Down

0 comments on commit e7a553b

Please sign in to comment.