Skip to content

Commit

Permalink
implement generate_atsp_data
Browse files Browse the repository at this point in the history
  • Loading branch information
abcdhhhh committed Oct 21, 2024
1 parent b0e4aa0 commit c024697
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions rl4co/data/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"op": ["const", "unif", "dist"],
"mdpp": [None],
"pdp": [None],
"atsp": [None]
}


Expand Down Expand Up @@ -212,6 +213,16 @@ def generate_mdpp_data(
"action_mask": available.astype(bool),
}

def generate_atsp_data(dataset_size, atsp_size, tmat_class: bool = True):
cost_matrix = np.random.uniform(size=(dataset_size, atsp_size, atsp_size))
cost_matrix[..., np.arange(atsp_size), np.arange(atsp_size)] = 0
if tmat_class:
for i in range(atsp_size):
cost_matrix = np.minimum(cost_matrix, cost_matrix[..., :, [i]] + cost_matrix[..., [i], :])
return {
"cost_matrix": cost_matrix.astype(np.float32)
}


def generate_dataset(
filename: Union[str, List[str]] = None,
Expand Down

0 comments on commit c024697

Please sign in to comment.