diff options
Diffstat (limited to 'src/AST.hs')
| -rw-r--r-- | src/AST.hs | 91 | 
1 files changed, 71 insertions, 20 deletions
@@ -4,7 +4,9 @@  {-# LANGUAGE DeriveTraversable #-}  {-# LANGUAGE EmptyCase #-}  {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-} @@ -16,7 +18,6 @@  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE FlexibleInstances #-}  module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where  import Data.Functor.Const @@ -62,6 +63,7 @@ data Expr x env t where    -- array operations    EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))    EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) +  EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)    -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)    EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)    ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) @@ -70,6 +72,7 @@ data Expr x env t where    EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))    EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))    EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) +  EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))    -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:    -- an implementation is allowed to parallelise this thing and store the b @@ -231,6 +234,7 @@ typeOf = \case    EConstArr _ n t _ -> STArr n (STScal t)    EBuild _ n _ e -> STArr n (typeOf e) +  EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a)    EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t    ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t    EUnit _ e -> STArr SZ (typeOf e) @@ -238,6 +242,7 @@ typeOf = \case    EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t    EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t    EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t +  EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2)    EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)    EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) @@ -282,6 +287,7 @@ extOf = \case    ELCase x _ _ _ _ -> x    EConstArr x _ _ _ -> x    EBuild x _ _ _ -> x +  EMap x _ _ -> x    EFold1Inner x _ _ _ _ -> x    ESum1Inner x _ -> x    EUnit x _ -> x @@ -289,6 +295,7 @@ extOf = \case    EMaximum1Inner x _ -> x    EMinimum1Inner x _ -> x    EReshape x _ _ _ -> x +  EZip x _ _ -> x    EFold1InnerD1 x _ _ _ _ -> x    EFold1InnerD2 x _ _ _ _ -> x    EConst x _ _ -> x @@ -331,12 +338,14 @@ travExt f = \case    ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c    EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a    EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b +  EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b    EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c    ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e    EUnit x e -> EUnit <$> f x <*> travExt f e    EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b    EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e    EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e +  EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b    EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b    EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c    EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c @@ -393,6 +402,7 @@ subst' f w = \case    ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)    EConstArr x n t a -> EConstArr x n t a    EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) +  EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b)    EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)    ESum1Inner x e -> ESum1Inner x (subst' f w e)    EUnit x e -> EUnit x (subst' f w e) @@ -400,6 +410,7 @@ subst' f w = \case    EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)    EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)    EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) +  EZip x a b -> EZip x (subst' f w a) (subst' f w b)    EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)    EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)    EConst x t v -> EConst x t v @@ -521,34 +532,24 @@ eidxEq (SS n) a b  emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)  emap f arr -  | STArr n t <- typeOf arr +  | STArr _ t <- typeOf arr    , Dict <- styKnown t -  = ELet ext arr $ -      EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ -        ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) -                           (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -          weakenExpr (WCopy (WSink .> WSink)) f +  = EMap ext f arr  ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)  ezipWith f arr1 arr2 -  | STArr n t1 <- typeOf arr1 +  | STArr _ t1 <- typeOf arr1    , STArr _ t2 <- typeOf arr2    , Dict <- styKnown t1    , Dict <- styKnown t2 -  = ELet ext arr1 $ -    ELet ext (weakenExpr WSink arr2) $ -      EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ -        ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) -                           (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -        ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) -                           (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ -          weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f +  = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) +                                   IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) +                                   IS (IS i) -> EVar ext t (IS i)) +                    f) +             (EZip ext arr1 arr2)  ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip arr1 arr2 = -  let STArr _ t1 = typeOf arr1 -      STArr _ t2 = typeOf arr2 -  in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 +ezip = EZip ext  eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a  eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) @@ -652,3 +653,53 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t      go SMTMaybe{} _ = ENil ext      go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e      go SMTScal{} _ = ENil ext + +splitSparsePair +  :: -- given a sparsity +     STy (TPair a b) -> Sparse (TPair a b) t' +  -> (forall a' b'. +          -- I give you back two sparsities for a and b +          Sparse a a' -> Sparse b b' +          -- furthermore, I tell you that either your t' is already this (a', b') pair... +       -> Either +            (t' :~: TPair a' b') +            -- or I tell you how to construct a' and b' from t', given an actual t' +            (forall r' env. +                 Idx env t' +              -> (forall env'. +                     (forall c. Ex env' c -> Ex env c) +                  -> Ex env' a' -> Ex env' b' -> r') +              -> r') +       -> r) +  -> r +splitSparsePair _ SpAbsent k = +  k SpAbsent SpAbsent $ Right $ \_ k2 -> +  k2 id (ENil ext) (ENil ext) +splitSparsePair _ (SpPair s1 s2) k1 = +  k1 s1 s2 $ Left Refl +splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = +  let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in +  k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> +  k2 (elet $ +       emaybe (EVar ext (STMaybe (applySparse s t)) i) +         (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) +         (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) +     (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) + +splitSparsePair _ (SpSparse SpAbsent) k = +  k SpAbsent SpAbsent $ Right $ \_ k2 -> +  k2 id (ENil ext) (ENil ext) +-- -- TODO: having to handle sparse-of-sparse at all is ridiculous +splitSparsePair t (SpSparse (SpSparse s)) k = +  splitSparsePair t (SpSparse s) $ \s1 s2 eres -> +  k s1 s2 $ Right $ \i k2 -> +  case eres of +    Left refl -> case refl of {} +    Right f -> +      f IZ $ \wrap e1 e2 -> +        k2 (\body -> +              elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) +                     (ENothing ext (applySparse s t)) +                     (evar IZ)) $ +                wrap body) +           e1 e2  | 
