From c582f7151d04b60dd9525fd6dfbdc2091a70cfae Mon Sep 17 00:00:00 2001 From: thautwarm Date: Sun, 12 Feb 2023 17:30:52 +0800 Subject: [PATCH] introduce MLStyle.enum_matcher --- Project.toml | 2 +- docs/syntax/pattern.md | 16 +++++--- src/MatchImpl.jl | 93 +++++++++++++++++++++++++++++++++++++++--- test/issues/154.jl | 63 ++++++++++++++++++++++++++++ test/runtests.jl | 3 +- 5 files changed, 163 insertions(+), 14 deletions(-) create mode 100644 test/issues/154.jl diff --git a/Project.toml b/Project.toml index e282a7b..90ddac5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLStyle" uuid = "d8e11817-5142-5d16-987a-aa16d5891078" authors = ["thautwarm "] -version = "0.4.16" +version = "0.4.17" [deps] diff --git a/docs/syntax/pattern.md b/docs/syntax/pattern.md index 48b8fc7..228bb36 100644 --- a/docs/syntax/pattern.md +++ b/docs/syntax/pattern.md @@ -350,24 +350,28 @@ You can extend following APIs for your pattern objects, to implement custom patt `MLStyle.pattern_unref(pat_obj, expr_to_pat, [:a, :b]`. - `MLStyle.is_enum` - + In a pattern `[A, B]`, usually we think both `A` and `B` are capturing patterns. However, it is handy if we can have a pattern `A` whose match means comparing to the global variable `A`. To achieve this, we provide `MLStyle.is_enum`. For a visible global variable `A`, if `MLStyle.is_enum(A) == true`, a symbol `A` will compile into a pattern with `MLStyle.pattern_uncall(A, expr_to_ast, [], [], [])`. +- `MLStyle.enum_matcher(E, value_to_match)`: + + If `MLStyle.is_enum(E) == true`, we will call `MLStyle.enum_matcher(E, value_to_match)` to compile `E` into a pattern. + We present some examples for understandability: ### Support Pattern Matching for Julia Enums ```julia-console julia> using MLStyle -julia> using MLStyle.AbstractPatterns: literal julia> @enum E E1 E2 # mark E1, E2 as non-capturing patterns julia> MLStyle.is_enum(::E) = true -# tell the compiler how to match E1, E2 -julia> MLStyle.pattern_uncall(e::E, _, _, _, _) = literal(e) +# tell the compiler how to match E1 and E2 +# NOTE: make sure it evaluates to a boolean value! +julia> MLStyle.enum_matcher(e::E, expr) = :($e === $expr) julia> x = E2 julia> @match x begin E1 => "match E1!" @@ -383,13 +387,13 @@ julia> @macroexpand @match x begin :(let var"##return#261" = nothing var"##263" = x - if var"##263" === E1 + if E1 === var"##263" var"##return#261" = let "match E1!" end $(Expr(:symbolicgoto, Symbol("####final#262#264"))) end - if var"##263" === E2 + if E1 === var"##263" var"##return#261" = let "match E2!" end diff --git a/src/MatchImpl.jl b/src/MatchImpl.jl index 770ae85..653e03a 100644 --- a/src/MatchImpl.jl +++ b/src/MatchImpl.jl @@ -5,7 +5,19 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@compi end export is_enum, - pattern_uncall, pattern_unref, pattern_unmacrocall, @switch, @case, @tryswitch, @match, @trymatch, Where, gen_match, gen_switch + enum_matcher, + pattern_uncall, + pattern_unref, + pattern_unmacrocall, + @switch, + @case, + @tryswitch, + @match, + @trymatch, + Where, + gen_match, + gen_switch + export Q import MLStyle using MLStyle: mlstyle_report_deprecation_msg! @@ -17,8 +29,70 @@ using MLStyle.AbstractPatterns using MLStyle.AbstractPatterns.BasicPatterns OptionalLn = Union{LineNumberNode, Nothing} +""" + is_enum(EnumPattern)::Bool + +Convert the pattern `EnumPattern` to `EnumPattern()`. + +e.g., + ``` + abstract type AbsS end + struct S1 <: AbsS end + struct S2 <: AbsS end + MLStyle.pattern_uncall(::Type{S}, self, _, _, _) where {S<:AbsS} = literal(S()) + MLStyle.is_enum(::Type{<:AbsS}) = true + + x = S1() + @match x begin + S2 => 1 + S1 => 2 + end + ``` + +""" is_enum(_)::Bool = false -function pattern_uncall end + +""" + enum_matcher(Enum, value)::Expr + +Generates the expression used to test if `value` is the case `Enum`. + +NOTE that this only works when `is_enum(Enum)` is `true`!!! + + @match V begin + Enum => ... + @end + +Above single case matches when + +1. `enum_matcher(Enum, ::Any)` is not defined and `V == Enum`. +2. The expression generated from `enum_matcher(Enum, :V)` + evaluates to `true` under the current module. + +""" +function enum_matcher end + +struct _EnumCase{E} + pattern::E +end + +function pattern_uncall(enumCase::_EnumCase{E}, self, type_params, type_args, args) where E + isempty(type_params) || error("Enum type should not have type parameters!") + isempty(type_args) || error("Enum type should not have type arguments!") + isempty(args) || error("Enum type should not have arguments!") + + let enumPattern = enumCase.pattern + if hasmethod(MLStyle.enum_matcher, Tuple{E, Any}) + function via_enum_matcher(target, _, _) + return MLStyle.enum_matcher(enumPattern, target) + end + guard(via_enum_matcher) + else + pattern_uncall(enumPattern, self, type_params, type_args, args) + end + end +end + function pattern_unref end function pattern_unmacrocall(macro_func, self::Function, args::AbstractArray) @sswitch args begin @@ -85,6 +159,12 @@ function guess_type_from_expr(m::Module, ex::Any, tps::Set{Symbol}) end end +struct ModuleBoundedEx2tf <: Function + m::Module +end + +@inline (self::ModuleBoundedEx2tf)(arg) = ex2tf(self.m, arg) + ex2tf(m::Module, @nospecialize(a)) = literal(a) ex2tf(m::Module, l::LineNumberNode) = wildcard ex2tf(m::Module, q::QuoteNode) = literal(q.value) @@ -97,8 +177,8 @@ ex2tf(m::Module, n::Symbol) = else if isdefined(m, n) p = getfield(m, n) - rec(x) = ex2tf(m, x) - is_enum(p) && return pattern_uncall(p, rec, [], [], []) + rec = ModuleBoundedEx2tf(m) + is_enum(p) && return pattern_uncall(_EnumCase(p), rec, [], [], []) end P_capture(n) end @@ -112,7 +192,8 @@ function ex2tf(m::Module, s::QuotePattern) end function ex2tf(m::Module, w::Where) - rec(x) = ex2tf(m, x) + rec = ModuleBoundedEx2tf(m) + @sswitch w begin @case Where(; value = val, type = t, type_parameters = tps) @@ -170,7 +251,7 @@ end function ex2tf(m::Module, ex::Expr) eval = m.eval - rec(x) = ex2tf(m, x) + rec = ModuleBoundedEx2tf(m) @sswitch ex begin @case Expr(:||, args) diff --git a/test/issues/154.jl b/test/issues/154.jl new file mode 100644 index 0000000..29c4c2f --- /dev/null +++ b/test/issues/154.jl @@ -0,0 +1,63 @@ +using MLStyle +import MLStyle.AbstractPatterns + +abstract type Enum154 end + +struct Enum154_1_Cons <: Enum154 end + +struct Enum154_2_Cons <: Enum154 + x::Vector{Int} +end +MLStyle.@as_record Enum154_2_Cons + +MLStyle.is_enum(::Enum154) = true +MLStyle.enum_matcher(enum::Enum154, expr) = :($enum === $expr) + +const Enum154_1 = Enum154_1_Cons() + +function Base.:(==)(a::Enum154, b::Enum154) + @match (a, b) begin + (Enum154_1, Enum154_1) => true + (Enum154_2_Cons(xs), Enum154_2_Cons(ys)) => xs == ys + _ => false + end +end + +# traditional behaviour + +@enum JuliaEnum_154 begin + JuliaEnum_154_a + JuliaEnum_154_b + JuliaEnum_154_c +end + +MLStyle.is_enum(::JuliaEnum_154) = true + +MLStyle.pattern_uncall(a::JuliaEnum_154, ::Vararg) = MLStyle.AbstractPatterns.literal(a) + +function eq_154(a, b) + @match (a, b) begin + (JuliaEnum_154_a, JuliaEnum_154_a) => true + (JuliaEnum_154_b, JuliaEnum_154_b) => true + (JuliaEnum_154_c, JuliaEnum_154_c) => true + _ => false + end +end + +@testset "issue 154" begin + @testset "tag matching support" begin + @test Enum154_1 == Enum154_1 + @test Enum154_2_Cons([1, 2, 3]) == Enum154_2_Cons([1, 2, 3]) + @test Enum154_2_Cons([1, 2, 3]) != Enum154_2_Cons([1, 2, 4]) + @test Enum154_1 != Enum154_2_Cons([1, 2, 3]) + end + + @testset "traditional" begin + @test eq_154(JuliaEnum_154_a, JuliaEnum_154_a) + @test eq_154(JuliaEnum_154_b, JuliaEnum_154_b) + @test eq_154(JuliaEnum_154_c, JuliaEnum_154_c) + @test !eq_154(JuliaEnum_154_a, JuliaEnum_154_b) + @test !eq_154(JuliaEnum_154_b, JuliaEnum_154_c) + @test !eq_154(JuliaEnum_154_c, JuliaEnum_154_a) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 48d5ece..a9a546a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,7 +52,6 @@ MODULE = TestModule @use GADT -include("issues/109.jl") include("when.jl") include("switch.jl") include("untyped_lam.jl") @@ -80,5 +79,7 @@ include("MQuery/test.jl") include("issues/87.jl") include("issues/62.jl") +include("issues/109.jl") +include("issues/154.jl") end