Skip to content

Commit

Permalink
add a simple overload, still need a pattern matcher
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger-luo committed Oct 11, 2023
1 parent 276a18c commit db2b779
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/expr/scalar/convert.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
Base.convert(::Type{Index.Type}, x::Int) = Index.Constant(x)
Base.convert(::Type{Index.Type}, x::Symbol) = Index.Variable(x)

Base.convert(::Type{Num.Type}, x::Real) = Num.Real(x)
Base.convert(::Type{Num.Type}, x::Real) = if iszero(x)
Num.Zero
elseif isone(x)
Num.One
else
Num.Real(x)
end

Base.convert(::Type{Num.Type}, x::Complex) = if iszero(imag(x))
Num.Real(real(x))
elseif iszero(real(x))
Num.Imag(imag(x))
else
Num.Complex(real(x), imag(x))
end

Base.convert(::Type{Num.Type}, x::typeof(MathConstants.e)) = Num.Euler
Base.convert(::Type{Num.Type}, x::typeof(pi)) = Num.Pi

Expand Down Expand Up @@ -39,6 +47,10 @@ end
Base.convert(::Type{Num.Type}, x::Scalar.Type) = onlyif_constant(x->x.:1, x)
Base.convert(::Type{T}, x::Num.Type) where {T <: Number} = if isa_variant(x, Num.Real)
return x.:1
elseif isa_variant(x, Num.Zero)
return zero(T)
elseif isa_variant(x, Num.One)
return one(T)
elseif isa_variant(x, Num.Imag)
return convert(T, x.:1) * im
elseif isa_variant(x, Num.Complex)
Expand Down
2 changes: 2 additions & 0 deletions src/expr/scalar/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ end
This is the basic numeric type.
"""
@data Num begin
Zero
One
Real(Float64)
Imag(Float64)
Complex(Float64, Float64)
Expand Down
4 changes: 4 additions & 0 deletions src/expr/scalar/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ function Data.show_data(io::IO, x::Num.Type)
return f.print("Num.Pi")
elseif isa_variant(x, Num.Euler)
return f.print("Num.Euler")
elseif isa_variant(x, Num.Zero)
return f.print("Num.Zero")
elseif isa_variant(x, Num.One)
return f.print("Num.One")
end

f.show(Data.variant_type(x))
Expand Down
11 changes: 11 additions & 0 deletions src/expr/scalar/syntax.jl
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
# some overloads for the syntax of scalar expressions
function Base.:(+)(lhs::Scalar.Type, rhs::Scalar.Type)
if isa_variant(lhs, Scalar.Constant) && isa_variant(rhs, Scalar.Constant)
Scalar.Constant(Number(lhs.:1) + Number(rhs.:1))
elseif isa_variant(rhs, Scalar.Constant)
Scalar.Sum(rhs, Dict(lhs => 1))
elseif isa_variant(lhs, Scalar.Constant)
Scalar.Sum(lhs, Dict(rhs => 1))
else
Scalar.Sum(Scalar.Constant(0), Dict(lhs => 1, rhs => 1))
end
end
2 changes: 2 additions & 0 deletions test/expr/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ x = Scalar.Sum(
Scalar.Constant(2.0) => 2.0,
)
)

Scalar.Constant(1.5) + Scalar.Variable(:x)

0 comments on commit db2b779

Please sign in to comment.