summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal2
-rw-r--r--src/AST.hs20
-rw-r--r--src/AST/Accum.hs58
-rw-r--r--src/AST/Count.hs2
-rw-r--r--src/AST/Pretty.hs21
-rw-r--r--src/AST/Sparse.hs110
-rw-r--r--src/AST/Sparse/Types.hs107
-rw-r--r--src/AST/SplitLets.hs2
-rw-r--r--src/AST/UnMonoid.hs111
-rw-r--r--src/Analysis/Identity.hs4
-rw-r--r--src/CHAD.hs117
-rw-r--r--src/CHAD/Accum.hs45
-rw-r--r--src/CHAD/Top.hs54
-rw-r--r--src/CHAD/Types.hs16
-rw-r--r--src/Compile.hs171
-rw-r--r--src/Data.hs8
-rw-r--r--src/Example.hs3
-rw-r--r--src/Interpreter.hs151
-rw-r--r--src/Language.hs6
-rw-r--r--src/Language/AST.hs5
-rw-r--r--src/Simplify.hs309
-rw-r--r--test/Main.hs29
22 files changed, 726 insertions, 625 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index b8510d2..b7270e4 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -19,12 +19,14 @@ library
AST.Env
AST.Pretty
AST.Sparse
+ AST.Sparse.Types
AST.SplitLets
AST.Types
AST.UnMonoid
AST.Weaken
AST.Weaken.Auto
CHAD
+ CHAD.Accum
CHAD.EnvDescr
CHAD.Top
CHAD.Types
diff --git a/src/AST.hs b/src/AST.hs
index c24e3e7..5aab4fc 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -25,6 +25,7 @@ import Data.Kind (Type)
import Array
import AST.Accum
+import AST.Sparse.Types
import AST.Types
import AST.Weaken
import CHAD.Types
@@ -91,11 +92,16 @@ data Expr x env t where
ERecompute :: x t -> Expr x env t -> Expr x env t
-- accumulation effect on monoids
+ -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it
+ -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not
+ -- need to create any zeros.
EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
- EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil
+ -- The 'Sparse' here is eliminated to dense by UnMonoid.
+ EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
+ EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t
EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t
@@ -218,9 +224,10 @@ typeOf = \case
ERecompute _ e -> typeOf e
EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum _ _ _ _ _ _ -> STNil
+ EAccum _ _ _ _ _ _ _ -> STNil
EZero _ t _ -> fromSMTy t
+ EDeepZero _ t _ -> fromSMTy t
EPlus _ t _ _ -> fromSMTy t
EOneHot _ t _ _ _ -> fromSMTy t
@@ -261,8 +268,9 @@ extOf = \case
ECustom x _ _ _ _ _ _ _ _ -> x
ERecompute x _ -> x
EWith x _ _ _ -> x
- EAccum x _ _ _ _ _ -> x
+ EAccum x _ _ _ _ _ _ -> x
EZero x _ _ -> x
+ EDeepZero x _ _ -> x
EPlus x _ _ _ -> x
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
@@ -306,8 +314,9 @@ travExt f = \case
ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2
ERecompute x e -> ERecompute <$> f x <*> travExt f e
EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2
- EAccum x t p e1 e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> travExt f e2 <*> travExt f e3
+ EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3
EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e
+ EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e
EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b
EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b
EError x t s -> EError <$> f x <*> pure t <*> pure s
@@ -364,8 +373,9 @@ subst' f w = \case
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
ERecompute x e -> ERecompute x (subst' f w e)
EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3)
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3)
EZero x t e -> EZero x t (subst' f w e)
+ EDeepZero x t e -> EDeepZero x t (subst' f w e)
EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
EError x t s -> EError x t s
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 158b4d9..619c2b1 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -1,5 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
@@ -33,35 +34,38 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
-type data StillDense = AI_D | AI_S
-data SStillDense dense where
- SAI_D :: SStillDense AI_D
- SAI_S :: SStillDense AI_S
-deriving instance Show (SStillDense dense)
+type data AIDense = AID | AIS
-type family AcIdx dense p t where
- AcIdx dense APHere t = TNil
- AcIdx AI_D (APFst p) (TPair a b) = AcIdx AI_D p a
- AcIdx AI_D (APSnd p) (TPair a b) = AcIdx AI_D p b
- AcIdx AI_S (APFst p) (TPair a b) = TPair (AcIdx AI_S p a) (ZeroInfo b)
- AcIdx AI_S (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AI_S p b)
- AcIdx dense (APLeft p) (TLEither a b) = AcIdx AI_S p a
- AcIdx dense (APRight p) (TLEither a b) = AcIdx AI_S p b
- AcIdx dense (APJust p) (TMaybe a) = AcIdx AI_S p a
- AcIdx AI_D (APArrIdx p) (TArr n a) = TPair (Tup (Replicate n TIx)) (AcIdx AI_D p a)
- AcIdx AI_S (APArrIdx p) (TArr n a) =
- -- ((index, shapes info), recursive info)
+data SAIDense d where
+ SAID :: SAIDense AID
+ SAIS :: SAIDense AIS
+deriving instance Show (SAIDense d)
+
+type family AcIdx d p t where
+ AcIdx d APHere t = TNil
+ AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a
+ AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b
+ AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b)
+ AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b)
+ AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a
+ AcIdx d (APRight p) (TLEither a b) = AcIdx d p b
+ AcIdx d (APJust p) (TMaybe a) = AcIdx d p a
+ AcIdx AID (APArrIdx p) (TArr n a) =
+ -- (index, recursive info)
+ TPair (Tup (Replicate n TIx)) (AcIdx AID p a)
+ AcIdx AIS (APArrIdx p) (TArr n a) =
+ -- ((index, shape info), recursive info)
TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
- (AcIdx AI_S p a)
- -- AcIdx AI_D (APArrSlice m) (TArr n a) =
+ (AcIdx AIS p a)
+ -- AcIdx AID (APArrSlice m) (TArr n a) =
-- -- index
-- Tup (Replicate m TIx)
- -- AcIdx AI_S (APArrSlice m) (TArr n a) =
+ -- AcIdx AIS (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
-type AcIdxD p t = AcIdx AI_D p t
-type AcIdxS p t = AcIdx AI_S p t
+type AcIdxD p t = AcIdx AID p t
+type AcIdxS p t = AcIdx AIS p t
acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
@@ -88,6 +92,16 @@ tZeroInfo (SMTMaybe _) = STNil
tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
tZeroInfo (SMTScal _) = STNil
+-- | Info needed to create a zero-valued deep accumulator for a monoid type.
+-- Should be constructable from a D1.
+type family DeepZeroInfo t where
+ DeepZeroInfo TNil = TNil
+ DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
+ DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
+ DeepZeroInfo (TScal t) = TNil
+
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
-- type family AccumInfo t where
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 03a36f6..05be524 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -134,7 +134,7 @@ occCountGeneral onehot unpush alter many = go WId
ECustom _ _ _ _ _ _ _ a b -> re a <> re b
ERecompute _ e -> re e
EWith _ _ a b -> re a <> re1 b
- EAccum _ _ _ a b e -> re a <> re b <> re e
+ EAccum _ _ _ a _ b e -> re a <> re b <> re e
EZero _ _ e -> re e
EPlus _ _ a b -> re a <> re b
EOneHot _ _ _ a b -> re a <> re b
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 41da656..fef9686 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -25,6 +25,7 @@ import System.IO.Unsafe (unsafePerformIO)
import AST
import AST.Count
+import AST.Sparse.Types
import CHAD.Types
import Data
@@ -304,18 +305,24 @@ ppExpr' d val expr = case expr of
<> hardline <> e2')
(ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
- EAccum _ t prj e1 e2 e3 -> do
+ EAccum _ t prj e1 sp e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), e1', e2', e3']
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t)))
+ [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3']
EZero _ t e1 -> do
e1' <- ppExpr' 11 val e1
return $ ppParen (d > 0) $
annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+ EDeepZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+
EPlus _ t a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
@@ -368,6 +375,16 @@ ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"
ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj
ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)
+ppSparse :: SMTy a -> Sparse a b -> String
+ppSparse t sp | Just Refl <- isDense t sp = "D"
+ppSparse _ SpAbsent = "A"
+ppSparse t (SpSparse s) = "S" ++ ppSparse t s
+ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
+ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
+ppSparse (SMTScal _) SpScal = "."
+
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
index 369d395..f0a1f2a 100644
--- a/src/AST/Sparse.hs
+++ b/src/AST/Sparse.hs
@@ -1,116 +1,19 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
-module AST.Sparse where
+module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where
-import Data.Kind (Constraint, Type)
import Data.Type.Equality
import AST
+import AST.Sparse.Types
+import Data (SBool(..))
-data Sparse t t' where
- SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
- SpAbsent :: Sparse t TNil
-
- SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
- SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
- SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
- SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
- SpScal :: Sparse (TScal t) (TScal t)
-deriving instance Show (Sparse t t')
-
-class ApplySparse f where
- applySparse :: Sparse t t' -> f t -> f t'
-
-instance ApplySparse STy where
- applySparse (SpSparse s) t = STMaybe (applySparse s t)
- applySparse SpAbsent _ = STNil
- applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
- applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
- applySparse SpScal t = t
-
-instance ApplySparse SMTy where
- applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
- applySparse SpAbsent _ = SMTNil
- applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
- applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
- applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
- applySparse SpScal t = t
-
-
-class IsSubType s where
- type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
- subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
- subtTrans :: s a b -> s b c -> s a c
- subtFull :: IsSubTypeSubject s f => f t -> s t t
-
-instance IsSubType (:~:) where
- type IsSubTypeSubject (:~:) f = ()
- subtApply = gcastWith
- subtTrans = trans
- subtFull _ = Refl
-
-instance IsSubType Sparse where
- type IsSubTypeSubject Sparse f = f ~ SMTy
- subtApply = applySparse
-
- subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
- subtTrans _ SpAbsent = SpAbsent
- subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
- subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
- subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
- subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
- subtTrans SpScal SpScal = SpScal
-
- subtFull = spDense
-
-spDense :: SMTy t -> Sparse t t
-spDense SMTNil = SpAbsent
-spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
-spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
-spDense (SMTMaybe t) = SpMaybe (spDense t)
-spDense (SMTArr _ t) = SpArr (spDense t)
-spDense (SMTScal _) = SpScal
-
-isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
-isDense SMTNil SpAbsent = Just Refl
-isDense _ SpSparse{} = Nothing
-isDense _ SpAbsent = Nothing
-isDense (SMTPair t1 t2) (SpPair s1 s2)
- | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
- | otherwise = Nothing
-isDense (SMTLEither t1 t2) (SpLEither s1 s2)
- | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
- | otherwise = Nothing
-isDense (SMTMaybe t) (SpMaybe s)
- | Just Refl <- isDense t s = Just Refl
- | otherwise = Nothing
-isDense (SMTArr _ t) (SpArr s)
- | Just Refl <- isDense t s = Just Refl
- | otherwise = Nothing
-isDense (SMTScal _) SpScal = Just Refl
-
-isAbsent :: Sparse t t' -> Bool
-isAbsent (SpSparse s) = isAbsent s
-isAbsent SpAbsent = True
-isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
-isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
-isAbsent (SpMaybe s) = isAbsent s
-isAbsent (SpArr s) = isAbsent s
-isAbsent SpScal = False
-
sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
sparsePlus _ SpAbsent _ _ = ENil ext
sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2
@@ -143,11 +46,6 @@ sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS I
sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
-data SBool b where
- SF :: SBool False
- ST :: SBool True
-deriving instance Show (SBool b)
-
data Injection sp a b where
-- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
-- 'sparsePlusS' can provide injections even if the caller doesn't require
diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs
new file mode 100644
index 0000000..10cac4e
--- /dev/null
+++ b/src/AST/Sparse/Types.hs
@@ -0,0 +1,107 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module AST.Sparse.Types where
+
+import AST.Types
+
+import Data.Kind (Type, Constraint)
+import Data.Type.Equality
+
+
+data Sparse t t' where
+ SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
+ SpAbsent :: Sparse t TNil
+
+ SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
+ SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
+ SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
+ SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpScal :: Sparse (TScal t) (TScal t)
+deriving instance Show (Sparse t t')
+
+class ApplySparse f where
+ applySparse :: Sparse t t' -> f t -> f t'
+
+instance ApplySparse STy where
+ applySparse (SpSparse s) t = STMaybe (applySparse s t)
+ applySparse SpAbsent _ = STNil
+ applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
+ applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse SpScal t = t
+
+instance ApplySparse SMTy where
+ applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
+ applySparse SpAbsent _ = SMTNil
+ applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
+ applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse SpScal t = t
+
+
+class IsSubType s where
+ type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
+ subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
+ subtTrans :: s a b -> s b c -> s a c
+ subtFull :: IsSubTypeSubject s f => f t -> s t t
+
+instance IsSubType (:~:) where
+ type IsSubTypeSubject (:~:) f = ()
+ subtApply = gcastWith
+ subtTrans = trans
+ subtFull _ = Refl
+
+instance IsSubType Sparse where
+ type IsSubTypeSubject Sparse f = f ~ SMTy
+ subtApply = applySparse
+
+ subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
+ subtTrans _ SpAbsent = SpAbsent
+ subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
+ subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
+ subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
+ subtTrans SpScal SpScal = SpScal
+
+ subtFull = spDense
+
+spDense :: SMTy t -> Sparse t t
+spDense SMTNil = SpAbsent
+spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
+spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
+spDense (SMTMaybe t) = SpMaybe (spDense t)
+spDense (SMTArr _ t) = SpArr (spDense t)
+spDense (SMTScal _) = SpScal
+
+isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
+isDense SMTNil SpAbsent = Just Refl
+isDense _ SpSparse{} = Nothing
+isDense _ SpAbsent = Nothing
+isDense (SMTPair t1 t2) (SpPair s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTLEither t1 t2) (SpLEither s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTMaybe t) (SpMaybe s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTArr _ t) (SpArr s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTScal _) SpScal = Just Refl
+
+isAbsent :: Sparse t t' -> Bool
+isAbsent (SpSparse s) = isAbsent s
+isAbsent SpAbsent = True
+isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpMaybe s) = isAbsent s
+isAbsent (SpArr s) = isAbsent s
+isAbsent SpScal = False
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 3c353d4..2dad17a 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -63,7 +63,7 @@ splitLets' = \sub -> \case
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
ERecompute x e -> ERecompute x (splitLets' sub e)
EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
- EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3)
EZero x t ezi -> EZero x t (splitLets' sub ezi)
EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 389dd5a..d498aaa 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -1,18 +1,22 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid (unMonoid, zero, plus) where
+module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
import AST
+import AST.Sparse.Types
import Data
--- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them
--- into their concrete implementations.
+-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
+-- expanding them into their concrete implementations. Also ensure that
+-- 'EAccum' has a dense sparsity.
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero _ t e -> zero t e
+ EDeepZero _ t e -> deepZero t e
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
@@ -49,7 +53,10 @@ unMonoid = \case
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
- EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
+ EAccum _ t p eidx sp eval eacc ->
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
EError _ t s -> EError ext t s
zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
@@ -66,6 +73,27 @@ zero (SMTScal t) _ = case t of
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
+deepZero SMTNil _ = ENil ext
+deepZero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+deepZero (SMTLEither t1 t2) e =
+ elcase e
+ (ELNil ext (fromSMTy t1) (fromSMTy t2))
+ (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
+ (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
+deepZero (SMTMaybe t) e =
+ emaybe e
+ (ENothing ext (fromSMTy t))
+ (EJust ext (deepZero t (evar IZ)))
+deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
+deepZero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
+ STF32 -> EConst ext STF32 0.0
+ STF64 -> EConst ext STF64 0.0
+
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
plus SMTNil _ _ = ENil ext
plus (SMTPair t1 t2) a b =
@@ -143,3 +171,78 @@ onehot typ topprj idx arg = case (typ, topprj) of
(onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
(ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
zero t1 (EVar ext (tZeroInfo t1) IZ))
+
+accumulateSparse
+ :: SMTy t -> Sparse t t' -> Ex env t'
+ -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
+ -> Ex env TNil
+accumulateSparse topty topsp arg accum = case (topty, topsp) of
+ (_, s) | Just Refl <- isDense topty s ->
+ accum WId SAPHere (ENil ext) arg
+ (SMTScal _, SpScal) ->
+ accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
+ (_, SpSparse s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w)))
+ (_, SpAbsent) ->
+ ENil ext
+ (SMTPair t1 t2, SpPair s1 s2) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
+ accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ elcase arg
+ (ENil ext)
+ (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
+ (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
+ (SMTMaybe t, SpMaybe s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
+ (SMTArr n t, SpArr s) ->
+ let tn = tTup (sreplicate n tIx) in
+ elet arg $
+ elet (EBuild ext n (EShape ext (evar IZ)) $
+ accumulateSparse t s
+ (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
+ (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
+ ENil ext
+
+acPrjCompose
+ :: SAIDense dense
+ -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b)
+ -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r
+acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2
+acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPFst p') idx'
+acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPSnd p') idx'
+acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ)))
+acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx')
+acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPLeft p') idx'
+acPrjCompose d (SAPRight p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPRight p') idx'
+acPrjCompose d (SAPJust p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPJust p') idx'
+acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
+acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs
index 4501c32..2fd321d 100644
--- a/src/Analysis/Identity.hs
+++ b/src/Analysis/Identity.hs
@@ -307,11 +307,11 @@ idana env expr = case expr of
let res = VIPair v2 x2
pure (res, EWith res t e1' e2')
- EAccum _ t prj e1 e2 e3 -> do
+ EAccum _ t prj e1 sp e2 e3 -> do
(_, e1') <- idana env e1
(_, e2') <- idana env e2
(_, e3') <- idana env e3
- pure (VINil, EAccum VINil t prj e1' e2' e3')
+ pure (VINil, EAccum VINil t prj e1' sp e2' e3')
EZero _ t e1 -> do
-- Approximate the result of EZero to be independent from the zero info
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 3dedec3..621aa3e 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -34,7 +34,6 @@ module CHAD (
import Data.Functor.Const
import Data.Some
-import Data.Type.Bool (If)
import Data.Type.Equality (type (==), testEquality)
import GHC.Stack (HasCallStack)
@@ -45,6 +44,7 @@ import AST.Count
import AST.Env
import AST.Sparse
import AST.Weaken.Auto
+import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
@@ -348,28 +348,8 @@ opt2UnSparse = go . opt2
go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary"
------------------------------------- MONOIDS -----------------------------------
-
-d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
-d2zeroInfo STNil _ = ENil ext
-d2zeroInfo (STPair a b) e =
- eunPair e $ \_ e1 e2 ->
- EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
-d2zeroInfo STEither{} _ = ENil ext
-d2zeroInfo STLEither{} _ = ENil ext
-d2zeroInfo STMaybe{} _ = ENil ext
-d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
-d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
-d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
-
-
----------------------------------- SPARSITY -----------------------------------
-subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
-subenvD1E SETop = SETop
-subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
-subenvD1E (SENo sub) = SENo (subenvD1E sub)
-
expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a)
expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e
expandSparse t (SpSparse sp) epr e =
@@ -499,23 +479,6 @@ assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
--------------------------------- ACCUMULATORS ---------------------------------
-makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
-makeAccumulators _ SNil e = e
-makeAccumulators w (t `SCons` envpro) e =
- makeAccumulators (WPop w) envpro $
- EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
-
-uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
-uninvertTup SNil _ e = EPair ext e (ENil ext)
-uninvertTup (t `SCons` list) tcore e =
- ELet ext (uninvertTup list (STPair tcore t) e) $
- let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
- in EPair ext
- (EFst ext (EFst ext (EVar ext recT IZ)))
- (EPair ext
- (ESnd ext (EVar ext recT IZ))
- (ESnd ext (EFst ext (EVar ext recT IZ))))
-
fromArrayValId :: Maybe (ValId t) -> Maybe Int
fromArrayValId (Just (VIArr i _)) = Just i
fromArrayValId _ = Nothing
@@ -788,8 +751,7 @@ drev des accumMap sd = \case
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (d2e (select SMerge des)))
(let ty = applySparse sd (d2M t)
- in accumulateSparse SAI_D (d2M t) sd (EVar ext (fromSMTy ty) IZ) $ \w prj val idx ->
- EAccum ext (d2M t) prj idx val (EVar ext (STAccum (d2M t)) (w @> IS accI)))
+ in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI)))
Idx2Me tupI ->
Ret BTop
@@ -1275,6 +1237,7 @@ drev des accumMap sd = \case
EWith{} -> err_accum
EZero{} -> err_monoid
+ EDeepZero{} -> err_monoid
EPlus{} -> err_monoid
EOneHot{} -> err_monoid
@@ -1392,76 +1355,6 @@ drevScoped des accumMap argty argsto argids sd expr = case argsto of
-- TODO: proper primal-only transform that doesn't depend on D1 = Id
drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t)
drevPrimal des e
- | Refl <- chadD1Id (typeOf e)
- , Refl <- chadD1EId (descrList des)
+ | Refl <- d1Identity (typeOf e)
+ , Refl <- d1eIdentity (descrList des)
= mapExt (const ext) e
- where
- chadD1Id :: STy a -> D1 a :~: a
- chadD1Id STNil = Refl
- chadD1Id (STPair a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
- chadD1Id (STEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
- chadD1Id (STLEither a b) | Refl <- chadD1Id a, Refl <- chadD1Id b = Refl
- chadD1Id (STMaybe a) | Refl <- chadD1Id a = Refl
- chadD1Id (STArr _ a) | Refl <- chadD1Id a = Refl
- chadD1Id (STScal _) = Refl
- chadD1Id STAccum{} = error "accumulators not allowed in source program"
-
- chadD1EId :: SList STy l -> D1E l :~: l
- chadD1EId SNil = Refl
- chadD1EId (SCons t l) | Refl <- chadD1Id t, Refl <- chadD1EId l = Refl
-
-accumulateSparse
- :: SStillDense dense -> SMTy t -> Sparse t t' -> Ex env t'
- -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' b -> Ex env' (AcIdx dense p t) -> Ex env' TNil)
- -> Ex env TNil
-accumulateSparse dense topty topsp arg accum = case (dense, topty, topsp) of
- (_, _, s) | Just Refl <- isDense topty s ->
- accum WId SAPHere arg (ENil ext)
- (_, SMTScal _, SpScal) ->
- accum WId SAPHere arg (ENil ext) -- should be handled by isDense already, but meh
- (_, _, SpSparse s) ->
- emaybe arg
- (ENil ext)
- (accumulateSparse dense topty s (evar IZ) (\w -> accum (WPop w)))
- (_, _, SpAbsent) ->
- ENil ext
- (SAI_D, SMTPair t1 t2, SpPair s1 s2) ->
- eunPair arg $ \w1 e1 e2 ->
- elet (accumulateSparse dense t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
- accumulateSparse dense t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
- (SAI_S, SMTPair{}, SpPair{}) ->
- error "TODO: accumulating into pair inside coproduct unimplemented"
- -- There are two different ways this can be accomplished:
- -- 1. Ensure we have the requisite ZeroInfo here. This means that an
- -- accum-mode variable reference will (if its incoming cotangent is
- -- sparse enough) need to store some ZeroInfo fragments computed from
- -- the primal (not necessarily the entire primal). Doing this properly,
- -- i.e. not just storing a full D1 but only the required ZeroInfo
- -- fragments, is possible and not too inefficient but a bit of
- -- engineering again.
- -- 2. When creating an accumulator, don't initialise it with a generic
- -- EZero based on a ZeroInfo, but instead a special "deep zero" based on
- -- probably a full D1. This deep zero also initialises Left/Right/Just
- -- modelled after the primal. With this, an accumulation needs no zero
- -- info whatsoever (!) under the assumption that it receives a cotangent
- -- that is compatible with the primal it is propagated back to.
- (_, SMTLEither t1 t2, SpLEither s1 s2) ->
- elcase arg
- (ENil ext)
- (accumulateSparse SAI_S t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
- (accumulateSparse SAI_S t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
- (_, SMTMaybe t, SpMaybe s) ->
- emaybe arg
- (ENil ext)
- (accumulateSparse SAI_S t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
- (SAI_D, SMTArr n t, SpArr s) ->
- let tn = tTup (sreplicate n tIx) in
- elet arg $
- elet (EBuild ext n (EShape ext (evar IZ)) $
- accumulateSparse dense t s
- (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
- (\w prj val idx -> accum (WPop (WPop w)) (SAPArrIdx prj) val (EPair ext (EVar ext tn (w @> IZ)) idx))) $
- ENil ext
- (SAI_S, SMTArr{}, SpArr{}) ->
- error "TODO: accumulating into array inside coproduct unimplemented"
- -- See the pair case above, same reasoning
diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs
new file mode 100644
index 0000000..8c7794a
--- /dev/null
+++ b/src/CHAD/Accum.hs
@@ -0,0 +1,45 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE TypeOperators #-}
+-- | TODO this module is a grab-bag of random utility functions that are shared
+-- between CHAD and CHAD.Top.
+module CHAD.Accum where
+
+import AST
+import CHAD.Types
+import Data
+import AST.Env
+
+
+d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t))
+d2zeroInfo STNil _ = ENil ext
+d2zeroInfo (STPair a b) e =
+ eunPair e $ \_ e1 e2 ->
+ EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2)
+d2zeroInfo STEither{} _ = ENil ext
+d2zeroInfo STLEither{} _ = ENil ext
+d2zeroInfo STMaybe{} _ = ENil ext
+d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e
+d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext
+d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program"
+
+makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
+makeAccumulators _ SNil e = e
+makeAccumulators w (t `SCons` envpro) e =
+ makeAccumulators (WPop w) envpro $
+ EWith ext (d2M t) (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e
+
+uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
+uninvertTup SNil _ e = EPair ext e (ENil ext)
+uninvertTup (t `SCons` list) tcore e =
+ ELet ext (uninvertTup list (STPair tcore t) e) $
+ let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
+ in EPair ext
+ (EFst ext (EFst ext (EVar ext recT IZ)))
+ (EPair ext
+ (ESnd ext (EVar ext recT IZ))
+ (ESnd ext (EFst ext (EVar ext recT IZ))))
+
+subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env')
+subenvD1E SETop = SETop
+subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub)
+subenvD1E (SENo sub) = SENo (subenvD1E sub)
diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs
index 130174a..484779e 100644
--- a/src/CHAD/Top.hs
+++ b/src/CHAD/Top.hs
@@ -12,9 +12,12 @@ module CHAD.Top where
import Analysis.Identity
import AST
+import AST.Env
+import AST.Sparse
import AST.SplitLets
import AST.Weaken.Auto
import CHAD
+import CHAD.Accum
import CHAD.EnvDescr
import CHAD.Types
import Data
@@ -43,36 +46,22 @@ accumDescr (t `SCons` env) k = accumDescr env $ \des ->
if hasArrays t then k (des `DPush` (t, Nothing, SAccum))
else k (des `DPush` (t, Nothing, SMerge))
-d1Identity :: STy t -> D1 t :~: t
-d1Identity = \case
- STNil -> Refl
- STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
- STMaybe t | Refl <- d1Identity t -> Refl
- STArr _ t | Refl <- d1Identity t -> Refl
- STScal _ -> Refl
- STAccum{} -> error "Accumulators not allowed in input program"
-
-d1eIdentity :: SList STy env -> D1E env :~: env
-d1eIdentity SNil = Refl
-d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
-
reassembleD2E :: Descr env sto
+ -> D1E env :> env'
-> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge"))))
-> Ex env' (Tup (D2E env))
-reassembleD2E DTop _ = ENil ext
-reassembleD2E (des `DPush` (_, _, SAccum)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EFst ext (EVar ext (typeOf e) IZ)))
- (ESnd ext (EVar ext (typeOf e) IZ))))
- (ESnd ext (EFst ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (_, _, SMerge)) e =
- ELet ext e $
- EPair ext (reassembleD2E des (EPair ext (EFst ext (EVar ext (typeOf e) IZ))
- (EFst ext (ESnd ext (EVar ext (typeOf e) IZ)))))
- (ESnd ext (ESnd ext (EVar ext (typeOf e) IZ)))
-reassembleD2E (des `DPush` (t, _, SDiscr)) e = EPair ext (reassembleD2E des e) (ezeroD2 t)
+reassembleD2E DTop _ _ = ENil ext
+reassembleD2E (des `DPush` (_, _, SAccum)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e1 $ \w2 e11 e12 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12
+reassembleD2E (des `DPush` (_, _, SMerge)) w e =
+ eunPair e $ \w1 e1 e2 ->
+ eunPair e2 $ \w2 e21 e22 ->
+ EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22
+reassembleD2E (des `DPush` (t, _, SDiscr)) w e =
+ EPair ext (reassembleD2E des (WPop w) e)
+ (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ))))
chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env)))
chad config env (term :: Ex env t)
@@ -82,21 +71,22 @@ chad config env (term :: Ex env t)
let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr)))
tvar = STPair t1 (tTup (d2e (select SAccum descr)))
in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $
- makeAccumulators (select SAccum descr) $
+ makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $
weakenExpr (autoWeak (#d (auto1 @(D2 t))
&. #acenv (d2ace (select SAccum descr))
&. #tl (d1e env))
(#d :++: #acenv :++: #tl)
(#acenv :++: #d :++: #tl)) $
- freezeRet descr (drev descr VarMap.empty term')) $
+ freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $
EPair ext (EFst ext (EFst ext (EVar ext tvar IZ)))
- (reassembleD2E descr (EPair ext (ESnd ext (EVar ext tvar IZ))
- (ESnd ext (EFst ext (EVar ext tvar IZ)))))
+ (reassembleD2E descr (WSink .> WSink)
+ (EPair ext (ESnd ext (EVar ext tvar IZ))
+ (ESnd ext (EFst ext (EVar ext tvar IZ)))))
| False <- chcArgArrayAccum config
, Refl <- mergeEnvNoAccum env
, Refl <- mergeEnvOnlyMerge env
- = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty term')
+ = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term')
where
term' = identityAnalysis env (splitLets term)
diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs
index 83f013d..8b3a8db 100644
--- a/src/CHAD/Types.hs
+++ b/src/CHAD/Types.hs
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module CHAD.Types where
@@ -124,3 +125,18 @@ lemZeroInfoScal STI64 = Refl
lemZeroInfoScal STF32 = Refl
lemZeroInfoScal STF64 = Refl
lemZeroInfoScal STBool = Refl
+
+d1Identity :: STy t -> D1 t :~: t
+d1Identity = \case
+ STNil -> Refl
+ STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl
+ STMaybe t | Refl <- d1Identity t -> Refl
+ STArr _ t | Refl <- d1Identity t -> Refl
+ STScal _ -> Refl
+ STAccum{} -> error "Accumulators not allowed in input program"
+
+d1eIdentity :: SList STy env -> D1E env :~: env
+d1eIdentity SNil = Refl
+d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl
diff --git a/src/Compile.hs b/src/Compile.hs
index 722b432..a5c4fb7 100644
--- a/src/Compile.hs
+++ b/src/Compile.hs
@@ -45,6 +45,7 @@ import qualified Prelude
import Array
import AST
import AST.Pretty (ppSTy, ppExpr)
+import AST.Sparse.Types (isDense)
import Compile.Exec
import Data
import Interpreter.Rep
@@ -1002,95 +1003,7 @@ compile' env = \case
rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t))
return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)]
- EAccum _ t prj eidx eval eacc -> do
- let -- Assumes v is a value of type (SMTArr n t1), and initialises it to a
- -- full zero array with the given zero info (for the type SMTArr n t1).
- initZeroArray :: SNat n -> SMTy a -> String -> String -> CompM ()
- initZeroArray n t1 v vzi = do
- shszname <- genName' "inacshsz"
- emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n vzi)
- newarrName <- allocArray "initZero" Calloc "inacarr" n (fromSMTy t1) (Just (CELit shszname)) (compileArrShapeComponents n vzi)
- emit $ SAsg v (CELit newarrName)
- forM_ (initZeroFromMemset t1) $ \f1 -> do
- ivar <- genName' "i"
- ((), initStmts) <- scope $ f1 (v++"["++ivar++"]") (vzi++"["++ivar++"]")
- emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) initStmts
-
- -- If something needs to be done to properly initialise this type to
- -- zero after memory has already been initialised to all-zero bytes,
- -- returns an action that does so.
- -- initZeroFromMemset (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroFromMemset :: SMTy a -> Maybe (String -> String -> CompM ())
- initZeroFromMemset SMTNil = Nothing
- initZeroFromMemset (SMTPair t1 t2) =
- case (initZeroFromMemset t1, initZeroFromMemset t2) of
- (Nothing, Nothing) -> Nothing
- (mf1, mf2) -> Just $ \v vzi -> do
- forM_ mf1 $ \f1 -> f1 (v++".a") (vzi++".a")
- forM_ mf2 $ \f2 -> f2 (v++".b") (vzi++".b")
- initZeroFromMemset SMTLEither{} = Nothing
- initZeroFromMemset SMTMaybe{} = Nothing
- initZeroFromMemset (SMTArr n t1) = Just $ \v vzi -> initZeroArray n t1 v vzi
- initZeroFromMemset SMTScal{} = Nothing
-
- let -- initZeroZI (type) (variable of that type to initialise to zero) (variable to a ZeroInfo for the type)
- initZeroZI :: SMTy a -> String -> String -> CompM ()
- initZeroZI SMTNil _ _ = return ()
- initZeroZI (SMTPair t1 t2) v vzi = do
- initZeroZI t1 (v++".a") (vzi++".a")
- initZeroZI t2 (v++".b") (vzi++".b")
- initZeroZI SMTLEither{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI SMTMaybe{} v _ = emit $ SAsg (v++".tag") (CELit "0")
- initZeroZI (SMTArr n t1) v vzi = initZeroArray n t1 v vzi
- initZeroZI (SMTScal sty) v _ = case sty of
- STI32 -> emit $ SAsg v (CELit "0")
- STI64 -> emit $ SAsg v (CELit "0l")
- STF32 -> emit $ SAsg v (CELit "0.0f")
- STF64 -> emit $ SAsg v (CELit "0.0")
-
- let -- Initialise an uninitialised accumulation value, potentially already
- -- with the addend, potentially to zero depending on the nature of the
- -- projection.
- -- 1. If the projection indexes only through dense monoids before
- -- reaching SAPHere, the thing cannot be initialised to zero with
- -- only an AcIdx; it would need to model a zero after the addend,
- -- which is stupid and redundant. In this case, we return Left:
- -- (accumulation value) (AcIdx value) (addend value).
- -- The addend is copied, not consumed. (We can't reliably _always_
- -- consume it, so it's not worth trying to do it sometimes.)
- -- 2. Otherwise, a sparse monoid is found along the way, and we can
- -- initalise the dense prefix of the path to zero by setting the
- -- indexed-through sparse value to a sparse zero. Afterwards, the
- -- main recursion can proceed further. In this case, we return
- -- Right: (accumulation value) (AcIdx value)
- -- initZeroChunk (type) (projection) (variable of that type to initialise to zero) (variable to an AcIdx for the type)
- initZeroChunk :: SMTy a -> SAcPrj p a b
- -> Either (String -> String -> String -> CompM ()) -- dense initialisation with addend
- (String -> String -> CompM ()) -- zero initialisation of sparse chunk
- initZeroChunk izaitoptyp izaitopprj = case (izaitoptyp, izaitopprj) of
- -- reached target before the first sparse constructor
- (t1 , SAPHere ) -> Left $ \v _ addend -> do
- incrementVarAlways "initZeroSparse" Increment (fromSMTy t1) addend
- emit $ SAsg v (CELit addend)
- -- sparse types
- (SMTMaybe{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- (SMTLEither{} , _ ) -> Right $ \v _ -> emit $ SAsg (v++".tag") (CELit "0")
- -- dense types
- (SMTPair t1 t2, SAPFst prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- f (v++".a") (i++".a")
- initZeroZI t2 (v++".b") (i++".b")
- (SMTPair t1 t2, SAPSnd prj') -> applySkeleton (initZeroChunk t2 prj') $ \f v i -> do
- initZeroZI t1 (v++".a") (i++".a")
- f (v++".b") (i++".b")
- (SMTArr n t1, SAPArrIdx prj') -> applySkeleton (initZeroChunk t1 prj') $ \f v i -> do
- initZeroArray n t1 v (i++".a.b")
- linidxvar <- genName' "li"
- emit $ SVarDecl False (repSTy tIx) linidxvar (toLinearIdx n v (i++".a.a"))
- f (v++".buf->xs["++linidxvar++"]") (i++".b")
- where
- applySkeleton (Left densef) skel = Left $ \v i addend -> skel (\v' i' -> densef v' i' addend) v i
- applySkeleton (Right sparsef) skel = Right $ \v i -> skel (\v' i' -> sparsef v' i') v i
-
+ EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do
let -- Add a value (s) into an existing accumulation value (d). If a sparse
-- component of d is encountered, s is copied there.
add :: SMTy a -> String -> String -> CompM ()
@@ -1160,67 +1073,55 @@ compile' env = \case
accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM ()
accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend
- accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") (i++".a") addend
- accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") (i++".b") addend
-
- accumRef (SMTLEither ta tb) prj0 v i addend = do
- let chunkres = case prj0 of SAPLeft prj' -> initZeroChunk ta prj'
- SAPRight prj' -> initZeroChunk tb prj'
- subv = v ++ (case prj0 of SAPLeft{} -> ".l"; SAPRight{} -> ".r")
- tagval = case prj0 of SAPLeft{} -> "1"
- SAPRight{} -> "2"
- ((), stmtsAdd) <- scope $ case prj0 of SAPLeft prj' -> accumRef ta prj' subv i addend
- SAPRight prj' -> accumRef tb prj' subv i addend
- case chunkres of
- Left densef -> do
- ((), stmtsSet) <- scope $ densef subv i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsSet)
- stmtsAdd -- TODO: emit check for consistency of tags?
- Right sparsef -> do
- ((), stmtsInit) <- scope $ sparsef subv i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit tagval)) <> stmtsInit) mempty
- forM_ stmtsAdd emit
+ accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend
+ accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend
+
+ accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef ta prj' (v++".l") i addend
+ accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tb prj' (v++".r") i addend
accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do
- case initZeroChunk tj prj' of
- Left densef -> do
- ((), stmtsSet1) <- scope $ densef (v++".j") i addend
- ((), stmtsAdd1) <- scope $ accumRef tj prj' (v++".j") i addend
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsSet1)
- stmtsAdd1
- Right sparsef -> do
- ((), stmtsInit1) <- scope $ sparsef (v++".j") i
- emit $ SIf (CEBinop (CELit (v++".tag")) "==" (CELit "0"))
- (pure (SAsg (v++".tag") (CELit "1")) <> stmtsInit1) mempty
- accumRef tj prj' (v++".j") i addend
+ when emitChecks $ do
+ emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1"))
+ (pure $ SVerbatim $
+ "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++
+ "return false;")
+ mempty
+ accumRef tj prj' (v++".j") i addend
accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do
when emitChecks $ do
let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]"
- forM_ (zip3 [0::Int ..]
- (indexTupleComponents n (i++".a.a"))
- (compileArrShapeComponents n (i++".a.b"))) $ \(j, ixcomp, shcomp) -> do
+ forM_ (zip [0::Int ..]
+ (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do
let a .||. b = CEBinop a "||" b
emit $ SIf (CEBinop ixcomp "<" (CELit "0")
.||.
- CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
- .||.
- CEBinop shcomp "!=" (CELit (v ++ ".buf->sh[" ++ show j ++ "]")))
+ CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".buf->sh[" ++ show j ++ "]"))))
(pure $ SVerbatim $
"fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++
- "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=" ++ shfmt ++ ")\\n\", " ++
+ "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++
v ++ ".buf" ++
concat [", " ++ v ++ ".buf->sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a.a")] ++
- concat [", " ++ printCExpr 2 comp "" | comp <- compileArrShapeComponents n (i++".a.b")] ++
+ concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++
"); " ++
"return false;")
mempty
- accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a.a")) "]") (i++".b") addend
+ accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend
nameidx <- compileAssign "acidx" env eidx
nameval <- compileAssign "acval" env eval
@@ -1234,6 +1135,9 @@ compile' env = \case
return $ CEStruct (repSTy STNil) []
+ EAccum{} ->
+ error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)"
+
EError _ t s -> do
let padleft len c s' = replicate (len - length s) c ++ s'
escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c]
@@ -1247,6 +1151,7 @@ compile' env = \case
return $ CEStruct name []
EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
+ EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)"
diff --git a/src/Data.hs b/src/Data.hs
index e86aaa6..e6978c8 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -8,12 +8,13 @@
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-module Data (module Data, (:~:)(Refl)) where
+module Data (module Data, (:~:)(Refl), If) where
import Data.Functor.Product
import Data.GADT.Compare
import Data.GADT.Show
import Data.Some
+import Data.Type.Bool (If)
import Data.Type.Equality
import Unsafe.Coerce (unsafeCoerce)
@@ -184,3 +185,8 @@ instance Applicative Bag where
instance Semigroup (Bag t) where (<>) = BTwo
instance Monoid (Bag t) where mempty = BNone
+
+data SBool b where
+ SF :: SBool False
+ ST :: SBool True
+deriving instance Show (SBool b)
diff --git a/src/Example.hs b/src/Example.hs
index d3f6d0d..b320ead 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -162,8 +162,7 @@ neuralGo =
ELet ext (EConst ext STF64 1.0) $
chad defaultConfig knownEnv neural
(primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of
- (primal', (((((), Just (Just dlay1_1'a, Just dlay1_1'b)), Just (Just dlay2_1'a, Just dlay2_1'b)), Just dlay3_1'), Just dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
- _ -> undefined
+ (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1')
(Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0
in trace (ppExpr knownEnv revderiv) $
(primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2))
diff --git a/src/Interpreter.hs b/src/Interpreter.hs
index b3576ce..ffc2929 100644
--- a/src/Interpreter.hs
+++ b/src/Interpreter.hs
@@ -21,6 +21,7 @@ module Interpreter (
) where
import Control.Monad (foldM, join, when, forM_)
+import Data.Bifunctor (bimap)
import Data.Bitraversable (bitraverse)
import Data.Char (isSpace)
import Data.Functor.Identity
@@ -35,6 +36,7 @@ import Debug.Trace
import Array
import AST
import AST.Pretty
+import AST.Sparse.Types
import Data
import Interpreter.Rep
@@ -158,14 +160,17 @@ interpret'Rec env = \case
initval <- interpret' env e1
withAccum t (typeOf e2) initval $ \accum ->
interpret' (V (STAccum t) accum `SCons` env) e2
- EAccum _ t p e1 e2 e3 -> do
+ EAccum _ t p e1 sp e2 e3 -> do
idx <- interpret' env e1
val <- interpret' env e2
accum <- interpret' env e3
- accumAddSparseD t p accum idx val
+ accumAddSparseD t p accum idx sp val
EZero _ t ezi -> do
zi <- interpret' env ezi
return $ zeroM t zi
+ EDeepZero _ t ezi -> do
+ zi <- interpret' env ezi
+ return $ deepZeroM t zi
EPlus _ t a b -> do
a' <- interpret' env a
b' <- interpret' env b
@@ -216,6 +221,19 @@ zeroM typ zi = case typ of
STF32 -> 0.0
STF64 -> 0.0
+deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t
+deepZeroM typ zi = case typ of
+ SMTNil -> ()
+ SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi))
+ SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi
+ SMTMaybe t -> fmap (deepZeroM t) zi
+ SMTArr _ t -> arrayMap (deepZeroM t) zi
+ SMTScal sty -> case sty of
+ STI32 -> 0
+ STI64 -> 0
+ STF32 -> 0.0
+ STF64 -> 0.0
+
addM :: SMTy t -> Rep t -> Rep t -> Rep t
addM typ a b = case typ of
SMTNil -> ()
@@ -256,15 +274,6 @@ withAccum t _ initval f = AcM $ do
val <- readAc t accum
return (out, val)
-newAcZero :: SMTy t -> Rep (ZeroInfo t) -> IO (RepAc t)
-newAcZero typ zi = case typ of
- SMTNil -> return ()
- SMTPair t1 t2 -> bitraverse (newAcZero t1) (newAcZero t2) zi
- SMTLEither{} -> newIORef Nothing
- SMTMaybe _ -> newIORef Nothing
- SMTArr _ t -> arrayMapM (newAcZero t) zi
- SMTScal sty -> numericIsNum sty $ newIORef 0
-
newAcDense :: SMTy a -> Rep a -> IO (RepAc a)
newAcDense typ val = case typ of
SMTNil -> return ()
@@ -274,22 +283,6 @@ newAcDense typ val = case typ of
SMTArr _ t1 -> arrayMapM (newAcDense t1) val
SMTScal _ -> newIORef val
-newAcSparse :: SMTy a -> SAcPrj p a b -> Rep (AcIdxS p a) -> Rep b -> IO (RepAc a)
-newAcSparse typ prj idx val = case (typ, prj) of
- (_, SAPHere) -> newAcDense typ val
-
- (SMTPair t1 t2, SAPFst prj') ->
- (,) <$> newAcSparse t1 prj' (fst idx) val <*> newAcZero t2 (snd idx)
- (SMTPair t1 t2, SAPSnd prj') ->
- (,) <$> newAcZero t1 (fst idx) <*> newAcSparse t2 prj' (snd idx) val
-
- (SMTLEither t1 _, SAPLeft prj') -> newIORef . Just . Left =<< newAcSparse t1 prj' idx val
- (SMTLEither _ t2, SAPRight prj') -> newIORef . Just . Right =<< newAcSparse t2 prj' idx val
-
- (SMTMaybe t1, SAPJust prj') -> newIORef . Just =<< newAcSparse t1 prj' idx val
-
- (SMTArr n t, SAPArrIdx prj') -> onehotArray (\idx' -> newAcSparse t prj' idx' val) (newAcZero t) n prj' idx
-
onehotArray :: Monad m
=> (Rep (AcIdxS p a) -> m v) -- ^ the "one"
-> (Rep (ZeroInfo a) -> m v) -- ^ the "zero"
@@ -309,81 +302,67 @@ readAc typ val = case typ of
SMTArr _ t -> traverse (readAc t) val
SMTScal _ -> readIORef val
-accumAddDense :: SMTy a -> RepAc a -> Rep a -> AcM s ()
-accumAddDense typ ref val = case typ of
- SMTNil -> return ()
- SMTPair t1 t2 -> do
- accumAddDense t1 (fst ref) (fst val)
- accumAddDense t2 (snd ref) (snd val)
- SMTLEither{} ->
- case val of
- Nothing -> return ()
- Just (Left val1) -> accumAddSparse typ (SAPLeft SAPHere) ref () val1
- Just (Right val2) -> accumAddSparse typ (SAPRight SAPHere) ref () val2
- SMTMaybe{} ->
- case val of
- Nothing -> return ()
- Just val' -> accumAddSparse typ (SAPJust SAPHere) ref () val'
- SMTArr _ t1 ->
- forM_ [0 .. arraySize ref - 1] $ \i ->
- accumAddDense t1 (arrayIndexLinear ref i) (arrayIndexLinear val i)
- SMTScal sty -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
-
-accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Rep b -> AcM s ()
-accumAddSparseD typ prj ref idx val = case (typ, prj) of
- (_, SAPHere) -> accumAddDense typ ref val
+accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s ()
+accumAddSparseD typ prj ref idx sp val = case (typ, prj) of
+ (_, SAPHere) -> accumAddDense typ ref sp val
- (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx val
- (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx val
+ (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val
+ (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val
(SMTLEither t1 _, SAPLeft prj') ->
- realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
- (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val
Right{} -> error "Mismatched Either in accumAddSparseD (r +l)")
(SMTLEither _ t2, SAPRight prj') ->
- realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
- (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val
Left{} -> error "Mismatched Either in accumAddSparseD (l +r)")
(SMTMaybe t1, SAPJust prj') ->
- realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
- (\ac -> accumAddSparse t1 prj' ac idx val)
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)")
+ (\ac -> accumAddSparseD t1 prj' ac idx sp val)
(SMTArr n t1, SAPArrIdx prj') ->
let (arrindex', idx') = idx
arrindex = unTupRepIdx IxNil IxCons n arrindex'
arrsh = arrayShape ref
linindex = toLinearIndex arrsh arrindex
- in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' val
-
-accumAddSparse :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxS p a) -> Rep b -> AcM s ()
-accumAddSparse typ prj ref idx val = case (typ, prj) of
- (_, SAPHere) -> accumAddDense typ ref val
-
- (SMTPair t1 _, SAPFst prj') -> accumAddSparse t1 prj' (fst ref) (fst idx) val
- (SMTPair _ t2, SAPSnd prj') -> accumAddSparse t2 prj' (snd ref) (snd idx) val
-
- (SMTLEither t1 _, SAPLeft prj') ->
- realiseMaybeSparse ref (Left <$> newAcSparse t1 prj' idx val)
- (\case Left ac1 -> accumAddSparse t1 prj' ac1 idx val
- Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
- (SMTLEither _ t2, SAPRight prj') ->
- realiseMaybeSparse ref (Right <$> newAcSparse t2 prj' idx val)
- (\case Right ac2 -> accumAddSparse t2 prj' ac2 idx val
- Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
-
- (SMTMaybe t1, SAPJust prj') ->
- realiseMaybeSparse ref (newAcSparse t1 prj' idx val)
- (\ac -> accumAddSparse t1 prj' ac idx val)
-
- (SMTArr n t1, SAPArrIdx prj') ->
- let ((arrindex', ziarr), idx') = idx
- arrindex = unTupRepIdx IxNil IxCons n arrindex'
- arrsh = arrayShape ziarr
- linindex = toLinearIndex arrsh arrindex
- in accumAddSparse t1 prj' (arrayIndexLinear ref linindex) idx' val
+ in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val
+accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s ()
+accumAddDense typ ref sp val = case (typ, sp) of
+ (_, _) | isAbsent sp -> return ()
+ (_, SpAbsent) -> return ()
+ (_, SpSparse s) ->
+ case val of
+ Nothing -> return ()
+ Just val' -> accumAddDense typ ref s val'
+ (SMTPair t1 t2, SpPair s1 s2) -> do
+ accumAddDense t1 (fst ref) s1 (fst val)
+ accumAddDense t2 (snd ref) s2 (snd val)
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ case val of
+ Nothing -> return ()
+ Just (Left val1) ->
+ realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)")
+ (\case Left ac1 -> accumAddDense t1 ac1 s1 val1
+ Right{} -> error "Mismatched Either in accumAddSparse (r +l)")
+ Just (Right val2) ->
+ realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)")
+ (\case Right ac2 -> accumAddDense t2 ac2 s2 val2
+ Left{} -> error "Mismatched Either in accumAddSparse (l +r)")
+ (SMTMaybe t, SpMaybe s) ->
+ case val of
+ Nothing -> return ()
+ Just val' ->
+ realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)")
+ (\ac -> accumAddDense t ac s val')
+ (SMTArr _ t1, SpArr s) ->
+ forM_ [0 .. arraySize ref - 1] $ \i ->
+ accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i)
+ (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ()))
+-- TODO: makeval is always 'error' now. Simplify?
realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s ()
realiseMaybeSparse ref makeval modifyval =
-- Try modifying what's already in ref. The 'join' makes the snd
diff --git a/src/Language.hs b/src/Language.hs
index 63279df..4e6d604 100644
--- a/src/Language.hs
+++ b/src/Language.hs
@@ -17,6 +17,7 @@ module Language (
import Array
import AST
+import AST.Sparse.Types
import AST.Types
import CHAD.Types
import Data
@@ -176,7 +177,10 @@ with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum
with a (n :-> b) = NEWith (knownMTy @t) a n b
accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
-accum p a b c = NEAccum knownMTy p a b c
+accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c
+
+accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
+accumS p a sp b c = NEAccum knownMTy p a sp b c
(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool)
diff --git a/src/Language/AST.hs b/src/Language/AST.hs
index 92792b3..be98ccf 100644
--- a/src/Language/AST.hs
+++ b/src/Language/AST.hs
@@ -21,6 +21,7 @@ import GHC.TypeLits (Symbol, SSymbol, symbolSing, KnownSymbol, TypeError, ErrorM
import Array
import AST
+import AST.Sparse.Types
import CHAD.Types
import Data
@@ -76,7 +77,7 @@ data NExpr env t where
-- accumulation effect on monoids
NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t)
- NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil
+ NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil
-- partiality
NEError :: STy a -> String -> NExpr env a
@@ -221,7 +222,7 @@ fromNamedExpr val = \case
NERecompute e -> ERecompute ext (go e)
NEWith t a n b -> EWith ext t (go a) (lambda val n b)
- NEAccum t p a b c -> EAccum ext t p (go a) (go b) (go c)
+ NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c)
NEError t s -> EError ext t s
diff --git a/src/Simplify.hs b/src/Simplify.hs
index d3b850f..74b6601 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -1,7 +1,9 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE QuasiQuotes #-}
@@ -19,13 +21,14 @@ import Control.Monad (ap)
import Data.Bifunctor (first)
import Data.Function (fix)
import Data.Monoid (Any(..))
-import Data.Type.Equality (testEquality)
import Debug.Trace
import AST
import AST.Count
import AST.Pretty
+import AST.Sparse.Types
+import AST.UnMonoid (acPrjCompose)
import Data
import Simplify.TH
@@ -81,22 +84,28 @@ runSM (SM f) = first getAny (f id)
smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt)
smReconstruct core = SM (\ctx -> (Any False, ctx core))
-tellActed :: SM tenv tt env t ()
-tellActed = SM (\_ -> (Any True, ()))
+class Monad m => ActedMonad m where
+ tellActed :: m ()
+ hideActed :: m a -> m a
+ liftActed :: (Any, a) -> m a
+
+instance ActedMonad ((,) Any) where
+ tellActed = (Any True, ())
+ hideActed (_, x) = (Any False, x)
+ liftActed = id
+
+instance ActedMonad (SM tenv tt env t) where
+ tellActed = SM (\_ -> tellActed)
+ hideActed (SM f) = SM (\ctx -> hideActed (f ctx))
+ liftActed pair = SM (\_ -> pair)
-- more convenient in practice
-acted :: SM tenv tt env t a -> SM tenv tt env t a
+acted :: ActedMonad m => m a -> m a
acted m = tellActed >> m
within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a
within subctx (SM f) = SM $ \ctx -> f (ctx . subctx)
-acted' :: (Any, a) -> (Any, a)
-acted' (_, x) = (Any True, x)
-
-liftActed :: (Any, a) -> SM tenv tt env t a
-liftActed pair = SM $ \_ -> pair
-
simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t)
simplify' expr
| scLogging ?config = do
@@ -167,10 +176,10 @@ simplify'Rec = \case
ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2)))
EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body))
EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e)))
- EAccum _ t p e1 (ELet _ rhs body) acc ->
+ EAccum _ t p e1 sp (ELet _ rhs body) acc ->
acted $ simplify' $
ELet ext rhs $
- EAccum ext t p (weakenExpr WSink e1) body (weakenExpr WSink acc)
+ EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc)
-- let () = e in () ~> e
ELet _ e1 (ENil _) | STNil <- typeOf e1 ->
@@ -194,6 +203,9 @@ simplify'Rec = \case
EIdx _ (EReplicate1Inner _ _ e2) e3 -> acted $ simplify' $ EIdx ext e2 (EFst ext e3)
EIdx _ (EUnit _ e1) _ -> acted $ simplify' $ e1
+ -- TODO: more array shape
+ EShape _ (EBuild _ _ e _) -> acted $ simplify' e
+
-- TODO: more constant folding
EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext))
EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext))
@@ -222,23 +234,40 @@ simplify'Rec = \case
acted $ simplify' $ EUnit ext (substInline (ENil ext) e)
-- monoid rules
- EAccum _ t p e1 e2 acc -> do
- e1' <- within (\e1' -> EAccum ext t p e1' e2 acc ) $ simplify' e1
- e2' <- within (\e2' -> EAccum ext t p e1' e2' acc ) $ simplify' e2
- acc' <- within (\acc' -> EAccum ext t p e1' e2' acc') $ simplify' acc
- simplifyOneHotTerm (OneHotTerm SAI_D t p e1' e2')
+ EAccum _ t p e1 sp e2 acc -> do
+ e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1
+ e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2
+ acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc
+ simplifyOHT (OneHotTerm SAID t p e1' sp e2')
(acted $ return (ENil ext))
- (\e -> return (EAccum ext t SAPHere (ENil ext) e acc'))
- (\(OneHotTerm SAI_D t' p' e1'' e2'') -> return (EAccum ext t' p' e1'' e2'' acc'))
+ (\sp' (InContext w wrap e) -> do
+ e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e
+ return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')))
+ (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do
+ -- The acted management here is a hideous mess.
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2''
+ return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')))
EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e
EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e
EOneHot _ t p e1 e2 -> do
e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1
e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2
- simplifyOneHotTerm (OneHotTerm SAI_S t p e1' e2')
+ simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2')
(acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2)))
- (\e -> acted $ return e)
- (\(OneHotTerm SAI_S t' p' e1'' e2'') -> return (EOneHot ext t' p' e1'' e2''))
+ (\sp' (InContext _ wrap e) ->
+ case isDense t sp' of
+ Just Refl -> do
+ e' <- hideActed $ within wrap $ simplify' e
+ return (wrap e')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
+ (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) ->
+ case isDense (acPrjTy p' t') sp' of
+ Just Refl -> do
+ e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1''
+ e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2''
+ return (wrap $ EOneHot ext t' p' e1''' e2''')
+ Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse")
-- type-specific equations for plus
EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) ->
@@ -302,8 +331,9 @@ simplify'Rec = \case
e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1)
e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2)
pure (EWith ext t e1' e2')
- EZero _ t e -> [simprec| EZero ext t *e |] -- EZero ext t <$> simplify' e
- EPlus _ t a b -> [simprec| EPlus ext t *a *b |] -- EPlus ext t <$> simplify' a <*> simplify' b
+ EZero _ t e -> [simprec| EZero ext t *e |]
+ EDeepZero _ t e -> [simprec| EDeepZero ext t *e |]
+ EPlus _ t a b -> [simprec| EPlus ext t *a *b |]
EError _ t s -> pure $ EError ext t s
cheapExpr :: Expr x env t -> Bool
@@ -353,8 +383,9 @@ hasAdds = \case
EOp _ _ e -> hasAdds e
EWith _ _ a b -> hasAdds a || hasAdds b
ERecompute _ e -> hasAdds e
- EAccum _ _ _ _ _ _ -> True
+ EAccum _ _ _ _ _ _ _ -> True
EZero _ _ e -> hasAdds e
+ EDeepZero _ _ e -> hasAdds e
EPlus _ _ a b -> hasAdds a || hasAdds b
EOneHot _ _ _ a b -> hasAdds a || hasAdds b
EError _ _ _ -> False
@@ -373,51 +404,161 @@ checkAccumInScope = \case SNil -> False
check (STScal _) = False
check STAccum{} = True
-data OneHotTerm dense env p a b where
- OneHotTerm :: SStillDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Ex env b -> OneHotTerm dense env p a b
-deriving instance Show (OneHotTerm dense env p a b)
-
-simplifyOneHotTerm :: OneHotTerm dense env p a b
- -> SM tenv tt env t r -- ^ Zero case (onehot is actually zero)
- -> (Ex env a -> SM tenv tt env t r) -- ^ Trivial case (no zeros in onehot)
- -> (forall p' b'. OneHotTerm dense env p' a b' -> SM tenv tt env t r)
- -> SM tenv tt env t r
-simplifyOneHotTerm (OneHotTerm dense t1 prj1 idx1 val1) kzero ktriv k = do
- val1' <- liftActed $ recogniseMonoid (acPrjTy prj1 t1) val1
- case val1' of
- EZero{} -> kzero
- EOneHot _ t2 prj2 idx2 val2
- | Just Refl <- testEquality (acPrjTy prj1 t1) t2 -> do
- tellActed -- record, whatever happens later, that we've modified something
- concatOneHots dense t1 prj1 idx1 prj2 idx2 $ \prj12 idx12 ->
- simplifyOneHotTerm (OneHotTerm dense t1 prj12 idx12 val2) kzero ktriv k
- _ -> case prj1 of
- SAPHere -> ktriv val1
- _ -> k (OneHotTerm dense t1 prj1 idx1 val1)
+data OneHotTerm dense env a where
+ OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a
+deriving instance Show (OneHotTerm dense env a)
+
+data InContext f env (a :: Ty) where
+ InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a
+
+simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do
+ val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val
+ return $ OneHotTerm dense t prj idx sp val'
+
+simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a)
+simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) =
+ unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 ->
+ acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' ->
+ return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2)
+simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht
+
+simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a)
+simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val))
+ | Just Refl <- isDense (acPrjTy prj1 t1) sp =
+ let idx2' :: Ex env (AcIdx dense p2 c)
+ idx2' = case dense of
+ SAID -> reduceAcIdx t2 prj2 idx2
+ SAIS -> idx2
+ in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' ->
+ acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val
+simplifyOHT_concat oht = return oht
+
+-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is
+-- -- dense, then the Sparse in the output will also be dense. This property is
+-- -- used when simplifying EOneHot, which cannot represent sparsity.
+simplifyOHT :: ActedMonad m => OneHotTerm dense env a
+ -> m r -- ^ Zero case (onehot is actually zero)
+ -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot)
+ -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified
+ -> m r
+simplifyOHT oht kzero ktriv k = do
+ -- traceM $ "sOHT: input " ++ show oht
+ oht1 <- simplifyOHT_recogniseMonoid oht
+ -- traceM $ "sOHT: recog " ++ show oht1
+ InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1
+ -- traceM $ "sOHT: unspa " ++ show oht2
+ oht3 <- simplifyOHT_concat oht2
+ -- traceM $ "sOHT: conca " ++ show oht3
+ -- traceM ""
+ case oht3 of
+ OneHotTerm _ _ _ _ _ EZero{} -> kzero
+ OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val)
+ _ -> k (InContext w1 wrap1 oht3)
+
+-- Sets the acted flag whenever a non-trivial projection is returned or the
+-- output Sparse is different from the input Sparse.
+unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a'
+ -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s)
+ -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r
+unsparseOneHotD topsp topval k = case (topsp, topval) of
+ -- eliminate always-Just sparse onehot
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k
+
+ -- expand the top levels of a onehot for a sparse type into a onehot for the
+ -- corresponding non-sparse type
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPFst spprj) idx' s1' e'
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPSnd spprj) idx' s1' e'
+ (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPLeft spprj) idx' s1' e'
+ (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPRight spprj) idx' s1' e'
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' ->
+ acted $ k w wrap (SAPJust spprj) idx' s1' e'
+ (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val)
+ | Dict <- styKnown (typeOf idx) ->
+ unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' ->
+ acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e'
+
+ -- anything else we don't know how to improve
+ _ -> k WId id SAPHere (ENil ext) topsp topval
+
+{-
+unsparseOneHotS :: ActedMonad m
+ => Sparse a a' -> Ex env a'
+ -> (forall b. Sparse a b -> Ex env b -> m r) -> m r
+unsparseOneHotS topsp topval k = case (topsp, topval) of
+ -- order is relevant to make sure we set the acted flag correctly
+ (SpAbsent, v@ENil{}) -> k SpAbsent v
+ (SpAbsent, v@EZero{}) -> k SpAbsent v
+ (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+ (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext))
+
+ -- the unsparsifying
+ (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) ->
+ acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k
+
+ -- recursion
+ -- TODO: coproducts could safely become projections as they do not need
+ -- zeroinfo. But that would only work if the coproduct is at the top, because
+ -- as soon as we hit a product, we need zeroinfo to make it a projection and
+ -- we don't have that.
+ (SpSparse s, e) -> k (SpSparse s) e
+ (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' ->
+ acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext))
+ (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' ->
+ acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do
+ case s2 of SpAbsent -> pure () ; _ -> tellActed
+ k (SpLEither s1' SpAbsent) (ELInl ext STNil e')
+ (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) ->
+ unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do
+ case s1 of SpAbsent -> pure () ; _ -> tellActed
+ acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e')
+ (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' ->
+ k (SpMaybe s1') (EJust ext e')
+ (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) ->
+ unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' ->
+ k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e')
+ _ -> _
+-}
-- | Recognises 'EZero' and 'EOneHot'.
recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t)
recogniseMonoid _ e@EOneHot{} = return e
-recogniseMonoid SMTNil (ENil _) = acted' $ return $ EZero ext SMTNil (ENil ext)
+recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext)
recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) =
((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case
- (EZero _ _ ezi1, EZero _ _ ezi2) -> acted' $ return $ EZero ext typ (EPair ext ezi1 ezi2)
- (a', EZero _ _ ezi2) -> acted' $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a'
- (EZero _ _ ezi1, b') -> acted' $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b'
+ (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2)
+ (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a'
+ (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b'
(a', b') -> return $ EPair ext a' b'
recogniseMonoid typ@(SMTLEither t1 t2) expr =
case expr of
- ELNil{} -> acted' $ return $ EZero ext typ (ENil ext)
- ELInl _ _ e -> acted' $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e
- ELInr _ _ e -> acted' $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e
+ ELNil{} -> acted $ return $ EZero ext typ (ENil ext)
+ ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e
_ -> return expr
recogniseMonoid typ@(SMTMaybe t1) expr =
case expr of
- ENothing{} -> acted' $ return $ EZero ext typ (ENil ext)
- EJust _ e -> acted' $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e
+ ENothing{} -> acted $ return $ EZero ext typ (ENil ext)
+ EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e
_ -> return expr
recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) =
- acted' $ do
+ acted $ do
e' <- recogniseMonoid t e
return $
ELet ext e' $
@@ -426,61 +567,21 @@ recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) =
(ENil ext))
(EVar ext (fromSMTy t) IZ)
recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of
- (STI32, 0) -> acted' $ return $ EZero ext typ (ENil ext)
- (STI64, 0) -> acted' $ return $ EZero ext typ (ENil ext)
- (STF32, 0) -> acted' $ return $ EZero ext typ (ENil ext)
- (STF64, 0) -> acted' $ return $ EZero ext typ (ENil ext)
+ (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext)
+ (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext)
_ -> return e
recogniseMonoid _ e = return e
-concatOneHots :: SStillDense dense -> SMTy a
- -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
- -> SAcPrj p2 b c -> Ex env (AcIdxS p2 b)
- -> (forall p12. SAcPrj p12 a c -> Ex env (AcIdx dense p12 a) -> r) -> r
-concatOneHots dense t1 prj1 idx1 prj2 idx2 k = case (dense, t1, prj1) of
- (SAI_D, _, SAPHere) -> k prj2 (reduceAcIdx t1 prj2 idx2)
- (SAI_S, _, SAPHere) -> k prj2 idx2
-
- (SAI_D, SMTPair a _, SAPFst prj1') ->
- concatOneHots SAI_D a prj1' idx1 prj2 idx2 $ \prj12 idx12 ->
- k (SAPFst prj12) idx12
- (SAI_S, SMTPair a _, SAPFst prj1') ->
- concatOneHots SAI_S a prj1' (EFst ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPFst prj12) (ELet ext idx1 $ EPair ext idx12 (ESnd ext (EVar ext (typeOf idx1) IZ)))
- (SAI_D, SMTPair _ b, SAPSnd prj1') ->
- concatOneHots dense b prj1' idx1 prj2 idx2 $ \prj12 idx12 ->
- k (SAPSnd prj12) idx12
- (SAI_S, SMTPair _ b, SAPSnd prj1') ->
- concatOneHots dense b prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPSnd prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
-
- (_, SMTLEither a _, SAPLeft prj1') ->
- concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPLeft prj12) idx12
- (_, SMTLEither _ b, SAPRight prj1') ->
- concatOneHots SAI_S b prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPRight prj12) idx12
-
- (_, SMTMaybe a, SAPJust prj1') ->
- concatOneHots SAI_S a prj1' idx1 prj2 idx2 $ \prj12 idx12 -> k (SAPJust prj12) idx12
-
- -- yes, twice the same code, but we need a concrete denseness indicator to
- -- reduce AcIdx (the only difference between the dense and sparse versions is
- -- whether there extra info also contains an array shape, and this code
- -- handles the extra info uniformly)
- (SAI_D, SMTArr _ a, SAPArrIdx prj1') ->
- concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
- (SAI_S, SMTArr _ a, SAPArrIdx prj1') ->
- concatOneHots dense a prj1' (ESnd ext (EVar ext (typeOf idx1) IZ)) prj2 (weakenExpr WSink idx2) $ \prj12 idx12 ->
- k (SAPArrIdx prj12) (ELet ext idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx12)
-
-reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdx AI_S p a) -> Ex env (AcIdx AI_D p a)
+reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a)
reduceAcIdx topty topprj e = case (topty, topprj) of
(_, SAPHere) -> ENil ext
(SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e)
(SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e)
- (SMTLEither{}, SAPLeft{}) -> e
- (SMTLEither{}, SAPRight{}) -> e
- (SMTMaybe{}, SAPJust{}) -> e
+ (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e
+ (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e
+ (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e
(SMTArr _ t, SAPArrIdx p) ->
eunPair e $ \_ e1 e2 ->
EPair ext (efst e1) (reduceAcIdx t p e2)
diff --git a/test/Main.hs b/test/Main.hs
index 1b83a2e..d79e63f 100644
--- a/test/Main.hs
+++ b/test/Main.hs
@@ -435,11 +435,22 @@ gen_neural = do
lay3 <- genArray tR (ShNil `ShCons` n2)
return (input `SCons` lay3 `SCons` lay2 `SCons` lay1 `SCons` SNil)
+term_build0 :: Ex '[TArr N0 R] R
+term_build0 = fromNamed $ lambda @(TArr N0 _) #x $ body $
+ idx0 $
+ build SZ (shape #x) $ #idx :-> #x ! #idx
+
term_build1_sum :: Ex '[TVec R] R
term_build1_sum = fromNamed $ lambda #x $ body $
idx0 $ sum1i $
build (SS SZ) (shape #x) $ #idx :-> #x ! #idx
+term_build1_idx :: Ex '[TVec R] R
+term_build1_idx = fromNamed $ lambda @(TVec _) #x $ body $
+ let_ #n (snd_ (shape #x)) $
+ idx0 $ sum1i $
+ build1 (#n `idiv` 2) (#i :-> #x ! pair nil (2 * #i))
+
term_pairs :: Ex [R, R] R
term_pairs = fromNamed $ lambda #x $ lambda #y $ body $
let_ #p (pair #x #y) $
@@ -502,22 +513,22 @@ tests_Compile = testGroup "Compile"
,compileTest "accum (f64,f64)" $ fromNamed $ lambda #b $ lambda #x $ body $
with @(TPair R R) (pair 0.0 0.0) $ #ac :->
- let_ #_ (if_ #b (accum (SAPFst SAPHere) (pair nil nil) 3.0 #ac) nil) $
+ let_ #_ (if_ #b (accum (SAPFst SAPHere) nil 3.0 #ac) nil) $
let_ #_ (accum SAPHere nil #x #ac) $
- let_ #_ (accum (SAPSnd SAPHere) (pair nil nil) 4.0 #ac) $
+ let_ #_ (accum (SAPSnd SAPHere) nil 4.0 #ac) $
nil
,compileTest "accum (Maybe (f64,f64))" $ fromNamed $ lambda #b $ lambda #x $ body $
- with @(TMaybe (TPair R R)) nothing $ #ac :->
- let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) (pair nil nil) 3.0 #ac) nil) $
+ with @(TMaybe (TPair R R)) (just (pair 0 0)) $ #ac :->
+ let_ #_ (if_ #b (accum (SAPJust (SAPFst SAPHere)) nil 3.0 #ac) nil) $
let_ #_ (accum SAPHere nil #x #ac) $
- let_ #_ (accum (SAPJust (SAPSnd SAPHere)) (pair nil nil) 4.0 #ac) $
+ let_ #_ (accum (SAPJust (SAPSnd SAPHere)) nil 4.0 #ac) $
nil
,compileTestTp "accum Arr 1 f64" (() :& C "" 3) $ fromNamed $ lambda #b $ lambda @(TVec R) #x $ body $
let_ #len (snd_ (shape #x)) $
with @(TVec R) (build1 #len (#_ :-> 0)) $ #ac :->
- let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair (pair nil 2) (build1 #len (#_ :-> nil))) nil) 6.0 #ac)
+ let_ #_ (if_ #b (accum (SAPArrIdx SAPHere) (pair (pair nil 2) nil) 6.0 #ac)
nil) $
let_ #_ (accum SAPHere nil #x #ac) $
nil
@@ -556,9 +567,7 @@ tests_AD = testGroup "AD"
,adTest "build0 const" $ fromNamed $ lambda @R #x $ body $
idx0 $ build SZ nil $ #idx :-> const_ 0.0
- ,adTest "build0" $ fromNamed $ lambda @(TArr N0 _) #x $ body $
- idx0 $
- build SZ (shape #x) $ #idx :-> #x ! #idx
+ ,adTest "build0" term_build0
,adTest "build1-sum" term_build1_sum
@@ -566,6 +575,8 @@ tests_AD = testGroup "AD"
idx0 $ sum1i . sum1i $
build (SS (SS SZ)) (shape #x) $ #idx :-> #x ! #idx
+ ,adTest "build1-idx" term_build1_idx
+
,adTestCon "maximum" (\(Value a `SCons` _) -> let _ `ShCons` n = arrayShape a in n > 0) $
fromNamed $ lambda @(TMat R) #x $ body $
idx0 $ sum1i $ maximum1i #x