Skip to content

Commit

Permalink
Added new pareto front code to properly reallocate values
Browse files Browse the repository at this point in the history
  • Loading branch information
ianran committed Jun 13, 2024
1 parent d7dd068 commit 9e5d475
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/rdml_graph/mcts/ParetoFront.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def add(self, r, n):
tmp = self.front
tmp_val = self.front_val
self.front = np.empty((self.front.shape[0]*2, self.front.shape[1]))
self.front_val = np.empty(self.size*s, dtype=np.object)
self.front_val = np.empty(self.size*2, dtype=np.object)
self.front[:self.size] = tmp
self.front_val[:self.size] = tmp_val

Expand Down
16 changes: 16 additions & 0 deletions tests/MCTS/test_pareto_front.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,19 @@ def test_pareto_function():
assert pr_idxs[i] == ans[i]


def test_pareto_function_with_realloc():
front = gr.ParetoFront(3, alloc_size=2)



rewards = np.array([[3,4,5], [2, 3,4], [5,2,1], [3,4,6], [3, 2, 1], [2, 7,2]])

for i in range(rewards.shape[0]):
front.check_and_add(rewards[i], i)

ans = [2,3,5]
pr_vals, pr_idxs = front.get()
for i in range(len(ans)):
assert pr_idxs[i] == ans[i]


0 comments on commit 9e5d475

Please sign in to comment.