Skip to content

Commit

Permalink
Update tests to not compare InvalidNumbers using strings
Browse files Browse the repository at this point in the history
  • Loading branch information
jsiirola committed Oct 22, 2024
1 parent e8b1a6b commit 4defde4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 30 deletions.
22 changes: 11 additions & 11 deletions pyomo/repn/tests/ampl/test_nlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)
import pyomo.environ as pyo

_invalid_1j = r'InvalidNumber\((\([-+0-9.e]+\+)?1j\)?\)'
nan = float('nan')


class INFO(object):
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_errors_divide_by_0(self):
)
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertEqual(str(repn.const), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -186,7 +186,7 @@ def test_errors_divide_by_0(self):
)
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertEqual(str(repn.const), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -201,7 +201,7 @@ def test_errors_divide_by_0(self):
)
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertEqual(str(repn.const), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -216,7 +216,7 @@ def test_errors_divide_by_0(self):
)
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertEqual(str(repn.const), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -231,7 +231,7 @@ def test_errors_divide_by_0(self):
)
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertEqual(str(repn.const), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down Expand Up @@ -424,7 +424,7 @@ def test_errors_negative_frac_pow(self):
)
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertRegex(str(repn.const), _invalid_1j)
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(1j))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -440,7 +440,7 @@ def test_errors_negative_frac_pow(self):
)
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertRegex(str(repn.const), _invalid_1j)
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(1j))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -460,7 +460,7 @@ def test_errors_unary_func(self):
)
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertEqual(str(repn.const), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -484,7 +484,7 @@ def test_errors_propagate_nan(self):
)
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertEqual(str(repn.const), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -494,7 +494,7 @@ def test_errors_propagate_nan(self):
repn = info.visitor.walk_expression((expr, None, None, 1))
self.assertEqual(repn.nl, None)
self.assertEqual(repn.mult, 1)
self.assertEqual(str(repn.const), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.const, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down
36 changes: 19 additions & 17 deletions pyomo/repn/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_scalars(self):
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down Expand Up @@ -203,7 +203,7 @@ def test_scalars(self):
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -214,7 +214,7 @@ def test_scalars(self):
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(1j)')
self.assertEqual(repn.constant, InvalidNumber(1j))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down Expand Up @@ -253,7 +253,7 @@ def test_npv(self):
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down Expand Up @@ -488,7 +488,7 @@ def test_monomial(self):
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -498,7 +498,7 @@ def test_monomial(self):
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand All @@ -508,7 +508,7 @@ def test_monomial(self):
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down Expand Up @@ -551,7 +551,7 @@ def test_monomial(self):
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down Expand Up @@ -765,7 +765,8 @@ def test_linear(self):
with LoggingIntercept() as LOG:
repn = LinearRepnVisitor(*cfg).walk_expression(e)
self.assertIn(
"DEPRECATED: Encountered 0*nan in expression tree.", LOG.getvalue()
"DEPRECATED: Encountered 0*InvalidNumber(nan) in expression tree.",
LOG.getvalue(),
)

self.assertEqual(cfg.subexpr, {})
Expand Down Expand Up @@ -1447,9 +1448,8 @@ def test_errors_propagate_nan(self):
"\texpression: (x + 1)/p\n",
)
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(nan)')
self.assertEqual(len(repn.linear), 1)
self.assertEqual(str(repn.linear[id(m.x)]), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertStructuredAlmostEqual(repn.linear, {id(m.x): InvalidNumber(nan)})
self.assertEqual(repn.nonlinear, None)

expr = m.y + m.x + m.z + ((3 * m.x) / m.p) / m.y
Expand All @@ -1464,16 +1464,16 @@ def test_errors_propagate_nan(self):
)
self.assertEqual(repn.multiplier, 1)
self.assertEqual(repn.constant, 1)
self.assertEqual(len(repn.linear), 2)
self.assertEqual(repn.linear[id(m.z)], 1)
self.assertEqual(str(repn.linear[id(m.x)]), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(
repn.linear, {id(m.z): 1, id(m.x): InvalidNumber(nan)}
)
self.assertEqual(repn.nonlinear, None)

m.y.fix(None)
expr = log(m.y) + 3
repn = LinearRepnVisitor(*cfg).walk_expression(expr)
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(nan)')
self.assertStructuredAlmostEqual(repn.constant, InvalidNumber(nan))
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down Expand Up @@ -1631,7 +1631,9 @@ def test_nonnumeric(self):
self.assertEqual(cfg.var_map, {})
self.assertEqual(cfg.var_order, {})
self.assertEqual(repn.multiplier, 1)
self.assertEqual(str(repn.constant), 'InvalidNumber(array([3, 4]))')
self.assertStructuredAlmostEqual(
repn.constant, InvalidNumber(numpy.array([3, 4]))
)
self.assertEqual(repn.linear, {})
self.assertEqual(repn.nonlinear, None)

Expand Down
4 changes: 2 additions & 2 deletions pyomo/repn/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test_apply_operation(self):
pyomo.repn.util.HALT_ON_EVALUATION_ERROR = False
with LoggingIntercept() as LOG:
val = apply_node_operation(div, [1, 0])
self.assertEqual(str(val), "InvalidNumber(nan)")
self.assertStructuredAlmostEqual(val, InvalidNumber(float('nan')))
self.assertEqual(
LOG.getvalue(),
"Exception encountered evaluating expression 'div(1, 0)'\n"
Expand Down Expand Up @@ -293,7 +293,7 @@ class Visitor(object):
pyomo.repn.util.HALT_ON_EVALUATION_ERROR = False
with LoggingIntercept() as LOG:
val = complex_number_error(1j, visitor, exp)
self.assertEqual(str(val), "InvalidNumber(1j)")
self.assertEqual(val, InvalidNumber(1j))
self.assertEqual(
LOG.getvalue(),
"Complex number returned from expression\n"
Expand Down

0 comments on commit 4defde4

Please sign in to comment.