diff --git a/dwave/embedding/__init__.py b/dwave/embedding/__init__.py index 84392556..6c05f740 100644 --- a/dwave/embedding/__init__.py +++ b/dwave/embedding/__init__.py @@ -20,7 +20,7 @@ from dwave.embedding.diagnostic import diagnose_embedding, is_valid_embedding, verify_embedding -from dwave.embedding.chain_breaks import broken_chains +from dwave.embedding.chain_breaks import broken_chains, break_points from dwave.embedding.chain_breaks import discard, majority_vote, weighted_random, MinimizeEnergy from dwave.embedding.transforms import embed_bqm, embed_ising, embed_qubo, unembed_sampleset, EmbeddedStructure diff --git a/dwave/embedding/chain_breaks.py b/dwave/embedding/chain_breaks.py index 8e6303ee..4600c47c 100644 --- a/dwave/embedding/chain_breaks.py +++ b/dwave/embedding/chain_breaks.py @@ -26,9 +26,75 @@ 'majority_vote', 'weighted_random', 'MinimizeEnergy', + 'break_points', ] +def break_points(samples_like, embedding): + """Identify breakpoints in each chain. + + Args: + samples_like (dimod.typing.SamplesLike): + A collection of raw samples for the embedded problem. + Each sample's variables' values should be 0/1 or -1/+1. + + embedding (dwave.embedding.transforms.EmbeddedStructure): + Mapping from source graph to target graph as a dict of form {s: [t, ...], ...}, + where s is a source-model variable and t is a target-model variable. + + Returns: + list: A list of `dict`. The size of the list is equal to number of input samples: + + dict: A dictionary whose keys are variable labels of the logical BQM + (the problem you care about), and values are lists of 2-tuples `(u, v)` + representing edges in the target graph (the QPU graph). The existence + of an edge (u, v) indicates `u` and `v` disagree in its value, i.e., + a break point in the chain. + + The index of the list corresponds to the index of the sample in `samples_like`. + + Examples: + + >>> from dwave.embedding.transforms import EmbeddedStructure + + >>> embedding = EmbeddedStructure([(0,1), (1,2)], {0: [0, 1, 2]}) + >>> samples = np.array([[-1, +1, -1], [-1, -1, -1]], dtype=np.int8) + >>> dwave.embedding.break_points(samples, embedding) + [{0: [(0, 1), (1, 2)]}, {}] + + >>> embedding = EmbeddedStructure([(0,1), (1,2), (0,2)], {0: [0, 1, 2]}) + >>> samples = np.array([[-1, +1, -1], [-1, +1, +1]], dtype=np.int8) + >>> dwave.embedding.break_points(samples, embedding) + [{0: [(0, 1), (1, 2)]}, {0: [(0, 1), (0, 2)]}] + + >>> samples = [{"a": -1, "b": +1, "c": -1}, {"a": -1, "b": -1, "c": -1},] + >>> target_edges = [("a", "b"), ("b", "c")] + >>> chains = {"x": ["a", "b", "c"]} + >>> embedding = EmbeddedStructure(target_edges, chains) + >>> dwave.embedding.break_points(samples, embedding) + [{"x": [("a", "b"), ("b", "c")]}, {}] + """ + result = [] + samples, labels = dimod.as_samples(samples_like) + label_to_i = {label: idx for idx, label in enumerate(labels)} + for sample in samples: + bps = {} + for node in embedding.keys(): + try: + chain_edges = embedding.chain_edges(node) + except AttributeError: + raise TypeError("'embedding' must be a dwave.embedding.EmbeddedStructure") from None + + broken_edges = [(u, v) for u, v in chain_edges + if sample[label_to_i[u]] != sample[label_to_i[v]]] + if len(broken_edges) > 0: + bps[node] = broken_edges + + result.append(bps) + + return result + + def broken_chains(samples, chains): """Find the broken chains. diff --git a/dwave/system/composites/embedding.py b/dwave/system/composites/embedding.py index 9021827f..581f7513 100644 --- a/dwave/system/composites/embedding.py +++ b/dwave/system/composites/embedding.py @@ -34,6 +34,7 @@ from dwave.embedding import (target_to_source, unembed_sampleset, embed_bqm, chain_to_quadratic, EmbeddedStructure) +from dwave.embedding.chain_breaks import break_points from dwave.system.warnings import WarningHandler, WarningAction __all__ = ('EmbeddingComposite', @@ -289,7 +290,8 @@ def async_unembed(response): if return_embedding: sampleset.info['embedding_context'].update( embedding_parameters=embedding_parameters, - chain_strength=embedding.chain_strength) + chain_strength=embedding.chain_strength, + break_points=break_points(response, embedding)) if chain_break_fraction and len(sampleset): warninghandler.issue("All samples have broken chains", diff --git a/tests/test_embedding_chain_breaks.py b/tests/test_embedding_chain_breaks.py index 963f98d5..7a45554d 100644 --- a/tests/test_embedding_chain_breaks.py +++ b/tests/test_embedding_chain_breaks.py @@ -20,6 +20,108 @@ import numpy as np import dwave.embedding +from dwave.embedding.transforms import EmbeddedStructure + + +class TestBreakPoints(unittest.TestCase): + def test_break_points_no_samples(self): + # No samples + target_edges = [(0, 1), (1, 2)] + chains = {0: [0, 1, 2]} + embedding = EmbeddedStructure(target_edges, chains) + + samples = np.array([], dtype=np.int8) + + break_points = dwave.embedding.break_points(samples, embedding) + answer = [] + + np.testing.assert_array_equal(answer, break_points) + + def test_break_points_no_breaks(self): + # No breaks :D + target_edges = [(0, 1), (1, 2)] + chains = {0: [0, 1, 2]} + embedding = EmbeddedStructure(target_edges, chains) + + samples = np.array([[+1, +1, +1], + [+1, +1, +1]], dtype=np.int8) + + break_points = dwave.embedding.break_points(samples, embedding) + answer = [{}, {}] + + np.testing.assert_array_equal(answer, break_points) + + def test_break_points_chain(self): + # Target chain of length 3, one embedded variable + target_edges = [(0, 1), (1, 2)] + chains = {0: [0, 1, 2]} + embedding = EmbeddedStructure(target_edges, chains) + + samples = np.array([[-1, +1, -1], + [-1, -1, -1]], dtype=np.int8) + + break_points = dwave.embedding.break_points(samples, embedding) + answer = [{0: [(0, 1), (1, 2)]}, + {}] + + np.testing.assert_array_equal(answer, break_points) + + def test_break_points_chain_string_labels(self): + # Target chain of length 3, one embedded variable, but labels are strings + target_edges = [("a", "b"), ("b", "c")] + chains = {"x": ["a", "b", "c"]} + embedding = EmbeddedStructure(target_edges, chains) + + # samples = np.array([[-1, +1, -1], + # [-1, -1, -1]], dtype=np.int8) + samples = [{"a": -1, "b": +1, "c": -1}, + {"a": -1, "b": -1, "c": -1},] + + break_points = dwave.embedding.break_points(samples, embedding) + answer = [{"x": [("a", "b"), ("b", "c")]}, + {}] + + np.testing.assert_array_equal(answer, break_points) + + def test_break_points_loop(self): + # Target triangle, one embedded variable + target_edges = [(0, 1), (1, 2), (0, 2)] + chains = {0: [0, 1, 2]} + embedding = EmbeddedStructure(target_edges, chains) + + samples = np.array([[-1, +1, -1], + [-1, +1, +1]], dtype=np.int8) + + break_points = dwave.embedding.break_points(samples, embedding) + answer = [{0: [(0, 1), (1, 2)]}, + {0: [(0, 1), (0, 2)]}] + np.testing.assert_array_equal(answer, break_points) + + def test_break_points_chain_2(self): + # Target triangle, two embedded variables + target_edges = [(0, 1), (1, 2)] + chains = {0: [0, 1], 1: [2]} + + embedding = EmbeddedStructure(target_edges, chains) + samples = np.array([[-1, +1, -1], + [-1, -1, +1]], dtype=np.int8) + + break_points = dwave.embedding.break_points(samples, embedding) + answer = [{0: [(0, 1)]}, + {}] + np.testing.assert_array_equal(answer, break_points) + + def test_break_points_loop_2(self): + # Target square, two embedded variables + target_edges = [(0, 1), (1, 2), (2, 3), (0, 3)] + chains = {0: [0, 1], 1: [2, 3]} + embedding = EmbeddedStructure(target_edges, chains) + samples = np.array([[-1, -1, +1, -1], + [-1, +1, +1, -1]], dtype=np.int8) + break_points = dwave.embedding.break_points(samples, embedding) + answer = [{1: [(2, 3)]}, {0: [(0, 1)], + 1: [(2, 3)]}] + np.testing.assert_array_equal(answer, break_points) class TestBrokenChains(unittest.TestCase): diff --git a/tests/test_embedding_composite.py b/tests/test_embedding_composite.py index 4be9bc43..17ebaf2f 100644 --- a/tests/test_embedding_composite.py +++ b/tests/test_embedding_composite.py @@ -248,6 +248,10 @@ def test_return_embedding(self): embedding = sampleset.info['embedding_context']['embedding'] self.assertEqual(set(embedding), {'a', 'c'}) + self.assertIn('break_points', sampleset.info['embedding_context']) + break_points = sampleset.info['embedding_context']['break_points'] + self.assertEqual(break_points, []) + self.assertIn('chain_break_method', sampleset.info['embedding_context']) self.assertEqual(sampleset.info['embedding_context']['chain_break_method'], 'majority_vote') # the default @@ -277,6 +281,10 @@ def test_return_embedding_as_class_variable(self): embedding = sampleset.info['embedding_context']['embedding'] self.assertEqual(set(embedding), {'a', 'c'}) + self.assertIn('break_points', sampleset.info['embedding_context']) + break_points = sampleset.info['embedding_context']['break_points'] + self.assertEqual(break_points, []) + self.assertIn('chain_break_method', sampleset.info['embedding_context']) self.assertEqual(sampleset.info['embedding_context']['chain_break_method'], 'majority_vote') # the default