aboutsummaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs91
1 files changed, 71 insertions, 20 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 663b83f..873a8a5 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -4,7 +4,9 @@
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
@@ -16,7 +18,6 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE FlexibleInstances #-}
module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where
import Data.Functor.Const
@@ -62,6 +63,7 @@ data Expr x env t where
-- array operations
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
+ EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)
-- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
@@ -70,6 +72,7 @@ data Expr x env t where
EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
+ EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))
-- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
-- an implementation is allowed to parallelise this thing and store the b
@@ -231,6 +234,7 @@ typeOf = \case
EConstArr _ n t _ -> STArr n (STScal t)
EBuild _ n _ e -> STArr n (typeOf e)
+ EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a)
EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
@@ -238,6 +242,7 @@ typeOf = \case
EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
+ EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2)
EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2)
@@ -282,6 +287,7 @@ extOf = \case
ELCase x _ _ _ _ -> x
EConstArr x _ _ _ -> x
EBuild x _ _ _ -> x
+ EMap x _ _ -> x
EFold1Inner x _ _ _ _ -> x
ESum1Inner x _ -> x
EUnit x _ -> x
@@ -289,6 +295,7 @@ extOf = \case
EMaximum1Inner x _ -> x
EMinimum1Inner x _ -> x
EReshape x _ _ _ -> x
+ EZip x _ _ -> x
EFold1InnerD1 x _ _ _ _ -> x
EFold1InnerD2 x _ _ _ _ -> x
EConst x _ _ -> x
@@ -331,12 +338,14 @@ travExt f = \case
ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c
EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a
EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b
EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e
EUnit x e -> EUnit <$> f x <*> travExt f e
EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b
EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
@@ -393,6 +402,7 @@ subst' f w = \case
ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b)
EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
ESum1Inner x e -> ESum1Inner x (subst' f w e)
EUnit x e -> EUnit x (subst' f w e)
@@ -400,6 +410,7 @@ subst' f w = \case
EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
+ EZip x a b -> EZip x (subst' f w a) (subst' f w b)
EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EConst x t v -> EConst x t v
@@ -521,34 +532,24 @@ eidxEq (SS n) a b
emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)
emap f arr
- | STArr n t <- typeOf arr
+ | STArr _ t <- typeOf arr
, Dict <- styKnown t
- = ELet ext arr $
- EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) f
+ = EMap ext f arr
ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
ezipWith f arr1 arr2
- | STArr n t1 <- typeOf arr1
+ | STArr _ t1 <- typeOf arr1
, STArr _ t2 <- typeOf arr2
, Dict <- styKnown t1
, Dict <- styKnown t2
- = ELet ext arr1 $
- ELet ext (weakenExpr WSink arr2) $
- EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $
- weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f
+ = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ)
+ IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ)
+ IS (IS i) -> EVar ext t (IS i))
+ f)
+ (EZip ext arr1 arr2)
ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
-ezip arr1 arr2 =
- let STArr _ t1 = typeOf arr1
- STArr _ t2 = typeOf arr2
- in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2
+ezip = EZip ext
eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
@@ -652,3 +653,53 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t
go SMTMaybe{} _ = ENil ext
go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
go SMTScal{} _ = ENil ext
+
+splitSparsePair
+ :: -- given a sparsity
+ STy (TPair a b) -> Sparse (TPair a b) t'
+ -> (forall a' b'.
+ -- I give you back two sparsities for a and b
+ Sparse a a' -> Sparse b b'
+ -- furthermore, I tell you that either your t' is already this (a', b') pair...
+ -> Either
+ (t' :~: TPair a' b')
+ -- or I tell you how to construct a' and b' from t', given an actual t'
+ (forall r' env.
+ Idx env t'
+ -> (forall env'.
+ (forall c. Ex env' c -> Ex env c)
+ -> Ex env' a' -> Ex env' b' -> r')
+ -> r')
+ -> r)
+ -> r
+splitSparsePair _ SpAbsent k =
+ k SpAbsent SpAbsent $ Right $ \_ k2 ->
+ k2 id (ENil ext) (ENil ext)
+splitSparsePair _ (SpPair s1 s2) k1 =
+ k1 s1 s2 $ Left Refl
+splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k =
+ let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in
+ k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 ->
+ k2 (elet $
+ emaybe (EVar ext (STMaybe (applySparse s t)) i)
+ (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2)))
+ (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ)))))
+ (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ))
+
+splitSparsePair _ (SpSparse SpAbsent) k =
+ k SpAbsent SpAbsent $ Right $ \_ k2 ->
+ k2 id (ENil ext) (ENil ext)
+-- -- TODO: having to handle sparse-of-sparse at all is ridiculous
+splitSparsePair t (SpSparse (SpSparse s)) k =
+ splitSparsePair t (SpSparse s) $ \s1 s2 eres ->
+ k s1 s2 $ Right $ \i k2 ->
+ case eres of
+ Left refl -> case refl of {}
+ Right f ->
+ f IZ $ \wrap e1 e2 ->
+ k2 (\body ->
+ elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i)
+ (ENothing ext (applySparse s t))
+ (evar IZ)) $
+ wrap body)
+ e1 e2