Skip to content

Commit

Permalink
Added linear preference model and short set of tests running
Browse files Browse the repository at this point in the history
  • Loading branch information
ianran committed Dec 12, 2023
1 parent d265959 commit 1726151
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/lop/models/PreferenceGP.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, cov_func, normalize_gp=True, pareto_pairs=False, \

self.delta_f = 0.0002 # set the convergence to stop
self.maxloops = 100


def optimize(self, optimize_hyperparameter=False):
if optimize_hyperparameter:
Expand Down Expand Up @@ -138,6 +139,10 @@ def predict(self, X):
#
# @return an array of output values (n), other output data (variance, covariance,etc)
def predict_large(self,X):
# lazy optimization of GP
if not self.optimized:
self.optimize(optimize_hyperparameter=self.use_hyper_optimization)

num_at_a_time = 15

num_runs = int(math.ceil(X.shape[0] / num_at_a_time))
Expand Down Expand Up @@ -249,6 +254,7 @@ def findMode(self, x_train, y_train, debug=False):
if debug:
print('Optimization ran for: '+str(n_loops))

self.n_loops = n_loops

self.F = F
# calculate W with final F
Expand Down
141 changes: 141 additions & 0 deletions src/lop/models/PreferenceLinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2023 Ian Rankin
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this
# software and associated documentation files (the "Software"), to deal in the Software
# without restriction, including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
# to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

# PreferenceLinear.py
# Written Ian Rankin - November 2023
#
# A linear latent function to learn the given preferences.

import numpy as np

from lop.models import PreferenceModel

class PreferenceLinear(PreferenceModel):
## init function
# @param pareto_pairs - [opt] specifies whether to consider pareto optimality
# @param other_probits - [opt] allows specification of additional probits
# @param mat_int - [opt] allows specification of different matrix inversion functions
# defaults to the numpy.linalg.pinv invert function
# @param active_learner - defines if there is an active learner for this model
def __init__(self, pareto_pairs=False, other_probits={},
mat_inv=np.linalg.pinv, active_learner=None):
super(PreferenceLinear, self).__init__(pareto_pairs, other_probits,active_learner)

self.mat_inv = mat_inv
self.delta_f = 0.002
self.maxloops = 100


## Predicts the output of the linear model at new locations
# @param X - the input test samples (n,k).
#
# @return an array of output values (n)
def predict(self, X):
# lazy optimization of GP
if not self.optimized:
self.optimize()

F = (X @ self.w[:,np.newaxis])[:,0]
return F, None



## optimize
# Runs the optimization step required by the user preference GP.
# @param optimize_hyperparameter - [opt] sets whether to optimize the hyperparameters
def optimize(self, optimize_hyperparameter=False):
if len(self.X_train.shape) > 1:
w = np.random.random(self.X_train.shape[1])
w = w / np.linalg.norm(w, ord=2)
else:
print('Only 1 reward parameter... linear model practically does not make sense')
raise Exception("PreferenceLinear can't optimize a single reward value (just scales it)")
#self.w = np.random.random(1)



# damped newton method
w_err = self.delta_f + 1
n_loops = 0
while w_err > self.delta_f and n_loops < self.maxloops:
# damped newton update (to find max, hence plus sign rather than negative sign.)
W, dpy_dw, py = self.derivatives(self.X_train, self.y_train, w)

gradient = dpy_dw
hess = -W

w_new = self.newton_update(w, # input value to change
gradient=gradient,
hess=hess,
lambda_type="binary", # sets the type of line search ("static", "binary", "iter")
line_search_max_itr=5)

# normalize the weights
w_new = w_new / np.linalg.norm(w_new, ord=2)

# measure error convergence
w_err = np.linalg.norm(w_new - w, ord=2)

w = w_new
n_loops += 1

self.n_loops = n_loops
self.w = w
self.optimized = True


########## Helper functions

## derivatives
# Calculates the derivatives for all of the given probits.
# @param y - the given set of labels for the probit
# this is given as a list of [(dk, u, v), ...]
# @param F - the input data samples
#
# @return - W, dpy_df, py
# W - is the second order derivative of the probit with respect to F
# dpy_df - the derivative of log P(y|x,theta) with respect to F
# py - log P(y|x,theta) for the given probit
def derivatives(self, x, y, w):
F = (x @ w[:,np.newaxis])[:,0]

W = np.zeros((len(F), len(F)))
grad_ll = np.zeros(len(F))
log_likelihood = 0
for j, probit in enumerate(self.probits):
if self.y_train[j] is not None:
W_local, dpy_df_local, py_local = probit.derivatives(y[j], F)

W += W_local
grad_ll += dpy_df_local
log_likelihood += py_local


# need to multiply by derivative of dl/df * df/dw
grad_ll = (grad_ll[np.newaxis,:] @ x)[0]
W = x.T @ W @ x

return W, grad_ll, log_likelihood


## calculates the loss function of the log liklihood with prior
# this is equation (139)
# @param w - the weights of the function
def loss_func(self, w):
F = (self.X_train @ w[:,np.newaxis])[:,0]
return -self.log_likelyhood_training(F)
3 changes: 1 addition & 2 deletions src/lop/models/PreferenceModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, pareto_pairs=False, other_probits={}, active_learner=None):
self.X_train = None

self.prior_idx = None
self.n_loops = -1


## add_prior
Expand Down Expand Up @@ -328,8 +329,6 @@ def newton_update(self, F, gradient, hess,
else:
raise ValueError("newton update given bad lambda type of: " + str(lambda_type))

print("\t lambda = " + str(lamb))

F_new = F - lamb * descent
return F_new

Expand Down
1 change: 1 addition & 0 deletions src/lop/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .PreferenceModel import PreferenceModel
from .GP import GP
from .PreferenceGP import PreferenceGP
from .PreferenceLinear import PreferenceLinear
58 changes: 58 additions & 0 deletions tests/models/test_preference_GP.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
def f_sin(x, data=None):
return 2 * np.cos(np.pi * (x-2)) * np.exp(-(0.9*x))

def f_sq(x, data=None):
return (x/10.0)**2

def test_pref_GP_construction():
gp = lop.PreferenceGP(lop.RBF_kern(1.0, 1.0))

Expand Down Expand Up @@ -69,3 +72,58 @@ def test_pref_GP_function():
assert y[1] < y[i]


def test_pref_GP_abs_bound():
gp = lop.PreferenceGP(lop.RBF_kern(1.0, 0.7), normalize_positive=True)

X_train = np.array([0.2,1.5,2.3,3.2,4.2,6.2,7.3])
y_train = f_sq(X_train)

gp.add(X_train, y_train, type='abs')

assert gp is not None

gp.optimize()

assert gp is not None
assert gp.optimized
assert gp.n_loops > 0 and gp.n_loops < 90

X = np.array([1.5,1.7,3.2])
y = gp(X)

assert isinstance(y, np.ndarray)
assert not np.isnan(y).any()

def test_pref_GP_multiple_probits_does_not_crash():
gp = lop.PreferenceGP(lop.RBF_kern(1.0, 0.7), normalize_positive=True)

X_train = np.array([0,1,2,3,4.2,6,7])
pairs = lop.generate_fake_pairs(X_train, f_sq, 0) + \
lop.generate_fake_pairs(X_train, f_sq, 1) + \
lop.generate_fake_pairs(X_train, f_sq, 2) + \
lop.generate_fake_pairs(X_train, f_sq, 3) + \
lop.generate_fake_pairs(X_train, f_sq, 4)


gp.add(X_train, pairs)

X_train = np.array([0.2,1.5,2.3,3.2,4.2,6.2,7.3])
y_train = f_sq(X_train)

gp.add(X_train, y_train, type='abs')

assert gp is not None

gp.optimize()

assert gp is not None
assert gp.optimized
assert gp.n_loops > 0 and gp.n_loops < 90

X = np.array([1.5,1.7,3.2])
y = gp(X)

assert isinstance(y, np.ndarray)
assert not np.isnan(y).any()


52 changes: 52 additions & 0 deletions tests/models/test_preference_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# test_GP.py
# Written Ian Rankin - December 2023
#
#

import pytest
import lop

import numpy as np


def f_sq(x, data=None):
return (x/10.0)**2

def f_sin(x, data=None):
return 2 * np.cos(np.pi * (x-2)) * np.exp(-(0.9*x))

def f_lin(x, data=None):
#return x[:,0]*x[:,1]
return x[:,0]+x[:,1]

def test_pref_linear_construction():
gp = lop.PreferenceLinear(lop.RBF_kern(1.0, 1.0))

assert gp is not None


def test_pref_linear_function():
pm = lop.PreferenceLinear()

X_train = np.array([[0,0],[1,2],[2,4],[3,2],[4.2, 5.6],[6,2],[7,8]])
pairs = lop.generate_fake_pairs(X_train, f_lin, 0) + \
lop.generate_fake_pairs(X_train, f_lin, 1) + \
lop.generate_fake_pairs(X_train, f_lin, 2) + \
lop.generate_fake_pairs(X_train, f_lin, 3) + \
lop.generate_fake_pairs(X_train, f_lin, 4)

pm.add(X_train, pairs)
pm.optimize()

assert pm is not None
assert pm.optimized
assert pm.n_loops > 0 and pm.n_loops < 90

y, _ = pm.predict(X_train)

for i in range(len(X_train)):
if i != 6:
assert y[6] > y[i]
if i!= 0:
assert y[0] < y[i]

0 comments on commit 1726151

Please sign in to comment.