Skip to content

Commit

Permalink
Minor cleanup in torchchat/cli/builder.py (#1308)
Browse files Browse the repository at this point in the history
Beautify a series of similar checks & fix a spelling error.
  • Loading branch information
swolchok authored Oct 18, 2024
1 parent 4f2f4fb commit fa6f9b6
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,16 @@ def __post_init__(self):
if self.dso_path and self.pte_path:
raise RuntimeError("specify either DSO path or PTE path, but not both")

if self.checkpoint_path and (self.dso_path or self.pte_path):
print(
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
)
if self.checkpoint_dir and (self.dso_path or self.pte_path):
print(
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
)
if self.gguf_path and (self.dso_path or self.pte_path):
print(
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
)
if not (self.dso_path) and not (self.pte_path):
if self.dso_path or self.pte_path:
ignored_params = [
(self.checkpoint_path, "checkpoint path"),
(self.checkpoint_dir, "checkpoint dir"),
(self.gguf_path, "GGUF path"),
]
for param, param_msg in ignored_params:
if param:
print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified")
else:
self.prefill_possible = True

@classmethod
Expand Down Expand Up @@ -446,7 +443,7 @@ def _maybe_init_distributed(
return world_mesh, parallel_dims


def _maybe_parellelize_model(
def _maybe_parallelize_model(
model: nn.Module,
builder_args: BuilderArgs,
world_mesh: DeviceMesh,
Expand Down Expand Up @@ -486,7 +483,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
model = _init_model_on_meta_device(builder_args)
else:
model = _load_model_default(builder_args)
model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims)
model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)

model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()
Expand Down

0 comments on commit fa6f9b6

Please sign in to comment.