diff options
| -rw-r--r-- | bench/Main.hs | 2 | ||||
| -rw-r--r-- | src/AST.hs | 91 | ||||
| -rw-r--r-- | src/AST/Count.hs | 49 | ||||
| -rw-r--r-- | src/AST/Pretty.hs | 15 | ||||
| -rw-r--r-- | src/AST/UnMonoid.hs | 2 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 2 | ||||
| -rw-r--r-- | src/Analysis/Identity.hs | 16 | ||||
| -rw-r--r-- | src/CHAD.hs | 210 | ||||
| -rw-r--r-- | src/Compile.hs | 54 | ||||
| -rw-r--r-- | src/Compile/Exec.hs | 12 | ||||
| -rw-r--r-- | src/ForwardAD.hs | 6 | ||||
| -rw-r--r-- | src/ForwardAD/DualNumbers.hs | 2 | ||||
| -rw-r--r-- | src/Interpreter.hs | 10 | ||||
| -rw-r--r-- | src/Simplify.hs | 8 | ||||
| -rw-r--r-- | test-framework/Test/Framework.hs | 77 | ||||
| -rw-r--r-- | test/Main.hs | 12 | 
16 files changed, 428 insertions, 140 deletions
diff --git a/bench/Main.hs b/bench/Main.hs index ec9264b..6db77b5 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -36,7 +36,7 @@ import Simplify  gradCHAD :: KnownEnv env => CHADConfig -> Ex env (TScal TF64) -> IO (SList Value env -> IO (Double, Rep (Tup (D2E env))))  gradCHAD config term = -  compile knownEnv $ +  compileStderr knownEnv $      simplifyFix $ pruneExpr knownEnv $      simplifyFix $ unMonoid $      simplifyFix $ @@ -4,7 +4,9 @@  {-# LANGUAGE DeriveTraversable #-}  {-# LANGUAGE EmptyCase #-}  {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-}  {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-}  {-# LANGUAGE LambdaCase #-}  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE QuantifiedConstraints #-} @@ -16,7 +18,6 @@  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeOperators #-}  {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE FlexibleInstances #-}  module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where  import Data.Functor.Const @@ -62,6 +63,7 @@ data Expr x env t where    -- array operations    EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))    EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) +  EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t)    -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right)    EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)    ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) @@ -70,6 +72,7 @@ data Expr x env t where    EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))    EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))    EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) +  EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b))    -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically:    -- an implementation is allowed to parallelise this thing and store the b @@ -231,6 +234,7 @@ typeOf = \case    EConstArr _ n t _ -> STArr n (STScal t)    EBuild _ n _ e -> STArr n (typeOf e) +  EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a)    EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t    ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t    EUnit _ e -> STArr SZ (typeOf e) @@ -238,6 +242,7 @@ typeOf = \case    EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t    EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t    EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t +  EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2)    EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb)    EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) @@ -282,6 +287,7 @@ extOf = \case    ELCase x _ _ _ _ -> x    EConstArr x _ _ _ -> x    EBuild x _ _ _ -> x +  EMap x _ _ -> x    EFold1Inner x _ _ _ _ -> x    ESum1Inner x _ -> x    EUnit x _ -> x @@ -289,6 +295,7 @@ extOf = \case    EMaximum1Inner x _ -> x    EMinimum1Inner x _ -> x    EReshape x _ _ _ -> x +  EZip x _ _ -> x    EFold1InnerD1 x _ _ _ _ -> x    EFold1InnerD2 x _ _ _ _ -> x    EConst x _ _ -> x @@ -331,12 +338,14 @@ travExt f = \case    ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c    EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a    EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b +  EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b    EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c    ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e    EUnit x e -> EUnit <$> f x <*> travExt f e    EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b    EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e    EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e +  EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b    EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b    EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c    EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c @@ -393,6 +402,7 @@ subst' f w = \case    ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)    EConstArr x n t a -> EConstArr x n t a    EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) +  EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b)    EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)    ESum1Inner x e -> ESum1Inner x (subst' f w e)    EUnit x e -> EUnit x (subst' f w e) @@ -400,6 +410,7 @@ subst' f w = \case    EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)    EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)    EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) +  EZip x a b -> EZip x (subst' f w a) (subst' f w b)    EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)    EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)    EConst x t v -> EConst x t v @@ -521,34 +532,24 @@ eidxEq (SS n) a b  emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b)  emap f arr -  | STArr n t <- typeOf arr +  | STArr _ t <- typeOf arr    , Dict <- styKnown t -  = ELet ext arr $ -      EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ -        ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) -                           (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -          weakenExpr (WCopy (WSink .> WSink)) f +  = EMap ext f arr  ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)  ezipWith f arr1 arr2 -  | STArr n t1 <- typeOf arr1 +  | STArr _ t1 <- typeOf arr1    , STArr _ t2 <- typeOf arr2    , Dict <- styKnown t1    , Dict <- styKnown t2 -  = ELet ext arr1 $ -    ELet ext (weakenExpr WSink arr2) $ -      EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ -        ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) -                           (EVar ext (tTup (sreplicate n tIx)) IZ)) $ -        ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) -                           (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ -          weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f +  = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) +                                   IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) +                                   IS (IS i) -> EVar ext t (IS i)) +                    f) +             (EZip ext arr1 arr2)  ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip arr1 arr2 = -  let STArr _ t1 = typeOf arr1 -      STArr _ t2 = typeOf arr2 -  in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 +ezip = EZip ext  eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a  eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) @@ -652,3 +653,53 @@ makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy t      go SMTMaybe{} _ = ENil ext      go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e      go SMTScal{} _ = ENil ext + +splitSparsePair +  :: -- given a sparsity +     STy (TPair a b) -> Sparse (TPair a b) t' +  -> (forall a' b'. +          -- I give you back two sparsities for a and b +          Sparse a a' -> Sparse b b' +          -- furthermore, I tell you that either your t' is already this (a', b') pair... +       -> Either +            (t' :~: TPair a' b') +            -- or I tell you how to construct a' and b' from t', given an actual t' +            (forall r' env. +                 Idx env t' +              -> (forall env'. +                     (forall c. Ex env' c -> Ex env c) +                  -> Ex env' a' -> Ex env' b' -> r') +              -> r') +       -> r) +  -> r +splitSparsePair _ SpAbsent k = +  k SpAbsent SpAbsent $ Right $ \_ k2 -> +  k2 id (ENil ext) (ENil ext) +splitSparsePair _ (SpPair s1 s2) k1 = +  k1 s1 s2 $ Left Refl +splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = +  let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in +  k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> +  k2 (elet $ +       emaybe (EVar ext (STMaybe (applySparse s t)) i) +         (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) +         (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) +     (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) + +splitSparsePair _ (SpSparse SpAbsent) k = +  k SpAbsent SpAbsent $ Right $ \_ k2 -> +  k2 id (ENil ext) (ENil ext) +-- -- TODO: having to handle sparse-of-sparse at all is ridiculous +splitSparsePair t (SpSparse (SpSparse s)) k = +  splitSparsePair t (SpSparse s) $ \s1 s2 eres -> +  k s1 s2 $ Right $ \i k2 -> +  case eres of +    Left refl -> case refl of {} +    Right f -> +      f IZ $ \wrap e1 e2 -> +        k2 (\body -> +              elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) +                     (ENothing ext (applySparse s t)) +                     (evar IZ)) $ +                wrap body) +           e1 e2 diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 296c021..bc02417 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -523,9 +523,8 @@ occCountX initialS topexpr k = case topexpr of        SsNone ->          occCountX SsFull a $ \env1 mka ->          occCountX SsNone b $ \env2'' mkb -> -        withSome (scaleMany (Some env2'')) $ \env2' -> -        occEnvPop' env2' $ \env2 s2 -> -        withSome (Some env1 <> Some env2) $ \env -> +        occEnvPop' env2'' $ \env2' s2 -> +        withSome (Some env1 <> scaleMany (Some env2')) $ \env ->          k env $ \env' ->            use (EBuild ext n (mka env') $                   use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ @@ -535,14 +534,31 @@ occCountX initialS topexpr k = case topexpr of        SsArr' s' ->          occCountX SsFull a $ \env1 mka ->          occCountX s' b $ \env2'' mkb -> -        withSome (scaleMany (Some env2'')) $ \env2' -> -        occEnvPop' env2' $ \env2 s2 -> -        withSome (Some env1 <> Some env2) $ \env -> +        occEnvPop' env2'' $ \env2' s2 -> +        withSome (Some env1 <> scaleMany (Some env2')) $ \env ->          k env $ \env' ->            EBuild ext n (mka env') $              elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $                weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) +  EMap _ a b -> +    case s of +      SsNone -> +        occCountX SsNone a $ \env1'' mka -> +        occEnvPop' env1'' $ \env1' s1 -> +        occCountX (SsArr s1) b $ \env2 mkb -> +        withSome (scaleMany (Some env1') <> Some env2) $ \env -> +        k env $ \env' -> +          use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ +            ENil ext +      SsArr' s' -> +        occCountX s' a $ \env1'' mka -> +        occEnvPop' env1'' $ \env1' s1 -> +        occCountX (SsArr s1) b $ \env2 mkb -> +        withSome (scaleMany (Some env1') <> Some env2) $ \env -> +        k env $ \env' -> +          EMap ext (mka (OccPush env' () s1)) (mkb env') +    EFold1Inner _ commut a b c ->      occCountX SsFull a $ \env1''' mka ->      withSome (scaleMany (Some env1''')) $ \env1'' -> @@ -608,6 +624,27 @@ occCountX initialS topexpr k = case topexpr of          k env $ \env' ->            EReshape ext n (mkesh env') (mke env') +  EZip _ a b -> +    case s of +      SsNone -> +        occCountX SsNone a $ \env1 mka -> +        occCountX SsNone b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (mka env') $ use (mkb env') $ ENil ext +      SsArr' SsNone -> +        occCountX (SsArr SsNone) a $ \env1 mka -> +        occCountX SsNone b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          use (mkb env') $ mka env' +      SsArr' (SsPair' s1 s2) -> +        occCountX (SsArr s1) a $ \env1 mka -> +        occCountX (SsArr s2) b $ \env2 mkb -> +        withSome (Some env1 <> Some env2) $ \env -> +        k env $ \env' -> +          EZip ext (mka env') (mkb env') +    EFold1InnerD1 _ cm e1 e2 e3 ->      case s of        -- If nothing is necessary, we can execute a fold and then proceed to ignore it diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 68fc629..2c51b85 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -204,6 +204,14 @@ ppExpr' d val expr = case expr of             <> hardline <> e')          (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) +  EMap _ a b -> do +    let STArr _ t1 = typeOf b +    name <- genNameIfUsedIn' "i" t1 IZ a +    a' <- ppExpr' 0 (Const name `SCons` val) a +    b' <- ppExpr' 11 val b +    return $ ppParen (d > 0) $ +      ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] +    EFold1Inner _ cm a b c -> do      name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a      name2 <- genNameIfUsedIn (typeOf a) IZ a @@ -238,7 +246,12 @@ ppExpr' d val expr = case expr of    EReshape _ n esh e -> do      esh' <- ppExpr' 11 val esh      e' <- ppExpr' 11 val e -    return $ ppParen (d > 10) $ ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr <+> esh' <+> e' +    return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + +  EZip _ e1 e2 -> do +    e1' <- ppExpr' 11 val e1 +    e2' <- ppExpr' 11 val e2 +    return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']    EFold1InnerD1 _ cm a b c -> do      name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs index a22b73f..1712ba5 100644 --- a/src/AST/UnMonoid.hs +++ b/src/AST/UnMonoid.hs @@ -38,6 +38,7 @@ unMonoid = \case    ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)    EConstArr _ n t x -> EConstArr ext n t x    EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) +  EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)    EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)    ESum1Inner _ e -> ESum1Inner ext (unMonoid e)    EUnit _ e -> EUnit ext (unMonoid e) @@ -45,6 +46,7 @@ unMonoid = \case    EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)    EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)    EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) +  EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)    EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)    EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)    EConst _ t x -> EConst ext t x diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 3a97fd1..f0820b8 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -129,7 +129,7 @@ wCopies bs w =    let bs' = slistMap (\_ -> Const ()) bs    in WStack bs' bs' WId w -wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env  wRaiseAbove SNil _ = WClosed  wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs index 6301dc1..71da793 100644 --- a/src/Analysis/Identity.hs +++ b/src/Analysis/Identity.hs @@ -202,6 +202,15 @@ idana env expr = case expr of      res <- VIArr <$> genId <*> shidsToVec dim shids      pure (res, EBuild res dim e1' e2') +  EMap _ e1 e2 -> do +    let STArr _ t = typeOf e2 +    x1 <- genIds t +    (_, e1') <- idana (x1 `SCons` env) e1 +    (v2, e2') <- idana env e2 +    let VIArr _ sh = v2 +    res <- VIArr <$> genId <*> pure sh +    pure (res, EMap res e1' e2') +    EFold1Inner _ cm e1 e2 e3 -> do      let t1 = typeOf e1      x1 <- genIds t1 @@ -250,6 +259,13 @@ idana env expr = case expr of      res <- VIArr <$> genId <*> shidsToVec dim v1      pure (res, EReshape res dim e1' e2') +  EZip _ e1 e2 -> do +    (v1, e1') <- idana env e1 +    (_, e2') <- idana env e2 +    let VIArr _ sh = v1 +    res <- VIArr <$> genId <*> pure sh +    pure (res, EZip res e1' e2') +    EFold1InnerD1 _ cm e1 e2 e3 -> do      let t1 = typeOf e2      x1 <- genIds t1 diff --git a/src/CHAD.hs b/src/CHAD.hs index 7594a0f..72ce36d 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1048,73 +1048,46 @@ drev des accumMap sd = \case      , let eltty = typeOf orige      , shty :: STy shty <- tTup (sreplicate ndim tIx)      , Refl <- indexTupD1Id ndim -> -    deleteUnused (descrList des) (occEnvPopSome (occCountAll orige)) $ \(usedSub :: Subenv env env') -> -    let e = unsafeWeakenWithSubenv (SEYesR usedSub) orige in -    subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> -    accumPromote sdElt usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> -    let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in -    case drev (prodes `DPush` (shty, Nothing, SDiscr)) accumMapPro sdElt e of { Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 SETop e2 -> -    case lemAppendNil @e_binds of { Refl -> -    let tapety = tapeTy (subList (bindingsBinds e0) subtapeE) in -    let collectexpr = bindingsCollectTape (bindingsBinds e0) subtapeE in -    let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in -    let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in -    Ret (mergePrimalBindings -          `bpush` weakenExpr (wSinks (d1e envPro)) (drevPrimal des she) +    drevLambda des accumMap (shty, SDiscr) sdElt orige $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> +    let library = #ix (shty `SCons` SNil) +                  &. #e0 (bindingsBinds e0) +                  &. #propr (d1e provars) +                  &. #d1env (desD1E des) +                  &. #d (auto1 @sdElt) +                  &. #tape (auto1 @e_tape) +                  &. #pro (d2ace provars) +                  &. #d2acEnv (d2ace (select SAccum des)) +                  &. #darr (auto1 @(TArr ndim sdElt)) +                  &. #tapearr (auto1 @(TArr ndim e_tape)) in +    Ret (proPrimalBinds            `bpush` EBuild ext ndim -                    (EVar ext shty IZ) -                    (letBinds (fst (weakenBindingsE (autoWeak (#ix (shty `SCons` SNil) -                                                               &. #sh (shty `SCons` SNil) -                                                               &. #propr (d1e envPro) -                                                               &. #d1env (desD1E des) -                                                               &. #d1env' (desD1E usedDes)) -                                                              (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                                              (#ix :++: #sh :++: #propr :++: #d1env)) +                    (weakenExpr (wSinks (d1e provars)) (drevPrimal des she)) +                    (letBinds (fst (weakenBindingsE (autoWeak library +                                                              (#ix :++: #d1env) +                                                              (#ix :++: #propr :++: #d1env))                                                      e0)) $ -                       let w = autoWeak (#ix (shty `SCons` SNil) -                                         &. #sh (shty `SCons` SNil) -                                         &. #e0 (bindingsBinds e0) -                                         &. #propr (d1e envPro) -                                         &. #d1env (desD1E des) -                                         &. #d1env' (desD1E usedDes)) -                                        (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) -                                        (#e0 :++: #ix :++: #sh :++: #propr :++: #d1env) -                           w' = w .> wCopies (bindingsBinds e0) (WClosed @(shty : D1E env')) -                       in EPair ext (weakenExpr w e1) (collectexpr w')) -          `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) IZ)) -        (SEYesR (SENo (SEYesR (subenvAll (d1e envPro))))) -        (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) tapety)) (IS IZ))) -        (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E (subenvCompose subMergeUsed proSub))) -        (let sinkOverEnvPro = wSinks @(sd : TArr ndim (Tape e_tape) : Tup (Replicate ndim TIx) : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace envPro) in +                      weakenExpr (autoWeak library (#e0 :++: #ix :++: #d1env) +                                                   (#e0 :++: #ix :++: #propr :++: #d1env)) +                                 (EPair ext e1 e1tape)) +          `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) +        (SEYesR (SENo (subenvAll (d1e provars)))) +        (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) +        (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) +        (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in           ESnd ext $ -           uninvertTup (d2e envPro) (STArr ndim STNil) $ -             makeAccumulators @_ @_ @(TArr ndim TNil) (WSink .> WSink .> WSink .> wRaiseAbove (d1e envPro) (d2ace (select SAccum des))) envPro $ -               EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ -                 -- the cotangent for this element -                 ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) -                                    (EVar ext shty IZ)) $ -                 -- the tape for this element -                 ELet ext (EIdx ext (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) -                                    (EVar ext shty (IS IZ))) $ -                 let (rebinds, prerebinds) = reconstructBindings (subList (bindingsBinds e0) subtapeE) -                 in letBinds (rebinds IZ) $ -                      weakenExpr (autoWeak (#d (auto1 @sdElt) -                                            &. #pro (d2ace envPro) -                                            &. #etape (subList (bindingsBinds e0) subtapeE) -                                            &. #prerebinds prerebinds -                                            &. #tape (auto1 @(Tape e_tape)) -                                            &. #ix (auto1 @shty) -                                            &. #darr (auto1 @(TArr ndim sdElt)) -                                            &. #tapearr (auto1 @(TArr ndim (Tape e_tape))) -                                            &. #sh (auto1 @shty) -                                            &. #propr (d1e envPro) -                                            &. #d2acUsed (d2ace (select SAccum usedDes)) -                                            &. #d2acEnv (d2ace (select SAccum des))) -                                           (#pro :++: #d :++: #etape :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) -                                           ((#etape :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #propr :++: #d2acEnv) -                                  .> wPro (subList (bindingsBinds e0) subtapeE)) -                                 e2) -    }} +           wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ +             EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ +               -- the tape for this element +               ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> sinkOverEnvPro @> IS IZ)) +                                  (EVar ext shty IZ)) $ +               -- the cotangent for this element +               ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> WSink .> sinkOverEnvPro @> IZ)) +                                  (EVar ext shty (IS IZ))) $ +                weakenExpr (autoWeak library (#d :++: #tape :++: #pro :++: #d2acEnv) +                                             (#d :++: #tape :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) +                           e2) + +  EMap{} -> undefined    EFold1Inner _ commut origef ex₀ earr      | SpArr @_ @sdElt sdElt <- sd @@ -1346,6 +1319,33 @@ drev des accumMap sd = \case                                    (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $            weakenExpr (WCopy (WSink .> WSink)) e2) +  EZip _ a b +    | SpArr sd' <- sd +    , STArr n t1 <- typeOf a +    , STArr _ t2 <- typeOf b -> +    splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> +    case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` +                        toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of +    { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> +    subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> +    Ret binds +        subtape +        (EZip ext a1 b1) +        subBoth +        (case pairSplitE of +           Left Refl -> +             let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in +             plus_A_B +               (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) +               (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) +           Right f -> f IZ $ \wrapPair pick1 pick2 -> +             elet (emap (wrapPair (EPair ext pick1 pick2)) +                    (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ +             plus_A_B +               (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) +               (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) +    } +    ENothing{} -> err_unsupported "ENothing"    EJust{} -> err_unsupported "EJust"    EMaybe{} -> err_unsupported "EMaybe" @@ -1476,6 +1476,88 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of      , Refl <- lemAppendNil @tapebinds ->          RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 +drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) +           => Descr env sto +           -> VarMap Int (D2AcE (Select env sto "accum")) +           -> (STy a, Storage s) +           -> Sparse (D2 t) dt +           -> Expr ValId (a : env) t +           -> (forall provars shbinds tape d2a'. +                  SList STy provars +               -> Subenv (D2E (Select env sto "merge")) (D2E provars) +               -> Bindings Ex (D1E env) (D1E provars)  -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) +               -> Bindings Ex (D1 a : D1E env) shbinds +               -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) +               -> Ex (Append shbinds (D1 a : D1E env)) tape +               -> Sparse (D2 a) d2a' +               -> (forall env' b. +                      D1E provars :> env' +                   -> Ex (Append (D2AcE provars) env') b +                   -> Ex (                       env') (TPair b (Tup (D2E provars)))) +               -> Ex (dt : tape : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' +               -> r) +           -> r +drevLambda des accumMap (argty, argsto) sd origef k = +  let t = typeOf origef in +  deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> +  let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in +  subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> +  accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> +  let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in +  let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in +  let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in +  case prf1 prodes argty argsto of { Refl -> +  case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> +  let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in +  extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> +  let library = #fbinds (bindingsBinds ef0) +                &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) +                &. #ftape (auto1 @(Tape e_tape)) +                &. #arg (d1 argty `SCons` SNil) +                &. #d (applySparse sd (d2 t) `SCons` SNil) +                &. #d1env (desD1E des) +                &. #d1env' (desD1E usedDes) +                &. #propr (d1e envPro) +                &. #d2acUsed (d2ace (select SAccum usedDes)) +                &. #d2acEnv (d2ace (select SAccum des)) +                &. #d2acPro (d2ace envPro) +                &. #efPrerebinds efPrerebinds in +  k envPro +    (subenvD2E (subenvCompose subMergeUsed proSub)) +    mergePrimalBindings +    (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) +    (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) +                                  (#fbinds :++: #arg :++: #d1env)) +                ef1) +    (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) +    argSp +    (\wpro1 body -> +      uninvertTup (d2e envPro) (typeOf body) $ +        makeAccumulators wpro1 envPro $ +          body) +    (letBinds (efRebinds (IS IZ)) $ +      weakenExpr +        (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) +                          ((#ftapebinds :++: #efPrerebinds) :++: #d :++: #ftape :++: #d2acPro :++: #d2acEnv) +         .> wPro (subList (bindingsBinds ef0) subtapeEf)) +        (getSparseArg ef2)) +  }} +  where +    extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) +                   => proxy env sto -> proxy2 a -> Storage s +                      -- if s == "merge", this simplifies to SubenvS '[D2 a] t' +                      -- if s == "discr", this simplifies to SubenvS '[] t' +                   -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' +                   -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r +    extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id +    extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) +    extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + +    prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s +         -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" +    prf1 _ _ SMerge = Refl +    prf1 _ _ SDiscr = Refl +  -- TODO: proper primal-only transform that doesn't depend on D1 = Id  drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)  drevPrimal des e diff --git a/src/Compile.hs b/src/Compile.hs index f2063ee..d6ad7ec 100644 --- a/src/Compile.hs +++ b/src/Compile.hs @@ -8,7 +8,7 @@  {-# LANGUAGE PolyKinds #-}  {-# LANGUAGE TupleSections #-}  {-# LANGUAGE TypeApplications #-} -module Compile (compile) where +module Compile (compile, compileStderr) where  import Control.Applicative (empty)  import Control.Monad (forM_, when, replicateM) @@ -71,28 +71,30 @@ debugAllocs :: Bool; debugAllocs = toEnum 0  -- | Emit extra C code that checks stuff  emitChecks :: Bool; emitChecks = toEnum 0 +-- | Returns compiled function plus compilation output (warnings)  compile :: SList STy env -> Ex env t -        -> IO (SList Value env -> IO (Rep t)) +        -> IO (SList Value env -> IO (Rep t), String)  compile = \env expr -> do    codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i))    let (source, offsets) = compileToString codeID env expr    when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>"    when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" -  lib <- buildKernel source "kernel" +  (lib, compileOutput) <- buildKernel source "kernel"    let result_type = typeOf expr        result_size = sizeofSTy result_type -  return $ \val -> do -    allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do -      let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) -      serialiseArguments args ptr $ do -        callKernelFun lib ptr -        ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) -        when (ok /= 1) $ -          ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) -        deserialise result_type ptr (koResultOffset offsets) +  let function val = do +        allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do +          let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) +          serialiseArguments args ptr $ do +            callKernelFun lib ptr +            ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) +            when (ok /= 1) $ +              ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) +            deserialise result_type ptr (koResultOffset offsets) +  return (function, compileOutput)    where      serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r      serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = @@ -100,6 +102,15 @@ compile = \env expr -> do          serialiseArguments args ptr k      serialiseArguments _ _ k = k +-- | 'compile', but writes any produced C compiler output to stderr. +compileStderr :: SList STy env -> Ex env t +              -> IO (SList Value env -> IO (Rep t)) +compileStderr env expr = do +  (fun, output) <- compile env expr +  when (not (null output)) $ +    hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" +  return fun +  data StructDecl = StructDecl      String  -- ^ name @@ -791,6 +802,15 @@ compile' env = \case      return (CELit arrname) +  -- TODO: actually generate decent code here +  EMap _ e1 e2 -> do +    let STArr n _ = typeOf e2 +    compile' env $ +      elet e2 $ +        EBuild ext n (EShape ext (evar IZ)) $ +          elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ +            weakenExpr (WCopy (WSink .> WSink)) e1 +    EFold1Inner _ commut efun ex0 earr -> do      let STArr (SS n) t = typeOf earr @@ -940,6 +960,16 @@ compile' env = \case                [("buf", CEProj (CELit arrname) "buf")                ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) +  -- TODO: actually generate decent code here +  EZip _ e1 e2 -> do +    let STArr n _ = typeOf e1 +    compile' env $ +      elet e1 $ +      elet (weakenExpr WSink e2) $ +        EBuild ext n (EShape ext (evar (IS IZ))) $ +          EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) +                    (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) +    EFold1InnerD1 _ commut efun ex0 earr -> do      let STArr (SS n) t = typeOf earr          STPair _ bty = typeOf efun diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs index cc6d5fa..ad4180f 100644 --- a/src/Compile/Exec.hs +++ b/src/Compile/Exec.hs @@ -30,7 +30,7 @@ debug = False  -- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs)  data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ()))) -buildKernel :: String -> String -> IO KernelLib +buildKernel :: String -> String -> IO (KernelLib, String)  buildKernel csource funname = do    template <- (++ "/tmp.chad.") <$> getTempDir    path <- mkdtemp template @@ -43,7 +43,8 @@ buildKernel csource funname = do               ,"-Wall", "-Wextra"               ,"-Wno-unused-variable", "-Wno-unused-but-set-variable"               ,"-Wno-unused-parameter", "-Wno-unused-function" -             ,"-Wno-alloc-size-larger-than"]  -- ideally we'd keep this, but gcc reports false positives +             ,"-Wno-alloc-size-larger-than"  -- ideally we'd keep this, but gcc reports false positives +             ,"-Wno-maybe-uninitialized"]  -- maximum1i goes out of range if its input is empty, yes, don't complain    (ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource    -- Print the source before the GCC output. @@ -51,11 +52,6 @@ buildKernel csource funname = do      ExitSuccess -> return ()      ExitFailure{} -> hPutStrLn stderr $ "[chad] Kernel compilation failed! Source: <<<\n" ++ lineNumbers csource ++ ">>>" -  when (not (null gccStdout)) $ -    hPutStrLn stderr $ "[chad] Kernel compilation: GCC stdout: <<<\n" ++ gccStdout ++ ">>>" -  when (not (null gccStderr)) $ -    hPutStrLn stderr $ "[chad] Kernel compilation: GCC stderr: <<<\n" ++ gccStderr ++ ">>>" -    case ec of      ExitSuccess -> return ()      ExitFailure{} -> do @@ -72,7 +68,7 @@ buildKernel csource funname = do    _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1))                             when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)"                             dlclose dl) -  return (KernelLib ref) +  return (KernelLib ref, gccStdout ++ (if null gccStdout then "" else "\n") ++ gccStderr)  foreign import ccall "dynamic"    wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO () diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs index b353def..6655423 100644 --- a/src/ForwardAD.hs +++ b/src/ForwardAD.hs @@ -254,8 +254,10 @@ makeFwdADArtifactInterp env expr =    in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr)  {-# NOINLINE makeFwdADArtifactCompile #-} -makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t) -makeFwdADArtifactCompile env expr = FwdADArtifact env (typeOf expr) . (unsafePerformIO .) <$> compile (dne env) (dfwdDN expr) +makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String) +makeFwdADArtifactCompile env expr = do +  (fun, output) <- compile (dne env) (dfwdDN expr) +  return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output)  drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env)  drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs index 44bdbb2..a1e9d0d 100644 --- a/src/ForwardAD/DualNumbers.hs +++ b/src/ForwardAD/DualNumbers.hs @@ -153,6 +153,7 @@ dfwdDN = \case      (EConstArr ext n t x)    EBuild _ n a b      | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) +  EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b)    EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c)    ESum1Inner _ e ->      let STArr n (STScal t) = typeOf e @@ -168,6 +169,7 @@ dfwdDN = \case    EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b)    EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e    EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e +  EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b)    EReshape _ n esh e      | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e)    EConst _ t x -> scalTyCase t diff --git a/src/Interpreter.hs b/src/Interpreter.hs index 9e3d2a6..d982261 100644 --- a/src/Interpreter.hs +++ b/src/Interpreter.hs @@ -116,6 +116,9 @@ interpret'Rec env = \case    EBuild _ dim a b -> do      sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a      arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) +  EMap _ a b -> do +    let STArr _ t = typeOf b +    arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b    EFold1Inner _ _ a b c -> do      let t = typeOf b      let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a @@ -150,6 +153,13 @@ interpret'Rec env = \case      sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh      arr <- interpret' env e      return $ arrayReshape sh arr +  EZip _ a b -> do +    arr1 <- interpret' env a +    arr2 <- interpret' env b +    let sh = arrayShape arr1 +    when (sh /= arrayShape arr2) $ +      error "Interpreter: mismatched shapes in EZip" +    return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i))    EFold1InnerD1 _ _ a b c -> do      let t = typeOf b      let f = \x y -> interpret' (V t y `SCons` V t x `SCons` env) a diff --git a/src/Simplify.hs b/src/Simplify.hs index b89d7f6..1889adc 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -200,12 +200,14 @@ simplify'Rec = \case        EMaybe ext (ESnd ext e1) (ESnd ext e2) e3    -- TODO: more array indexing -  EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet e3 e2 +  EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2 +  EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1    EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)    EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1    -- TODO: more array shape    EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 +  EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2)    -- TODO: more constant folding    EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) @@ -308,6 +310,7 @@ simplify'Rec = \case    ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |]    EConstArr _ n t v -> pure $ EConstArr ext n t v    EBuild _ n a b -> [simprec| EBuild ext n *a *b |] +  EMap _ a b -> [simprec| EMap ext *a *b |]    EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |]    ESum1Inner _ e -> [simprec| ESum1Inner ext *e |]    EUnit _ e -> [simprec| EUnit ext *e |] @@ -315,6 +318,7 @@ simplify'Rec = \case    EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |]    EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |]    EReshape _ n a b -> [simprec| EReshape ext n *a *b |] +  EZip _ a b -> [simprec| EZip ext *a *b |]    EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |]    EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |]    EConst _ t v -> pure $ EConst ext t v @@ -364,6 +368,7 @@ hasAdds = \case    ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c    EConstArr _ _ _ _ -> False    EBuild _ _ a b -> hasAdds a || hasAdds b +  EMap _ a b -> hasAdds a || hasAdds b    EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c    ESum1Inner _ e -> hasAdds e    EUnit _ e -> hasAdds e @@ -371,6 +376,7 @@ hasAdds = \case    EMaximum1Inner _ e -> hasAdds e    EMinimum1Inner _ e -> hasAdds e    EReshape _ _ a b -> hasAdds a || hasAdds b +  EZip _ a b -> hasAdds a || hasAdds b    EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c    EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c    ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e diff --git a/test-framework/Test/Framework.hs b/test-framework/Test/Framework.hs index 80711b2..b7d0dc2 100644 --- a/test-framework/Test/Framework.hs +++ b/test-framework/Test/Framework.hs @@ -1,9 +1,12 @@ +{-# LANGUAGE DeriveGeneric #-}  {-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DerivingVia #-}  {-# LANGUAGE ExistentialQuantification #-}  {-# LANGUAGE GeneralizedNewtypeDeriving #-}  {-# LANGUAGE ImplicitParams #-}  {-# LANGUAGE ImportQualifiedPost #-}  {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TupleSections #-}  module Test.Framework ( @@ -11,12 +14,16 @@ module Test.Framework (    testGroup,    groupSetCollapse,    testProperty, -  withResource, -  withResource',    runTests,    defaultMain,    Options(..), +  -- * Resources +  withResource, +  withResource', +  TestCtx, +  outputWarningText, +    -- * Compatibility    TestName,  ) where @@ -29,11 +36,13 @@ import Control.Monad (forM, when, forM_)  import Control.Monad.IO.Class  import Data.IORef  import Data.List (isInfixOf, intercalate) -import Data.Maybe (isJust, mapMaybe, fromJust) +import Data.Maybe (mapMaybe, fromJust) +import Data.Monoid (All(..), Any(..), Sum(..))  import Data.PQueue.Prio.Min qualified as PQ  import Data.String (fromString)  import Data.Time.Clock  import GHC.Conc (getNumProcessors) +import GHC.Generics (Generic, Generically(..))  import System.Console.ANSI qualified as ANSI  import System.Console.Concurrent (outputConcurrent)  import System.Console.Regions @@ -57,10 +66,16 @@ type TestName = String  data TestTree    = Group GroupOpts String [TestTree] -  | forall a. Resource String (IO a) (a -> IO ()) (a -> TestTree) +  | forall a. Resource String ((?testCtx :: TestCtx) => IO a) ((?testCtx :: TestCtx) => a -> IO ()) (a -> TestTree)        -- ^ Name is not specified by user, but inherited from the tree below    | HP String H.Property +data TestCtx = TestCtx +  { tctxOutput :: String -> IO () } + +outputWarningText :: (?testCtx :: TestCtx) => String -> IO () +outputWarningText = tctxOutput ?testCtx +  -- Not exported because a Resource is not supposed to have a name in the first place  treeName :: TestTree -> String  treeName (Group _ name _) = name @@ -82,13 +97,13 @@ groupSetCollapse (Group opts name trees) = Group opts { goCollapse = True } name  groupSetCollapse _ = error "groupSetCollapse: not called on a Group"  -- | The @a -> TestTree@ function must use the @a@ only inside properties: the --- functoin will be passed 'undefined' when exploring the test tree (without +-- function will be passed 'undefined' when exploring the test tree (without  -- running properties). -withResource :: IO a -> (a -> IO ()) -> (a -> TestTree) -> TestTree +withResource :: ((?testCtx :: TestCtx) => IO a) -> ((?testCtx :: TestCtx) => a -> IO ()) -> (a -> TestTree) -> TestTree  withResource make cleanup fun = Resource (treeName (fun undefined)) make cleanup fun  -- | Same caveats as 'withResource'. -withResource' :: IO a -> (a -> TestTree) -> TestTree +withResource' :: ((?testCtx :: TestCtx) => IO a) -> (a -> TestTree) -> TestTree  withResource' make fun = withResource make (\_ -> return ()) fun  testProperty :: String -> H.Property -> TestTree @@ -226,7 +241,7 @@ runTests options = \tree' ->                           successVar <- newEmptyMVar                           runTreePar Nothing [] [] tree successVar                           readMVar successVar -             else isJust <$> runTreeSeq 0 [] tree +             else getAll . seqresAllSuccess <$> runTreeSeq 0 [] tree        stats <- readIORef statsRef        endtm <- getCurrentTime        let ?istty = isterm in printStats (treeNumTests tree) stats (diffUTCTime endtm starttm) @@ -284,6 +299,9 @@ runTreePar topmparregion revidxlist revpath toptree@Resource{} topoutvar = runRe        let pathitem = '[' : show depth ++ "](" ++ inhname ++ ")"            path = intercalate "/" (reverse (pathitem : revpath))            idxlist = reverse revidxlist +      let ?testCtx = TestCtx (\str -> +                       outputConcurrent (ansiYellow ++ "## Warning for " ++ path ++ ":" ++ ansiReset ++ +                                         "\n" ++ str ++ "\n"))        submitOrRunIn mparregion idxlist Nothing $ \makeRegion -> do          setConsoleRegion makeRegion ('|' : path ++ " [R] making...") @@ -337,37 +355,51 @@ submitOrRunIn (Just reg) _idxlist outvar fun = do    result <- fun reg    forM_ outvar $ \mvar -> putMVar mvar result +data SeqRes = SeqRes +  { seqresHaveWarnings :: Any +  , seqresAllSuccess :: All +  , seqresNumLines :: Sum Int } +  deriving (Generic) +  deriving (Semigroup, Monoid) via Generically SeqRes +  -- | If all tests are successful, returns the number of output lines produced  runTreeSeq :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int, ?istty :: Bool) -           => Int -> [String] -> TestTree -> IO (Maybe Int) +           => Int -> [String] -> TestTree -> IO SeqRes  runTreeSeq indent revpath (Group opts name trees) = do    putStrLn (replicate (2 * indent) ' ' ++ name) >> hFlush stdout    starttm <- getCurrentTime -  mlns <- fmap (fmap sum . sequence) . forM trees $ -            runTreeSeq (indent + 1) (name : revpath) +  res <- fmap mconcat . forM trees $ +           runTreeSeq (indent + 1) (name : revpath)    endtm <- getCurrentTime -  case mlns of -    Just lns | goCollapse opts, ?istty -> do +  if not (getAny (seqresHaveWarnings res)) && getAll (seqresAllSuccess res) && goCollapse opts && ?istty +    then do        let thislen = 2*indent + length name +      let Sum lns = seqresNumLines res        putStrLn $ concat (replicate (lns+1) (ANSI.cursorUpCode 1 ++ ANSI.clearLineCode)) ++                   ANSI.setCursorColumnCode 0 ++                   replicate (2 * indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' ' ++                   ansiGreen ++ "OK" ++ ansiReset ++                   prettyDuration False (realToFrac (diffUTCTime endtm starttm)) -      return (Just 1) -    _ -> return ((+1) <$> mlns) +      return (mempty { seqresNumLines = 1 }) +    else return (res <> (mempty { seqresNumLines = 1 }))  runTreeSeq indent path (Resource _ make cleanup fun) = do +  outputted <- newIORef False +  let ?testCtx = TestCtx (\str -> do +                   atomicModifyIORef' outputted (\_ -> (True, ())) +                   putStrLn (ansiYellow ++ "## Warning for " ++ (intercalate "/" (reverse path)) ++ +                             ":" ++ ansiReset ++ "\n" ++ str))    value <- make -  success <- runTreeSeq indent path (fun value) +  res <- runTreeSeq indent path (fun value)    cleanup value -  return success +  warnings <- readIORef outputted +  return (res <> (mempty { seqresHaveWarnings = Any warnings }))  runTreeSeq indent path (HP name prop) = do    let thislen = 2*indent + length name    let prefix = replicate (2*indent) ' ' ++ name ++ ": " ++ replicate (?maxlen - thislen) ' '    when ?istty $ putStr prefix >> hFlush stdout    (ok, rendered) <- runHP (outputProgress (?maxlen + 2)) path name prop    putStrLn ((if ?istty then ANSI.clearFromCursorToLineEndCode else prefix) ++ rendered) >> hFlush stdout -  return (if ok then Just 1 else Nothing) +  return (mempty { seqresAllSuccess = All ok, seqresNumLines = 1 })  runHP :: (?options :: Options, ?stats :: IORef Stats, ?maxlen :: Int)        => (H.Report H.Progress -> IO ()) @@ -489,10 +521,11 @@ ansi :: (?istty :: Bool) => String -> String  ansi | ?istty = id       | otherwise = const "" -ansiRed, ansiGreen, ansiReset :: (?istty :: Bool) => String -ansiRed   = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Red]) -ansiGreen = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Green]) -ansiReset = ansi (ANSI.setSGRCode [ANSI.Reset]) +ansiRed, ansiYellow, ansiGreen, ansiReset :: (?istty :: Bool) => String +ansiRed    = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Red]) +ansiYellow = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Vivid ANSI.Yellow]) +ansiGreen  = ansi (ANSI.setSGRCode [ANSI.SetColor ANSI.Foreground ANSI.Dull ANSI.Green]) +ansiReset  = ansi (ANSI.setSGRCode [ANSI.Reset])  -- getTermIsDark :: IO Bool  -- getTermIsDark = do diff --git a/test/Main.hs b/test/Main.hs index 4bc9082..d586973 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -11,6 +11,7 @@  {-# LANGUAGE UndecidableInstances #-}  module Main where +import Control.Monad (when)  import Control.Monad.Trans.Class (lift)  import Control.Monad.Trans.State  import Data.Bifunctor @@ -352,7 +353,10 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e        dtermSChadSUS = simplifyFix $ unMonoid dtermSChadS        dtermSChadSUSP = pruneExpr env dtermSChadSUS    in -  withResource (makeFwdADArtifactCompile env exprS) (\_ -> pure ()) $ \fwdartifactC -> +  withResource' (do (fun, output) <- makeFwdADArtifactCompile env exprS +                    when (not (null output)) $ +                      outputWarningText $ "Forward AD compile GCC output: <<<\n" ++ output ++ ">>>" +                    return fun) $ \fwdartifactC ->    withCompiled env dtermSChadSUSP $ \dcompSChadSUSP ->      testProperty testname $ property $ do        annotate (concat (unSList (\t -> ppSTy 0 t ++ " -> ") env) ++ ppSTy 0 (typeOf expr)) @@ -416,7 +420,11 @@ adTestGenChad testname config env envGenerator expr exprS primalSfun | Dict <- e        diff tansCompSChadSUSP closeIshE' tansFwd  withCompiled :: SList STy env -> Ex env t -> ((SList Value env -> IO (Rep t)) -> TestTree) -> TestTree -withCompiled env expr = withResource (compile env expr) (\_ -> pure ()) +withCompiled env expr = withResource' $ do +  (fun, output) <- compile env expr +  when (not (null output)) $ +    outputWarningText $ "Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" +  return fun  gen_gmm :: Gen (SList Value [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64])  gen_gmm = do  | 
