Skip to content

Commit

Permalink
Revert "Adjust multifun"
Browse files Browse the repository at this point in the history
This reverts commit c7a2f15.

Work around lebrice/SimpleParsing#322
  • Loading branch information
JasonGross committed Aug 21, 2024
1 parent eb073c5 commit dd6b835
Showing 1 changed file with 186 additions and 13 deletions.
199 changes: 186 additions & 13 deletions gbmi/exp_multifun/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import sys
from dataclasses import dataclass, field
from functools import cache
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Literal, Optional, Sequence, Tuple, Union

import numpy as np
import simple_parsing
import torch
import torch.nn.functional as F
from jaxtyping import Bool, Float, Integer
Expand Down Expand Up @@ -595,26 +594,200 @@ def test_dataloader(self):
return DataLoader(self.data_test, batch_size=self.config.batch_size)


def main(argv: List[str] = sys.argv):
parser = simple_parsing.ArgumentParser(
def config_of_argv(argv=sys.argv) -> tuple[Config[Multifun], dict]:
parser = argparse.ArgumentParser(
description="Train a model with configurable attention rate."
)
parser.add_arguments(
Multifun, dest="experiment_config", default=MULTIFUN_OF_2_CONFIG.experiment
)
add_force_argument(parser)
add_no_save_argument(parser)
Config.add_arguments(parser, default=MULTIFUN_OF_2_CONFIG)

# add --K N argument accepting 2 and 10
parser.add_argument(
"--K",
metavar="K",
type=int,
default=10,
help="The length of the list to take the reduction of.",
)
parser.add_argument(
"--func",
metavar="FUNC",
type=str,
nargs="+",
default=["max", "min"],
help="The functions to apply to the list.",
)
parser.add_argument(
"--force-adjacent-gap",
metavar="K",
type=str,
action="append",
help="For --K 2, include all sequences (n, n±K) in training set. Accepts int and comma-separated-list.",
)
parser.add_argument(
"--training-ratio",
type=float,
default=0.7,
help="For --K 2, the fraction of sequences to include in training.",
)
parser.add_argument(
"--use-log1p",
action=argparse.BooleanOptionalAction,
default=False,
help="Use a more accurate implementation of log_softmax.",
)
parser.add_argument(
"--use-end-of-sequence",
action=argparse.BooleanOptionalAction,
default=False,
help="Use an end-of-sequence token so the query-side attention vector is fixed.",
)
parser.add_argument("--weight-decay", type=float, default=None, help="Weight decay")
parser.add_argument(
"--optimizer",
choices=["Adam", "AdamW", "SGD"],
default="Adam",
help="The optimizer to use.",
)
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
parser.add_argument(
"--betas",
type=float,
nargs=2,
default=(0.9, 0.999),
help="coefficients used for computing running averages of gradient and its square",
)
parser.add_argument(
"--summary-slug-extra", type=str, default="", help="Extra model description"
)
parser.add_argument(
"--pick-max-first",
action=argparse.BooleanOptionalAction,
default=False,
help="Pick the maximum value first, then fill in the rest of the sequence. Only meaningful for --K N > 2.",
)
parser.add_argument(
"--use-kaiming-init",
action=argparse.BooleanOptionalAction,
default=False,
help="Use torch.nn.init.kaiming_uniform_, rather than HookedTransformer's init.",
)
parser.add_argument(
"--log-matrix-interp",
action=argparse.BooleanOptionalAction,
default=False,
help="Log matrices every train step",
)
parser.add_argument(
"--checkpoint-matrix-interp",
action=argparse.BooleanOptionalAction,
default=False,
help="Log matrices for checkpointing",
)
parser.add_argument(
"--log-final-matrix-interp",
action=argparse.BooleanOptionalAction,
default=True,
help="Log matrices after training",
)
HOOKED_TRANSFORMER_CONFIG_ARGS = set(
(
"normalization_type",
"d_model",
"d_head",
"n_layers",
"n_heads",
"d_vocab",
"dtype",
"eps",
)
)
Config.add_arguments(parser)
add_HookedTransformerConfig_arguments(parser, HOOKED_TRANSFORMER_CONFIG_ARGS)
args = parser.parse_args(argv[1:])

config = Config(args.experiment_config)
config = config.update_from_args(args)
print("Model config:", MultifunTrainingWrapper.build_model(config).cfg)
config = set_params(
(MULTIFUN_OF_2_CONFIG if args.K <= 2 else MULTIFUN_OF_10_SINGLE_CONFIG),
{
("experiment", "seq_len"): args.K,
("experiment", "funcs"): tuple(args.func),
("experiment", "use_end_of_sequence"): args.use_end_of_sequence,
("experiment", "use_log1p"): args.use_log1p,
("experiment", "optimizer"): args.optimizer,
("experiment", "summary_slug_extra"): args.summary_slug_extra,
("experiment", "train_dataset_cfg", "pick_max_first"): args.pick_max_first,
("experiment", "logging_options"): ModelMatrixLoggingOptions.all(),
("experiment", "log_matrix_on_run_batch_prefixes"): set()
| ({"test_"} if args.log_final_matrix_interp else set())
| ({"periodic_test_"} if args.checkpoint_matrix_interp else set())
| ({""} if args.log_matrix_interp else set()),
},
).update_from_args(args)
config.experiment = MultifunTrainingWrapper.update_config_from_model_config(
config.experiment,
update_HookedTransformerConfig_from_args(
config,
MultifunTrainingWrapper.build_model_config(config),
args,
HOOKED_TRANSFORMER_CONFIG_ARGS,
),
)
config.experiment.__post_init__() # for seq_len, d_vocab
if args.weight_decay is not None:
config.experiment.optimizer_kwargs["weight_decay"] = args.weight_decay
config.experiment.optimizer_kwargs.update(
{"lr": args.lr, "betas": tuple(args.betas)}
)
if args.argmax_of <= 2:
if args.force_adjacent_gap:
force_adjacent = tuple(
sorted(
set(
int(k.strip())
for s in args.force_adjacent_gap
for k in s.split(",")
)
)
)
config = set_params(
config,
{
(
"experiment",
"train_dataset_cfg",
"force_adjacent",
): force_adjacent,
(
"experiment",
"test_dataset_cfg",
"force_adjacent",
): force_adjacent,
},
)
config = set_params(
config,
{
(
"experiment",
"train_dataset_cfg",
"training_ratio",
): args.training_ratio,
(
"experiment",
"test_dataset_cfg",
"training_ratio",
): args.training_ratio,
},
)
return config, dict(force=args.force, save_to=args.save_to)


def main(argv=sys.argv):
config, kwargs = config_of_argv(argv)
print("Training model:", config)
train_or_load_model(config, force=args.force, save_to=args.save_to)
return train_or_load_model(config, **kwargs)


# %%
if __name__ == "__main__":
main()

Expand Down

0 comments on commit dd6b835

Please sign in to comment.