Skip to content

Commit

Permalink
refactor GADTs, discard some features and make their occurrences clar…
Browse files Browse the repository at this point in the history
…ified
  • Loading branch information
thautwarm committed Mar 4, 2019
1 parent 8bd913e commit 3077863
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 115 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ MLStyle.jl
- [Time Overhead](#time-overhead)
- [Allocation](#allocation)
- [Gallery](#gallery)

- [Contributing to MLStyle](https://github.com/thautwarm/MLStyle.jl#contributing-to-mlstyle)


Expand Down Expand Up @@ -125,16 +125,16 @@ end
@use GADT

@data public Exp{T} begin
Sym :: Symbol => Exp{A} where {A}
Val{A} :: A => Exp{A}
App{A, B} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B} where {A_ <: A}
Lam{A, B} :: (Symbol, Exp{B}) => Exp{Fun{A, B}}
If{A} :: (Exp{Bool}, Exp{A}, Exp{A}) => Exp{A}
Sym{A} :: Symbol => Exp{A}
Val{A} :: A => Exp{A}
App{A, B, A_} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B}
Lam{A, B} :: (Symbol, Exp{B}) => Exp{Fun{A, B}}
If{A} :: (Exp{Bool}, Exp{A}, Exp{A}) => Exp{A}
end

```

A simple intepreter implementation using GADTs could be found at `test/untyped_lam.jl`.
A simple intepreter implemented via GADTs could be found at `test/untyped_lam.jl`.


### Active Patterns
Expand All @@ -157,7 +157,7 @@ end # RegexMatch("123")
### Prerequisite

Recently the rudimentary benchmarks have been finished, which turns out that MLStyle.jl could be extremely fast
when matching cases are complicated, while in terms of some very simple cases(straightforward destruct shallow tuples, arrays and datatypes without recursive invocations), Match.jl could be faster.
when matching cases are complicated, while in terms of some very simple cases(straightforwardly destructure shallow tuples, arrays and datatypes without recursive invocations), Match.jl could be faster.

All benchmark scripts are provided at directory [Matrix-Benchmark](https://github.com/thautwarm/MLStyle.jl/blob/master/matrix-benchmark).

Expand Down
17 changes: 0 additions & 17 deletions docs/README.md

This file was deleted.

86 changes: 68 additions & 18 deletions docs/src/syntax/adt.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Syntax
'end'
<GADT> =
'@data' ['public' | 'internal'] <Type> 'begin'
'@data' ['public' | 'internal' | 'visible' 'in' <Seq Module>] <Type> 'begin'
(<ConsName>[{<Seq TVar>}] '::'
( '('
Expand All @@ -42,6 +42,34 @@ Syntax
```

Examples:

```julia

@data internal A begin
A1(Int, Int)
A2(a :: Int, b :: Int)
A3(a, b) # equals to `A3(a::Any, b::Any)`
end

@data B{T} begin
B1(T, Int)
B2(a :: T)
end

@data visible in MyModule C{T} begin
C1(T)
C2{A} :: Vector{A} => C{A}
end

abstract type DD end
@data visible in [Main, Base, Core] D{T} <: DD begin
D1 :: Int => D{T} where T # implicit type vars
D2{A, B} :: (A, B, Int) => D{Tuple{A, B}}
D3{A} :: A => D{Array{A, N}} where N # implicit type vars
end
```

Qualifier
----------------------

Expand Down Expand Up @@ -105,8 +133,7 @@ the extension system like Haskell here.

Since that you can define your own `where` pattern and export it to any modules.
Given an arbitrary Julia module, if you don't use `@use GADT` to enable GADT extensions and,
the qualifier of the your `where` pattern makes it visible here(current module),
your own `where` pattern could work here.
your own `where` pattern just works here.


Here's a simple intepreter implemented using GADTs.
Expand Down Expand Up @@ -149,14 +176,13 @@ And now let's define the operators of our abstract machine.
@data public Exp{T} begin

# The symbol referes to some variable in current context.
Sym :: Symbol => Exp{A} where {A}
Sym{A} :: Symbol => Exp{A}

# Value.
Val{A} :: A => Exp{A}

# Function application.
# add constraints to implicit tvars to get covariance
App{A, B} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B} where {A_ <: A}
App{A, B, A_ <: A} :: (Exp{Fun{A, B}}, Exp{A_}) => Exp{B}

# Lambda/Anonymous function.
Lam{A, B} :: (Symbol, Exp{B}) => Exp{Fun{A, B}}
Expand All @@ -166,18 +192,6 @@ And now let's define the operators of our abstract machine.
end
```

Something deserved to be remark here: when using this GADT syntax like

```
ConsName{TVars1...} :: ... => Exp{TVars2...} where {TVar3...}
```

You can add constraints to both `TVars1` and `TVars3`, and `TVars2` should be
always empty or a sequence of `Symbol`s. Furthermore, `TVars3` are the so-called
implicit type variables, and `TVars1` are the normal generic type variables.

Let's back to our topic.

To make function abstractions, we need a `substitute` operation.

```julia
Expand Down Expand Up @@ -251,3 +265,39 @@ ctx = Dict{Symbol, Any}()
```


Implicit Type Variables of Generalized ADT
----------------------------------------------------


Sometimes you might want this:

```julia
@use GADT

@data A{T} begin
A1 :: Int => A{T} where T
end
```
It means that for all `T`, we have `A{T} >: A1`, where `A1` is a case class and could be used as a constructor.

You can work with them in this way:
```julia
function string_A() :: A{String}
A1(2)
end

@assert String == @match string_A() begin
A{T} where T => T
end
```

Currently, there're several limitations with implicit type variables, say, you're not expected to use implicit type variables in
the argument types of constructors, like:

```julia
@data A{T} begin
A1 :: T => A{T} where T # NOT EXPECTED!
end
```

It's possible to achieve more flexible implicit type variables, but it's quite difficult for such a package without statically type checking.
81 changes: 45 additions & 36 deletions src/DataType.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using MLStyle.Infras
using MLStyle.Pervasives
using MLStyle.Render: render
using MLStyle.Record: def_record
using MLStyle.TypeVarExtraction

export @data
is_symbol_capitalized = isCapitalized string
Expand Down Expand Up @@ -59,26 +60,26 @@ end

function impl(t, variants :: Expr, mod :: Module)
l :: LineNumberNode = LineNumberNode(1)
abs_tvars = get_tvars(t)
abs_tvars = collect(get_tvars(t))
defs = []
abst() = isempty(abs_tvars) ? t : :($t{$(abs_tvars...)})
VAR = mangle(mod)

for each in variants.args
@match each begin
::LineNumberNode => (l = each)
:($case{$(tvars...)} :: ($(params...), ) => $(ret_ty) where {$(gtvars...)}) ||
:($case{$(tvars...)} :: ($(params...), ) => $(ret_ty && Do(gtvars=[]))) ||
:($case{$(tvars...)} :: $(arg_ty && Do(params = [arg_ty])) => $ret_ty where {$(gtvars...)}) ||
:($case{$(tvars...)} :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty && Do(gtvars=[]))) ||
:($(case && Do(tvars = [])) :: ($(params...), ) => $(ret_ty) where {$(gtvars...)}) ||
:($(case && Do(tvars = [])) :: ($(params...), ) => $(ret_ty && Do(gtvars=[]))) ||
:($(case && Do(tvars = [])) :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty) where {$(gtvars...)}) ||
:($(case && Do(tvars = [])) :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty && Do(gtvars=[]))) ||
:($case{$(generic_tvars...)} :: ($(params...), ) => $(ret_ty) where {$(implicit_tvars...)}) ||
:($case{$(generic_tvars...)} :: ($(params...), ) => $(ret_ty && Do(implicit_tvars=[]))) ||
:($case{$(generic_tvars...)} :: $(arg_ty && Do(params = [arg_ty])) => $ret_ty where {$(implicit_tvars...)}) ||
:($case{$(generic_tvars...)} :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty && Do(implicit_tvars=[]))) ||
:($(case && Do(generic_tvars = [])) :: ($(params...), ) => $(ret_ty) where {$(implicit_tvars...)}) ||
:($(case && Do(generic_tvars = [])) :: ($(params...), ) => $(ret_ty && Do(implicit_tvars=[]))) ||
:($(case && Do(generic_tvars = [])) :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty) where {$(implicit_tvars...)}) ||
:($(case && Do(generic_tvars = [])) :: $(arg_ty && Do(params = [arg_ty])) => $(ret_ty && Do(implicit_tvars=[]))) ||

:($case($((params && Do(tvars=abs_tvars))...))) && Do(ret_ty = abst(), gtvars=[]) => begin
:($case($((params && Do(generic_tvars=abs_tvars))...))) && Do(ret_ty = abst(), implicit_tvars=[]) => begin

config = Dict{Symbol, Any}([(gtvar isa Expr ? gtvar.args[1] : gtvar) => Any for gtvar in gtvars])
config = Dict{Symbol, Any}([(gtvar isa Expr ? gtvar.args[1] : gtvar) => Any for gtvar in implicit_tvars])

pairs = map(enumerate(params)) do (i, each)
@match each begin
Expand All @@ -91,48 +92,56 @@ function impl(t, variants :: Expr, mod :: Module)
definition_body = [Expr(:(::), field, ty) for (field, ty, _) in pairs]
constructor_args = [Expr(:(::), field, ty) for (field, _, ty) in pairs]
arg_names = [field for (field, _, _) in pairs]
spec_tvars = [tvars..., [Any for _ in gtvars]...]
getfields = [:($VAR.$field) for field in arg_names]

convert_fn = isempty(gtvars) ? nothing : let (=>) = (a, b) -> convert(b, a)
out_tvars = [tvars..., fill(nothing, length(gtvars))...] => Vector{Any}
inp_tvars = [tvars..., fill(nothing, length(gtvars))...] => Vector{Any}
fresh_tvars1 = fill(nothing, length(gtvars)) => Vector{Any}
fresh_tvars2 = fill(nothing, length(gtvars)) => Vector{Any}
tcovs = []
for (i, _) in enumerate(gtvars)
TAny = mangle(mod)
TCov = mangle(mod)
push!(tcovs, TCov)
fresh_tvars2[end-i + 1] = :($TCov <: $TAny)
fresh_tvars1[end-i + 1] = TAny
inp_tvars[end-i + 1] = TAny
out_tvars[end-i + 1] = TCov
definition_head = :($case{$(generic_tvars...), $(implicit_tvars...)})

generic_tvars = collect(map(to_tvar, extract_tvars(generic_tvars)))
implicit_tvars = collect(map(to_tvar, extract_tvars(implicit_tvars)))

convert_fn = isempty(implicit_tvars) ? nothing : let (=>) = (a, b) -> convert(b, a)
out_tvars = [generic_tvars; implicit_tvars]
inp_tvars = [generic_tvars; [mangle(mod) for _ in implicit_tvars]]
fresh_tvars1 = []
fresh_tvars2 = []
n_generic_tvars = length(generic_tvars)
for i in 1 + n_generic_tvars : length(implicit_tvars) + n_generic_tvars
TAny = inp_tvars[i]
TCov = out_tvars[i]
push!(fresh_tvars2, :($TCov <: $TAny))
push!(fresh_tvars1, TAny)
end
tcovs = reverse(tcovs)
arg2 = :($VAR :: $case{$(inp_tvars...)})
arg1_cov = :($Type{$case{$(out_tvars...)}})
arg1_abs = :($Type{$t{$(tcovs...)}})
arg1_abs = :($Type{$ret_ty})
casted = :($case{$(out_tvars...)}($(getfields...)))
quote
$Base.convert(::$arg1_cov, $arg2) where {$(tvars...), $(fresh_tvars1...), $(fresh_tvars2...)} = $casted
$Base.convert(::$arg1_abs, $arg2) where {$(tvars...), $(fresh_tvars1...), $(fresh_tvars2...)} = $casted
$Base.convert(::$arg1_cov, $arg2) where {$(generic_tvars...), $(fresh_tvars1...), $(fresh_tvars2...)} = $casted
$Base.convert(::$arg1_abs, $arg2) where {$(generic_tvars...), $(fresh_tvars1...), $(fresh_tvars2...)} = $casted
end
end

definition_head = :($case{$(tvars...), $(gtvars...)})
def_cons =
isempty(spec_tvars) ?
isempty(generic_tvars) && isempty(implicit_tvars) ?
!isempty(constructor_args) ?
quote
function $case(;$(constructor_args...))
$case($(arg_names...))
end
end :
nothing :
quote
function $case($(constructor_args...), ) where {$(tvars...)}
$case{$(spec_tvars...)}($(arg_names...))
let spec_tvars = [generic_tvars; [Any for _ in implicit_tvars]]
if isempty(generic_tvars)
quote
function $case($(constructor_args...), )
$case{$(spec_tvars...)}($(arg_names...))
end
end
else
quote
function $case($(constructor_args...), ) where {$(generic_tvars...)}
$case{$(spec_tvars...)}($(arg_names...))
end
end
end
end

Expand Down
2 changes: 2 additions & 0 deletions src/MLStyle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ using MLStyle.Pervasives

include("Qualification.jl")

include("TypeVarExtraction.jl")

include("StandardPatterns.jl")
using MLStyle.StandardPatterns

Expand Down
36 changes: 1 addition & 35 deletions src/StandardPatterns.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,11 @@
module StandardPatterns
# This module is designed for creating complex patterns from the primtive ones.

using MLStyle
using MLStyle.Toolz.List: cons, nil
using MLStyle.Infras
using MLStyle.MatchCore
using MLStyle.Qualification
using MLStyle.TypeVarExtraction

struct TypeVar
t :: Symbol
end

struct Relation
l :: Symbol
op :: Symbol
r
end


function any_constraint(t, forall)

function is_rel(::Relation)
true
end

function is_rel(::TypeVar)
false
end

!(t isa Symbol) || any(is_rel, collect(extract_tvars(forall)))
end

macro type_matching(t, forall)
quote
Expand Down Expand Up @@ -63,16 +39,6 @@ macro type_matching(t, forall)
end |> esc
end

function extract_tvars(t :: AbstractArray)
@match t begin
[] => nil()
[hd && if hd isa Symbol end, tl...] => cons(TypeVar(hd), extract_tvars(tl))
[:($hd <: $r), tl...] => cons(Relation(hd, :<:, r), extract_tvars(tl))
[:($hd >: $(r)), tl...] => cons(Relation(hd, Symbol(">:"), r), extract_tvars(tl))
_ => @error "invalid tvars"
end
end

def_pattern(StandardPatterns,
predicate = x -> x isa Expr && x.head == :(::),
rewrite = (tag, case, mod) ->
Expand Down
Loading

0 comments on commit 3077863

Please sign in to comment.