Skip to content

Commit

Permalink
[Feat] data download/generation docs #1
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Jul 19, 2024
1 parent da68b76 commit cb0ab1e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 69 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ If you would like to install all dependencies including optional solvers, please

We recommend exploring [this quickstart notebook](examples/1.quickstart.ipynb) to get started with the `RouteFinder` codebase!


### Generating Data

Data may be generated by running the following command:

```bash
python routefinder/data/generate_data.py
```
and will be saved under the `data/` directory.

Note that we provide the original testing data since the data may differ slightly across devices due to PyTorch's random number generator. The distribution will however be the same, so results should be comparable. To ensure full reproducibility and make sure the data is exactly the same, you may use the uploaded files under the `data/` folder.


### Running

The main runner (example here of main baseline) can be called via:
Expand All @@ -50,6 +63,8 @@ python run.py experiment=main/rf/rf-100
You may change the experiment by using the `experiment=YOUR_EXP`, with the path under [`configs/experiment`](configs/experiment) directory.




## 🚚 Available Environments

<div align="center">
Expand Down
40 changes: 29 additions & 11 deletions routefinder/data/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@
folder = "data/"


def generate(num_loc, num_data, variant, phase="val"):
filename = f"{variant}/{phase}/{num_loc}.npz"
def generate(num_loc, num_data, variant, phase="val", mixed=False):
if mixed:
# variant mb: find "b", insert "m" before i
new_variant = variant[: variant.find("b")] + "m" + variant[variant.find("b") :]
filename = f"{new_variant}/{phase}/{num_loc}.npz"
backhaul_class = 2
else:
filename = f"{variant}/{phase}/{num_loc}.npz"
backhaul_class = 1

path = os.path.join(folder, filename)
os.makedirs(os.path.dirname(path), exist_ok=True)

generator = MTVRPGenerator(num_loc=num_loc, variant_preset=variant)
generator = MTVRPGenerator(
num_loc=num_loc, variant_preset=variant, backhaul_class=backhaul_class
)
env = MTVRPEnv(generator, check_solution=False)
td_data = env.generator(num_data)

Expand All @@ -24,24 +33,33 @@ def generate(num_loc, num_data, variant, phase="val"):


def main():
# validation has less data for faster training
for variant in MTVRPGenerator.available_variants():
# Validation (less data for faster training)
generate(50, 128, variant, phase="val")
generate(100, 128, variant, phase="val")
generate(200, 128, variant, phase="val")

# Test
generate(50, 1000, variant, phase="test")
generate(100, 1000, variant, phase="test")
generate(200, 1000, variant, phase="test")
generate(500, 128, variant, phase="test")
generate(1000, 128, variant, phase="test")

# mixed variants: if not contains "b", skip
if "b" not in variant:
continue
else:
generate(50, 128, variant, phase="val", mixed=True)
generate(100, 128, variant, phase="val", mixed=True)
generate(200, 128, variant, phase="val", mixed=True)
generate(50, 1000, variant, phase="test", mixed=True)
generate(100, 1000, variant, phase="test", mixed=True)
generate(200, 1000, variant, phase="test", mixed=True)


if __name__ == "__main__":
input(
"WARNING: you should not generate the dataset but download it from Github"
" since generation results are not reproducible across devices. Press Enter to continue anyways."
"Warning: generated data may differ slightly across devices due to PyTorch's random number generator. "
"The distribution will however be the same, so results should be comparable. "
"To ensure full reproducibility and make sure the data is exactly the same, you may use the uploaded files under the data/ folder."
"Note that this will overwrite any existing datasets. Press Enter to confirm."
)

main()
58 changes: 0 additions & 58 deletions routefinder/data/generate_data_mb.py

This file was deleted.

0 comments on commit cb0ab1e

Please sign in to comment.