Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare set expression for output to target expression #359

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion cgp/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ def change_address_gene_of_output_node(self, new_address: int, output_node_idx:
self.dna = dna

def set_expression_for_output(
self, dna_insert: List[int], hidden_start_node: int = 0, output_node_idx: int = 0
self,
dna_insert: List[int],
target_expression: str,
hidden_start_node: int = 0,
output_node_idx: int = 0,
):
"""Set an expression for one output node

Expand All @@ -370,6 +374,9 @@ def set_expression_for_output(
----------
dna_insert: List[int]
dna segment to be inserted at the first hidden nodes.
target_expression: str
Expression the output node should compile to. Numbers must be written as float.
Defaults to None.
HenrikMettler marked this conversation as resolved.
Show resolved Hide resolved
hidden_start_node: int
Index of the hidden node, where the insert starts.
Relative to the first hidden node.
Expand All @@ -388,6 +395,23 @@ def set_expression_for_output(
self.change_address_gene_of_output_node(
new_address=last_inserted_node, output_node_idx=output_node_idx
)
try:
HenrikMettler marked this conversation as resolved.
Show resolved Hide resolved
import sympy

except ModuleNotFoundError:
raise ModuleNotFoundError(
"Can not check output expression. No module named 'sympy' (extra requirement)"
)

if target_expression is not None:
HenrikMettler marked this conversation as resolved.
Show resolved Hide resolved
if self._n_outputs > 1:
output_as_sympy = CartesianGraph(self).to_sympy()[output_node_idx]
else:
output_as_sympy = CartesianGraph(self).to_sympy()

target_expression_as_sympy = sympy.parse_expr(target_expression)
if not output_as_sympy == target_expression_as_sympy:
raise ValueError("expression of output and target expression do not match")

def reorder(self, rng: np.random.RandomState) -> None:
"""Reorder the genome
Expand Down
28 changes: 23 additions & 5 deletions test/test_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,17 +835,35 @@ def test_set_expression_for_output(genome_params, rng):
genome = cgp.Genome(**genome_params)
genome.randomize(rng)

new_dna = [0, 0, 1]
genome.set_expression_for_output(new_dna)

x_0 = sympy.symbols("x_0")
x_1 = sympy.symbols("x_1")

new_dna = [0, 0, 1]
genome.set_expression_for_output(new_dna, target_expression="x_0 + x_1")
assert CartesianGraph(genome).to_sympy() == x_0 + x_1

new_dna = [1, 0, 1]
genome.set_expression_for_output(new_dna)
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 - x_1")
assert CartesianGraph(genome).to_sympy() == x_0 - x_1

new_dna = [0, 0, 1, 2, 0, 0, 1, 0, 0, 0, 2, 3] # x_0+x_1; 1.0; 0; x_0+x_1 + 1.0
genome.set_expression_for_output(new_dna)
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 + 1.0")
assert CartesianGraph(genome).to_sympy() == x_0 + x_1 + 1.0

with pytest.raises(ValueError):
# setting an int in the str causes an error
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 + 1")
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 * 1.0")

genome2_params = {
"n_inputs": 2,
"n_outputs": 2,
"primitives": (cgp.Add, cgp.Sub, cgp.ConstantFloat),
}
genome2 = cgp.Genome(**genome2_params)
genome2.randomize(rng)

genome2.set_expression_for_output(
new_dna, output_node_idx=1, target_expression="x_0 + x_1 + 1.0"
)
assert CartesianGraph(genome2).to_sympy()[1] == x_0 + x_1 + 1.0