Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zanj integration: datasets & training #177

Merged
merged 89 commits into from
Apr 28, 2023
Merged

Conversation

mivanit
Copy link
Member

@mivanit mivanit commented Apr 13, 2023

(this is a mega pr, sorry)

configs

Modifying configs from the command line is now easier!

  • ConfigHolder.get_config_multisource()
    • takes one of: config object, a file to read the config from, or a list of names to get presets for each of the sub-configs
    • a dotlist-dict to modify any parameters of the config
  • GPTDataset().to_fname() used to generate filename for saving config (and also to find matching config to load/download). MazeDatasetConfig also implements this in a custom way
  • MazeDatasetConfig now has a maze_ctor_kwargs field, for passing keyword arguments to maze generation (see Constrained depth first search #183)

maze dataset

You can now get a MazeDataset from just a config -- it will load, download, or generate a dataset on the fly. The mess of ways of storing a dataset we had before is now gone -- a MazeDataset contains a list of SolvedMaze, and it will return one of those when you call __getitem__. We also added filters and fixed some parallelization issues!

  • GPTDataset().from_config() as a new, simplified version of getting a dataset: simply pass a config, and it will attempt to load from local directory, download, or generate. any of these can be disabled, and kwargs (for things like # of cores to use) are passed down.
  • canonical representation of the dataset as list of SolvedMaze
  • mazes_objs, mazes_tokens, mazes_array are now cached properties. they will work, but might be slow due to no parallelization
  • MazeDataset.__getitem__() now returns a SolvedMaze
  • create_dataset() deprecated but should still work. remove this?
  • filtering! you can specify filters in the config under the applied_filters field, or you can call dataset.filter_by.your_filter_func(your_arg=your_val). Both of these work the same under the hood.
  • can specify in from_config() whether to run in parallel or not (default is no). this is useful since for small datasets, parallelization has huge overhead. tests are now much faster.
  • there may have been some issues to parallelization and using the same fixed seed across all processes. This was fixed in Constrained depth first search #183 , but in a hacky way

training

Models now saved as ZANJ objects, and the command line interface is improved.

  • train() now:
    • saves models as ZANJ
    • returns the trained ZanjHookedTransformer
  • train_model():
    • now returns TrainingResult which contains output path, model, and eventually logging info perhaps?
    • for config, interface inherited from ConfigHolder.get_config_multisource() and kwargs are passed as modification dict

remaining todos:

  • add seed to config
  • allow filtering a dataset by path length however you want via dataset.filter_by.some_function(**kwargs)
    • add MazeDataset().custom_maze_filter() which takes a custom function (operating on mazes) as argument -- this makes it easier to add new filters in notebooks etc
    • @register_wrap_dataset_filter wraps a function which takes dataset and kwargs, and returns a dataset. we might want to have a different decorator register_wrap_solved_maze_filter which just takes a function (m: SolvedMaze, **kwargs) -> bool which is then put inside a regular python filter() function
  • dependencies
  • tests
  • [~] [out of scope for this PR] implement downloading of maze datasets from wandb
  • speed up generation code
    • re-add parallelization in MazeDataset.generate()
    • [~] [out of scope for this PR] generally speed up / vectorize generation code
    • [~] [out of scope for this PR] if possible: JIT-compile with numba?
  • [~] [out of scope for this PR] integrating with new training code from @afspies in experiments repo
  • augmented maze generation

questions:

  • how should we handle hashes being included in GPTDataset.to_fname() ?

    resolved: no strong opinions, can change this without too much cost. including hash for now.

  • should MazeDataset.__getitem__() give a SolvedMaze, string or tokenized array?

    resolved: use SolvedMaze everywhere

Base automatically changed from add-maze-from-ascii to main April 20, 2023 20:29
mivanit and others added 13 commits April 20, 2023 14:37
…getedLatticeMaze"

there is a bug, but my fix does not fix it!
This reverts commit 88002f6.
* return SolvedMazes from dataset.__getitem__

* Move tokenization into Maze classes

* Move batch preprocessing into dataloader

* Lots of tests for datasets

* Tidy up filters a bit and allow positional args

* Speed up tests by using a non-parallel dataloader 

* integration-v1 training config renamed to test-v1
- constraint options for `gen_dfs` generation algorithm (by @canrager)
- added `maze_ctor_kwargs` to `MazeDatasetConfig` to allow setting those options
- fixed some issues arising from parallelism + fixed seed (this was hacky)
- minor things:
  - bumped muutils to 0.3.10
  - we now use `Coord` and `CoordArray` (numpy) in many places, instead of tuples/lists
  - separated `MAZE_DATASET_CONFIGS` to [maze_transformer/training/maze_dataset_configs.py](https://github.com/AISC-understanding-search/maze-transformer/pull/184/files#diff-ab008b2d4ddb7138116afef18584f657832ec00430af732f195136a63b0debaf)
  - some random junk


---------

Co-authored-by: mivanit <[email protected]>
@mivanit mivanit marked this pull request as ready for review April 28, 2023 03:08
@mivanit
Copy link
Member Author

mivanit commented Apr 28, 2023

@valedan here are the remaining problems which we need to fix before merging. Once tests pass, I think we are good to go!

  • transposing issue in baseline solver tests: tests/unit/maze_transformer/evaluation/test_baseline_models.py
    • unclear why this is only causing issues with baseline solver
    • basically, the solution the baseline solver is a valid solution but only in the transpose of the maze. this is probably happening somewhere in the tokenization and might be indicative of a larger issue of some kind, but I hope not.
  • RESOLVED: bumping pytest fixed it! 3 errors of fixture "mocker" not found in tests/unit/maze_transformer/training/test_dataset.py

Copy link
Contributor

@valedan valedan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants