Skip to content

Commit

Permalink
Remove some unnecessary tf.casts (#829)
Browse files Browse the repository at this point in the history
  • Loading branch information
uri-granta authored Mar 14, 2024
1 parent 45b230d commit d320e3e
Show file tree
Hide file tree
Showing 18 changed files with 140 additions and 59 deletions.
27 changes: 26 additions & 1 deletion tests/unit/models/gpflow/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import math
import unittest
from typing import Any, Callable, List, Tuple, Type
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import gpflow
import numpy.testing as npt
Expand All @@ -26,6 +26,7 @@
from check_shapes.exceptions import ShapeMismatchError
from scipy import stats

from tests.unit.models.gpflow.test_interface import _QuadraticPredictor
from tests.util.misc import TF_DEBUGGING_ERROR_TYPES, ShapeLike, quadratic, random_seed
from tests.util.models.gpflow.models import (
GaussianProcess,
Expand Down Expand Up @@ -57,6 +58,7 @@
SupportsPredictJoint,
)
from trieste.objectives import Branin
from trieste.types import TensorType

REPARAMETRIZATION_SAMPLERS: List[Type[ReparametrizationSampler[SupportsPredictJoint]]] = [
BatchReparametrizationSampler,
Expand Down Expand Up @@ -350,6 +352,29 @@ def test_batch_reparametrization_sampler_samples_are_repeatable(qmc: bool, qmc_s
npt.assert_allclose(sampler.sample(xs), sampler.sample(xs))


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("qmc_skip", [True, False])
@pytest.mark.parametrize("dtype", [tf.float32, tf.float64])
def test_batch_reparametrization_sampler_doesnt_cast(
qmc: bool, qmc_skip: bool, dtype: tf.DType
) -> None:
sampler = BatchReparametrizationSampler(100, _QuadraticPredictor(), qmc=qmc, qmc_skip=qmc_skip)
xs = tf.random.uniform([3, 1, 7, 7], dtype=dtype)

original_tf_cast = tf.cast

def patched_tf_cast(x: TensorType, dtype: tf.DType) -> TensorType:
# ensure there are no unnecessary casts from float64 to float32 or vice versa
if isinstance(x, tf.Tensor) and x.dtype in (tf.float32, tf.float64) and x.dtype != dtype:
raise ValueError(f"unexpected cast: {x} to {dtype}")
return original_tf_cast(x, dtype)

with patch("tensorflow.cast", side_effect=patched_tf_cast):
samples = sampler.sample(xs)
assert samples.dtype is dtype
npt.assert_allclose(samples, sampler.sample(xs))


@pytest.mark.parametrize("qmc", [True, False])
@pytest.mark.parametrize("qmc_skip", [True, False])
def test_batch_reparametrization_sampler_different_batch_sizes(qmc: bool, qmc_skip: bool) -> None:
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/models/gpflux/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from __future__ import annotations

from typing import Callable, Tuple
from unittest.mock import patch

import gpflow.kernels
import gpflux.layers
Expand Down Expand Up @@ -151,7 +152,17 @@ def test_dgp_reparam_sampler_sample_is_repeatable(

sampler = DeepGaussianProcessReparamSampler(100, model)
xs = tf.random.uniform([100, 2], minval=-10.0, maxval=10.0, dtype=tf.float64)[:, None, :]
npt.assert_allclose(sampler.sample(xs), sampler.sample(xs))

# also check there are no unnecessary casts from float64 to float32 or vice versa
original_tf_cast = tf.cast

def patched_tf_cast(x: TensorType, dtype: tf.DType) -> TensorType:
if isinstance(x, tf.Tensor) and x.dtype in (tf.float32, tf.float64) and x.dtype != dtype:
raise ValueError(f"unexpected cast: {x} to {dtype}")
return original_tf_cast(x, dtype)

with patch("tensorflow.cast", side_effect=patched_tf_cast):
npt.assert_allclose(sampler.sample(xs), sampler.sample(xs))


@random_seed
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/models/keras/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,8 @@ def test_deep_ensemble_prepare_data_call(
bootstrap_data: bool,
) -> None:
n_rows = 100
x = tf.constant(np.arange(0, n_rows, 1), shape=[n_rows, 1])
y = tf.constant(np.arange(0, n_rows, 1), shape=[n_rows, 1])
x = tf.constant(np.arange(0, n_rows, 1), shape=[n_rows, 1], dtype=tf.float32)
y = tf.constant(np.arange(0, n_rows, 1), shape=[n_rows, 1], dtype=tf.float32)
example_data = Dataset(x, y)

model, _, _ = trieste_deep_ensemble_model(example_data, ensemble_size, bootstrap_data, False)
Expand Down
32 changes: 24 additions & 8 deletions tests/unit/models/keras/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import random
from typing import Any, Callable, Optional, cast
from unittest.mock import patch

import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -55,26 +56,41 @@ def _num_outputs_fixture(request: Any) -> int:
return request.param


@pytest.mark.parametrize(
"dtype", [pytest.param(tf.float64, id="float64"), pytest.param(tf.float32, id="float32")]
)
def test_ensemble_trajectory_sampler_returns_trajectory_function_with_correctly_shaped_output(
num_evals: int,
batch_size: int,
dim: int,
diversify: bool,
num_outputs: int,
dtype: tf.DType,
) -> None:
"""
Inputs should be [N,B,d] while output should be [N,B,M]. Note that for diversify
option only single output models are allowed.
"""
example_data = empty_dataset([dim], [num_outputs])
test_data = tf.random.uniform([num_evals, batch_size, dim]) # [N, B, d]
example_data = empty_dataset([dim], [num_outputs], dtype)
test_data = tf.random.uniform([num_evals, batch_size, dim], dtype=dtype) # [N, B, d]

model, _, _ = trieste_deep_ensemble_model(example_data, _ENSEMBLE_SIZE)

sampler = DeepEnsembleTrajectorySampler(model, diversify=diversify)
trajectory = sampler.get_trajectory()

assert trajectory(test_data).shape == (num_evals, batch_size, num_outputs)
original_tf_cast = tf.cast

def patched_tf_cast(x: TensorType, dtype: tf.DType) -> TensorType:
# ensure there are no unnecessary casts from float64 to float32 or vice versa
if isinstance(x, tf.Tensor) and x.dtype in (tf.float32, tf.float64) and x.dtype != dtype:
raise ValueError(f"unexpected cast: {x} to {dtype}")
return original_tf_cast(x, dtype)

with patch("tensorflow.cast", side_effect=patched_tf_cast):
samples = trajectory(test_data)
assert samples.dtype == dtype
assert samples.shape == (num_evals, batch_size, num_outputs)


def test_ensemble_trajectory_sampler_returns_deterministic_trajectory(
Expand Down Expand Up @@ -223,7 +239,7 @@ def test_ensemble_trajectory_sampler_eps_broadcasted_correctly() -> None:
"""
We check if eps are broadcasted correctly in diversify mode.
"""
example_data = empty_dataset([1], [1])
example_data = empty_dataset([1], [1], tf.float32)
test_data = tf.linspace([-10.0], [10.0], 100)
test_data = tf.expand_dims(test_data, -2) # [N, 1, d]
test_data = tf.tile(test_data, [1, 2, 1]) # [N, 2, D]
Expand All @@ -234,7 +250,7 @@ def test_ensemble_trajectory_sampler_eps_broadcasted_correctly() -> None:
trajectory = trajectory_sampler.get_trajectory()

_ = trajectory(test_data) # first call needed to initialize the state
trajectory._eps.assign(tf.constant([[0], [1]], dtype=tf.float64)) # type: ignore
trajectory._eps.assign(tf.constant([[0], [1]], dtype=tf.float32)) # type: ignore
evals = trajectory(test_data)

npt.assert_array_less(
Expand Down Expand Up @@ -354,7 +370,7 @@ def test_ensemble_trajectory_sampler_update_trajectory_updates_and_doesnt_retrac
batch_size = 2
num_data = 100

example_data = empty_dataset([dim], [1])
example_data = empty_dataset([dim], [1], tf.float32)
test_data = tf.random.uniform([num_data, batch_size, dim]) # [N, B, d]

model, _, _ = trieste_deep_ensemble_model(example_data, _ENSEMBLE_SIZE)
Expand Down Expand Up @@ -434,7 +450,7 @@ def test_ensemble_trajectory_sampler_returns_state(batch_size: int, diversify: b
dim = 3
num_evals = 10

example_data = empty_dataset([dim], [1])
example_data = empty_dataset([dim], [1], tf.float32)
test_data = tf.random.uniform([num_evals, batch_size, dim]) # [N, B, d]

model, _, _ = trieste_deep_ensemble_model(example_data, _ENSEMBLE_SIZE)
Expand All @@ -443,7 +459,7 @@ def test_ensemble_trajectory_sampler_returns_state(batch_size: int, diversify: b
trajectory = cast(deep_ensemble_trajectory, sampler.get_trajectory())

if diversify:
dtype = tf.float64
dtype = tf.float32
rnd_state_name = "eps"
else:
dtype = tf.int32
Expand Down
9 changes: 6 additions & 3 deletions tests/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,17 @@ def mk_dataset(
)


def empty_dataset(query_point_shape: ShapeLike, observation_shape: ShapeLike) -> Dataset:
def empty_dataset(
query_point_shape: ShapeLike, observation_shape: ShapeLike, dtype: tf.DType = tf.float64
) -> Dataset:
"""
:param query_point_shape: The shape of a *single* query point.
:param observation_shape: The shape of a *single* observation.
:param dtype: The dtype.
:return: An empty dataset with points of the specified shapes, and dtype `tf.float64`.
"""
qp = tf.zeros(tf.TensorShape([0]) + query_point_shape, tf.float64)
obs = tf.zeros(tf.TensorShape([0]) + observation_shape, tf.float64)
qp = tf.zeros(tf.TensorShape([0]) + query_point_shape, dtype)
obs = tf.zeros(tf.TensorShape([0]) + observation_shape, dtype)
return Dataset(qp, obs)


Expand Down
10 changes: 5 additions & 5 deletions trieste/acquisition/function/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def acquisition(x: TensorType) -> TensorType:
t = (threshold - mean) / stdev
t_plus = t + alpha
t_minus = t - alpha
normal = tfp.distributions.Normal(tf.cast(0, x.dtype), tf.cast(1, x.dtype))
normal = tfp.distributions.Normal(tf.constant(0, x.dtype), tf.constant(1, x.dtype))

if delta == 1:
G = (
Expand Down Expand Up @@ -361,11 +361,11 @@ def __init__(
self._integration_points = integration_points

if threshold is None:
self._weights = tf.cast(1.0, integration_points.dtype)
self._weights = tf.constant(1.0, integration_points.dtype)

else:
if isinstance(threshold, float):
t_threshold = tf.cast([threshold], integration_points.dtype)
t_threshold = tf.constant([threshold], integration_points.dtype)
else:
t_threshold = tf.cast(threshold, integration_points.dtype)

Expand Down Expand Up @@ -504,10 +504,10 @@ def __call__(self, x: TensorType) -> TensorType:
mean, variance = self._model.predict(tf.squeeze(x, -2))
variance = tf.maximum(variance, self._jitter)

normal = tfp.distributions.Normal(tf.cast(0, mean.dtype), tf.cast(1, mean.dtype))
normal = tfp.distributions.Normal(tf.constant(0, mean.dtype), tf.constant(1, mean.dtype))
p = normal.cdf((mean / tf.sqrt(variance + 1)))

C2 = (math.pi * tf.math.log(tf.cast(2, mean.dtype))) / 2
C2 = (math.pi * tf.math.log(tf.constant(2, mean.dtype))) / 2
Ef = (tf.sqrt(C2) / tf.sqrt(variance + C2)) * tf.exp(-(mean**2) / (2 * (variance + C2)))

return -p * tf.math.log(p + self._jitter) - (1 - p) * tf.math.log(1 - p + self._jitter) - Ef
6 changes: 3 additions & 3 deletions trieste/acquisition/function/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __call__(self, x: TensorType) -> TensorType:
fsd, CLAMP_LB, fmean.dtype.max
) # clip below to improve numerical stability

normal = tfp.distributions.Normal(tf.cast(0, fmean.dtype), tf.cast(1, fmean.dtype))
normal = tfp.distributions.Normal(tf.constant(0, fmean.dtype), tf.constant(1, fmean.dtype))
gamma = (tf.squeeze(self._samples) - fmean) / fsd

log_minus_cdf = normal.log_cdf(-gamma)
Expand Down Expand Up @@ -496,7 +496,7 @@ def __call__(self, x: TensorType) -> TensorType: # [N, D] -> [N, 1]
) # clip below to improve numerical stability
gamma = (tf.squeeze(self._samples) - fmean) / fsd

normal = tfp.distributions.Normal(tf.cast(0, fmean.dtype), tf.cast(1, fmean.dtype))
normal = tfp.distributions.Normal(tf.constant(0, fmean.dtype), tf.constant(1, fmean.dtype))
log_minus_cdf = normal.log_cdf(-gamma)
ratio = tf.math.exp(normal.log_prob(gamma) - log_minus_cdf)
inner_log = 1 + rho_squared * ratio * (gamma - ratio)
Expand Down Expand Up @@ -785,7 +785,7 @@ def __call__(self, x: TensorType) -> TensorType:
rho_squared = (cov**2) / (fvar * yvar)
rho_squared = tf.clip_by_value(rho_squared, 0.0, 1.0)

normal = tfp.distributions.Normal(tf.cast(0, fmean.dtype), tf.cast(1, fmean.dtype))
normal = tfp.distributions.Normal(tf.constant(0, fmean.dtype), tf.constant(1, fmean.dtype))
gamma = (tf.squeeze(self._samples) - fmean) / fsd
log_minus_cdf = normal.log_cdf(-gamma)
ratio = tf.math.exp(normal.log_prob(gamma) - log_minus_cdf)
Expand Down
8 changes: 4 additions & 4 deletions trieste/acquisition/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def fast_constraints_feasibility(
def acquisition(x: TensorType) -> TensorType:
if smoothing_function is None:
_smoothing_function = tfp.distributions.Normal(
tf.cast(0.0, x.dtype), tf.cast(1e-3, x.dtype)
tf.constant(0.0, x.dtype), tf.constant(1e-3, x.dtype)
).cdf
else:
_smoothing_function = smoothing_function
Expand Down Expand Up @@ -637,8 +637,8 @@ def __init__(
tf.debugging.assert_less_equal(float(min_feasibility_probability), 1.0)
else:
dtype = min_feasibility_probability.dtype
tf.debugging.assert_greater_equal(min_feasibility_probability, tf.cast(0, dtype))
tf.debugging.assert_less_equal(min_feasibility_probability, tf.cast(1, dtype))
tf.debugging.assert_greater_equal(min_feasibility_probability, tf.constant(0, dtype))
tf.debugging.assert_less_equal(min_feasibility_probability, tf.constant(1, dtype))

self._objective_tag = objective_tag
self._constraint_builder = constraint_builder
Expand Down Expand Up @@ -1896,7 +1896,7 @@ def __call__(self, x: TensorType) -> TensorType:

if not self._initialized:
normal = tfp.distributions.Normal(
tf.cast(0.0, dtype=x.dtype), tf.cast(1.0, dtype=x.dtype)
tf.constant(0.0, dtype=x.dtype), tf.constant(1.0, dtype=x.dtype)
)
spread = 0.5 + 0.5 * tf.range(1, batch_size + 1, dtype=x.dtype) / (
tf.cast(batch_size, dtype=x.dtype) + 1.0
Expand Down
2 changes: 1 addition & 1 deletion trieste/acquisition/function/greedy_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def __call__(self, x: TensorType) -> TensorType:
)
standardised_distances = (pairwise_distances - self._radius) / self._scale

normal = tfp.distributions.Normal(tf.cast(0, x.dtype), tf.cast(1, x.dtype))
normal = tfp.distributions.Normal(tf.constant(0, x.dtype), tf.constant(1, x.dtype))
penalization = normal.cdf(standardised_distances)
return tf.reduce_prod(penalization, axis=-1)

Expand Down
6 changes: 3 additions & 3 deletions trieste/acquisition/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def generate_initial_points(
remainder = vectorization % tf.shape(candidates)[1]
tf.debugging.assert_equal(
remainder,
tf.cast(0, dtype=remainder.dtype),
tf.constant(0, dtype=remainder.dtype),
message=(
f"""
The vectorization of the target function {vectorization} must be a multiple of
Expand Down Expand Up @@ -436,7 +436,7 @@ def optimize_continuous(
remainder = V % tf.shape(random_points)[1]
tf.debugging.assert_equal(
remainder,
tf.cast(0, dtype=remainder.dtype),
tf.constant(0, dtype=remainder.dtype),
message=(
f"""
The vectorization of the target function {V} must be a multiple of the batch
Expand Down Expand Up @@ -612,7 +612,7 @@ def _objective_value_and_gradient(x: TensorType) -> Tuple[TensorType, TensorType
remainder = V % len(bounds)
tf.debugging.assert_equal(
remainder,
tf.cast(0, dtype=remainder.dtype),
tf.constant(0, dtype=remainder.dtype),
message=(
f"""
The vectorization of the target function {V} must be a multiple of the length
Expand Down
4 changes: 3 additions & 1 deletion trieste/acquisition/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def sample(
fsd = tf.math.sqrt(fvar)

def probf(y: tf.Tensor) -> tf.Tensor: # Build empirical CDF for Pr(y*^hat<y)
unit_normal = tfp.distributions.Normal(tf.cast(0, fmean.dtype), tf.cast(1, fmean.dtype))
unit_normal = tfp.distributions.Normal(
tf.constant(0, fmean.dtype), tf.constant(1, fmean.dtype)
)
log_cdf = unit_normal.log_cdf(-(y - fmean) / fsd)
return 1 - tf.exp(tf.reduce_sum(log_cdf, axis=0))

Expand Down
Loading

0 comments on commit d320e3e

Please sign in to comment.