From d08d7470866261ebc50958ccd20e222142b284da Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 30 May 2024 12:16:11 +0100 Subject: [PATCH] ConstantValue: Support general dtypes --- test/test_literals.py | 8 ++++++++ ufl/constantvalue.py | 11 ++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/test/test_literals.py b/test/test_literals.py index 7e63c89bf..d31035181 100755 --- a/test/test_literals.py +++ b/test/test_literals.py @@ -1,6 +1,8 @@ __authors__ = "Martin Sandve Alnæs" __date__ = "2011-04-14" +import numpy as np + from ufl import PermutationSymbol, as_matrix, as_vector, indices, product from ufl.classes import Indexed from ufl.constantvalue import ComplexValue, FloatValue, IntValue, Zero, as_ufl @@ -29,6 +31,7 @@ def test_float(self): f4 = FloatValue(1.0) f5 = 3 - FloatValue(1) - 1 f6 = 3 * FloatValue(2) / 6 + f7 = as_ufl(np.ones((1,), dtype="d")[0]) assert f1 == f1 self.assertNotEqual(f1, f2) # IntValue vs FloatValue, == compares representations! @@ -36,6 +39,7 @@ def test_float(self): assert f2 == f4 assert f2 == f5 assert f2 == f6 + assert f2 == f7 def test_int(self): @@ -45,6 +49,7 @@ def test_int(self): f4 = IntValue(1.0) f5 = 3 - IntValue(1) - 1 f6 = 3 * IntValue(2) / 6 + f7 = as_ufl(np.ones((1,), dtype="int")[0]) assert f1 == f1 self.assertNotEqual(f1, f2) # IntValue vs FloatValue, == compares representations! @@ -52,6 +57,7 @@ def test_int(self): assert f1 == f4 assert f1 == f5 assert f2 == f6 # Division produces a FloatValue + assert f1 == f7 def test_complex(self): @@ -62,6 +68,7 @@ def test_complex(self): f5 = ComplexValue(1.0 + 1.0j) f6 = as_ufl(1.0) f7 = as_ufl(1.0j) + f8 = as_ufl(np.array([1+1j], dtype="complex")[0]) assert f1 == f1 assert f1 == f4 @@ -71,6 +78,7 @@ def test_complex(self): assert f5 == f2 + f3 assert f4 == f5 assert f6 + f7 == f2 + f3 + assert f4 == f8 def test_scalar_sums(self): diff --git a/ufl/constantvalue.py b/ufl/constantvalue.py index 0aa320d8f..ef0b69329 100644 --- a/ufl/constantvalue.py +++ b/ufl/constantvalue.py @@ -10,6 +10,7 @@ # Modified by Massimiliano Leoni, 2016. from math import atan2 +import numbers import ufl @@ -506,12 +507,12 @@ def as_ufl(expression): """Converts expression to an Expr if possible.""" if isinstance(expression, (Expr, ufl.BaseForm)): return expression - elif isinstance(expression, complex): - return ComplexValue(expression) - elif isinstance(expression, float): - return FloatValue(expression) - elif isinstance(expression, int): + elif isinstance(expression, numbers.Integral): return IntValue(expression) + elif isinstance(expression, numbers.Real): + return FloatValue(expression) + elif isinstance(expression, numbers.Complex): + return ComplexValue(expression) else: raise ValueError( f"Invalid type conversion: {expression} can not be converted to any UFL type."