-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Include chain break points in returned embedding context #447
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,9 +26,58 @@ | |
'majority_vote', | ||
'weighted_random', | ||
'MinimizeEnergy', | ||
'break_points', | ||
] | ||
|
||
|
||
def break_points(samples, embedding): | ||
"""Identify breakpoints in each chain. | ||
|
||
Args: | ||
samples (array_like): | ||
Samples as a nS x nV array_like object where nS is the number of samples and nV is the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should also specify that the samples should be for the embedded problem |
||
number of variables. The values should all be 0/1 or -1/+1. | ||
embedding (dwave.embedding.transforms.EmbeddedStructure): | ||
Mapping from source graph to target graph as a dict of form {s: [t, ...], ...}, | ||
where s is a source-model variable and t is a target-model variable. | ||
|
||
Returns: | ||
list: A list, of size nS, of `dict`: | ||
|
||
dict: A dictionary whose keys are variables of a BQM, and values are lists of | ||
2-tuples `(u, v)` representing edges in the target graph. The existent of an edge indicates | ||
`u` and `v` disagree in its value, constituting a chain break. | ||
|
||
The index of the list corresponds to the index of the sample in `samples`. | ||
|
||
Examples: | ||
|
||
>>> from dwave.embedding.transforms import EmbeddedStructure | ||
|
||
>>> embedding = EmbeddedStructure([(0,1), (1,2)], {0: [0, 1, 2]}) | ||
>>> samples = np.array([[-1, +1, -1], [-1, -1, -1]], dtype=np.int8) | ||
>>> dwave.embedding.break_points(samples, embedding) | ||
[{0: [(0, 1), (1, 2)]}, {}] | ||
|
||
>>> embedding = EmbeddedStructure([(0,1), (1,2), (0,2)], {0: [0, 1, 2]}) | ||
>>> samples = np.array([[-1, +1, -1], [-1, +1, +1]], dtype=np.int8) | ||
>>> dwave.embedding.break_points(samples, embedding) | ||
[{0: [(0, 1), (1, 2)]}, {0: [(0, 1), (0, 2)]}] | ||
""" | ||
|
||
result = [] | ||
for sample in samples: | ||
bps = {} | ||
for node in embedding.keys(): | ||
chain_edges = embedding.chain_edges(node) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to avoid some confusing errors, you probably want try:
chain_edges = embedding.chain_edges(nodes)
except AttributeError:
raise TypeError("'embedding' must be a dwave.embedding.EmbeddedStructure") from None |
||
broken_edges = [(u, v) for u, v in chain_edges if sample[u] != sample[v]] | ||
if len(broken_edges) > 0: | ||
bps[node] = broken_edges | ||
result.append(bps) | ||
|
||
return result | ||
|
||
|
||
def broken_chains(samples, chains): | ||
"""Find the broken chains. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,7 @@ | |
|
||
from dwave.embedding import (target_to_source, unembed_sampleset, embed_bqm, | ||
chain_to_quadratic, EmbeddedStructure) | ||
from dwave.embedding.chain_breaks import break_points | ||
from dwave.system.warnings import WarningHandler, WarningAction | ||
|
||
__all__ = ('EmbeddingComposite', | ||
|
@@ -289,7 +290,8 @@ def async_unembed(response): | |
if return_embedding: | ||
sampleset.info['embedding_context'].update( | ||
embedding_parameters=embedding_parameters, | ||
chain_strength=embedding.chain_strength) | ||
chain_strength=embedding.chain_strength, | ||
break_points=break_points(response, embedding)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a non-trivial performance hit. IMO we either should not do this by default or we need to write a more performant implementation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, ideally this would be a lazy proxy. But FWIW, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's true, unless for instance the inspector is imported. My inclination is to not include this in the embedding composite for now, but document how to use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the PR according to feedback except for this comment. Should I move the statistic to I think the lazy proxy approach would require storing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the best thing is just to remove the statistic from the embedding composite altogether. I would then add an example to the |
||
|
||
if chain_break_fraction and len(sampleset): | ||
warninghandler.issue("All samples have broken chains", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be expanded to samples like. You can use
array, labels = dimod.as_samples(samples_like)
to get the numpy array.The reason I say this is because in Ocean you can get embeddings that look like
{'a': ['b', 'c']}
. The QPU (currently) only uses integer labels for its qubits, but that might change in the future.Should add a test for this as well.