diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/AST.hs | 99 | ||||
| -rw-r--r-- | src/AST/Count.hs | 75 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 29 | ||||
| -rw-r--r-- | src/AST/SplitLets.hs | 6 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 2 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 2 | ||||
| -rw-r--r-- | src/AST/Weaken/Auto.hs | 44 | ||||
| -rw-r--r-- | src/Analysis/Identity.hs | 26 | ||||
| -rw-r--r-- | src/CHAD.hs | 302 | ||||
| -rw-r--r-- | src/Compile.hs | 67 | ||||
| -rw-r--r-- | src/Compile/Exec.hs | 12 | ||||
| -rw-r--r-- | src/ForwardAD.hs | 6 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
| -rw-r--r-- | src/Interpreter.hs | 18 | ||||
| -rw-r--r-- | src/Language.hs | 32 | ||||
| -rw-r--r-- | src/Language/AST.hs | 41 | ||||
| -rw-r--r-- | src/Simplify.hs | 8 |
17 files changed, 530 insertions, 241 deletions
@@ -4,7 +4,9 @@ {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} @@ -16,7 +18,6 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE FlexibleInstances #-} module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where import Data.Functor.Const @@ -62,21 +63,23 @@ data Expr x env t where -- array operations EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) + EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) + EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: -- an implementation is allowed to parallelise this thing and store the b -- values in some implementation-defined order. -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative - -> Expr x (t1 : t1 : env) (TPair t1 b) + -> Expr x (TPair t1 t1 : env) (TPair t1 b) -> Expr x env t1 -> Expr x env (TArr (S n) t1) -> Expr x env (TPair (TArr n t1) -- normal primal fold output @@ -231,6 +234,7 @@ typeOf = \case EConstArr _ n t _ -> STArr n (STScal t) EBuild _ n _ e -> STArr n (typeOf e) + EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a) EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) @@ -238,6 +242,7 @@ typeOf = \case EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t + EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2) EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) @@ -282,6 +287,7 @@ extOf = \case ELCase x _ _ _ _ -> x EConstArr x _ _ _ -> x EBuild x _ _ _ -> x + EMap x _ _ -> x EFold1Inner x _ _ _ _ -> x ESum1Inner x _ -> x EUnit x _ -> x @@ -289,6 +295,7 @@ extOf = \case EMaximum1Inner x _ -> x EMinimum1Inner x _ -> x EReshape x _ _ _ -> x + EZip x _ _ -> x EFold1InnerD1 x _ _ _ _ -> x EFold1InnerD2 x _ _ _ _ -> x EConst x _ _ -> x @@ -331,12 +338,14 @@ travExt f = \case ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b + EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e EUnit x e -> EUnit <$> f x <*> travExt f e EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c @@ -393,14 +402,16 @@ subst' f w = \case ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) ESum1Inner x e -> ESum1Inner x (subst' f w e) EUnit x e -> EUnit x (subst' f w e) EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) - EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EZip x a b -> EZip x (subst' f w a) (subst' f w b) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) @@ -521,34 +532,24 @@ eidxEq (SS n) a b emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) emap f arr - | STArr n t <- typeOf arr + | STArr _ t <- typeOf arr , Dict <- styKnown t - = ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f + = EMap ext f arr ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) ezipWith f arr1 arr2 - | STArr n t1 <- typeOf arr1 + | STArr _ t1 <- typeOf arr1 , STArr _ t2 <- typeOf arr2 , Dict <- styKnown t1 , Dict <- styKnown t2 - = ELet ext arr1 $ - ELet ext (weakenExpr WSink arr2) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ - weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f + = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) + IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) + IS (IS i) -> EVar ext t (IS i)) + f) + (EZip ext arr1 arr2) ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip arr1 arr2 = - let STArr _ t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 +ezip = EZip ext eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) @@ -652,3 +653,53 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t go SMTMaybe{} _ = ENil ext go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e go SMTScal{} _ = ENil ext + +splitSparsePair + :: -- given a sparsity + STy (TPair a b) -> Sparse (TPair a b) t' + -> (forall a' b'. + -- I give you back two sparsities for a and b + Sparse a a' -> Sparse b b' + -- furthermore, I tell you that either your t' is already this (a', b') pair... + -> Either + (t' :~: TPair a' b') + -- or I tell you how to construct a' and b' from t', given an actual t' + (forall r' env. + Idx env t' + -> (forall env'. + (forall c. Ex env' c -> Ex env c) + -> Ex env' a' -> Ex env' b' -> r') + -> r') + -> r) + -> r +splitSparsePair _ SpAbsent k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +splitSparsePair _ (SpPair s1 s2) k1 = + k1 s1 s2 $ Left Refl +splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = + let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in + k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> + k2 (elet $ + emaybe (EVar ext (STMaybe (applySparse s t)) i) + (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) + (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) + (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) + +splitSparsePair _ (SpSparse SpAbsent) k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +-- -- TODO: having to handle sparse-of-sparse at all is ridiculous +splitSparsePair t (SpSparse (SpSparse s)) k = + splitSparsePair t (SpSparse s) $ \s1 s2 eres -> + k s1 s2 $ Right $ \i k2 -> + case eres of + Left refl -> case refl of {} + Right f -> + f IZ $ \wrap e1 e2 -> + k2 (\body -> + elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) + (ENothing ext (applySparse s t)) + (evar IZ)) $ + wrap body) + e1 e2 diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 296c021..a53822d 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -523,9 +523,8 @@ occCountX initialS topexpr k = case topexpr of SsNone -> occCountX SsFull a $ \env1 mka -> occCountX SsNone b $ \env2'' mkb -> - withSome (scaleMany (Some env2'')) $ \env2' -> - occEnvPop' env2' $ \env2 s2 -> - withSome (Some env1 <> Some env2) $ \env -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> k env $ \env' -> use (EBuild ext n (mka env') $ use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ @@ -535,31 +534,49 @@ occCountX initialS topexpr k = case topexpr of SsArr' s' -> occCountX SsFull a $ \env1 mka -> occCountX s' b $ \env2'' mkb -> - withSome (scaleMany (Some env2'')) $ \env2' -> - occEnvPop' env2' $ \env2 s2 -> - withSome (Some env1 <> Some env2) $ \env -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> k env $ \env' -> EBuild ext n (mka env') $ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) + EMap _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ + ENil ext + SsArr' s' -> + occCountX s' a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + EMap ext (mka (OccPush env' () s1)) (mkb env') + EFold1Inner _ commut a b c -> - occCountX SsFull a $ \env1''' mka -> - withSome (scaleMany (Some env1''')) $ \env1'' -> - occEnvPop' env1'' $ \env1' s2 -> - occEnvPop' env1' $ \env1 s1 -> - let s0 = case s of + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of SsNone -> Some SsNone SsArr' s' -> Some s' in - withSome (Some s1 <> Some s2 <> s0) $ \sElt -> + withSome (s1 <> s0) $ \sElt -> occCountX sElt b $ \env2 mkb -> - occCountX (SsArr sElt) c $ \env3 mkc -> - withSome (Some env1 <> Some env2 <> Some env3) $ \env -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsArr sElt) s $ EFold1Inner ext commut (projectSmallerSubstruc SsFull sElt $ - mka (OccPush (OccPush env' () sElt) () sElt)) + mka (OccPush env' () (SsPair sElt sElt))) (mkb env') (mkc env') ESum1Inner _ e -> handleReduction (ESum1Inner ext) e @@ -608,6 +625,27 @@ occCountX initialS topexpr k = case topexpr of k env $ \env' -> EReshape ext n (mkesh env') (mke env') + EZip _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ mka env' + SsArr' (SsPair' s1 s2) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EZip ext (mka env') (mkb env') + EFold1InnerD1 _ cm e1 e2 e3 -> case s of -- If nothing is necessary, we can execute a fold and then proceed to ignore it @@ -628,7 +666,7 @@ occCountX initialS topexpr k = case topexpr of elet (mapExt (\_ -> ext) e3) $ EPair ext (EShape ext (evar IZ)) - (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1))) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) (mapExt (\_ -> ext) (weakenExpr WSink e2)) (evar IZ)) in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> @@ -638,15 +676,14 @@ occCountX initialS topexpr k = case topexpr of -- If at least some of the additional stores are required, we need to keep this a mapAccum SsPair' _ (SsArr' sB) -> -- TODO: propagate usage of primals - occCountX (SsPair SsFull sB) e1 $ \env1_2' mka -> - occEnvPop' env1_2' $ \env1_1' _ -> + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> occEnvPop' env1_1' $ \env1' _ -> occCountX SsFull e2 $ \env2 mkb -> occCountX SsFull e3 $ \env3 mkc -> withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> k env $ \env' -> projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ - EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull)) + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) (mkb env') (mkc env') EFold1InnerD2 _ cm ef ebog ed -> diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 68fc629..ecdaa88 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -204,15 +204,22 @@ ppExpr' d val expr = case expr of <> hardline <> e') (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) + EMap _ a b -> do + let STArr _ t1 = typeOf b + name <- genNameIfUsedIn' "i" t1 IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + return $ ppParen (d > 0) $ + ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] + EFold1Inner _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf a) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c let opname = "fold1i" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] ESum1Inner _ e -> do e' <- ppExpr' 11 val e @@ -238,17 +245,21 @@ ppExpr' d val expr = case expr of EReshape _ n esh e -> do esh' <- ppExpr' 11 val esh e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e' + return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + + EZip _ e1 e2 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] EFold1InnerD1 _ cm a b c -> do - name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a - name2 <- genNameIfUsedIn (typeOf b) IZ a - a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a + name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a b' <- ppExpr' 11 val b c' <- ppExpr' 11 val c let opname = "fold1iD1" ++ ppCommut cm return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c'] + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] EFold1InnerD2 _ cm ef ebog ed -> do let STArr _ tB = typeOf ebog diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs index d276e44..267dd87 100644 --- a/src/AST/SplitLets.hs +++ b/src/AST/SplitLets.hs @@ -34,10 +34,10 @@ splitLets' = \sub -> \case in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) EFold1Inner x cm a b c -> let STArr _ t1 = typeOf c - in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) EFold1InnerD1 x cm a b c -> let STArr _ t1 = typeOf c - in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c) + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) EFold1InnerD2 x cm a b c -> let STArr _ tB = typeOf b STArr _ t2 = typeOf c @@ -56,12 +56,14 @@ splitLets' = \sub -> \case ELInr x t e -> ELInr x t (splitLets' sub e) EConstArr x n t a -> EConstArr x n t a EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) ESum1Inner x e -> ESum1Inner x (splitLets' sub e) EUnit x e -> EUnit x (splitLets' sub e) EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (splitLets' sub e) EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index a22b73f..1712ba5 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -38,6 +38,7 @@ unMonoid = \case ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) EConstArr _ n t x -> EConstArr ext n t x EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) ESum1Inner _ e -> ESum1Inner ext (unMonoid e) EUnit _ e -> EUnit ext (unMonoid e) @@ -45,6 +46,7 @@ unMonoid = \case EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) EConst _ t x -> EConst ext t x diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 3a97fd1..f0820b8 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -129,7 +129,7 @@ wCopies bs w = let bs' = slistMap (\_ -> Const ()) bs in WStack bs' bs' WId w -wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env wRaiseAbove SNil _ = WClosed wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index c6efe37..7370df1 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -11,6 +11,7 @@ {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE AllowAmbiguousTypes #-} @@ -23,6 +24,7 @@ module AST.Weaken.Auto ( ) where import Data.Functor.Const +import Data.Kind (Constraint) import GHC.OverloadedLabels import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) @@ -39,18 +41,21 @@ type family Lookup name list where -- | The @withPre@ type parameter indicates whether there can be 'LPreW' --- occurrences within this layout. -data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where - LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments) +-- occurrences within this layout. 'names' is the list of names that this +-- layout /produces/. That is: for LPreW, it contains the target name. The +-- 'names' list of a source layout must be a subset of the names list of the +-- target layout (which cannot contain LPreW); this is checked with SubLayout. +data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where + LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments) -- | Pre-weaken with a weakening LPreW :: forall name1 name2 segments. SegmentName name1 -> SegmentName name2 -> Lookup name1 segments :> Lookup name2 segments - -> Layout True segments (Lookup name1 segments) - (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2) + -> Layout True segments '[name2] (Lookup name1 segments) + (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2) infixr :++: -instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where +instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where fromLabel = LSeg (symbolSing @name) newtype SegmentName name = SegmentName (SSymbol name) @@ -60,6 +65,18 @@ instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') wh fromLabel = SegmentName symbolSing +type family SubLayout names1 names2 where + SubLayout '[] _ = () :: Constraint + SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2 +type family SubLayout' n ok names1 names2 where + SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.") + SubLayout' _ True names1 names2 = SubLayout names1 names2 +type family Contains n names where + Contains _ '[] = False + Contains n (n : _) = True + Contains n (_ : names) = Contains n names + + data SSegments (segments :: [(Symbol, [t])]) where SSegNil :: SSegments '[] SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) @@ -74,7 +91,7 @@ auto1 :: SList (Const ()) '[t] auto1 = Const () `SCons` SNil infixr &. -(&.) :: SSegments segs1 -> SSegments segs2 -> SSegments (Append segs1 segs2) +(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2) (&.) = ssegmentsAppend where ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) @@ -118,12 +135,12 @@ linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env -lineariseLayout (LSeg name :: Layout _ _ seg) +lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env +lineariseLayout (LSeg name :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinApp name LinEnd lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 -lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg) +lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg) | Refl <- lemAppendNil @seg = LinAppPreW name1 name2 w LinEnd @@ -151,8 +168,7 @@ pullDown segs name@SSymbol linlayout kNotFound k = k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) .> wCopies (segmentLookup segs n') w) -sortLinLayouts :: forall segments env1 env2. - SSegments segments +sortLinLayouts :: SSegments segments -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 sortLinLayouts _ LinEnd LinEnd = WId sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) @@ -169,8 +185,8 @@ sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail sortLinLayouts _ LinEnd LinApp{} = WClosed sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" -autoWeak :: forall segments env1 env2. - SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2 +autoWeak :: SubLayout names1 names2 + => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2 autoWeak segs ly1 ly2 = preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 6301dc1..7b896a3 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -202,11 +202,19 @@ idana env expr = case expr of res <- VIArr <$> genId <*> shidsToVec dim shids pure (res, EBuild res dim e1' e2') + EMap _ e1 e2 -> do + let STArr _ t = typeOf e2 + x1 <- genIds t + (_, e1') <- idana (x1 `SCons` env) e1 + (v2, e2') <- idana env e2 + let VIArr _ sh = v2 + res <- VIArr <$> genId <*> pure sh + pure (res, EMap res e1' e2') + EFold1Inner _ cm e1 e2 e3 -> do let t1 = typeOf e1 - x1 <- genIds t1 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ (_ :< sh) = v3 @@ -250,11 +258,17 @@ idana env expr = case expr of res <- VIArr <$> genId <*> shidsToVec dim v1 pure (res, EReshape res dim e1' e2') + EZip _ e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + let VIArr _ sh = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EZip res e1' e2') + EFold1InnerD1 _ cm e1 e2 e3 -> do let t1 = typeOf e2 - x1 <- genIds t1 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 (_, e2') <- idana env e2 (v3, e3') <- idana env e3 let VIArr _ sh'@(_ :< sh) = v3 diff --git a/src/CHAD.hs b/src/CHAD.hs index 7594a0f..9da5395 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1048,93 +1048,58 @@ drev des accumMap sd = \case , let eltty = typeOf orige , shty :: STy shty <- tTup (sreplicate ndim tIx) , Refl <- indexTupD1Id ndim -> - deleteUnused (descrList des) (occEnvPopSome (occCountAll orige)) $ \(usedSub :: Subenv env env') -> - let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in - subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> - accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 SETop e2 -> - case lemAppendNil @e_binds of { Refl -> - let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in - let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in - let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in - let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in - Ret (mergePrimalBindings - `bpush` weakenExpr (wSinks (d1e envPro)) (drevPrimal des she) + drevLambda des accumMap (shty, SDiscr) sdElt orige $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> + let library = #ix (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d (auto1 @sdElt) + &. #tape (auto1 @e_tape) + &. #pro (d2ace provars) + &. #d2acEnv (d2ace (select SAccum des)) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim e_tape)) in + Ret (proPrimalBinds `bpush` EBuild ext ndim - (EVar ext shty IZ) - (letBinds (fst (weakenBindingsE (autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #propr (d1e envPro) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#ix :++: #sh :++: #propr :++: #d1env)) + (weakenExpr (wSinks (d1e provars)) (drevPrimal des she)) + (letBinds (fst (weakenBindingsE (autoWeak library + (#ix :++: #d1env) + (#ix :++: #propr :++: #d1env)) e0)) $ - let w = autoWeak (#ix (shty `SCons` SNil) - &. #sh (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #propr (d1e envPro) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes)) - (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) - w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) - in EPair ext (weakenExpr w e1) (collectexpr w')) - `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)) - (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) - (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) - (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) - (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in + weakenExpr (autoWeak library (#e0 :++: #ix :++: #d1env) + (#e0 :++: #ix :++: #propr :++: #d1env)) + (EPair ext e1 e1tape)) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) + (SEYesR (SENo (subenvAll (d1e provars)))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in ESnd ext $ - uninvertTup (d2e envPro) (STArr ndim STNil) $ - makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ - EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty (IS IZ))) $ - let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) - in letBinds (rebinds IZ) $ - weakenExpr (autoWeak (#d (auto1 @sdElt) - &. #pro (d2ace envPro) - &. #etape (subList (bindingsBinds e0) subtapeE) - &. #prerebinds prerebinds - &. #tape (auto1 @(Tape e_tape)) - &. #ix (auto1 @shty) - &. #darr (auto1 @(TArr ndim sdElt)) - &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) - &. #sh (auto1 @shty) - &. #propr (d1e envPro) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des))) - (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) - .> wPro (subList (bindingsBinds e0) subtapeE)) - e2) - }} + wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ + EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) + (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + e2) + + EMap{} -> error "TODO: CHAD EMap" EFold1Inner _ commut origef ex₀ earr | SpArr @_ @sdElt sdElt <- sd , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> - deleteUnused (descrList des) (occEnvPopSome (occEnvPopSome (occCountAll origef))) $ \(usedSub :: Subenv env env') -> - let ef = unsafeWeakenWithSubenv (SEYesR (SEYesR usedSub)) origef in - subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> - accumPromote (d2 eltty) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in - let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in - let (mergePrimalBindings', _) = weakenBindingsE (sinkWithBindings bindsx₀a) mergePrimalBindings in - case drev (prodes `DPush` (eltty, Nothing, SMerge) `DPush` (eltty, Nothing, SMerge)) accumMapPro (spDense (d2M eltty)) ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> - let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in - let bogTy = STArr (SS ndim) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) + drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) subEf wrapAccum ef2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in + let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) + bogTy = STArr (SS ndim) bogEltTy primalTy = STPair (STArr ndim (d1 eltty)) bogTy - zipPrimalTy = STPair (d1 eltty) (STPair (d1 eltty) (tapeTy (subList (bindingsBinds ef0) subtapeEf))) - library = #xy (d1 eltty `SCons` d1 eltty `SCons` SNil) + library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) &. #parr (auto1 @(TArr (S n) (D1 elt))) &. #px₀ (auto1 @(D1 elt)) &. #px (auto1 @(D1 elt)) @@ -1145,70 +1110,52 @@ drev des accumMap sd = \case &. #x₀abinds (bindingsBinds bindsx₀a) &. #fbinds (bindingsBinds ef0) &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) - &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) - &. #ftape (auto1 @(Tape e_tape)) - &. #primalzip (zipPrimalTy `SCons` SNil) - &. #efPrerebinds efPrerebinds - &. #propr (d1e envPro) + &. #ftape (auto1 @ef_tape) + &. #bogelt (bogEltTy `SCons` SNil) + &. #propr (d1e provars) &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes) - &. #d2acUsed (d2ace (select SAccum usedDes)) &. #d2acEnv (d2ace (select SAccum des)) - &. #d2acPro (d2ace envPro) + &. #d2acPro (d2ace provars) &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> - subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) $ \subx₀af _ _ plus_x₀a_f -> - Ret (bconcat bindsx₀a mergePrimalBindings' + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a proPrimalBinds' `bpush` weakenExpr wOverPrimalBindings ex₀1 `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 `bpush` EFold1InnerD1 ext commut (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in - letBinds (fst (weakenBindingsE (autoWeak library - (#xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - layout) - ef0)) $ - elet (weakenExpr (autoWeak library (#fbinds :++: #xy :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#fbinds :++: layout)) - ef1) $ - EPair ext - (evar IZ) + letBinds (fst (weakenBindingsE (autoWeak library (#xy :++: #d1env) layout) ef0)) $ + EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) + (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1) (EPair ext - (evar IZ) - (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#px :++: #fbinds :++: layout))))) + (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: layout) @> IZ)) + (weakenExpr (autoWeak library (#fbinds :++: #xy :++: #d1env) (#fbinds :++: layout)) ef1tape))) (EVar ext (d1 eltty) (IS (IS IZ))) (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) - (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e envPro))))))) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars))))))) (EFst ext (EVar ext primalTy IZ)) subx₀af (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in elet - (uninvertTup (d2e envPro) (STPair (STArr ndim (d2 eltty)) (STArr (SS ndim) (d2 eltty))) $ - makeAccumulators (autoWeak library #propr layout1) envPro $ - let layout2 = #d2acPro :++: layout1 in - EFold1InnerD2 ext commut - (elet (ESnd ext (ESnd ext (EVar ext zipPrimalTy (IS IZ)))) $ - elet (EFst ext (ESnd ext (EVar ext zipPrimalTy (IS (IS IZ))))) $ - elet (EFst ext (EVar ext zipPrimalTy (IS (IS (IS IZ))))) $ - letBinds (efRebinds (IS (IS IZ))) $ - let layout3 = (#ftapebinds :++: #efPrerebinds) :++: #xy :++: #ftape :++: #d :++: #primalzip :++: layout2 in - elet (expandSubenvZeros (autoWeak library #xy layout3) (eltty `SCons` eltty `SCons` SNil) subEf $ - weakenExpr (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) layout3 - .> wPro (subList (bindingsBinds ef0) subtapeEf)) - ef2) $ - EPair ext (ESnd ext (EFst ext (evar IZ))) (ESnd ext (evar IZ))) - (ezip - (EVar ext (STArr (SS ndim) (d1 eltty)) (autoWeak library #parr layout2 @> IZ)) - (ESnd ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) - (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) - (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) - (EFst ext $ EVar ext primalTy (autoWeak library #primal layout2 @> IZ)))) $ + (wrapAccum (autoWeak library #propr layout1) $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ + let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in + expandSparse (STPair eltty eltty) subEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ + weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) + (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ plus_x₀a_f (plus_x₀_a (elet (EIdx0 ext (EFold1Inner ext Commut - (EPlus ext (d2M eltty) (EVar ext (d2 eltty) (IS IZ)) (EVar ext (d2 eltty) IZ)) + (let t = STPair (d2 eltty) (d2 eltty) + in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) (eflatten (EFst ext (EFst ext (evar IZ)))))) $ weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) @@ -1216,7 +1163,6 @@ drev des accumMap sd = \case (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) (ESnd ext (evar IZ))) - } EUnit _ e | SpArr sdElt <- sd @@ -1240,9 +1186,8 @@ drev des accumMap sd = \case (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) sub (ELet ext (EFold1Inner ext Commut - (sparsePlus (d2M eltty) sdElt' - (EVar ext (applySparse sdElt' (d2 eltty)) (IS IZ)) - (EVar ext (applySparse sdElt' (d2 eltty)) IZ)) + (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) + in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) (inj2 (ENil ext)) (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ weakenExpr (WCopy WSink) e2) @@ -1346,6 +1291,33 @@ drev des accumMap sd = \case (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ weakenExpr (WCopy (WSink .> WSink)) e2) + EZip _ a b + | SpArr sd' <- sd + , STArr n t1 <- typeOf a + , STArr _ t2 <- typeOf b -> + splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> + case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` + toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of + { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EZip ext a1 b1) + subBoth + (case pairSplitE of + Left Refl -> + let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in + plus_A_B + (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) + (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) + Right f -> f IZ $ \wrapPair pick1 pick2 -> + elet (emap (wrapPair (EPair ext pick1 pick2)) + (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ + plus_A_B + (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) + (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) + } + ENothing{} -> err_unsupported "ENothing" EJust{} -> err_unsupported "EJust" EMaybe{} -> err_unsupported "EMaybe" @@ -1476,6 +1448,88 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of , Refl <- lemAppendNil @tapebinds -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 +drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) + => Descr env sto + -> VarMap Int (D2AcE (Select env sto "accum")) + -> (STy a, Storage s) + -> Sparse (D2 t) dt + -> Expr ValId (a : env) t + -> (forall provars shbinds tape d2a'. + SList STy provars + -> Subenv (D2E (Select env sto "merge")) (D2E provars) + -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) + -> Bindings Ex (D1 a : D1E env) shbinds + -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) + -> Ex (Append shbinds (D1 a : D1E env)) tape + -> Sparse (D2 a) d2a' + -> (forall env' b. + D1E provars :> env' + -> Ex (Append (D2AcE provars) env') b + -> Ex ( env') (TPair b (Tup (D2E provars)))) + -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> r) + -> r +drevLambda des accumMap (argty, argsto) sd origef k = + let t = typeOf origef in + deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + case prf1 prodes argty argsto of { Refl -> + case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> + let library = #fbinds (bindingsBinds ef0) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #arg (d1 argty `SCons` SNil) + &. #d (applySparse sd (d2 t) `SCons` SNil) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #propr (d1e envPro) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #efPrerebinds efPrerebinds in + k envPro + (subenvD2E (subenvCompose subMergeUsed proSub)) + mergePrimalBindings + (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) + (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: #arg :++: #d1env)) + ef1) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) + argSp + (\wpro1 body -> + uninvertTup (d2e envPro) (typeOf body) $ + makeAccumulators wpro1 envPro $ + body) + (letBinds (efRebinds IZ) $ + weakenExpr + (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + (getSparseArg ef2)) + }} + where + extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) + => proxy env sto -> proxy2 a -> Storage s + -- if s == "merge", this simplifies to SubenvS '[D2 a] t' + -- if s == "discr", this simplifies to SubenvS '[] t' + -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' + -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r + extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id + extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) + extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + + prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s + -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" + prf1 _ _ SMerge = Refl + prf1 _ _ SDiscr = Refl + -- TODO: proper primal-only transform that doesn't depend on D1 = Id drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) drevPrimal des e diff --git a/src/Compile.hs b/src/Compile.hs index f2063ee..8627905 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -8,7 +8,7 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} -module Compile (compile) where +module Compile (compile, compileStderr) where import Control.Applicative (empty) import Control.Monad (forM_, when, replicateM) @@ -71,28 +71,30 @@ debugAllocs :: Bool; debugAllocs = toEnum 0 -- | Emit extra C code that checks stuff emitChecks :: Bool; emitChecks = toEnum 0 +-- | Returns compiled function plus compilation output (warnings) compile :: SList STy env -> Ex env t - -> IO (SList Value env -> IO (Rep t)) + -> IO (SList Value env -> IO (Rep t), String) compile = \env expr -> do codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i)) let (source, offsets) = compileToString codeID env expr when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" - lib <- buildKernel source "kernel" + (lib, compileOutput) <- buildKernel source "kernel" let result_type = typeOf expr result_size = sizeofSTy result_type - return $ \val -> do - allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do - let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) - serialiseArguments args ptr $ do - callKernelFun lib ptr - ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) - when (ok /= 1) $ - ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) - deserialise result_type ptr (koResultOffset offsets) + let function val = do + allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do + let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) + serialiseArguments args ptr $ do + callKernelFun lib ptr + ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) + when (ok /= 1) $ + ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) + deserialise result_type ptr (koResultOffset offsets) + return (function, compileOutput) where serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = @@ -100,6 +102,15 @@ compile = \env expr -> do serialiseArguments args ptr k serialiseArguments _ _ k = k +-- | 'compile', but writes any produced C compiler output to stderr. +compileStderr :: SList STy env -> Ex env t + -> IO (SList Value env -> IO (Rep t)) +compileStderr env expr = do + (fun, output) <- compile env expr + when (not (null output)) $ + hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun + data StructDecl = StructDecl String -- ^ name @@ -791,6 +802,15 @@ compile' env = \case return (CELit arrname) + -- TODO: actually generate decent code here + EMap _ e1 e2 -> do + let STArr n _ = typeOf e2 + compile' env $ + elet e2 $ + EBuild ext n (EShape ext (evar IZ)) $ + elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e1 + EFold1Inner _ commut efun ex0 earr -> do let STArr (SS n) t = typeOf earr @@ -820,11 +840,14 @@ compile' env = \case -- kvar <- if vecwid > 1 then genName' "k" else return "" accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" - (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit + pairstrname <- emitStruct (STPair t t) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here @@ -834,6 +857,7 @@ compile' env = \case -- what comes out of the function anyway, so that's -- fine, but we do need to increment the array element. arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) <> funStmts <> pure (SAsg accvar funres)) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) @@ -940,6 +964,16 @@ compile' env = \case [("buf", CEProj (CELit arrname) "buf") ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + -- TODO: actually generate decent code here + EZip _ e1 e2 -> do + let STArr n _ = typeOf e1 + compile' env $ + elet e1 $ + elet (weakenExpr WSink e2) $ + EBuild ext n (EShape ext (evar (IS IZ))) $ + EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) + EFold1InnerD1 _ commut efun ex0 earr -> do let STArr (SS n) t = typeOf earr STPair _ bty = typeOf efun @@ -967,12 +1001,14 @@ compile' env = \case jvar <- genName' "j" accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" - (funres, funStmts) <- scope $ compile' (Const arreltlit `SCons` Const accvar `SCons` env) efun funresvar <- genName' "res" ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + pairstrname <- emitStruct (STPair t t) emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ pure (SVarDecl False (repSTy t) accvar (CELit x0name)) <> x0incrStmts -- we're copying x0 here @@ -982,8 +1018,9 @@ compile' env = \case -- what comes out of the function anyway, so that's -- fine, but we do need to increment the array element. arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) <> funStmts - <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs index cc6d5fa..ad4180f 100644 --- a/src/Compile/Exec.hs +++ b/src/Compile/Exec.hs @@ -30,7 +30,7 @@ debug = False -- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs) data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ()))) -buildKernel :: String -> String -> IO KernelLib +buildKernel :: String -> String -> IO (KernelLib, String) buildKernel csource funname = do template <- (++ "/tmp.chad.") <$> getTempDir path <- mkdtemp template @@ -43,7 +43,8 @@ buildKernel csource funname = do ,"-Wall", "-Wextra" ,"-Wno-unused-variable", "-Wno-unused-but-set-variable" ,"-Wno-unused-parameter", "-Wno-unused-function" - ,"-Wno-alloc-size-larger-than"] -- ideally we'd keep this, but gcc reports false positives + ,"-Wno-alloc-size-larger-than" -- ideally we'd keep this, but gcc reports false positives + ,"-Wno-maybe-uninitialized"] -- maximum1i goes out of range if its input is empty, yes, don't complain (ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource -- Print the source before the GCC output. @@ -51,11 +52,6 @@ buildKernel csource funname = do ExitSuccess -> return () ExitFailure{} -> hPutStrLn stderr $ "[chad] Kernel compilation failed! Source: <<<\n" ++ lineNumbers csource ++ ">>>" - when (not (null gccStdout)) $ - hPutStrLn stderr $ "[chad] Kernel compilation: GCC stdout: <<<\n" ++ gccStdout ++ ">>>" - when (not (null gccStderr)) $ - hPutStrLn stderr $ "[chad] Kernel compilation: GCC stderr: <<<\n" ++ gccStderr ++ ">>>" - case ec of ExitSuccess -> return () ExitFailure{} -> do @@ -72,7 +68,7 @@ buildKernel csource funname = do _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1)) when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)" dlclose dl) - return (KernelLib ref) + return (KernelLib ref, gccStdout ++ (if null gccStdout then "" else "\n") ++ gccStderr) foreign import ccall "dynamic" wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO () diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index b353def..6655423 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -254,8 +254,10 @@ makeFwdADArtifactInterp env expr = in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) {-# NOINLINE makeFwdADArtifactCompile #-} -makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t) -makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr) +makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String) +makeFwdADArtifactCompile env expr = do + (fun, output) <- compile (dne env) (dfwdDN expr) + return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output) drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 44bdbb2..a1e9d0d 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -153,6 +153,7 @@ dfwdDN = \case (EConstArr ext n t x) EBuild _ n a b | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) + EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b) EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) ESum1Inner _ e -> let STArr n (STScal t) = typeOf e @@ -168,6 +169,7 @@ dfwdDN = \case EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b) EReshape _ n esh e | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) EConst _ t x -> scalTyCase t diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 9e3d2a6..e1c81cd 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -116,13 +116,16 @@ interpret'Rec env = \case EBuild _ dim a b -> do sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) + EMap _ a b -> do + let STArr _ t = typeOf b + arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b EFold1Inner _ _ a b c -> do let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr - arrayGenerateM sh $ \idx -> foldM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] ESum1Inner _ e -> do arr <- interpret' env e let STArr _ (STScal t) = typeOf e @@ -150,16 +153,23 @@ interpret'Rec env = \case sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh arr <- interpret' env e return $ arrayReshape sh arr + EZip _ a b -> do + arr1 <- interpret' env a + arr2 <- interpret' env b + let sh = arrayShape arr1 + when (sh /= arrayShape arr2) $ + error "Interpreter: mismatched shapes in EZip" + return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i)) EFold1InnerD1 _ _ a b c -> do let t = typeOf b - let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a x0 <- interpret' env b arr <- interpret' env c let sh `ShCons` n = arrayShape arr -- TODO: this is very inefficient, even for an interpreter; with mutable -- arrays this can be a lot better with no lists res <- arrayGenerateM sh $ \idx -> do - (y, stores) <- mapAccumLM f x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] return (y, arrayFromList (ShNil `ShCons` n) stores) return (arrayMap fst res ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> diff --git a/src/Language.hs b/src/Language.hs index 31b4b87..c1a6248 100644 --- a/src/Language.hs +++ b/src/Language.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} @@ -15,6 +16,8 @@ module Language ( Lookup, ) where +import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) + import Array import AST import AST.Sparse.Types @@ -113,7 +116,19 @@ map_ (v :-> a) b NEDrop (SS SZ) (NEDrop (SS SZ) a) fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i (v1 :-> v2 :-> e1) e2 e3 = NEFold1Inner v1 v2 e1 e2 e3 +fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t t) + in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) sum1i e = NESum1Inner e @@ -135,7 +150,20 @@ reshape = NEReshape fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) -fold1iD1 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD1 v1 v2 e1 e2 e3 +fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t1 t1) + in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) diff --git a/src/Language/AST.hs b/src/Language/AST.hs index c9d05c9..a3b8130 100644 --- a/src/Language/AST.hs +++ b/src/Language/AST.hs @@ -4,7 +4,9 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneKindSignatures #-} @@ -17,7 +19,7 @@ module Language.AST where import Data.Kind (Type) import Data.Type.Equality import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..)) +import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) import Array import AST @@ -50,7 +52,7 @@ data NExpr env t where -- array operations NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 t -> Var name2 t -> NExpr ('(name2, t) : '(name1, t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEUnit :: NExpr env t -> NExpr env (TArr Z t) NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) @@ -58,7 +60,7 @@ data NExpr env t where NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) - NEFold1InnerD1 :: Var n1 t1 -> Var n2 t1 -> NExpr ('(n2, t1) : '(n1, t1) : env) (TPair t1 b) + NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) @@ -96,11 +98,16 @@ data NExpr env t where NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t deriving instance Show (NExpr env t) -type family Lookup name env where - Lookup "_" _ = TypeError (Text "Attempt to use variable with name '_'") - Lookup name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") - Lookup name ('(name, t) : env) = t - Lookup name (_ : env) = Lookup name env +type Lookup name env = Lookup1 (name == "_") name env +type family Lookup1 eqblank name env where + Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup1 False name env = Lookup2 name env +type family Lookup2 name env where + Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") + Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env +type family Lookup3 eq t name env where + Lookup3 True t _ _ = t + Lookup3 False _ name env = Lookup2 name env type family DropNth i env where DropNth Z (_ : env) = env @@ -209,7 +216,7 @@ fromNamedExpr val = \case NEConstArr n t x -> EConstArr ext n t x NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEFold1Inner n1 n2 a b c -> EFold1Inner ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) NESum1Inner e -> ESum1Inner ext (go e) NEUnit e -> EUnit ext (go e) NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) @@ -217,7 +224,7 @@ fromNamedExpr val = \case NEMinimum1Inner e -> EMinimum1Inner ext (go e) NEReshape n a b -> EReshape ext n (go a) (go b) - NEFold1InnerD1 n1 n2 a b c -> EFold1InnerD1 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) NEConst t x -> EConst ext t x @@ -275,3 +282,17 @@ dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env dropNthW SZ (_ `NPush` _) = WSink dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) dropNthW _ NTop = error "DropNth: index out of range" + +assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r +assertSymbolNotUnderscore s@SSymbol k = + case symbolVal s of + "_" -> error "assertSymbolNotUnderscore: was underscore" + _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k + +assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r +assertSymbolDistinct s1@SSymbol s2@SSymbol k + | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" + | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k + +equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r +equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k diff --git a/src/Simplify.hs b/src/Simplify.hs index b89d7f6..1889adc 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -200,12 +200,14 @@ simplify'Rec = \case EMaybe ext (ESnd ext e1) (ESnd ext e2) e3 -- TODO: more array indexing - EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet e3 e2 + EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2 + EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1 EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1 -- TODO: more array shape EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 + EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2) -- TODO: more constant folding EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) @@ -308,6 +310,7 @@ simplify'Rec = \case ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] EConstArr _ n t v -> pure $ EConstArr ext n t v EBuild _ n a b -> [simprec| EBuild ext n *a *b |] + EMap _ a b -> [simprec| EMap ext *a *b |] EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] EUnit _ e -> [simprec| EUnit ext *e |] @@ -315,6 +318,7 @@ simplify'Rec = \case EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] EReshape _ n a b -> [simprec| EReshape ext n *a *b |] + EZip _ a b -> [simprec| EZip ext *a *b |] EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |] EConst _ t v -> pure $ EConst ext t v @@ -364,6 +368,7 @@ hasAdds = \case ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c EConstArr _ _ _ _ -> False EBuild _ _ a b -> hasAdds a || hasAdds b + EMap _ a b -> hasAdds a || hasAdds b EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ESum1Inner _ e -> hasAdds e EUnit _ e -> hasAdds e @@ -371,6 +376,7 @@ hasAdds = \case EMaximum1Inner _ e -> hasAdds e EMinimum1Inner _ e -> hasAdds e EReshape _ _ a b -> hasAdds a || hasAdds b + EZip _ a b -> hasAdds a || hasAdds b EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e |
