Skip to content

Commit

Permalink
Add dtype arg to Gemma HF conversion script (#1452)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkovela1 authored Feb 22, 2024
1 parent e5a91f7 commit f1428e6
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions tools/gemma/export_gemma_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import contextlib
import os

import torch
Expand Down Expand Up @@ -116,6 +116,19 @@
"A path containing the vocabulary (must be a `.spm` file or equivalent). "
"If not passed, the vocabulary of the preset will be used.",
)
flags.DEFINE_string(
"dtype",
"float32",
"Set the precision of the converted checkpoint. Must be a valid PyTorch dtype.",
)


@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)


def convert_checkpoints(preset, weights_file, size, output_dir, vocab_path):
Expand Down Expand Up @@ -310,17 +323,25 @@ def flag_error_handler():
"Invalid `size`. Please pass the appropriate size (`2b` or `7b`) "
"for your model to the `--size` flag."
)
if FLAGS.dtype:
dtype = getattr(torch, FLAGS.dtype)
if not isinstance(dtype, torch.dtype):
raise ValueError(
"Invalid `dtype`. Please pass a valid PyTorch data type (e.g. "
"`float32', 'float16`, etc.) to the `--dtype` flag."
)


def main(_):
flag_error_handler()
convert_checkpoints(
FLAGS.preset,
FLAGS.weights_file,
FLAGS.size,
FLAGS.output_dir,
FLAGS.vocab_path,
)
with _set_default_tensor_type(getattr(torch, FLAGS.dtype)):
convert_checkpoints(
FLAGS.preset,
FLAGS.weights_file,
FLAGS.size,
FLAGS.output_dir,
FLAGS.vocab_path,
)


if __name__ == "__main__":
Expand Down

0 comments on commit f1428e6

Please sign in to comment.