Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for np.random.Generator #6566

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@
MeasurementKey,
MeasurementType,
PeriodicValue,
PRNG_OR_SEED_LIKE,
RANDOM_STATE_OR_SEED_LIKE,
state_vector_to_probabilities,
SympyCondition,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
'QUANTUM_STATE_LIKE',
'QubitOrderOrList',
'RANDOM_STATE_OR_SEED_LIKE',
'PRNG_OR_SEED_LIKE',
'STATE_VECTOR_LIKE',
'Sweepable',
'TParamKey',
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/value/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@
from cirq.value.type_alias import TParamKey, TParamVal, TParamValComplex

from cirq.value.value_equality_attr import value_equality


from cirq.value.prng import parse_prng, CustomPRNG, PRNG_OR_SEED_LIKE
80 changes: 80 additions & 0 deletions cirq-core/cirq/value/prng.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
from typing import TypeVar, Union, overload

import numpy as np

from cirq._doc import document


class CustomPRNG(abc.ABC): ...


_CUSTOM_PRNG_T = TypeVar("_CUSTOM_PRNG_T", bound=CustomPRNG)
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
_PRNG_T = Union[np.random.Generator, np.random.RandomState, _CUSTOM_PRNG_T]
_SEED_T = Union[int, None]
PRNG_OR_SEED_LIKE = Union[None, int, np.random.RandomState, np.random.Generator, _CUSTOM_PRNG_T]

document(
PRNG_OR_SEED_LIKE,
"""A pseudorandom number generator or object that can be converted to one.

If an integer or None, turns into a `np.random.Generator` seeded with that
value.

If none of the above, it is used unmodified. In this case, it is assumed
that the object implements whatever methods are required for the use case
at hand. For example, it might be an existing instance of `np.random.Generator`
or `np.random.RandomState` or a custom pseudorandom number generator implementation
and in that case, it has to inherit `cirq.value.CustomPRNG`.
""",
)


@overload
def parse_prng(prng_or_seed: _SEED_T) -> np.random.Generator: ...


@overload
def parse_prng(prng_or_seed: np.random.Generator) -> np.random.Generator: ...


@overload
def parse_prng(prng_or_seed: np.random.RandomState) -> np.random.RandomState: ...


@overload
def parse_prng(prng_or_seed: _CUSTOM_PRNG_T) -> _CUSTOM_PRNG_T: ...


def parse_prng(
prng_or_seed: PRNG_OR_SEED_LIKE,
) -> Union[np.random.Generator, np.random.RandomState, _CUSTOM_PRNG_T]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type of different types tends to be a code smell. Such returned value is less useful for type checking. In addition, Generator and RandomState (not to mention _CUSTOM_PRNG_T) have different APIs so the parse_prng caller would still need to do some isinstance check to ascertain the actual type and figure what methods can be called.

I would propose an alternative approach:

(1) convert the RANDOM_STATE_OR_SEED_LIKE type from Any to the Union of numpy types that can be converted to RandomState and the np.random.Generator type. Hopefully this can be done without too much hassle with typechecks, because the current Any type skips them completely.

(2) extend parse_random_state to accept a Generator object and convert it to RandomState.
Generator-s have bit_generator attribute that can be used to create RandomState.

(3) add method parse_random_generator to the cirq.value.random_state module which would take RANDOM_STATE_OR_SEED_LIKE argument and convert it to a Generator object.
np.random.RandomState() has a _bit_generator attribute that can be used for creating a Generator.
If in some configurations the _bit_generator is not present, we can just use RandomState.randint to get a seed for the np.random.default_rng()

With these steps in place, we can keep all the existing interfaces that take RANDOM_STATE_OR_SEED_LIKE and just start replacing its interpretation from parse_random_state to parse_random_generator as needed.

This would also avoid bifurcation between RANDOM_STATE_OR_SEED_LIKE and PRNG_OR_SEED_LIKE types that may need several major releases to clear up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptal

"""Interpret an object as a pseudorandom number generator.

If `prng_or_seed` is None or an integer, returns `np.random.default_rng(prng_or_seed)`.
Otherwise, returns `prng_or_seed` unmodified.

Args:
prng_or_seed: The object to be used as or converted to a pseudorandom
number generator.

Returns:
The pseudorandom number generator object.
"""
if prng_or_seed is None or isinstance(prng_or_seed, int):
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved
return np.random.default_rng(prng_or_seed)
return prng_or_seed
48 changes: 48 additions & 0 deletions cirq-core/cirq/value/prng_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2024 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union
import numpy as np
import cirq


class TestPrng(cirq.value.CustomPRNG):
NoureldinYosri marked this conversation as resolved.
Show resolved Hide resolved

def random(self, size):
return tuple(range(size))


def _sample(prng):
return tuple(prng.random(10))
Comment on lines +23 to +24
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this. One output from random() is enough to check if 2 generators are at the same seed.



def test_parse_rng() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_parse_rng() -> None:
def test_parse_prng() -> None:

eq = cirq.testing.EqualsTester()

# An `np.random.Generator` or a seed.
group_inputs: List[Union[int, np.random.Generator]] = [42, np.random.default_rng(42)]
group: List[np.random.Generator] = [cirq.value.parse_prng(s) for s in group_inputs]
eq.add_equality_group(*[_sample(g) for g in group])
Comment on lines +30 to +33
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us not check cross-group inequality. Following the test_parse_random_state style is a bit more readable

Suggested change
# An `np.random.Generator` or a seed.
group_inputs: List[Union[int, np.random.Generator]] = [42, np.random.default_rng(42)]
group: List[np.random.Generator] = [cirq.value.parse_prng(s) for s in group_inputs]
eq.add_equality_group(*[_sample(g) for g in group])
# An `np.random.Generator` or a seed.
prngs = [
cirq.value.parse_prng(42),
cirq.value.parse_prng(np.int32(42)),
cirq.value.parse_prng(np.random.default_rng(42)),
]
vals = [prng.random() for prng in prngs]
eq = cirq.testing.EqualsTester()
eq.add_equality_group(*vals)


# A None seed.
prng: np.random.Generator = cirq.value.parse_prng(None)
eq.add_equality_group(_sample(prng))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a noop check for a single value. Perhaps replace with

assert prng is cirq.value.parse_prng(None)

if you are OK with the previous suggestion to have a singleton generator for None.


# Custom PRNG.
custom_prng: TestPrng = cirq.value.parse_prng(TestPrng())
eq.add_equality_group(_sample(custom_prng))

# RandomState PRNG.
random_state: np.random.RandomState = cirq.value.parse_prng(np.random.RandomState(42))
eq.add_equality_group(_sample(random_state))
2 changes: 2 additions & 0 deletions cirq-core/cirq/value/random_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
at hand. For example, it might be an existing instance of
`np.random.RandomState` or a custom pseudorandom number generator
implementation.

Note: prefer to use cirq.PRNG_OR_SEED_LIKE.
""",
)

Expand Down