diff --git a/src/expr/scalar/convert.jl b/src/expr/scalar/convert.jl index f126803..46a9d38 100644 --- a/src/expr/scalar/convert.jl +++ b/src/expr/scalar/convert.jl @@ -1,7 +1,14 @@ Base.convert(::Type{Index.Type}, x::Int) = Index.Constant(x) Base.convert(::Type{Index.Type}, x::Symbol) = Index.Variable(x) -Base.convert(::Type{Num.Type}, x::Real) = Num.Real(x) +Base.convert(::Type{Num.Type}, x::Real) = if iszero(x) + Num.Zero +elseif isone(x) + Num.One +else + Num.Real(x) +end + Base.convert(::Type{Num.Type}, x::Complex) = if iszero(imag(x)) Num.Real(real(x)) elseif iszero(real(x)) @@ -9,6 +16,7 @@ elseif iszero(real(x)) else Num.Complex(real(x), imag(x)) end + Base.convert(::Type{Num.Type}, x::typeof(MathConstants.e)) = Num.Euler Base.convert(::Type{Num.Type}, x::typeof(pi)) = Num.Pi @@ -39,6 +47,10 @@ end Base.convert(::Type{Num.Type}, x::Scalar.Type) = onlyif_constant(x->x.:1, x) Base.convert(::Type{T}, x::Num.Type) where {T <: Number} = if isa_variant(x, Num.Real) return x.:1 +elseif isa_variant(x, Num.Zero) + return zero(T) +elseif isa_variant(x, Num.One) + return one(T) elseif isa_variant(x, Num.Imag) return convert(T, x.:1) * im elseif isa_variant(x, Num.Complex) diff --git a/src/expr/scalar/data.jl b/src/expr/scalar/data.jl index 2e17f20..6b31683 100644 --- a/src/expr/scalar/data.jl +++ b/src/expr/scalar/data.jl @@ -46,6 +46,8 @@ end This is the basic numeric type. """ @data Num begin + Zero + One Real(Float64) Imag(Float64) Complex(Float64, Float64) diff --git a/src/expr/scalar/show.jl b/src/expr/scalar/show.jl index a5fdd99..591faa5 100644 --- a/src/expr/scalar/show.jl +++ b/src/expr/scalar/show.jl @@ -4,6 +4,10 @@ function Data.show_data(io::IO, x::Num.Type) return f.print("Num.Pi") elseif isa_variant(x, Num.Euler) return f.print("Num.Euler") + elseif isa_variant(x, Num.Zero) + return f.print("Num.Zero") + elseif isa_variant(x, Num.One) + return f.print("Num.One") end f.show(Data.variant_type(x)) diff --git a/src/expr/scalar/syntax.jl b/src/expr/scalar/syntax.jl index b00c04d..b4dcb4b 100644 --- a/src/expr/scalar/syntax.jl +++ b/src/expr/scalar/syntax.jl @@ -1 +1,12 @@ # some overloads for the syntax of scalar expressions +function Base.:(+)(lhs::Scalar.Type, rhs::Scalar.Type) + if isa_variant(lhs, Scalar.Constant) && isa_variant(rhs, Scalar.Constant) + Scalar.Constant(Number(lhs.:1) + Number(rhs.:1)) + elseif isa_variant(rhs, Scalar.Constant) + Scalar.Sum(rhs, Dict(lhs => 1)) + elseif isa_variant(lhs, Scalar.Constant) + Scalar.Sum(lhs, Dict(rhs => 1)) + else + Scalar.Sum(Scalar.Constant(0), Dict(lhs => 1, rhs => 1)) + end +end diff --git a/test/expr/scalar.jl b/test/expr/scalar.jl index 8d887b8..27477a1 100644 --- a/test/expr/scalar.jl +++ b/test/expr/scalar.jl @@ -16,3 +16,5 @@ x = Scalar.Sum( Scalar.Constant(2.0) => 2.0, ) ) + +Scalar.Constant(1.5) + Scalar.Variable(:x)