Skip to content

Commit

Permalink
Added a few UCB learning tests. Need to fix that it is failing test. …
Browse files Browse the repository at this point in the history
…oops!
  • Loading branch information
ianran committed Dec 19, 2023
1 parent 02cf18a commit 98417a3
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/lop/active_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# init the active_learning subfolder

from .ActiveLearner import ActiveLearner
from .BestLearner import BestLearner, WorstLearner
from .BestLearner import BestLearner, WorstLearner
from .UCBLearner import UCBLearner
23 changes: 20 additions & 3 deletions src/lop/models/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,26 @@ def select_best(self, candidate_pts):

## select
# This function calls the active learner and specifies the number of alternatives to select
def select(self, candidate_pts, num_alts, prefer_num=-1):
return self.active_learner.select(candidate_pts, num_alts, prefer_num)

# A wrapper around calling model.active_learner.select
# @param candidate_pts - a numpy array of points (nxk), n = number points, k = number of dimmensions
# @param num_alts - the number of alterantives to selec (including the highest mean)
# @param prev_selection - [opt, default = []]a list of indicies that
# @param prefer_num - [default = None] the points at the start of the candidates
# to prefer selecting from. Returned as:
# a. A number of points at the start of canididate_pts to prefer
# b. A set of points to prefer to select.
# c. 'pareto' to indicate
# d. Enter 0 explicitly ignore selections
# e. None (default) assumes 0 unless default to pareto is true.
# @param return_not_selected - [opt default-false] returns the not selected points when there
# a preference to selecting to certian points. [] if not but set to true.
#
#
# @return [highest_mean, highest_selection, next highest selection, ...],
# selection values for candidate_pts,
# only returns highest mean if "always select best is set"
def select(self, candidate_pts, num_alts, prev_selection=[], prefer_pts=None, return_not_selected=False):
return self.active_learner.select(candidate_pts, num_alts, prev_selection, prefer_pts, return_not_selected)

## SimplelestModel
# A simple model that just outputs exactly it's input no matter what
Expand Down
40 changes: 38 additions & 2 deletions tests/active_learning/test_UCB_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,41 @@
import lop


def test_UCB_learner():
assert False

def f_sin(x, data=None):
x = 6-x
return 2 * np.cos(np.pi * (x[:]-2)) * np.exp(-(0.9*x[:]))


def test_UCB_learner_constructs():
al = lop.UCBLearner()
model = lop.Model(active_learner=al)

assert isinstance(al, lop.UCBLearner)
assert isinstance(model, lop.Model)

def test_UCB_learner_trains_basic_GP():
al = lop.UCBLearner()
model = lop.GP(lop.RBF_kern(0.5,1.0), active_learner=al)


np.random.seed(5) # just to ensure it doesn't break the test on a bad dice roll
for i in range(10):
# generate random test set to select test point from
x_canidiates = np.random.random(20)*10

test_pt_idxs = model.select(x_canidiates, 2)


x_train = x_canidiates[test_pt_idxs]
y_train = f_sin(x_train)

model.add(x_train, y_train)


x_test = np.array([0,1,2,3,4.5,7,9])
y_test = f_sin(x_test)
y_pred = model(x_test)

assert (np.abs(y_pred - y_test) < 2.0).all()

9 changes: 9 additions & 0 deletions tests/active_learning/test_active_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def test_active_learning_select_best():
with pytest.raises(Exception):
best_idx = al.select_best(pts,{0,4,5}, {0,1,2,3,4,5,6})

def test_model_calls_select_function_correctly():
al = lop.BestLearner()
model = lop.SimplelestModel(active_learner=al)

idxs = model.select(np.array([0,1,2,4,4.6]), 2)

assert len(idxs) == 2
assert model is not None




Expand Down

0 comments on commit 98417a3

Please sign in to comment.