aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bench/Main.hs2
-rw-r--r--src/AST.hs91
-rw-r--r--src/AST/Count.hs49
-rw-r--r--src/AST/Pretty.hs15
-rw-r--r--src/AST/UnMonoid.hs2
-rw-r--r--src/AST/Weaken.hs2
-rw-r--r--src/Analysis/Identity.hs16
-rw-r--r--src/CHAD.hs210
-rw-r--r--src/Compile.hs54
-rw-r--r--src/Compile/Exec.hs12
-rw-r--r--src/ForwardAD.hs6
-rw-r--r--src/ForwardAD/DualNumbers.hs2
-rw-r--r--src/Interpreter.hs10
-rw-r--r--src/Simplify.hs8
-rw-r--r--test-framework/Test/Framework.hs77
-rw-r--r--test/Main.hs12
16 files changed, 428 insertions, 140 deletions
diff --git a/bench/Main.hs b/bench/Main.hs
index ec9264b..6db77b5 100644
--- a/bench/Main.hs
+++ b/bench/Main.hs
@@ -36,7 +36,7 @@ import Simplify
gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env))))
gradCHAD config term =
- compile knownEnv $
+ compileStderr knownEnv $
simplifyFix $ pruneExpr knownEnv $
simplifyFix $ unMonoid $
simplifyFix $
diff --git a/src/AST.hs b/src/AST.hs
index 663b83f..873a8a5 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -4,7 +4,9 @@
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
@@ -16,7 +18,6 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-{-# LANGUAGE FlexibleInstances #-}
module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where
import Data.Functor.Const
@@ -62,6 +63,7 @@ data Expr x env t where
-- array operations
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
+ EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)
-- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)
EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
@@ -70,6 +72,7 @@ data Expr x env t where
EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t)
+ EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))
-- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:
-- an implementation is allowed to parallelise this thing and store the b
@@ -231,6 +234,7 @@ typeOf = \case
EConstArr _ n t _ -> STArr n (STScal t)
EBuild _ n _ e -> STArr n (typeOf e)
+ EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a)
EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
@@ -238,6 +242,7 @@ typeOf = \case
EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t
+ EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2)
EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)
EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2)
@@ -282,6 +287,7 @@ extOf = \case
ELCase x _ _ _ _ -> x
EConstArr x _ _ _ -> x
EBuild x _ _ _ -> x
+ EMap x _ _ -> x
EFold1Inner x _ _ _ _ -> x
ESum1Inner x _ -> x
EUnit x _ -> x
@@ -289,6 +295,7 @@ extOf = \case
EMaximum1Inner x _ -> x
EMinimum1Inner x _ -> x
EReshape x _ _ _ -> x
+ EZip x _ _ -> x
EFold1InnerD1 x _ _ _ _ -> x
EFold1InnerD2 x _ _ _ _ -> x
EConst x _ _ -> x
@@ -331,12 +338,14 @@ travExt f = \case
ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c
EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a
EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b
+ EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b
EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e
EUnit x e -> EUnit <$> f x <*> travExt f e
EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b
EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e
EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e
+ EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b
EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b
EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c
@@ -393,6 +402,7 @@ subst' f w = \case
ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b)
EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
ESum1Inner x e -> ESum1Inner x (subst' f w e)
EUnit x e -> EUnit x (subst' f w e)
@@ -400,6 +410,7 @@ subst' f w = \case
EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b)
+ EZip x a b -> EZip x (subst' f w a) (subst' f w b)
EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
EConst x t v -> EConst x t v
@@ -521,34 +532,24 @@ eidxEq (SS n) a b
emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)
emap f arr
- | STArr n t <- typeOf arr
+ | STArr _ t <- typeOf arr
, Dict <- styKnown t
- = ELet ext arr $
- EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) f
+ = EMap ext f arr
ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
ezipWith f arr1 arr2
- | STArr n t1 <- typeOf arr1
+ | STArr _ t1 <- typeOf arr1
, STArr _ t2 <- typeOf arr2
, Dict <- styKnown t1
, Dict <- styKnown t2
- = ELet ext arr1 $
- ELet ext (weakenExpr WSink arr2) $
- EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $
- weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f
+ = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ)
+ IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ)
+ IS (IS i) -> EVar ext t (IS i))
+ f)
+ (EZip ext arr1 arr2)
ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
-ezip arr1 arr2 =
- let STArr _ t1 = typeOf arr1
- STArr _ t2 = typeOf arr2
- in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2
+ezip = EZip ext
eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
@@ -652,3 +653,53 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t
go SMTMaybe{} _ = ENil ext
go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e
go SMTScal{} _ = ENil ext
+
+splitSparsePair
+ :: -- given a sparsity
+ STy (TPair a b) -> Sparse (TPair a b) t'
+ -> (forall a' b'.
+ -- I give you back two sparsities for a and b
+ Sparse a a' -> Sparse b b'
+ -- furthermore, I tell you that either your t' is already this (a', b') pair...
+ -> Either
+ (t' :~: TPair a' b')
+ -- or I tell you how to construct a' and b' from t', given an actual t'
+ (forall r' env.
+ Idx env t'
+ -> (forall env'.
+ (forall c. Ex env' c -> Ex env c)
+ -> Ex env' a' -> Ex env' b' -> r')
+ -> r')
+ -> r)
+ -> r
+splitSparsePair _ SpAbsent k =
+ k SpAbsent SpAbsent $ Right $ \_ k2 ->
+ k2 id (ENil ext) (ENil ext)
+splitSparsePair _ (SpPair s1 s2) k1 =
+ k1 s1 s2 $ Left Refl
+splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k =
+ let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in
+ k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 ->
+ k2 (elet $
+ emaybe (EVar ext (STMaybe (applySparse s t)) i)
+ (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2)))
+ (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ)))))
+ (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ))
+
+splitSparsePair _ (SpSparse SpAbsent) k =
+ k SpAbsent SpAbsent $ Right $ \_ k2 ->
+ k2 id (ENil ext) (ENil ext)
+-- -- TODO: having to handle sparse-of-sparse at all is ridiculous
+splitSparsePair t (SpSparse (SpSparse s)) k =
+ splitSparsePair t (SpSparse s) $ \s1 s2 eres ->
+ k s1 s2 $ Right $ \i k2 ->
+ case eres of
+ Left refl -> case refl of {}
+ Right f ->
+ f IZ $ \wrap e1 e2 ->
+ k2 (\body ->
+ elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i)
+ (ENothing ext (applySparse s t))
+ (evar IZ)) $
+ wrap body)
+ e1 e2
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 296c021..bc02417 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -523,9 +523,8 @@ occCountX initialS topexpr k = case topexpr of
SsNone ->
occCountX SsFull a $ \env1 mka ->
occCountX SsNone b $ \env2'' mkb ->
- withSome (scaleMany (Some env2'')) $ \env2' ->
- occEnvPop' env2' $ \env2 s2 ->
- withSome (Some env1 <> Some env2) $ \env ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
k env $ \env' ->
use (EBuild ext n (mka env') $
use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
@@ -535,14 +534,31 @@ occCountX initialS topexpr k = case topexpr of
SsArr' s' ->
occCountX SsFull a $ \env1 mka ->
occCountX s' b $ \env2'' mkb ->
- withSome (scaleMany (Some env2'')) $ \env2' ->
- occEnvPop' env2' $ \env2 s2 ->
- withSome (Some env1 <> Some env2) $ \env ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
k env $ \env' ->
EBuild ext n (mka env') $
elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
+ EMap _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $
+ ENil ext
+ SsArr' s' ->
+ occCountX s' a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ EMap ext (mka (OccPush env' () s1)) (mkb env')
+
EFold1Inner _ commut a b c ->
occCountX SsFull a $ \env1''' mka ->
withSome (scaleMany (Some env1''')) $ \env1'' ->
@@ -608,6 +624,27 @@ occCountX initialS topexpr k = case topexpr of
k env $ \env' ->
EReshape ext n (mkesh env') (mke env')
+ EZip _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' SsNone ->
+ occCountX (SsArr SsNone) a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkb env') $ mka env'
+ SsArr' (SsPair' s1 s2) ->
+ occCountX (SsArr s1) a $ \env1 mka ->
+ occCountX (SsArr s2) b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EZip ext (mka env') (mkb env')
+
EFold1InnerD1 _ cm e1 e2 e3 ->
case s of
-- If nothing is necessary, we can execute a fold and then proceed to ignore it
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 68fc629..2c51b85 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -204,6 +204,14 @@ ppExpr' d val expr = case expr of
<> hardline <> e')
(ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e'])
+ EMap _ a b -> do
+ let STArr _ t1 = typeOf b
+ name <- genNameIfUsedIn' "i" t1 IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
+ b' <- ppExpr' 11 val b
+ return $ ppParen (d > 0) $
+ ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b']
+
EFold1Inner _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
name2 <- genNameIfUsedIn (typeOf a) IZ a
@@ -238,7 +246,12 @@ ppExpr' d val expr = case expr of
EReshape _ n esh e -> do
esh' <- ppExpr' 11 val esh
e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e'
+ return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e']
+
+ EZip _ e1 e2 -> do
+ e1' <- ppExpr' 11 val e1
+ e2' <- ppExpr' 11 val e2
+ return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']
EFold1InnerD1 _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index a22b73f..1712ba5 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -38,6 +38,7 @@ unMonoid = \case
ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
+ EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
@@ -45,6 +46,7 @@ unMonoid = \case
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
+ EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)
EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EConst _ t x -> EConst ext t x
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 3a97fd1..f0820b8 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -129,7 +129,7 @@ wCopies bs w =
let bs' = slistMap (\_ -> Const ()) bs
in WStack bs' bs' WId w
-wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
+wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env
wRaiseAbove SNil _ = WClosed
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 6301dc1..71da793 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -202,6 +202,15 @@ idana env expr = case expr of
res <- VIArr <$> genId <*> shidsToVec dim shids
pure (res, EBuild res dim e1' e2')
+ EMap _ e1 e2 -> do
+ let STArr _ t = typeOf e2
+ x1 <- genIds t
+ (_, e1') <- idana (x1 `SCons` env) e1
+ (v2, e2') <- idana env e2
+ let VIArr _ sh = v2
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EMap res e1' e2')
+
EFold1Inner _ cm e1 e2 e3 -> do
let t1 = typeOf e1
x1 <- genIds t1
@@ -250,6 +259,13 @@ idana env expr = case expr of
res <- VIArr <$> genId <*> shidsToVec dim v1
pure (res, EReshape res dim e1' e2')
+ EZip _ e1 e2 -> do
+ (v1, e1') <- idana env e1
+ (_, e2') <- idana env e2
+ let VIArr _ sh = v1
+ res <- VIArr <$> genId <*> pure sh
+ pure (res, EZip res e1' e2')
+
EFold1InnerD1 _ cm e1 e2 e3 -> do
let t1 = typeOf e2
x1 <- genIds t1
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 7594a0f..72ce36d 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1048,73 +1048,46 @@ drev des accumMap sd = \case
, let eltty = typeOf orige
, shty :: STy shty <- tTup (sreplicate ndim tIx)
, Refl <- indexTupD1Id ndim ->
- deleteUnused (descrList des) (occEnvPopSome (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
- let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in
- subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
- accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
- let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
- case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 SETop e2 ->
- case lemAppendNil @e_binds of { Refl ->
- let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in
- let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in
- let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
- let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
- Ret (mergePrimalBindings
- `bpush` weakenExpr (wSinks (d1e envPro)) (drevPrimal des she)
+ drevLambda des accumMap (shty, SDiscr) sdElt orige $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 ->
+ let library = #ix (shty `SCons` SNil)
+ &. #e0 (bindingsBinds e0)
+ &. #propr (d1e provars)
+ &. #d1env (desD1E des)
+ &. #d (auto1 @sdElt)
+ &. #tape (auto1 @e_tape)
+ &. #pro (d2ace provars)
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #darr (auto1 @(TArr ndim sdElt))
+ &. #tapearr (auto1 @(TArr ndim e_tape)) in
+ Ret (proPrimalBinds
`bpush` EBuild ext ndim
- (EVar ext shty IZ)
- (letBinds (fst (weakenBindingsE (autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #propr (d1e envPro)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#ix :++: #sh :++: #propr :++: #d1env))
+ (weakenExpr (wSinks (d1e provars)) (drevPrimal des she))
+ (letBinds (fst (weakenBindingsE (autoWeak library
+ (#ix :++: #d1env)
+ (#ix :++: #propr :++: #d1env))
e0)) $
- let w = autoWeak (#ix (shty `SCons` SNil)
- &. #sh (shty `SCons` SNil)
- &. #e0 (bindingsBinds e0)
- &. #propr (d1e envPro)
- &. #d1env (desD1E des)
- &. #d1env' (desD1E usedDes))
- (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
- (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env)
- w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env'))
- in EPair ext (weakenExpr w e1) (collectexpr w'))
- `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ))
- (SEYesR (SENo (SEYesR (subenvAll (d1e envPro)))))
- (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ)))
- (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub)))
- (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in
+ weakenExpr (autoWeak library (#e0 :++: #ix :++: #d1env)
+ (#e0 :++: #ix :++: #propr :++: #d1env))
+ (EPair ext e1 e1tape))
+ `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ))
+ (SEYesR (SENo (subenvAll (d1e provars))))
+ (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ)))
+ (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub)
+ (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in
ESnd ext $
- uninvertTup (d2e envPro) (STArr ndim STNil) $
- makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $
- EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $
- -- the cotangent for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ))
- (EVar ext shty IZ)) $
- -- the tape for this element
- ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ))
- (EVar ext shty (IS IZ))) $
- let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE)
- in letBinds (rebinds IZ) $
- weakenExpr (autoWeak (#d (auto1 @sdElt)
- &. #pro (d2ace envPro)
- &. #etape (subList (bindingsBinds e0) subtapeE)
- &. #prerebinds prerebinds
- &. #tape (auto1 @(Tape e_tape))
- &. #ix (auto1 @shty)
- &. #darr (auto1 @(TArr ndim sdElt))
- &. #tapearr (auto1 @(TArr ndim (Tape e_tape)))
- &. #sh (auto1 @shty)
- &. #propr (d1e envPro)
- &. #d2acUsed (d2ace (select SAccum usedDes))
- &. #d2acEnv (d2ace (select SAccum des)))
- (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
- ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv)
- .> wPro (subList (bindingsBinds e0) subtapeE))
- e2)
- }}
+ wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $
+ EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $
+ -- the tape for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ))
+ (EVar ext shty IZ)) $
+ -- the cotangent for this element
+ ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ))
+ (EVar ext shty (IS IZ))) $
+ weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv)
+ (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv))
+ e2)
+
+ EMap{} -> undefined
EFold1Inner _ commut origef ex₀ earr
| SpArr @_ @sdElt sdElt <- sd
@@ -1346,6 +1319,33 @@ drev des accumMap sd = \case
(EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $
weakenExpr (WCopy (WSink .> WSink)) e2)
+ EZip _ a b
+ | SpArr sd' <- sd
+ , STArr n t1 <- typeOf a
+ , STArr _ t2 <- typeOf b ->
+ splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE ->
+ case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons`
+ toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of
+ { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) ->
+ subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B ->
+ Ret binds
+ subtape
+ (EZip ext a1 b1)
+ subBoth
+ (case pairSplitE of
+ Left Refl ->
+ let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in
+ plus_A_B
+ (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2)
+ (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2)
+ Right f -> f IZ $ \wrapPair pick1 pick2 ->
+ elet (emap (wrapPair (EPair ext pick1 pick2))
+ (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $
+ plus_A_B
+ (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2)
+ (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2))
+ }
+
ENothing{} -> err_unsupported "ENothing"
EJust{} -> err_unsupported "EJust"
EMaybe{} -> err_unsupported "EMaybe"
@@ -1476,6 +1476,88 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
, Refl <- lemAppendNil @tapebinds ->
RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2
+drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False)
+ => Descr env sto
+ -> VarMap Int (D2AcE (Select env sto "accum"))
+ -> (STy a, Storage s)
+ -> Sparse (D2 t) dt
+ -> Expr ValId (a : env) t
+ -> (forall provars shbinds tape d2a'.
+ SList STy provars
+ -> Subenv (D2E (Select env sto "merge")) (D2E provars)
+ -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator)
+ -> Bindings Ex (D1 a : D1E env) shbinds
+ -> Ex (Append shbinds (D1 a : D1E env)) (D1 t)
+ -> Ex (Append shbinds (D1 a : D1E env)) tape
+ -> Sparse (D2 a) d2a'
+ -> (forall env' b.
+ D1E provars :> env'
+ -> Ex (Append (D2AcE provars) env') b
+ -> Ex ( env') (TPair b (Tup (D2E provars))))
+ -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a'
+ -> r)
+ -> r
+drevLambda des accumMap (argty, argsto) sd origef k =
+ let t = typeOf origef in
+ deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') ->
+ let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in
+ subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro ->
+ let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in
+ let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in
+ let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in
+ case prf1 prodes argty argsto of { Refl ->
+ case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 ->
+ let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in
+ extractContrib prodes argty argsto subEf $ \argSp getSparseArg ->
+ let library = #fbinds (bindingsBinds ef0)
+ &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf)
+ &. #ftape (auto1 @(Tape e_tape))
+ &. #arg (d1 argty `SCons` SNil)
+ &. #d (applySparse sd (d2 t) `SCons` SNil)
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes)
+ &. #propr (d1e envPro)
+ &. #d2acUsed (d2ace (select SAccum usedDes))
+ &. #d2acEnv (d2ace (select SAccum des))
+ &. #d2acPro (d2ace envPro)
+ &. #efPrerebinds efPrerebinds in
+ k envPro
+ (subenvD2E (subenvCompose subMergeUsed proSub))
+ mergePrimalBindings
+ (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0))
+ (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
+ (#fbinds :++: #arg :++: #d1env))
+ ef1)
+ (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env)))
+ argSp
+ (\wpro1 body ->
+ uninvertTup (d2e envPro) (typeOf body) $
+ makeAccumulators wpro1 envPro $
+ body)
+ (letBinds (efRebinds (IS IZ)) $
+ weakenExpr
+ (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed))
+ ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv)
+ .> wPro (subList (bindingsBinds ef0) subtapeEf))
+ (getSparseArg ef2))
+ }}
+ where
+ extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False)
+ => proxy env sto -> proxy2 a -> Storage s
+ -- if s == "merge", this simplifies to SubenvS '[D2 a] t'
+ -- if s == "discr", this simplifies to SubenvS '[] t'
+ -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t'
+ -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r
+ extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id
+ extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext)
+ extractContrib _ _ SDiscr SETop k' = k' SpAbsent id
+
+ prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s
+ -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum"
+ prf1 _ _ SMerge = Refl
+ prf1 _ _ SDiscr = Refl
+
-- TODO: proper primal-only transform that doesn't depend on D1 = Id
drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
drevPrimal des e
diff --git a/src/Compile.hs b/src/Compile.hs
index f2063ee..d6ad7ec 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -8,7 +8,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
-module Compile (compile) where
+module Compile (compile, compileStderr) where
import Control.Applicative (empty)
import Control.Monad (forM_, when, replicateM)
@@ -71,28 +71,30 @@ debugAllocs :: Bool; debugAllocs = toEnum 0
-- | Emit extra C code that checks stuff
emitChecks :: Bool; emitChecks = toEnum 0
+-- | Returns compiled function plus compilation output (warnings)
compile :: SList STy env -> Ex env t
- -> IO (SList Value env -> IO (Rep t))
+ -> IO (SList Value env -> IO (Rep t), String)
compile = \env expr -> do
codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i))
let (source, offsets) = compileToString codeID env expr
when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"
when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>"
- lib <- buildKernel source "kernel"
+ (lib, compileOutput) <- buildKernel source "kernel"
let result_type = typeOf expr
result_size = sizeofSTy result_type
- return $ \val -> do
- allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
- let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
- serialiseArguments args ptr $ do
- callKernelFun lib ptr
- ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
- when (ok /= 1) $
- ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
- deserialise result_type ptr (koResultOffset offsets)
+ let function val = do
+ allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do
+ let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets)
+ serialiseArguments args ptr $ do
+ callKernelFun lib ptr
+ ok <- peekByteOff @Word8 ptr (koOkResOffset offsets)
+ when (ok /= 1) $
+ ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing)
+ deserialise result_type ptr (koResultOffset offsets)
+ return (function, compileOutput)
where
serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r
serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k =
@@ -100,6 +102,15 @@ compile = \env expr -> do
serialiseArguments args ptr k
serialiseArguments _ _ k = k
+-- | 'compile', but writes any produced C compiler output to stderr.
+compileStderr :: SList STy env -> Ex env t
+ -> IO (SList Value env -> IO (Rep t))
+compileStderr env expr = do
+ (fun, output) <- compile env expr
+ when (not (null output)) $
+ hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>"
+ return fun
+
data StructDecl = StructDecl
String -- ^ name
@@ -791,6 +802,15 @@ compile' env = \case
return (CELit arrname)
+ -- TODO: actually generate decent code here
+ EMap _ e1 e2 -> do
+ let STArr n _ = typeOf e2
+ compile' env $
+ elet e2 $
+ EBuild ext n (EShape ext (evar IZ)) $
+ elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy (WSink .> WSink)) e1
+
EFold1Inner _ commut efun ex0 earr -> do
let STArr (SS n) t = typeOf earr
@@ -940,6 +960,16 @@ compile' env = \case
[("buf", CEProj (CELit arrname) "buf")
,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))])
+ -- TODO: actually generate decent code here
+ EZip _ e1 e2 -> do
+ let STArr n _ = typeOf e1
+ compile' env $
+ elet e1 $
+ elet (weakenExpr WSink e2) $
+ EBuild ext n (EShape ext (evar (IS IZ))) $
+ EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ))
+
EFold1InnerD1 _ commut efun ex0 earr -> do
let STArr (SS n) t = typeOf earr
STPair _ bty = typeOf efun
diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs
index cc6d5fa..ad4180f 100644
--- a/src/Compile/Exec.hs
+++ b/src/Compile/Exec.hs
@@ -30,7 +30,7 @@ debug = False
-- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs)
data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ())))
-buildKernel :: String -> String -> IO KernelLib
+buildKernel :: String -> String -> IO (KernelLib, String)
buildKernel csource funname = do
template <- (++ "/tmp.chad.") <$> getTempDir
path <- mkdtemp template
@@ -43,7 +43,8 @@ buildKernel csource funname = do
,"-Wall", "-Wextra"
,"-Wno-unused-variable", "-Wno-unused-but-set-variable"
,"-Wno-unused-parameter", "-Wno-unused-function"
- ,"-Wno-alloc-size-larger-than"] -- ideally we'd keep this, but gcc reports false positives
+ ,"-Wno-alloc-size-larger-than" -- ideally we'd keep this, but gcc reports false positives
+ ,"-Wno-maybe-uninitialized"] -- maximum1i goes out of range if its input is empty, yes, don't complain
(ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource
-- Print the source before the GCC output.
@@ -51,11 +52,6 @@ buildKernel csource funname = do
ExitSuccess -> return ()
ExitFailure{} -> hPutStrLn stderr $ "[chad] Kernel compilation failed! Source: <<<\n" ++ lineNumbers csource ++ ">>>"
- when (not (null gccStdout)) $
- hPutStrLn stderr $ "[chad] Kernel compilation: GCC stdout: <<<\n" ++ gccStdout ++ ">>>"
- when (not (null gccStderr)) $
- hPutStrLn stderr $ "[chad] Kernel compilation: GCC stderr: <<<\n" ++ gccStderr ++ ">>>"
-
case ec of
ExitSuccess -> return ()
ExitFailure{} -> do
@@ -72,7 +68,7 @@ buildKernel csource funname = do
_ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1))
when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)"
dlclose dl)
- return (KernelLib ref)
+ return (KernelLib ref, gccStdout ++ (if null gccStdout then "" else "\n") ++ gccStderr)
foreign import ccall "dynamic"
wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO ()
diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs
index b353def..6655423 100644
--- a/src/ForwardAD.hs
+++ b/src/ForwardAD.hs
@@ -254,8 +254,10 @@ makeFwdADArtifactInterp env expr =
in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr)
{-# NOINLINE makeFwdADArtifactCompile #-}
-makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t)
-makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr)
+makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String)
+makeFwdADArtifactCompile env expr = do
+ (fun, output) <- compile (dne env) (dfwdDN expr)
+ return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output)
drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)
drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr)
diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs
index 44bdbb2..a1e9d0d 100644
--- a/src/ForwardAD/DualNumbers.hs
+++ b/src/ForwardAD/DualNumbers.hs
@@ -153,6 +153,7 @@ dfwdDN = \case
(EConstArr ext n t x)
EBuild _ n a b
| Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b)
+ EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c)
ESum1Inner _ e ->
let STArr n (STScal t) = typeOf e
@@ -168,6 +169,7 @@ dfwdDN = \case
EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)
EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e
EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e
+ EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b)
EReshape _ n esh e
| Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e)
EConst _ t x -> scalTyCase t
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index 9e3d2a6..d982261 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -116,6 +116,9 @@ interpret'Rec env = \case
EBuild _ dim a b -> do
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a
arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b)
+ EMap _ a b -> do
+ let STArr _ t = typeOf b
+ arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b
EFold1Inner _ _ a b c -> do
let t = typeOf b
let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
@@ -150,6 +153,13 @@ interpret'Rec env = \case
sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh
arr <- interpret' env e
return $ arrayReshape sh arr
+ EZip _ a b -> do
+ arr1 <- interpret' env a
+ arr2 <- interpret' env b
+ let sh = arrayShape arr1
+ when (sh /= arrayShape arr2) $
+ error "Interpreter: mismatched shapes in EZip"
+ return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i))
EFold1InnerD1 _ _ a b c -> do
let t = typeOf b
let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a
diff --git a/src/Simplify.hs b/src/Simplify.hs
index b89d7f6..1889adc 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -200,12 +200,14 @@ simplify'Rec = \case
EMaybe ext (ESnd ext e1) (ESnd ext e2) e3
-- TODO: more array indexing
- EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet e3 e2
+ EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2
+ EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1
EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1
-- TODO: more array shape
EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1
+ EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2)
-- TODO: more constant folding
EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext))
@@ -308,6 +310,7 @@ simplify'Rec = \case
ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]
EConstArr _ n t v -> pure $ EConstArr ext n t v
EBuild _ n a b -> [simprec| EBuild ext n *a *b |]
+ EMap _ a b -> [simprec| EMap ext *a *b |]
EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |]
ESum1Inner _ e -> [simprec| ESum1Inner ext *e |]
EUnit _ e -> [simprec| EUnit ext *e |]
@@ -315,6 +318,7 @@ simplify'Rec = \case
EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]
EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]
EReshape _ n a b -> [simprec| EReshape ext n *a *b |]
+ EZip _ a b -> [simprec| EZip ext *a *b |]
EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]
EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |]
EConst _ t v -> pure $ EConst ext t v
@@ -364,6 +368,7 @@ hasAdds = \case
ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c
EConstArr _ _ _ _ -> False
EBuild _ _ a b -> hasAdds a || hasAdds b
+ EMap _ a b -> hasAdds a || hasAdds b
EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
ESum1Inner _ e -> hasAdds e
EUnit _ e -> hasAdds e
@@ -371,6 +376,7 @@ hasAdds = \case
EMaximum1Inner _ e -> hasAdds e
EMinimum1Inner _ e -> hasAdds e
EReshape _ _ a b -> hasAdds a || hasAdds b
+ EZip _ a b -> hasAdds a || hasAdds b
EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c
ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e
diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs
index 80711b2..b7d0dc2 100644
--- a/test-framework/Test/Framework.hs
+++ b/test-framework/Test/Framework.hs
@@ -1,9 +1,12 @@
+{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
module Test.Framework (
@@ -11,12 +14,16 @@ module Test.Framework (
testGroup,
groupSetCollapse,
testProperty,
- withResource,
- withResource',
runTests,
defaultMain,
Options(..),
+ -- * Resources
+ withResource,
+ withResource',
+ TestCtx,
+ outputWarningText,
+
-- * Compatibility
TestName,
) where
@@ -29,11 +36,13 @@ import Control.Monad (forM, when, forM_)
import Control.Monad.IO.Class
import Data.IORef
import Data.List (isInfixOf, intercalate)
-import Data.Maybe (isJust, mapMaybe, fromJust)
+import Data.Maybe (mapMaybe, fromJust)
+import Data.Monoid (All(..), Any(..), Sum(..))
import Data.PQueue.Prio.Min qualified as PQ
import Data.String (fromString)
import Data.Time.Clock
import GHC.Conc (getNumProcessors)
+import GHC.Generics (Generic, Generically(..))
import System.Console.ANSI qualified as ANSI
import System.Console.Concurrent (outputConcurrent)
import System.Console.Regions
@@ -57,10 +66,16 @@ type TestName = String
data TestTree
= Group GroupOpts String [TestTree]
- | forall a. Resource String (IO a) (a -> IO ()) (a -> TestTree)
+ | forall a. Resource String ((?testCtx :: TestCtx) => IO a) ((?testCtx :: TestCtx) => a -> IO ()) (a -> TestTree)
-- ^ Name is not specified by user, but inherited from the tree below
| HP String H.Property
+data TestCtx = TestCtx
+ { tctxOutput :: String -> IO () }
+
+outputWarningText :: (?testCtx :: TestCtx) => String -> IO ()
+outputWarningText = tctxOutput ?testCtx
+
-- Not exported because a Resource is not supposed to have a name in the first place
treeName :: TestTree -> String
treeName (Group _ name _) = name
@@ -82,13 +97,13 @@ groupSetCollapse (Group opts name trees) = Group opts { goCollapse = True } name
groupSetCollapse _ = error "groupSetCollapse: not called on a Group"
-- | The @a -> TestTree@ function must use the @a@ only inside properties: the
--- functoin will be passed 'undefined' when exploring the test tree (without
+-- function will be passed 'undefined' when exploring the test tree (without
-- running properties).
-withResource :: IO a -> (a -> IO ()) -> (a -> TestTree) -> TestTree
+withResource :: ((?testCtx :: TestCtx) => IO a) -> ((?testCtx :: TestCtx) => a -> IO ()) -> (a -> TestTree) -> TestTree
withResource make cleanup fun = Resource (treeName (fun undefined)) make cleanup fun
-- | Same caveats as 'withResource'.
-withResource' :: IO a -> (a -> TestTree) -> TestTree
+withResource' :: ((?testCtx :: TestCtx) => IO a) -> (a -> TestTree) -> TestTree
withResource' make fun = withResource make (\_ -> return ()) fun
testProperty :: String -> H.Property -> TestTree
@@ -226,7 +241,7 @@ runTests options = \tree' ->
successVar <- newEmptyMVar
runTreePar Nothing [] [] tree successVar
readMVar successVar
- else isJust <$> runTreeSeq 0 [] tree
+ else getAll . seqresAllSuccess <$> runTreeSeq 0 [] tree
stats <- readIORef statsRef
endtm <- getCurrentTime
let ?istty = isterm in printStats (treeNumTests tree) stats (diffUTCTime endtm starttm)
@@ -284,6 +299,9 @@ runTreePar topmparregion revidxlist revpath toptree@Resource{} topoutvar = runRe
let pathitem = '[' : show depth ++ "](" ++ inhname ++ ")"
path = intercalate "/" (reverse (pathitem : revpath))
idxlist = reverse revidxlist
+ let ?testCtx = TestCtx (\str ->
+ outputConcurrent (ansiYellow ++ "## Warning for " ++ path ++ ":" ++ ansiReset ++
+ "\n" ++ str ++ "\n"))
submitOrRunIn mparregion idxlist Nothing $ \makeRegion -> do
setConsoleRegion makeRegion ('|' : path ++ " [R] making...")
@@ -337,37 +355,51 @@ submitOrRunIn (Just reg) _idxlist outvar fun = do
result <- fun reg
forM_ outvar $ \mvar -> putMVar mvar result
+data SeqRes = SeqRes
+ { seqresHaveWarnings :: Any
+ , seqresAllSuccess :: All
+ , seqresNumLines :: Sum Int }
+ deriving (Generic)
+ deriving (Semigroup, Monoid) via Generically SeqRes
+
-- | If all tests are successful, returns the number of output lines produced
runTreeSeq :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool)
- => Int -> [String] -> TestTree -> IO (Maybe Int)
+ => Int -> [String] -> TestTree -> IO SeqRes
runTreeSeq indent revpath (Group opts name trees) = do
putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout
starttm <- getCurrentTime
- mlns <- fmap (fmap sum . sequence) . forM trees $
- runTreeSeq (indent + 1) (name : revpath)
+ res <- fmap mconcat . forM trees $
+ runTreeSeq (indent + 1) (name : revpath)
endtm <- getCurrentTime
- case mlns of
- Just lns | goCollapse opts, ?istty -> do
+ if not (getAny (seqresHaveWarnings res)) && getAll (seqresAllSuccess res) && goCollapse opts && ?istty
+ then do
let thislen = 2*indent + length name
+ let Sum lns = seqresNumLines res
putStrLn $ concat (replicate (lns+1) (ANSI.cursorUpCode 1 ++ ANSI.clearLineCode)) ++
ANSI.setCursorColumnCode 0 ++
replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++
ansiGreen ++ "OK" ++ ansiReset ++
prettyDuration False (realToFrac (diffUTCTime endtm starttm))
- return (Just 1)
- _ -> return ((+1) <$> mlns)
+ return (mempty { seqresNumLines = 1 })
+ else return (res <> (mempty { seqresNumLines = 1 }))
runTreeSeq indent path (Resource _ make cleanup fun) = do
+ outputted <- newIORef False
+ let ?testCtx = TestCtx (\str -> do
+ atomicModifyIORef' outputted (\_ -> (True, ()))
+ putStrLn (ansiYellow ++ "## Warning for " ++ (intercalate "/" (reverse path)) ++
+ ":" ++ ansiReset ++ "\n" ++ str))
value <- make
- success <- runTreeSeq indent path (fun value)
+ res <- runTreeSeq indent path (fun value)
cleanup value
- return success
+ warnings <- readIORef outputted
+ return (res <> (mempty { seqresHaveWarnings = Any warnings }))
runTreeSeq indent path (HP name prop) = do
let thislen = 2*indent + length name
let prefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' '
when ?istty $ putStr prefix >> hFlush stdout
(ok, rendered) <- runHP (outputProgress (?maxlen + 2)) path name prop
putStrLn ((if ?istty then ANSI.clearFromCursorToLineEndCode else prefix) ++ rendered) >> hFlush stdout
- return (if ok then Just 1 else Nothing)
+ return (mempty { seqresAllSuccess = All ok, seqresNumLines = 1 })
runHP :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int)
=> (H.Report H.Progress -> IO ())
@@ -489,10 +521,11 @@ ansi :: (?istty :: Bool) => String -> String
ansi | ?istty = id
| otherwise = const ""
-ansiRed, ansiGreen, ansiReset :: (?istty :: Bool) => String
-ansiRed = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Red])
-ansiGreen = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Green])
-ansiReset = ansi (ANSI.setSGRCode [ANSI.Reset])
+ansiRed, ansiYellow, ansiGreen, ansiReset :: (?istty :: Bool) => String
+ansiRed = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Red])
+ansiYellow = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Vivid ANSI.Yellow])
+ansiGreen = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Green])
+ansiReset = ansi (ANSI.setSGRCode [ANSI.Reset])
-- getTermIsDark :: IO Bool
-- getTermIsDark = do
diff --git a/test/Main.hs b/test/Main.hs
index 4bc9082..d586973 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -11,6 +11,7 @@
{-# LANGUAGE UndecidableInstances #-}
module Main where
+import Control.Monad (when)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.State
import Data.Bifunctor
@@ -352,7 +353,10 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS
dtermSChadSUSP = pruneExpr env dtermSChadSUS
in
- withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC ->
+ withResource' (do (fun, output) <- makeFwdADArtifactCompile env exprS
+ when (not (null output)) $
+ outputWarningText $ "Forward AD compile GCC output: <<<\n" ++ output ++ ">>>"
+ return fun) $ \fwdartifactC ->
withCompiled env dtermSChadSUSP $ \dcompSChadSUSP ->
testProperty testname $ property $ do
annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr))
@@ -416,7 +420,11 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e
diff tansCompSChadSUSP closeIshE' tansFwd
withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree
-withCompiled env expr = withResource (compile env expr) (\_ -> pure ())
+withCompiled env expr = withResource' $ do
+ (fun, output) <- compile env expr
+ when (not (null output)) $
+ outputWarningText $ "Kernel compilation GCC output: <<<\n" ++ output ++ ">>>"
+ return fun
gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64])
gen_gmm = do