diff --git a/docs/language-reference.rst b/docs/language-reference.rst index 6c926b9376..d29419eb58 100644 --- a/docs/language-reference.rst +++ b/docs/language-reference.rst @@ -936,6 +936,10 @@ is optional if ``body`` is a ``let`` expression. The binding is not let-generalised, meaning it has a monomorphic type. This can be significant if ``e`` is of functional type. +If ``e`` is of type ``i64`` and ``pat`` binds only a single name +``v``, then the type of the overall expression is the type of +``body``, but with any occurence of ``v`` replaced by ``e``. + ``let [n] pat = e in body`` ........................... @@ -1156,11 +1160,11 @@ sizes. Size going out of scope ....................... -An unknown size is created when the proper size of an array refers to -a name that has gone out of scope:: +An unknown size is created in some cases when the a type references a +name that has gone out of scope:: - let c = a + b - in replicate c 0 + match ... + case #some c -> replicate c 0 The type of ``replicate c 0`` is ``[c]i32``, but since ``c`` is locally bound, the type of the entire expression is ``[k]i32`` for diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index a444b9d0d7..9e7e9ff110 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -610,7 +610,20 @@ checkExp (AppExp (LetPat sizes pat e body loc) _) = do bindingPat sizes' pat t $ \pat' -> do body' <- checkExp body body_t <- expTypeFully body' - (body_t', retext) <- unscopeType loc (patNames pat') body_t + + -- If the bound expression is of type i64, then we replace the + -- pattern name with the expression in the type of the body. + -- Otherwise, we need to come up with unknown sizes for the + -- sizes going out of scope. + t' <- normType t -- Might be overloaded integer until now. + (body_t', retext) <- + case (t', patNames pat') of + (Scalar (Prim (Signed Int64)), [v]) + | not $ hasBinding e' -> do + let f x = if x == v then Just (ExpSubst e') else Nothing + pure (applySubst f body_t, []) + _ -> + unscopeType loc (patNames pat') body_t pure $ AppExp diff --git a/tests/returntype-error4.fut b/tests/returntype-error4.fut index 013cbb4494..15525e522f 100644 --- a/tests/returntype-error4.fut +++ b/tests/returntype-error4.fut @@ -2,7 +2,7 @@ -- error: Cannot generalise def foo n = - let m = n+1 + let (m,_) = (n+1,true) in (iota ((m+1)+1), zip (iota (m+1)), zip (iota m)) diff --git a/tests/shapes/letshape5.fut b/tests/shapes/letshape5.fut index 2a4f39b5cd..568fe1944e 100644 --- a/tests/shapes/letshape5.fut +++ b/tests/shapes/letshape5.fut @@ -1,7 +1,8 @@ -- A size goes out of scope. -- == --- error: "m" +-- input { 2i64 } +-- output { [0i64,1i64,2i64] } -def main (n: i64) : [n]i32 = - let m = n +def main (n: i64) : [n+1]i64 = + let m = n + 1 in iota m diff --git a/tests/shapes/unknown-param.fut b/tests/shapes/unknown-param.fut index 971e9dfb71..5175c9c9d9 100644 --- a/tests/shapes/unknown-param.fut +++ b/tests/shapes/unknown-param.fut @@ -4,5 +4,5 @@ -- error: Unknown size.*in parameter def f (x: bool) = - let n = if x then 10 else 20 + let (n,_) = if x then (10,true) else (20,true) in \(_: [n]bool) -> true diff --git a/tests/sumtypes/sumtype51.fut b/tests/sumtypes/sumtype51.fut index ec5dceb0cc..c782f18f09 100644 --- a/tests/sumtypes/sumtype51.fut +++ b/tests/sumtypes/sumtype51.fut @@ -5,7 +5,7 @@ type option 'a = #None | #Some a def gen () : ?[n].[n]i32 = - let n = 0 + let (n,_) = (0,true) in replicate n 0i32 entry main b: option ([]i32) = diff --git a/tests/sumtypes/sumtype52.fut b/tests/sumtypes/sumtype52.fut index c9acdffcdd..1ed6fca70d 100644 --- a/tests/sumtypes/sumtype52.fut +++ b/tests/sumtypes/sumtype52.fut @@ -5,7 +5,7 @@ type option 'a = #None | #Some a def gen () : ?[n].[n]i32 = - let n = 0 + let (n,_) = (0,true) in replicate n 0i32 def ite b t f = if b then t() else f() diff --git a/tests/unscoping.fut b/tests/unscoping.fut index d11fc0f29f..903b06d874 100644 --- a/tests/unscoping.fut +++ b/tests/unscoping.fut @@ -2,7 +2,7 @@ -- error: Cannot apply "bar" to "xs" def foo n = - let m = n+1 + let (m,_) = (n+1,true) in (iota ((m+1)+1), \_ -> iota (m+1), \_ -> iota m)