diff --git a/README.md b/README.md index 337886b..7d18a40 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ MLStyle.jl - [Time Overhead](#time-overhead) - [Allocation](#allocation) - [Gallery](#gallery) - + - [Contributing to MLStyle](https://github.com/thautwarm/MLStyle.jl#contributing-to-mlstyle) @@ -125,16 +125,16 @@ end @use GADT @data public Exp{T} begin - Sym :: Symbol => Exp{A} where {A} - Val{A} :: A => Exp{A} - App{A, B} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B} where {A_ <: A} - Lam{A, B} :: (Symbol, Exp{B}) => Exp{Fun{A, B}} - If{A} :: (Exp{Bool}, Exp{A}, Exp{A}) => Exp{A} + Sym{A} :: Symbol => Exp{A} + Val{A} :: A => Exp{A} + App{A, B, A_} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B} + Lam{A, B} :: (Symbol, Exp{B}) => Exp{Fun{A, B}} + If{A} :: (Exp{Bool}, Exp{A}, Exp{A}) => Exp{A} end ``` -A simple intepreter implementation using GADTs could be found at `test/untyped_lam.jl`. +A simple intepreter implemented via GADTs could be found at `test/untyped_lam.jl`. ### Active Patterns @@ -157,7 +157,7 @@ end # RegexMatch("123") ### Prerequisite Recently the rudimentary benchmarks have been finished, which turns out that MLStyle.jl could be extremely fast -when matching cases are complicated, while in terms of some very simple cases(straightforward destruct shallow tuples, arrays and datatypes without recursive invocations), Match.jl could be faster. +when matching cases are complicated, while in terms of some very simple cases(straightforwardly destructure shallow tuples, arrays and datatypes without recursive invocations), Match.jl could be faster. All benchmark scripts are provided at directory [Matrix-Benchmark](https://github.com/thautwarm/MLStyle.jl/blob/master/matrix-benchmark). diff --git a/docs/README.md b/docs/README.md deleted file mode 100644 index 3b16cb1..0000000 --- a/docs/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# MLStyle documentation - -Syntax -=========== - -ADT ------------ -[Doc](https://github.com/thautwarm/MLStyle.jl/blob/master/docs/src/syntax/adt.md) here - -Pattern ------------- -[Doc](https://github.com/thautwarm/MLStyle.jl/blob/master/docs/src/syntax/pattern.md) there - -Pattern Function ------------------ -[Doc](https://github.com/thautwarm/MLStyle.jl/blob/master/docs/src/syntax/pattern-function.md) mer. - diff --git a/docs/src/syntax/adt.md b/docs/src/syntax/adt.md index 62856b0..8a7f611 100644 --- a/docs/src/syntax/adt.md +++ b/docs/src/syntax/adt.md @@ -26,7 +26,7 @@ Syntax 'end' = - '@data' ['public' | 'internal'] 'begin' + '@data' ['public' | 'internal' | 'visible' 'in' ] 'begin' ([{}] '::' ( '(' @@ -42,6 +42,34 @@ Syntax ``` +Examples: + +```julia + +@data internal A begin + A1(Int, Int) + A2(a :: Int, b :: Int) + A3(a, b) # equals to `A3(a::Any, b::Any)` +end + +@data B{T} begin + B1(T, Int) + B2(a :: T) +end + +@data visible in MyModule C{T} begin + C1(T) + C2{A} :: Vector{A} => C{A} +end + +abstract type DD end +@data visible in [Main, Base, Core] D{T} <: DD begin + D1 :: Int => D{T} where T # implicit type vars + D2{A, B} :: (A, B, Int) => D{Tuple{A, B}} + D3{A} :: A => D{Array{A, N}} where N # implicit type vars +end +``` + Qualifier ---------------------- @@ -105,8 +133,7 @@ the extension system like Haskell here. Since that you can define your own `where` pattern and export it to any modules. Given an arbitrary Julia module, if you don't use `@use GADT` to enable GADT extensions and, -the qualifier of the your `where` pattern makes it visible here(current module), -your own `where` pattern could work here. +your own `where` pattern just works here. Here's a simple intepreter implemented using GADTs. @@ -149,14 +176,13 @@ And now let's define the operators of our abstract machine. @data public Exp{T} begin # The symbol referes to some variable in current context. - Sym :: Symbol => Exp{A} where {A} + Sym{A} :: Symbol => Exp{A} # Value. Val{A} :: A => Exp{A} # Function application. - # add constraints to implicit tvars to get covariance - App{A, B} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B} where {A_ <: A} + App{A, B, A_ <: A} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B} # Lambda/Anonymous function. Lam{A, B} :: (Symbol, Exp{B}) => Exp{Fun{A, B}} @@ -166,18 +192,6 @@ And now let's define the operators of our abstract machine. end ``` -Something deserved to be remark here: when using this GADT syntax like - -``` - ConsName{TVars1...} :: ... => Exp{TVars2...} where {TVar3...} -``` - -You can add constraints to both `TVars1` and `TVars3`, and `TVars2` should be -always empty or a sequence of `Symbol`s. Furthermore, `TVars3` are the so-called -implicit type variables, and `TVars1` are the normal generic type variables. - -Let's back to our topic. - To make function abstractions, we need a `substitute` operation. ```julia @@ -251,3 +265,39 @@ ctx = Dict{Symbol, Any}() ``` +Implicit Type Variables of Generalized ADT +---------------------------------------------------- + + +Sometimes you might want this: + +```julia +@use GADT + +@data A{T} begin + A1 :: Int => A{T} where T +end +``` +It means that for all `T`, we have `A{T} >: A1`, where `A1` is a case class and could be used as a constructor. + +You can work with them in this way: +```julia +function string_A() :: A{String} + A1(2) +end + +@assert String == @match string_A() begin + A{T} where T => T +end +``` + +Currently, there're several limitations with implicit type variables, say, you're not expected to use implicit type variables in +the argument types of constructors, like: + +```julia +@data A{T} begin + A1 :: T => A{T} where T # NOT EXPECTED! +end +``` + +It's possible to achieve more flexible implicit type variables, but it's quite difficult for such a package without statically type checking. \ No newline at end of file diff --git a/src/DataType.jl b/src/DataType.jl index ab34155..78bae13 100644 --- a/src/DataType.jl +++ b/src/DataType.jl @@ -7,6 +7,7 @@ using MLStyle.Infras using MLStyle.Pervasives using MLStyle.Render: render using MLStyle.Record: def_record +using MLStyle.TypeVarExtraction export @data is_symbol_capitalized = isCapitalized ∘ string @@ -59,7 +60,7 @@ end function impl(t, variants :: Expr, mod :: Module) l :: LineNumberNode = LineNumberNode(1) - abs_tvars = get_tvars(t) + abs_tvars = collect(get_tvars(t)) defs = [] abst() = isempty(abs_tvars) ? t : :($t{$(abs_tvars...)}) VAR = mangle(mod) @@ -67,18 +68,18 @@ function impl(t, variants :: Expr, mod :: Module) for each in variants.args @match each begin ::LineNumberNode => (l = each) - :($case{$(tvars...)} :: ($(params...), ) => $(ret_ty) where {$(gtvars...)}) || - :($case{$(tvars...)} :: ($(params...), ) => $(ret_ty && Do(gtvars=[]))) || - :($case{$(tvars...)} :: $(arg_ty && Do(params = [arg_ty])) => $ret_ty where {$(gtvars...)}) || - :($case{$(tvars...)} :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty && Do(gtvars=[]))) || - :($(case && Do(tvars = [])) :: ($(params...), ) => $(ret_ty) where {$(gtvars...)}) || - :($(case && Do(tvars = [])) :: ($(params...), ) => $(ret_ty && Do(gtvars=[]))) || - :($(case && Do(tvars = [])) :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty) where {$(gtvars...)}) || - :($(case && Do(tvars = [])) :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty && Do(gtvars=[]))) || + :($case{$(generic_tvars...)} :: ($(params...), ) => $(ret_ty) where {$(implicit_tvars...)}) || + :($case{$(generic_tvars...)} :: ($(params...), ) => $(ret_ty && Do(implicit_tvars=[]))) || + :($case{$(generic_tvars...)} :: $(arg_ty && Do(params = [arg_ty])) => $ret_ty where {$(implicit_tvars...)}) || + :($case{$(generic_tvars...)} :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty && Do(implicit_tvars=[]))) || + :($(case && Do(generic_tvars = [])) :: ($(params...), ) => $(ret_ty) where {$(implicit_tvars...)}) || + :($(case && Do(generic_tvars = [])) :: ($(params...), ) => $(ret_ty && Do(implicit_tvars=[]))) || + :($(case && Do(generic_tvars = [])) :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty) where {$(implicit_tvars...)}) || + :($(case && Do(generic_tvars = [])) :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty && Do(implicit_tvars=[]))) || - :($case($((params && Do(tvars=abs_tvars))...))) && Do(ret_ty = abst(), gtvars=[]) => begin + :($case($((params && Do(generic_tvars=abs_tvars))...))) && Do(ret_ty = abst(), implicit_tvars=[]) => begin - config = Dict{Symbol, Any}([(gtvar isa Expr ? gtvar.args[1] : gtvar) => Any for gtvar in gtvars]) + config = Dict{Symbol, Any}([(gtvar isa Expr ? gtvar.args[1] : gtvar) => Any for gtvar in implicit_tvars]) pairs = map(enumerate(params)) do (i, each) @match each begin @@ -91,38 +92,36 @@ function impl(t, variants :: Expr, mod :: Module) definition_body = [Expr(:(::), field, ty) for (field, ty, _) in pairs] constructor_args = [Expr(:(::), field, ty) for (field, _, ty) in pairs] arg_names = [field for (field, _, _) in pairs] - spec_tvars = [tvars..., [Any for _ in gtvars]...] getfields = [:($VAR.$field) for field in arg_names] - - convert_fn = isempty(gtvars) ? nothing : let (=>) = (a, b) -> convert(b, a) - out_tvars = [tvars..., fill(nothing, length(gtvars))...] => Vector{Any} - inp_tvars = [tvars..., fill(nothing, length(gtvars))...] => Vector{Any} - fresh_tvars1 = fill(nothing, length(gtvars)) => Vector{Any} - fresh_tvars2 = fill(nothing, length(gtvars)) => Vector{Any} - tcovs = [] - for (i, _) in enumerate(gtvars) - TAny = mangle(mod) - TCov = mangle(mod) - push!(tcovs, TCov) - fresh_tvars2[end-i + 1] = :($TCov <: $TAny) - fresh_tvars1[end-i + 1] = TAny - inp_tvars[end-i + 1] = TAny - out_tvars[end-i + 1] = TCov + definition_head = :($case{$(generic_tvars...), $(implicit_tvars...)}) + + generic_tvars = collect(map(to_tvar, extract_tvars(generic_tvars))) + implicit_tvars = collect(map(to_tvar, extract_tvars(implicit_tvars))) + + convert_fn = isempty(implicit_tvars) ? nothing : let (=>) = (a, b) -> convert(b, a) + out_tvars = [generic_tvars; implicit_tvars] + inp_tvars = [generic_tvars; [mangle(mod) for _ in implicit_tvars]] + fresh_tvars1 = [] + fresh_tvars2 = [] + n_generic_tvars = length(generic_tvars) + for i in 1 + n_generic_tvars : length(implicit_tvars) + n_generic_tvars + TAny = inp_tvars[i] + TCov = out_tvars[i] + push!(fresh_tvars2, :($TCov <: $TAny)) + push!(fresh_tvars1, TAny) end - tcovs = reverse(tcovs) arg2 = :($VAR :: $case{$(inp_tvars...)}) arg1_cov = :($Type{$case{$(out_tvars...)}}) - arg1_abs = :($Type{$t{$(tcovs...)}}) + arg1_abs = :($Type{$ret_ty}) casted = :($case{$(out_tvars...)}($(getfields...))) quote - $Base.convert(::$arg1_cov, $arg2) where {$(tvars...), $(fresh_tvars1...), $(fresh_tvars2...)} = $casted - $Base.convert(::$arg1_abs, $arg2) where {$(tvars...), $(fresh_tvars1...), $(fresh_tvars2...)} = $casted + $Base.convert(::$arg1_cov, $arg2) where {$(generic_tvars...), $(fresh_tvars1...), $(fresh_tvars2...)} = $casted + $Base.convert(::$arg1_abs, $arg2) where {$(generic_tvars...), $(fresh_tvars1...), $(fresh_tvars2...)} = $casted end end - definition_head = :($case{$(tvars...), $(gtvars...)}) def_cons = - isempty(spec_tvars) ? + isempty(generic_tvars) && isempty(implicit_tvars) ? !isempty(constructor_args) ? quote function $case(;$(constructor_args...)) @@ -130,9 +129,19 @@ function impl(t, variants :: Expr, mod :: Module) end end : nothing : - quote - function $case($(constructor_args...), ) where {$(tvars...)} - $case{$(spec_tvars...)}($(arg_names...)) + let spec_tvars = [generic_tvars; [Any for _ in implicit_tvars]] + if isempty(generic_tvars) + quote + function $case($(constructor_args...), ) + $case{$(spec_tvars...)}($(arg_names...)) + end + end + else + quote + function $case($(constructor_args...), ) where {$(generic_tvars...)} + $case{$(spec_tvars...)}($(arg_names...)) + end + end end end diff --git a/src/MLStyle.jl b/src/MLStyle.jl index f39f885..d988f7a 100644 --- a/src/MLStyle.jl +++ b/src/MLStyle.jl @@ -41,6 +41,8 @@ using MLStyle.Pervasives include("Qualification.jl") +include("TypeVarExtraction.jl") + include("StandardPatterns.jl") using MLStyle.StandardPatterns diff --git a/src/StandardPatterns.jl b/src/StandardPatterns.jl index 24f8893..848582f 100644 --- a/src/StandardPatterns.jl +++ b/src/StandardPatterns.jl @@ -1,35 +1,11 @@ module StandardPatterns # This module is designed for creating complex patterns from the primtive ones. - using MLStyle -using MLStyle.Toolz.List: cons, nil using MLStyle.Infras using MLStyle.MatchCore using MLStyle.Qualification +using MLStyle.TypeVarExtraction -struct TypeVar - t :: Symbol -end - -struct Relation - l :: Symbol - op :: Symbol - r -end - - -function any_constraint(t, forall) - - function is_rel(::Relation) - true - end - - function is_rel(::TypeVar) - false - end - - !(t isa Symbol) || any(is_rel, collect(extract_tvars(forall))) -end macro type_matching(t, forall) quote @@ -63,16 +39,6 @@ macro type_matching(t, forall) end |> esc end -function extract_tvars(t :: AbstractArray) - @match t begin - [] => nil() - [hd && if hd isa Symbol end, tl...] => cons(TypeVar(hd), extract_tvars(tl)) - [:($hd <: $r), tl...] => cons(Relation(hd, :<:, r), extract_tvars(tl)) - [:($hd >: $(r)), tl...] => cons(Relation(hd, Symbol(">:"), r), extract_tvars(tl)) - _ => @error "invalid tvars" - end -end - def_pattern(StandardPatterns, predicate = x -> x isa Expr && x.head == :(::), rewrite = (tag, case, mod) -> diff --git a/src/TypeVarExtraction.jl b/src/TypeVarExtraction.jl new file mode 100644 index 0000000..50c0fc7 --- /dev/null +++ b/src/TypeVarExtraction.jl @@ -0,0 +1,61 @@ +module TypeVarExtraction + +using MLStyle +using MLStyle.Toolz.List: cons, nil +using MLStyle.Infras +using MLStyle.MatchCore +using MLStyle.Qualification + +export TypeVar, Relation, ChainRelation +export any_constraint, to_tvar, extract_tvars + +struct TypeVar + t :: Symbol +end + +struct Relation + l :: Symbol + op :: Symbol + r +end + +struct ChainRelation + var :: Symbol + lower + super +end + +function any_constraint(t, forall) + + function is_rel(::Relation) + true + end + + function is_rel(::ChainRelation) + true + end + + function is_rel(::TypeVar) + false + end + + !(t isa Symbol) || any(is_rel, collect(extract_tvars(forall))) +end + +function extract_tvars(t :: AbstractArray) + @match t begin + [] => nil() + [hd && if hd isa Symbol end, tl...] => cons(TypeVar(hd), extract_tvars(tl)) + [:($hd <: $r), tl...] => cons(Relation(hd, :<:, r), extract_tvars(tl)) + [:($hd >: $(r)), tl...] => cons(Relation(hd, Symbol(">:"), r), extract_tvars(tl)) + [:($lower <: $hd <: $super), tl...] || + [:($super >: $hd >: $lower), tl...] => cons(ChainRelation(hd, lower, super)) + _ => @syntax_err "invalid tvars($t)" + end +end + +to_tvar(t::TypeVar) = t.t +to_tvar(t::Relation) = t.l +to_tvar(t::ChainRelation) = t.var + +end \ No newline at end of file diff --git a/test/untyped_lam.jl b/test/untyped_lam.jl index cab4210..d103e57 100644 --- a/test/untyped_lam.jl +++ b/test/untyped_lam.jl @@ -23,7 +23,7 @@ end @data public Exp{T} begin Sym :: Symbol => Exp{A} where {A} Val{A} :: A => Exp{A} - App{A, B} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B} where {A_ <: A} + App{A, B, A_ <: A} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B} Lam{A, B} :: (Symbol, Exp{B}) => Exp{Fun{A, B}} If{A} :: (Exp{Bool}, Exp{A}, Exp{A}) => Exp{A} end