-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
480 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
results/ | ||
results/ | ||
|
||
**.rew |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# GP_visualization.py | ||
# Written Ian Rankin December 2021 | ||
# | ||
# A set of functions to visualize the learned GPs. | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import pdb | ||
|
||
import lop | ||
|
||
def record_gp_state(model, fake_f, bounds=[(0,0),(1.5,1.5)], folder='./', \ | ||
file_header='', visualize=False): | ||
# Generate test grid | ||
num_side = 25 | ||
x = np.linspace(bounds[0][0], bounds[0][1], num_side) | ||
y = np.linspace(bounds[1][0], bounds[1][1], num_side) | ||
|
||
X, Y = np.meshgrid(x,y) | ||
pts = np.vstack([X.ravel(), Y.ravel()]).transpose() | ||
|
||
|
||
fake_ut = fake_f(pts) | ||
max_fake = np.linalg.norm(fake_ut, ord=np.inf) | ||
fake_ut = fake_ut / max_fake | ||
|
||
pred_ut, pred_sigma = model.predict_large(pts) | ||
|
||
|
||
# save useful data | ||
# save entire pickle file | ||
gp_filename = folder+file_header+'_gp.p' | ||
print(gp_filename) | ||
#pickle.dump(gp, open(gp_filename, "wb")) | ||
# Save useful visualization data | ||
viz_filename = folder+file_header+'_viz' | ||
print(viz_filename) | ||
np.savez(viz_filename, \ | ||
pts=pts, \ | ||
fake_ut=fake_ut, \ | ||
pred_ut=pred_ut, \ | ||
pred_sigma=pred_sigma, \ | ||
GP_pts=model.X_train, \ | ||
GP_pref=model.y_train, \ | ||
GP_prior_idx=model.prior_idx) | ||
|
||
|
||
if visualize == True: | ||
visualize_data(X, Y, num_side, fake_ut, pred_ut, pred_sigma, | ||
model.X_train, model.prior_idx, \ | ||
folder, file_header, \ | ||
also_display=False) | ||
|
||
|
||
def visualize_data(X,Y, num_side, fake_ut, pred_ut, pred_sigma, \ | ||
GP_pts, GP_prior_idx, \ | ||
folder='./', file_header='', also_display=False): | ||
|
||
Z_pred = np.reshape(pred_ut, (num_side, num_side)) | ||
Z_fake = np.reshape(fake_ut, (num_side, num_side)) | ||
Z_sigma = np.reshape(pred_sigma, (num_side, num_side)) | ||
Z_std = np.sqrt(Z_sigma) | ||
Z_ucb = Z_pred + Z_std | ||
|
||
plt.figure() | ||
|
||
plt.pcolor(X, Y, Z_pred) | ||
plt.contour(X, Y, Z_pred) | ||
if GP_prior_idx is None: | ||
GP_prior_idx = [0,0] | ||
if GP_pts is not None: | ||
plt.scatter(GP_pts[GP_prior_idx[1]:,0],GP_pts[GP_prior_idx[1]:,1]) | ||
else: | ||
plt.scatter([],[]) | ||
plt.title(file_header+'_active samples') | ||
plt.xlabel('Migratory fish reward') | ||
plt.ylabel('Sea floor fish reward') | ||
|
||
plt.savefig(folder+file_header+'_active_samples.jpg') | ||
|
||
plt.figure() | ||
|
||
plt.pcolor(X, Y, Z_pred) | ||
plt.contour(X, Y, Z_pred) | ||
if GP_pts is not None: | ||
plt.scatter(GP_pts[:,0],GP_pts[:,1]) | ||
else: | ||
plt.scatter([],[]) | ||
plt.title(file_header+'_with prior points') | ||
plt.xlabel('Migratory fish reward') | ||
plt.ylabel('Sea floor fish reward') | ||
plt.savefig(folder+file_header+'_with_prior.jpg') | ||
|
||
if also_display: | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
######### for active learning selectors | ||
default_to_pareto: false | ||
always_select_best: true | ||
|
||
# alpha | ||
UCB_scalar: 1.0 | ||
|
||
|
||
|
||
######## For model | ||
pareto_pairs: false | ||
|
||
rbf_sigma: 0.3 | ||
rbf_lengthscale: 0.25 | ||
|
||
hyperparameter_optimization: false | ||
|
||
# does this need to be updated for depending on the fake function being used? | ||
normalize_gp: true | ||
normalize_postive: false | ||
|
||
|
||
add_model_prior: false | ||
prior_pts: 25 | ||
prior_bounds: | ||
- | ||
- 0.0 | ||
- 1.5 | ||
- | ||
- 0.0 | ||
- 1.5 | ||
|
Oops, something went wrong.