Skip to content

Commit

Permalink
fix: parse aliases before get pretrained weight
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 6, 2024
1 parent 0cdef98 commit 3857be8
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/xlens/pretrained/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def _get_safe_weights_files(self, safe_weight_index: Any) -> list[str]:
def get_pretrained_weights(
self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs: Any
) -> dict[str, jax.Array]:
model_name_or_path = (
model_name_or_path if os.path.isdir(model_name_or_path) else self.rev_alias_map[model_name_or_path]
)
if os.path.isdir(model_name_or_path):
if os.path.isfile(os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)):
resolved_archive_files = [os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)]
Expand Down

0 comments on commit 3857be8

Please sign in to comment.