Skip to content

Commit

Permalink
Remove get_node_value
Browse files Browse the repository at this point in the history
No longer required
  • Loading branch information
mstimberg committed Aug 4, 2023
1 parent bdf03e8 commit bb154dd
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 46 deletions.
28 changes: 10 additions & 18 deletions brian2/codegen/optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
brian_dtype_from_dtype,
dtype_hierarchy,
)
from brian2.parsing.rendering import NodeRenderer, get_node_value
from brian2.parsing.rendering import NodeRenderer
from brian2.utils.stringtools import get_identifiers, word_substitute

from .statements import Statement
Expand Down Expand Up @@ -271,7 +271,7 @@ def render_BinOp(self, node):
if op.__class__.__name__ == "Mult":
for operand, other in [(left, right), (right, left)]:
if operand.__class__.__name__ in ["Num", "Constant"]:
op_value = get_node_value(operand)
op_value = operand.value
if op_value == 0:
# Do not remove stateful functions
if node.stateless:
Expand All @@ -286,23 +286,20 @@ def render_BinOp(self, node):
# Handle division by 1, or 0/x
elif op.__class__.__name__ == "Div":
if (
left.__class__.__name__ in ["Num", "Constant"]
and get_node_value(left) == 0
left.__class__.__name__ in ["Num", "Constant"] and left.value == 0
): # 0/x
if node.stateless:
# Do not remove stateful functions
return _replace_with_zero(left, node)
if (
right.__class__.__name__ in ["Num", "Constant"]
and get_node_value(right) == 1
right.__class__.__name__ in ["Num", "Constant"] and right.value == 1
): # x/1
# only simplify this if the type wouldn't be cast by the operation
if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]:
return left
elif op.__class__.__name__ == "FloorDiv":
if (
left.__class__.__name__ in ["Num", "Constant"]
and get_node_value(left) == 0
left.__class__.__name__ in ["Num", "Constant"] and left.value == 0
): # 0//x
if node.stateless:
# Do not remove stateful functions
Expand All @@ -313,25 +310,22 @@ def render_BinOp(self, node):
if (
left.dtype == right.dtype == "integer"
and right.__class__.__name__ in ["Num", "Constant"]
and get_node_value(right) == 1
and right.value == 1
): # x//1
return left
# Handle addition of 0
elif op.__class__.__name__ == "Add":
for operand, other in [(left, right), (right, left)]:
if (
operand.__class__.__name__ in ["Num", "Constant"]
and get_node_value(operand) == 0
and operand.value == 0
):
# only simplify this if the type wouldn't be cast by the operation
if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[other.dtype]:
return other
# Handle subtraction of 0
elif op.__class__.__name__ == "Sub":
if (
right.__class__.__name__ in ["Num", "Constant"]
and get_node_value(right) == 0
):
if right.__class__.__name__ in ["Num", "Constant"] and right.value == 0:
# only simplify this if the type wouldn't be cast by the operation
if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]:
return left
Expand All @@ -346,12 +340,10 @@ def render_BinOp(self, node):
]:
for subnode in [node.left, node.right]:
if subnode.__class__.__name__ in ["Num", "Constant"] and not (
get_node_value(subnode) is True or get_node_value(subnode) is False
subnode.value is True or subnode.value is False
):
subnode.dtype = "float"
subnode.value = prefs.core.default_float_dtype(
get_node_value(subnode)
)
subnode.value = prefs.core.default_float_dtype(subnode.value)
return node


Expand Down
3 changes: 1 addition & 2 deletions brian2/parsing/bast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy

from brian2.parsing.rendering import get_node_value
from brian2.utils.logger import get_logger

__all__ = ["brian_ast", "BrianASTRenderer", "dtype_hierarchy"]
Expand Down Expand Up @@ -168,7 +167,7 @@ def render_Name(self, node):

def render_Num(self, node):
node.complexity = 0
node.dtype = brian_dtype_from_value(get_node_value(node))
node.dtype = brian_dtype_from_value(node.value)
node.scalar = True
node.stateless = True
return node
Expand Down
4 changes: 2 additions & 2 deletions brian2/parsing/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ast

from brian2.core.functions import Function
from brian2.parsing.rendering import NodeRenderer, get_node_value
from brian2.parsing.rendering import NodeRenderer
from brian2.units.fundamentalunits import (
DIMENSIONLESS,
DimensionMismatchError,
Expand Down Expand Up @@ -138,7 +138,7 @@ def _get_value_from_expression(expr, variables):
else:
raise ValueError(f"Unknown identifier {name}")
elif expr.__class__ is ast.Constant:
return get_node_value(expr)
return expr.value
elif expr.__class__ is ast.BoolOp:
raise SyntaxError(
"Cannot determine the numerical value for a boolean operation."
Expand Down
30 changes: 6 additions & 24 deletions brian2/parsing/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,9 @@
"NumpyNodeRenderer",
"CPPNodeRenderer",
"SympyNodeRenderer",
"get_node_value",
]


def get_node_value(node):
"""Helper function to mask differences between Python versions"""
try:
value = node.value
except AttributeError:
try:
value = node.n
except AttributeError:
value = None

if value is None:
raise AttributeError(f'Node {node} has neither "n" nor "value" attribute')
return value


class NodeRenderer:
expression_ops = {
# BinOp
Expand Down Expand Up @@ -97,9 +81,9 @@ def render_Name(self, node):
return node.id

def render_Num(self, node):
return repr(get_node_value(node))
return repr(node.value)

def render_Constant(self, node): # For literals in Python 3.8
def render_Constant(self, node):
if node.value is True or node.value is False or node.value is None:
return self.render_NameConstant(node)
else:
Expand Down Expand Up @@ -130,9 +114,7 @@ def render_element_parentheses(self, node):
"""
if node.__class__.__name__ == "Name":
return self.render_node(node)
elif (
node.__class__.__name__ in ["Num", "Constant"] and get_node_value(node) >= 0
):
elif node.__class__.__name__ in ["Num", "Constant"] and node.value >= 0:
return self.render_node(node)
elif node.__class__.__name__ == "Call":
return self.render_node(node)
Expand Down Expand Up @@ -285,10 +267,10 @@ def render_NameConstant(self, node):
return str(node.value)

def render_Num(self, node):
if isinstance(get_node_value(node), numbers.Integral):
return sympy.Integer(get_node_value(node))
if isinstance(node.value, numbers.Integral):
return sympy.Integer(node.value)
else:
return sympy.Float(get_node_value(node))
return sympy.Float(node.value)

def render_BinOp(self, node):
op_name = node.op.__class__.__name__
Expand Down

0 comments on commit bb154dd

Please sign in to comment.