Skip to content

Commit

Permalink
Implement the unwarp method of half rank component warper. This imple…
Browse files Browse the repository at this point in the history
…mentation mostly follows cpp output unwarping.

PiperOrigin-RevId: 520694297
  • Loading branch information
SetarehAr authored and copybara-github committed Mar 30, 2023
1 parent cc239b5 commit 53a703b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 24 deletions.
83 changes: 68 additions & 15 deletions vizier/_src/algorithms/designers/gp/output_warpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,26 @@
from tensorflow_probability.substrates import jax as tfp


def _validate_and_deepcopy(labels_arr: chex.Array) -> chex.Array:
def _validate_labels(
labels_arr: chex.Array, warping: bool = True
) -> chex.Array:
"""Checks and modifies the shape and values of the labels."""
labels_arr = labels_arr.astype(float)
labels_arr_copy = copy.deepcopy(labels_arr)
if not (labels_arr.ndim == 2 and labels_arr.shape[-1] == 1):
raise ValueError('Labels need to be an array of shape (num_points, 1).')
if np.isposinf(labels_arr).any():
raise ValueError('Inifinity metric value is not valid.')
if np.isneginf(labels_arr).any():
labels_arr_copy[np.isneginf(labels_arr)] = np.nan
return labels_arr_copy
labels_arr[np.isneginf(labels_arr)] = np.nan
if (
np.unique(labels_arr[np.isfinite(labels_arr).flatten(), :]).size <= 1
and np.isnan(labels_arr).sum() == 0
) and warping:
raise ValueError(
'Labels need to include at least two finite unique value in the absence'
' of infeaible points.'
)
return labels_arr


class OutputWarper(abc.ABC):
Expand Down Expand Up @@ -103,7 +112,9 @@ def warp(self, labels_arr: chex.Array) -> chex.Array:
Returns:
(num_points, 1) shaped array of warped labels.
"""
labels_arr = _validate_and_deepcopy(labels_arr)
labels_arr = copy.deepcopy(labels_arr)
if np.isneginf(labels_arr).any():
labels_arr[np.isneginf(labels_arr)] = np.nan
if np.isfinite(labels_arr).all() and len(
np.unique(labels_arr).flatten()) == 1:
return np.zeros(labels_arr.shape)
Expand All @@ -127,7 +138,7 @@ def unwarp(self, labels_arr: chex.Array) -> chex.Array:
Returns:
(num_points, 1) shaped array of unwarped labels.
"""
labels_arr = _validate_and_deepcopy(labels_arr)
labels_arr = copy.deepcopy(labels_arr)
if (
np.isfinite(labels_arr).all()
and len(np.unique(labels_arr).flatten()) == 1
Expand Down Expand Up @@ -194,6 +205,10 @@ class HalfRankComponent(OutputWarper):
Note that this warping is performed on finite values of the array and NaNs are
untouched.
"""
_median = Optional[float] = attr.field(default=None)
_stddev = Optional[float] = attr.field(default=None)
_dedup_median_index = Optional[int] = attr.field(default=None)
_unique_labels = Optional[chex.Array] = attr.field(default=None)

def _estimate_std_of_good_half(
self, unique_labels: chex.Array, threshold: float
Expand Down Expand Up @@ -221,16 +236,20 @@ def _estimate_std_of_good_half(

def warp(self, labels_arr: chex.Array) -> chex.Array:
"""See base class."""
labels_arr = _validate_and_deepcopy(labels_arr)
labels_arr = _validate_labels(labels_arr)
if labels_arr.size == 1:
return labels_arr
labels_arr = labels_arr.flatten()
# Compute median, unique labels, and ranks.
median = np.nanmedian(labels_arr)
self._median = median
self._stddev = np.nanstd(labels_arr)
unique_labels = np.unique(labels_arr[np.isfinite(labels_arr)])
self._unique_labels = unique_labels
ranks = stats.rankdata(labels_arr, method='dense') # nans ranked last.

dedup_median_index = unique_labels.searchsorted(median, 'left')
self._dedup_median_index = dedup_median_index
denominator = dedup_median_index + (unique_labels[dedup_median_index]
== median) * .5
estimated_std = self._estimate_std_of_good_half(unique_labels, median)
Expand All @@ -248,9 +267,42 @@ def warp(self, labels_arr: chex.Array) -> chex.Array:
return np.reshape(labels_arr, [-1, 1])

def unwarp(self, labels_arr: chex.Array) -> chex.Array:
raise NotImplementedError(
'unwarp method for HalfRankComponent is not implemented yet.'
labels_arr = _validate_labels(labels_arr, warping=False)
if np.isnan(labels_arr).any():
raise ValueError('Array passed to unwarp cannot include nans.')
if self._dedup_median_index == 0:
return self._median + self._stddev * labels_arr
labels_arr[labels_arr >= 0.0] = (
self._median + self._stddev * labels_arr[labels_arr >= 0.0]
)
rank_bad = np.array(
[
2 * stats.norm.cdf(y) * (self._dedup_median_index + 0.5) - 0.5
for y in labels_arr[labels_arr < 0.0]
]
)
if (rank_bad < -0.5).any() or (
rank_bad > 1.0001 * self._dedup_median_index
).any():
raise ValueError('Rank needs to be within [-0.5, 1.0001 * median-index].')
labels_bad = np.ones(labels_arr[labels_arr < 0.0].shape)
scale = self._stddev + self._median - np.min(self._unique_labels)
if scale < 0.0:
raise ValueError('Scale needs to be non-negative.')
r_ints, r_fracs = divmod(rank_bad[rank_bad >= 0.0], 1)
labels_bad[rank_bad >= 0.0] = np.array(
[
self._unique_labels(int(r_int)) * (1 - r_frac)
+ (self._unique_labels(int(r_int) + 1) * r_frac)
for r_int, r_frac in zip(r_ints, r_fracs)
]
)
labels_bad[rank_bad < 0.0] = (
np.min(self._unique_labels) + scale * rank_bad[rank_bad < 0.0]
)

labels_arr[labels_arr < 0.0] = labels_bad
return labels_arr


@attr.define
Expand All @@ -267,7 +319,7 @@ class LogWarperComponent(OutputWarper):

def warp(self, labels_arr: chex.Array) -> chex.Array:
"""See base class."""
labels_arr = _validate_and_deepcopy(labels_arr)
labels_arr = _validate_labels(labels_arr)
self._labels_min = np.nanmin(labels_arr)
self._labels_max = np.nanmax(labels_arr)
labels_arr = labels_arr.flatten()
Expand Down Expand Up @@ -303,7 +355,7 @@ class InfeasibleWarperComponent(OutputWarper):
"""Warps the infeasible/nan value to feasible/finite values."""

def warp(self, labels_arr: chex.Array) -> chex.Array:
labels_arr = _validate_and_deepcopy(labels_arr)
labels_arr = _validate_labels(labels_arr)
labels_arr = labels_arr.flatten()
labels_range = np.nanmax(labels_arr) - np.nanmin(labels_arr)
warped_bad_value = np.nanmin(labels_arr) - (0.5 * labels_range + 1)
Expand All @@ -328,7 +380,7 @@ def warp(self, labels_arr: chex.Array) -> chex.Array:
Returns:
(num_points, 1) shaped array of standardize labels.
"""
labels_arr = _validate_and_deepcopy(labels_arr)
labels_arr = _validate_labels(labels_arr)
if np.isnan(labels_arr).all():
raise ValueError('Labels need to have at least one non-NaN entry.')
labels_finite_ind = np.isfinite(labels_arr)
Expand Down Expand Up @@ -360,7 +412,8 @@ def warp(self, labels_arr: chex.Array) -> chex.Array:
Returns:
(num_points, 1) shaped array of normalized labels.
"""
labels_arr = _validate_and_deepcopy(labels_arr)
labels_arr = _validate_labels(labels_arr)

if np.isnan(labels_arr).all():
raise ValueError('Labels need to have at least one non-NaN entry.')
if np.nanmax(labels_arr) == np.nanmax(labels_arr):
Expand Down Expand Up @@ -451,7 +504,7 @@ def _estimate_variance(self, labels_arr: chex.Array) -> float:
(4 * num_points))**2)

def warp(self, labels_arr: chex.Array) -> chex.Array:
labels_arr = _validate_and_deepcopy(labels_arr)
labels_arr = _validate_labels(labels_arr)
labels_finite_ind = np.isfinite(labels_arr)
labels_arr_finite = labels_arr[labels_finite_ind]
labels_median = np.median(labels_arr_finite)
Expand Down Expand Up @@ -496,7 +549,7 @@ def __init__(
self.use_rank = use_rank

def warp(self, labels_arr: chex.Array) -> chex.Array:
labels_arr = _validate_and_deepcopy(labels_arr)
labels_arr = _validate_labels(labels_arr)
labels_arr = np.asarray(labels_arr, dtype=np.float64)
labels_arr_flattened = labels_arr.flatten()
if self.use_rank:
Expand Down
24 changes: 15 additions & 9 deletions vizier/_src/algorithms/designers/gp/output_warpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,6 @@ def warper(self) -> OutputWarper:
def always_maps_to_finite(self) -> bool:
return True

def test_all_nonfinite_labels(self):
labels_infeaible = np.array([[-np.inf], [np.nan], [np.nan], [-np.inf]])
self.assertTrue(
(
self.warper.warp(labels_infeaible)
== -1 * np.ones(shape=labels_infeaible.shape).flatten()
).all()
)

@parameterized.parameters([
dict(labels=np.zeros(shape=(5, 1))),
dict(labels=np.ones(shape=(5, 1))),
Expand Down Expand Up @@ -376,5 +367,20 @@ def test_known_arrays(self):
# TODO: Add a couple of parameterized test cases.
self.skipTest('No test cases provided')


class OutputWarperPipelineTest(absltest.TestCase):
"""Tests the default outpur warper edge cases."""

def test_all_nonfinite_labels(self):
warper = output_warpers.OutputWarperPipeline()
labels_infeaible = np.array([[-np.inf], [np.nan], [np.nan], [-np.inf]])
self.assertTrue(
(
warper.warp(labels_infeaible)
== -1 * np.ones(shape=labels_infeaible.shape).flatten()
).all()
)


if __name__ == '__main__':
absltest.main()

0 comments on commit 53a703b

Please sign in to comment.