aboutsummaryrefslogtreecommitdiff
path: root/src/AST
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST')
-rw-r--r--src/AST/Accum.hs75
-rw-r--r--src/AST/Bindings.hs20
-rw-r--r--src/AST/Count.hs952
-rw-r--r--src/AST/Env.hs84
-rw-r--r--src/AST/Pretty.hs93
-rw-r--r--src/AST/Sparse.hs287
-rw-r--r--src/AST/Sparse/Types.hs107
-rw-r--r--src/AST/SplitLets.hs49
-rw-r--r--src/AST/Types.hs42
-rw-r--r--src/AST/UnMonoid.hs125
-rw-r--r--src/AST/Weaken.hs8
-rw-r--r--src/AST/Weaken/Auto.hs2
12 files changed, 1653 insertions, 191 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index e84034b..988a450 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -1,14 +1,13 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
-{-# LANGUAGE KindSignatures #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where
import AST.Types
-import CHAD.Types
import Data
@@ -35,21 +34,39 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
-type family AcIdx p t where
- AcIdx APHere t = TNil
- AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b)
- AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b)
- AcIdx (APLeft p) (TLEither a b) = AcIdx p a
- AcIdx (APRight p) (TLEither a b) = AcIdx p b
- AcIdx (APJust p) (TMaybe a) = AcIdx p a
- AcIdx (APArrIdx p) (TArr n a) =
- -- ((index, shapes info), recursive info)
+type data AIDense = AID | AIS
+
+data SAIDense d where
+ SAID :: SAIDense AID
+ SAIS :: SAIDense AIS
+deriving instance Show (SAIDense d)
+
+type family AcIdx d p t where
+ AcIdx d APHere t = TNil
+ AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a
+ AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b
+ AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b)
+ AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b)
+ AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a
+ AcIdx d (APRight p) (TLEither a b) = AcIdx d p b
+ AcIdx d (APJust p) (TMaybe a) = AcIdx d p a
+ AcIdx AID (APArrIdx p) (TArr n a) =
+ -- (index, recursive info)
+ TPair (Tup (Replicate n TIx)) (AcIdx AID p a)
+ AcIdx AIS (APArrIdx p) (TArr n a) =
+ -- ((index, shape info), recursive info)
TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
- (AcIdx p a)
- -- AcIdx (APArrSlice m) (TArr n a) =
+ (AcIdx AIS p a)
+ -- AcIdx AID (APArrSlice m) (TArr n a) =
+ -- -- index
+ -- Tup (Replicate m TIx)
+ -- AcIdx AIS (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
+type AcIdxD p t = AcIdx AID p t
+type AcIdxS p t = AcIdx AIS p t
+
acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
@@ -75,19 +92,23 @@ tZeroInfo (SMTMaybe _) = STNil
tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
tZeroInfo (SMTScal _) = STNil
-lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil
-lemZeroInfoD2 STNil = Refl
-lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
-lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
-lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl
-lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl
-lemZeroInfoD2 (STScal STI32) = Refl
-lemZeroInfoD2 (STScal STI64) = Refl
-lemZeroInfoD2 (STScal STF32) = Refl
-lemZeroInfoD2 (STScal STF64) = Refl
-lemZeroInfoD2 (STScal STBool) = Refl
-lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program"
-lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
+-- | Info needed to create a zero-valued deep accumulator for a monoid type.
+-- Should be constructable from a D1.
+type family DeepZeroInfo t where
+ DeepZeroInfo TNil = TNil
+ DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
+ DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
+ DeepZeroInfo (TScal t) = TNil
+
+tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
+tDeepZeroInfo SMTNil = STNil
+tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
+tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
+tDeepZeroInfo (SMTScal _) = STNil
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs
index 3d99afe..463586a 100644
--- a/src/AST/Bindings.hs
+++ b/src/AST/Bindings.hs
@@ -16,6 +16,7 @@
module AST.Bindings where
import AST
+import AST.Env
import Data
import Lemmas
@@ -27,6 +28,10 @@ data Bindings f env binds where
deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
infixl `BPush`
+bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds)
+bpush b e = b `BPush` (typeOf e, e)
+infixl `bpush`
+
mapBindings :: (forall env' t'. f env' t' -> g env' t')
-> Bindings f env binds -> Bindings g env binds
mapBindings _ BTop = BTop
@@ -41,6 +46,11 @@ weakenBindings wf w (BPush b (t, x)) =
let (b', w') = weakenBindings wf w b
in (BPush b' (t, wf w' x), WCopy w')
+weakenBindingsE :: env1 :> env2
+ -> Bindings (Expr x) env1 binds
+ -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2)
+weakenBindingsE = weakenBindings weakenExpr
+
weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
weakenOver SNil w = w
weakenOver (SCons _ ts) w = WCopy (weakenOver ts w)
@@ -62,3 +72,13 @@ bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds)
letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
letBinds BTop = id
letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs
+
+collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env'
+collectBindings = \env -> fst . go env WId
+ where
+ go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0)
+ go _ _ SETop = (BTop, WId)
+ go (ty `SCons` env) w (SEYesR sub) =
+ let (bs, w') = go env (WPop w) sub
+ in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w')
+ go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index feaaa1e..bc02417 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -3,6 +3,7 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
@@ -10,17 +11,31 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE PatternSynonyms #-}
module AST.Count where
-import Data.Functor.Const
+import Data.Functor.Product
+import Data.Some
+import Data.Type.Equality
import GHC.Generics (Generic, Generically(..))
+import Array
import AST
import AST.Env
import Data
+-- | The monoid operation combines assuming that /both/ branches are taken.
+class Monoid a => Occurrence a where
+ -- | One of the two branches is taken
+ (<||>) :: a -> a -> a
+ -- | This code is executed many times
+ scaleMany :: a -> a
+
+
data Count = Zero | One | Many
deriving (Show, Eq, Ord)
@@ -30,6 +45,10 @@ instance Semigroup Count where
_ <> _ = Many
instance Monoid Count where
mempty = Zero
+instance Occurrence Count where
+ (<||>) = max
+ scaleMany Zero = Zero
+ scaleMany _ = Many
data Occ = Occ { _occLexical :: Count
, _occRuntime :: Count }
@@ -40,120 +59,855 @@ instance Show Occ where
showsPrec d (Occ l r) = showParen (d > 10) $
showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r
--- | One of the two branches is taken
-(<||>) :: Occ -> Occ -> Occ
-Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (max r1 r2)
+instance Occurrence Occ where
+ Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2)
+ scaleMany (Occ l c) = Occ l (scaleMany c)
+
+
+data Substruc t t' where
+ -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below
+ SsFull :: Substruc t t
+ SsNone :: Substruc t TNil
+ SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b')
+ SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b')
+ SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b')
+ SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a')
+ SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements
+ SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a')
+
+pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t'
+pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2)
+ where SsPair' = SsPair
+{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-}
+
+pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t'
+pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s)
+ where SsArr' = SsArr
+{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-}
+
+instance Semigroup (Some (Substruc t)) where
+ Some SsFull <> _ = Some SsFull
+ _ <> Some SsFull = Some SsFull
+ Some SsNone <> s = s
+ s <> Some SsNone = s
+ Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2)
+ Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2)
+ Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2)
+ Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2)
+ Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2)
+ Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2)
+instance Monoid (Some (Substruc t)) where
+ mempty = Some SsNone
+
+instance TestEquality (Substruc t) where
+ testEquality SsFull s = isFull s
+ testEquality s SsFull = sym <$> isFull s
+ testEquality SsNone SsNone = Just Refl
+ testEquality SsNone _ = Nothing
+ testEquality _ SsNone = Nothing
+ testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing
+ testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+ testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+ testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing
+
+isFull :: Substruc t t' -> Maybe (t :~: t')
+isFull SsFull = Just Refl
+isFull SsNone = Nothing -- TODO: nil?
+isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing
+isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing
+
+applySubstruc :: Substruc t t' -> STy t -> STy t'
+applySubstruc SsFull t = t
+applySubstruc SsNone _ = STNil
+applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b)
+applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t)
+applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t)
+applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t)
+
+applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t'
+applySubstrucM SsFull t = t
+applySubstrucM SsNone _ = SMTNil
+applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b)
+applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b)
+applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t)
+applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t)
+applySubstrucM _ t = case t of {}
+
+data ExMap a b = ExMap (forall env. Ex env a -> Ex env b)
+ | a ~ b => ExMapId
+
+fromExMap :: ExMap a b -> Ex env a -> Ex env b
+fromExMap (ExMap f) = f
+fromExMap ExMapId = id
+
+simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t'
+simplifySubstruc STNil SsNone = SsFull
+
+simplifySubstruc _ SsFull = SsFull
+simplifySubstruc _ SsNone = SsNone
+simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2)
+simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s)
+simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s)
+simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s)
+
+-- simplifySubstruc' :: Substruc t t'
+-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r
+-- simplifySubstruc' SsFull k = k SsFull ExMapId
+-- simplifySubstruc' SsNone k = k SsNone ExMapId
+-- simplifySubstruc' (SsPair s1 s2) k =
+-- simplifySubstruc' s1 $ \s1' f1 ->
+-- simplifySubstruc' s2 $ \s2' f2 ->
+-- case (s1', s2') of
+-- (SsFull, SsFull) ->
+-- k SsFull (case (f1, f2) of
+-- (ExMapId, ExMapId) -> ExMapId
+-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 ->
+-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2)))
+-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext))))
+-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ)))))
+-- simplifySubstruc' _ _ = _
+
+-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b)
+-- ssUnpair SsFull = (SsFull, SsFull)
+-- ssUnpair SsNone = (SsNone, SsNone)
+-- ssUnpair (SsPair a b) = (a, b)
+
+-- ssUnleft :: Substruc (TEither a b) -> Substruc a
+-- ssUnleft SsFull = SsFull
+-- ssUnleft SsNone = SsNone
+-- ssUnleft (SsEither a _) = a
+
+-- ssUnright :: Substruc (TEither a b) -> Substruc b
+-- ssUnright SsFull = SsFull
+-- ssUnright SsNone = SsNone
+-- ssUnright (SsEither _ b) = b
+
+-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a
+-- ssUnlleft SsFull = SsFull
+-- ssUnlleft SsNone = SsNone
+-- ssUnlleft (SsLEither a _) = a
+
+-- ssUnlright :: Substruc (TLEither a b) -> Substruc b
+-- ssUnlright SsFull = SsFull
+-- ssUnlright SsNone = SsNone
+-- ssUnlright (SsLEither _ b) = b
+
+-- ssUnjust :: Substruc (TMaybe a) -> Substruc a
+-- ssUnjust SsFull = SsFull
+-- ssUnjust SsNone = SsNone
+-- ssUnjust (SsMaybe a) = a
+
+-- ssUnarr :: Substruc (TArr n a) -> Substruc a
+-- ssUnarr SsFull = SsFull
+-- ssUnarr SsNone = SsNone
+-- ssUnarr (SsArr a) = a
+
+-- ssUnaccum :: Substruc (TAccum a) -> Substruc a
+-- ssUnaccum SsFull = SsFull
+-- ssUnaccum SsNone = SsNone
+-- ssUnaccum (SsAccum a) = a
+
+
+type family MapEmpty env where
+ MapEmpty '[] = '[]
+ MapEmpty (t : env) = TNil : MapEmpty env
+
+data OccEnv a env env' where
+ OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top!
+ OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env')
+
+instance Semigroup a => Semigroup (Some (OccEnv a env)) where
+ Some OccEnd <> e = e
+ e <> Some OccEnd = e
+ Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2)
+
+instance Semigroup a => Monoid (Some (OccEnv a env)) where
+ mempty = Some OccEnd
+
+instance Occurrence a => Occurrence (Some (OccEnv a env)) where
+ Some OccEnd <||> e = e
+ e <||> Some OccEnd = e
+ Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2)
+
+ scaleMany (Some OccEnd) = Some OccEnd
+ scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s)
+
+onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env)
+onehotOccEnv IZ v s = Some (OccPush OccEnd v s)
+onehotOccEnv (IS i) v s
+ | Some env' <- onehotOccEnv i v s
+ = Some (OccPush env' mempty SsNone)
+
+occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t')
+occEnvPop (OccPush e _ s) = (e, s)
+occEnvPop OccEnd = (OccEnd, SsNone)
+
+occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r
+occEnvPop' (OccPush e _ s) k = k e s
+occEnvPop' OccEnd k = k OccEnd SsNone
--- | This code is executed many times
-scaleMany :: Occ -> Occ
-scaleMany (Occ l Zero) = Occ l Zero
-scaleMany (Occ l _) = Occ l Many
+occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env)
+occEnvPopSome (Some (OccPush e _ _)) = Some e
+occEnvPopSome (Some OccEnd) = Some OccEnd
+
+occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t))
+occEnvPrj OccEnd _ = mempty
+occEnvPrj (OccPush _ o s) IZ = (o, Some s)
+occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i
+
+occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env'))
+occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ)
+occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i'))
+occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ)
+occEnvPrjS (OccPush e _ _) (IS i)
+ | Some (Pair s' i') <- occEnvPrjS e i
+ = Some (Pair s' (IS i'))
+
+projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small
+projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of
+ _ | Just Refl <- testEquality topsbig topssmall -> ex
+
+ (SsFull, SsFull) -> ex
+ (SsNone, SsNone) -> ex
+ (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller"
+ (_, SsNone) ->
+ case typeOf ex of
+ STNil -> ex
+ _ -> use ex $ ENil ext
+
+ (SsPair s1 s2, SsPair s1' s2') ->
+ eunPair ex $ \_ e1 e2 ->
+ EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2)
+ (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex
+ (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex
+
+ (SsEither s1 s2, SsEither s1' s2')
+ | STEither t1 t2 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
+ in ecase ex
+ (EInl ext (typeOf e2) e1)
+ (EInr ext (typeOf e1) e2)
+ (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex
+ (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex
+
+ (SsLEither s1 s2, SsLEither s1' s2')
+ | STLEither t1 t2 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ)
+ in elcase ex
+ (ELNil ext (typeOf e1) (typeOf e2))
+ (ELInl ext (typeOf e2) e1)
+ (ELInr ext (typeOf e1) e2)
+ (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex
+ (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex
+
+ (SsMaybe s1, SsMaybe s1')
+ | STMaybe t1 <- typeOf ex ->
+ let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ)
+ in emaybe ex
+ (ENothing ext (typeOf e1))
+ (EJust ext e1)
+ (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex
+ (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex
+
+ (SsArr s1, SsArr s2)
+ | STArr n t <- typeOf ex ->
+ elet ex $
+ EBuild ext n (EShape ext (evar IZ)) $
+ projectSmallerSubstruc s1 s2
+ (EIdx ext (EVar ext (STArr n t) (IS IZ))
+ (EVar ext (tTup (sreplicate n tIx)) IZ))
+ (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex
+ (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex
+
+ (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum"
+ (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex
+ (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex
+
+
+-- | A boolean for each entry in the environment, with the ability to uniformly
+-- mask the top part above a certain index.
+data EnvMask env where
+ EMRest :: Bool -> EnvMask env
+ EMPush :: EnvMask env -> Bool -> EnvMask (t : env)
+
+envMaskPrj :: EnvMask env -> Idx env t -> Bool
+envMaskPrj (EMRest b) _ = b
+envMaskPrj (_ `EMPush` b) IZ = b
+envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i
occCount :: Idx env a -> Expr x env t -> Occ
-occCount idx =
- getConst . occCountGeneral
- (\w i o -> if idx2int i == idx2int (w @> idx) then Const o else mempty)
- (\(Const o) -> Const o)
- (\(Const o1) (Const o2) -> Const (o1 <||> o2))
- (\(Const o) -> Const (scaleMany o))
+occCount idx ex
+ | Some env <- occCountAll ex
+ = fst (occEnvPrj env idx)
+
+occCountAll :: Expr x env t -> Some (OccEnv Occ env)
+occCountAll ex = occCountX SsFull ex $ \env _ -> Some env
+
+pruneExpr :: SList f env -> Expr x env t -> Ex env t
+pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env)
+ where
+ fullOccEnv :: SList f env -> OccEnv () env env
+ fullOccEnv SNil = OccEnd
+ fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull
+
+-- In one traversal, count occurrences of variables and determine what parts of
+-- expressions are actually used. These two results are computed independently:
+-- even if (almost) nothing of a particular term is actually used, variable
+-- references in that term still count as usual.
+--
+-- In @occCountX s t k@:
+-- * s: how much of the result of this term is required
+-- * t: the term to analyse
+-- * k: is passed the actual environment usage of this expression, including
+-- occurrence counts. The callback reconstructs a new expression in an
+-- updated "response" environment. The response must be at least as large as
+-- the computed usages.
+occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t
+ -> (forall env'. OccEnv Occ env env'
+ -- response OccEnv must be at least as large as the OccEnv returned above
+ -> (forall env''. OccEnv () env env'' -> Ex env'' t')
+ -> r)
+ -> r
+occCountX initialS topexpr k = case topexpr of
+ EVar _ t i ->
+ withSome (onehotOccEnv i (Occ One One) s) $ \env ->
+ k env $ \env' ->
+ withSome (occEnvPrjS env' i) $ \(Pair s' i') ->
+ projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i')
+ ELet _ rhs body ->
+ occCountX s body $ \envB mkbody ->
+ occEnvPop' envB $ \envB' s1 ->
+ occCountX s1 rhs $ \envR mkrhs ->
+ withSome (Some envB' <> Some envR) $ \env ->
+ k env $ \env' ->
+ ELet ext (mkrhs env') (mkbody (OccPush env' () s1))
+ EPair _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsPair' s1 s2 ->
+ occCountX s1 a $ \env1 mka ->
+ occCountX s2 b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EPair ext (mka env') (mkb env')
+ EFst _ e ->
+ occCountX (SsPair s SsNone) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EFst ext (mke env')
+ ESnd _ e ->
+ occCountX (SsPair SsNone s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ ESnd ext (mke env')
+ ENil _ ->
+ case s of
+ SsFull -> k OccEnd (\_ -> ENil ext)
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ EInl _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsEither s1 s2 ->
+ occCountX s1 e $ \env1 mke ->
+ k env1 $ \env' ->
+ EInl ext (applySubstruc s2 t) (mke env')
+ SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
+ EInr _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsEither s1 s2 ->
+ occCountX s2 e $ \env1 mke ->
+ k env1 $ \env' ->
+ EInr ext (applySubstruc s1 t) (mke env')
+ SsFull -> occCountX (SsEither SsFull SsFull) topexpr k
+ ECase _ e a b ->
+ occCountX s a $ \env1' mka ->
+ occCountX s b $ \env2' mkb ->
+ occEnvPop' env1' $ \env1 s1 ->
+ occEnvPop' env2' $ \env2 s2 ->
+ occCountX (SsEither s1 s2) e $ \env0 mke ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
+ k env $ \env' ->
+ ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2))
+ ENothing _ t ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t))
+ SsFull -> occCountX (SsMaybe SsFull) topexpr k
+ EJust _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsMaybe s' ->
+ occCountX s' e $ \env1 mke ->
+ k env1 $ \env' ->
+ EJust ext (mke env')
+ SsFull -> occCountX (SsMaybe SsFull) topexpr k
+ EMaybe _ a b e ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s2 ->
+ occCountX (SsMaybe s2) e $ \env0 mke ->
+ withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env ->
+ k env $ \env' ->
+ EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env')
+ ELNil _ t1 t2 ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2))
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELInl _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsLEither s1 s2 ->
+ occCountX s1 e $ \env1 mke ->
+ k env1 $ \env' ->
+ ELInl ext (applySubstruc s2 t) (mke env')
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELInr _ t e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ SsLEither s1 s2 ->
+ occCountX s2 e $ \env1 mke ->
+ k env1 $ \env' ->
+ ELInr ext (applySubstruc s1 t) (mke env')
+ SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k
+ ELCase _ e a b c ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2' mkb ->
+ occCountX s c $ \env3' mkc ->
+ occEnvPop' env2' $ \env2 s1 ->
+ occEnvPop' env3' $ \env3 s2 ->
+ occCountX (SsLEither s1 s2) e $ \env0 mke ->
+ withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env ->
+ k env $ \env' ->
+ ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2))
+
+ EConstArr _ n t x ->
+ case s of
+ SsNone -> k OccEnd (\_ -> ENil ext)
+ SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext))
+ SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x)
+
+ EBuild _ n a b ->
+ case s of
+ SsNone ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsNone b $ \env2'' mkb ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
+ k env $ \env' ->
+ use (EBuild ext n (mka env') $
+ use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $
+ ENil ext) $
+ ENil ext
+ SsArr' s' ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX s' b $ \env2'' mkb ->
+ occEnvPop' env2'' $ \env2' s2 ->
+ withSome (Some env1 <> scaleMany (Some env2')) $ \env ->
+ k env $ \env' ->
+ EBuild ext n (mka env') $
+ elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $
+ weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))
+
+ EMap _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $
+ ENil ext
+ SsArr' s' ->
+ occCountX s' a $ \env1'' mka ->
+ occEnvPop' env1'' $ \env1' s1 ->
+ occCountX (SsArr s1) b $ \env2 mkb ->
+ withSome (scaleMany (Some env1') <> Some env2) $ \env ->
+ k env $ \env' ->
+ EMap ext (mka (OccPush env' () s1)) (mkb env')
+
+ EFold1Inner _ commut a b c ->
+ occCountX SsFull a $ \env1''' mka ->
+ withSome (scaleMany (Some env1''')) $ \env1'' ->
+ occEnvPop' env1'' $ \env1' s2 ->
+ occEnvPop' env1' $ \env1 s1 ->
+ let s0 = case s of
+ SsNone -> Some SsNone
+ SsArr' s' -> Some s' in
+ withSome (Some s1 <> Some s2 <> s0) $ \sElt ->
+ occCountX sElt b $ \env2 mkb ->
+ occCountX (SsArr sElt) c $ \env3 mkc ->
+ withSome (Some env1 <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsArr sElt) s $
+ EFold1Inner ext commut
+ (projectSmallerSubstruc SsFull sElt $
+ mka (OccPush (OccPush env' () sElt) () sElt))
+ (mkb env') (mkc env')
+
+ ESum1Inner _ e -> handleReduction (ESum1Inner ext) e
+
+ EUnit _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr' s' ->
+ occCountX s' e $ \env mke ->
+ k env $ \env' ->
+ EUnit ext (mke env')
+
+ EReplicate1Inner _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' s' ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX (SsArr s') b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EReplicate1Inner ext (mka env') (mkb env')
+
+ EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e
+ EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e
+
+ EReshape _ n esh e ->
+ case s of
+ SsNone ->
+ occCountX SsNone esh $ \env1 mkesh ->
+ occCountX SsNone e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkesh env') $ use (mke env') $ ENil ext
+ SsArr' s' ->
+ occCountX SsFull esh $ \env1 mkesh ->
+ occCountX (SsArr s') e $ \env2 mke ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EReshape ext n (mkesh env') (mke env')
+
+ EZip _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' SsNone ->
+ occCountX (SsArr SsNone) a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mkb env') $ mka env'
+ SsArr' (SsPair' s1 s2) ->
+ occCountX (SsArr s1) a $ \env1 mka ->
+ occCountX (SsArr s2) b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EZip ext (mka env') (mkb env')
+
+ EFold1InnerD1 _ cm e1 e2 e3 ->
+ case s of
+ -- If nothing is necessary, we can execute a fold and then proceed to ignore it
+ SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex
+ -- If we don't need the stores, still a fold suffices
+ SsPair' sP SsNone ->
+ let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1))
+ (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3)
+ in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext)
+ -- If for whatever reason the additional stores themselves are
+ -- unnecessary but the shape of the array is, then oblige
+ SsPair' sP (SsArr' SsNone) ->
+ let STArr sn _ = typeOf e3
+ foldex =
+ elet (mapExt (\_ -> ext) e3) $
+ EPair ext
+ (EShape ext (evar IZ))
+ (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy (WCopy WSink)) e1)))
+ (mapExt (\_ -> ext) (weakenExpr WSink e2))
+ (evar IZ))
+ in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex ->
+ k env1 $ \env' ->
+ eunPair (mkfoldex env') $ \_ eshape earr ->
+ EPair ext earr (EBuild ext sn eshape (ENil ext))
+ -- If at least some of the additional stores are required, we need to keep this a mapAccum
+ SsPair' _ (SsArr' sB) ->
+ -- TODO: propagate usage of primals
+ occCountX (SsPair SsFull sB) e1 $ \env1_2' mka ->
+ occEnvPop' env1_2' $ \env1_1' _ ->
+ occEnvPop' env1_1' $ \env1' _ ->
+ occCountX SsFull e2 $ \env2 mkb ->
+ occCountX SsFull e3 $ \env3 mkc ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $
+ EFold1InnerD1 ext cm (mka (OccPush (OccPush env' () SsFull) () SsFull))
+ (mkb env') (mkc env')
+
+ EFold1InnerD2 _ cm ef ebog ed ->
+ -- TODO: propagate usage of duals
+ occCountX SsFull ef $ \env1_2' mkef ->
+ occEnvPop' env1_2' $ \env1_1' _ ->
+ occEnvPop' env1_1' $ \env1' sB ->
+ occCountX (SsArr sB) ebog $ \env2 mkebog ->
+ occCountX SsFull ed $ \env3 mked ->
+ withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $
+ EFold1InnerD2 ext cm
+ (mkef (OccPush (OccPush env' () sB) () SsFull))
+ (mkebog env') (mked env')
+
+ EConst _ t x ->
+ k OccEnd $ \_ ->
+ case s of
+ SsNone -> ENil ext
+ SsFull -> EConst ext t x
+
+ EIdx0 _ e ->
+ occCountX (SsArr s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EIdx0 ext (mke env')
+
+ EIdx1 _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ SsArr' s' ->
+ occCountX (SsArr s') a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EIdx1 ext (mka env') (mkb env')
+
+ EIdx _ a b ->
+ case s of
+ SsNone ->
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ _ ->
+ occCountX (SsArr s) a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EIdx ext (mka env') (mkb env')
+
+ EShape _ e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ _ ->
+ occCountX (SsArr SsNone) e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EShape ext (mke env')
+ EOp _ op e ->
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env1 mke ->
+ k env1 $ \env' ->
+ use (mke env') $ ENil ext
+ _ ->
+ occCountX SsFull e $ \env1 mke ->
+ k env1 $ \env' ->
+ projectSmallerSubstruc SsFull s $ EOp ext op (mke env')
-data OccEnv env where
- OccEnd :: OccEnv env -- not necessarily top!
- OccPush :: OccEnv env -> Occ -> OccEnv (t : env)
+ ECustom _ t1 t2 t3 e1 e2 e3 a b
+ | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 ->
+ error "Accumulators not allowed in input/output/tape of an ECustom"
+ | otherwise ->
+ case s of
+ SsNone ->
+ -- Allowed to ignore e1/e2/e3 here because no accumulators are
+ -- communicated, and hence no relevant effects exist
+ occCountX SsNone a $ \env1 mka ->
+ occCountX SsNone b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (mka env') $ use (mkb env') $ ENil ext
+ s' -> -- Let's be pessimistic for safety
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s' $
+ ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env')
-instance Semigroup (OccEnv env) where
- OccEnd <> e = e
- e <> OccEnd = e
- OccPush e o <> OccPush e' o' = OccPush (e <> e') (o <> o')
+ ERecompute _ e ->
+ occCountX s e $ \env1 mke ->
+ k env1 $ \env' ->
+ ERecompute ext (mke env')
-instance Monoid (OccEnv env) where
- mempty = OccEnd
+ EWith _ t a b ->
+ case s of
+ SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with
+ occCountX SsNone b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s1 ->
+ withSome (case s1 of
+ SsFull -> Some SsFull
+ SsAccum s' -> Some s'
+ SsNone -> Some SsNone) $ \s1' ->
+ occCountX s1' a $ \env1 mka ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $
+ ENil ext
+ SsPair sB sA ->
+ occCountX sB b $ \env2' mkb ->
+ occEnvPop' env2' $ \env2 s1 ->
+ let s1' = case s1 of
+ SsFull -> Some SsFull
+ SsAccum s' -> Some s'
+ SsNone -> Some SsNone in
+ withSome (Some sA <> s1') $ \sA' ->
+ occCountX sA' a $ \env1 mka ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $
+ EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA')))
+ SsFull -> occCountX (SsPair SsFull SsFull) topexpr k
-onehotOccEnv :: Idx env t -> Occ -> OccEnv env
-onehotOccEnv IZ v = OccPush OccEnd v
-onehotOccEnv (IS i) v = OccPush (onehotOccEnv i v) mempty
+ EAccum _ t p a sp b e ->
+ -- TODO: do better!
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb ->
+ occCountX SsFull e $ \env3 mke ->
+ withSome (Some env1 <> Some env2) $ \env12 ->
+ withSome (Some env12 <> Some env3) $ \env ->
+ k env $ \env' ->
+ case s of {SsFull -> id; SsNone -> id} $
+ EAccum ext t p (mka env') sp (mkb env') (mke env')
-(<||>!) :: OccEnv env -> OccEnv env -> OccEnv env
-OccEnd <||>! e = e
-e <||>! OccEnd = e
-OccPush e o <||>! OccPush e' o' = OccPush (e <||>! e') (o <||> o')
+ EZero _ t e ->
+ occCountX (subZeroInfo s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EZero ext (applySubstrucM s t) (mke env')
+ where
+ subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2)
+ subZeroInfo SsFull = SsFull
+ subZeroInfo SsNone = SsNone
+ subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2)
+ subZeroInfo SsEither{} = error "Either is not a monoid"
+ subZeroInfo SsLEither{} = SsNone
+ subZeroInfo SsMaybe{} = SsNone
+ subZeroInfo (SsArr s') = SsArr (subZeroInfo s')
+ subZeroInfo SsAccum{} = error "Accum is not a monoid"
-scaleManyOccEnv :: OccEnv env -> OccEnv env
-scaleManyOccEnv OccEnd = OccEnd
-scaleManyOccEnv (OccPush e o) = OccPush (scaleManyOccEnv e) (scaleMany o)
+ EDeepZero _ t e ->
+ occCountX (subDeepZeroInfo s) e $ \env1 mke ->
+ k env1 $ \env' ->
+ EDeepZero ext (applySubstrucM s t) (mke env')
+ where
+ subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2)
+ subDeepZeroInfo SsFull = SsFull
+ subDeepZeroInfo SsNone = SsNone
+ subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2)
+ subDeepZeroInfo SsEither{} = error "Either is not a monoid"
+ subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2)
+ subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s')
+ subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s')
+ subDeepZeroInfo SsAccum{} = error "Accum is not a monoid"
-occEnvPop :: OccEnv (t : env) -> OccEnv env
-occEnvPop (OccPush o _) = o
-occEnvPop OccEnd = OccEnd
+ EPlus _ t a b ->
+ occCountX s a $ \env1 mka ->
+ occCountX s b $ \env2 mkb ->
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ EPlus ext (applySubstrucM s t) (mka env') (mkb env')
-occCountAll :: Expr x env t -> OccEnv env
-occCountAll = occCountGeneral (const onehotOccEnv) occEnvPop (<||>!) scaleManyOccEnv
+ EOneHot _ t p a b ->
+ occCountX SsFull a $ \env1 mka ->
+ occCountX SsFull b $ \env2 mkb -> -- TODO: do better
+ withSome (Some env1 <> Some env2) $ \env ->
+ k env $ \env' ->
+ projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env')
-occCountGeneral :: forall r env t x.
- (forall env'. Monoid (r env'))
- => (forall env' a. env :> env' -> Idx env' a -> Occ -> r env') -- ^ one-hot
- -> (forall env' a. r (a : env') -> r env') -- ^ unpush
- -> (forall env'. r env' -> r env' -> r env') -- ^ alternation
- -> (forall env'. r env' -> r env') -- ^ scale-many
- -> Expr x env t -> r env
-occCountGeneral onehot unpush alter many = go WId
+ EError _ t msg ->
+ k OccEnd $ \_ -> EError ext (applySubstruc s t) msg
where
- go :: forall env' t'. Monoid (r env') => env :> env' -> Expr x env' t' -> r env'
- go w = \case
- EVar _ _ i -> onehot w i (Occ One One)
- ELet _ rhs body -> re rhs <> re1 body
- EPair _ a b -> re a <> re b
- EFst _ e -> re e
- ESnd _ e -> re e
- ENil _ -> mempty
- EInl _ _ e -> re e
- EInr _ _ e -> re e
- ECase _ e a b -> re e <> (re1 a `alter` re1 b)
- ENothing _ _ -> mempty
- EJust _ e -> re e
- EMaybe _ a b e -> re a <> re1 b <> re e
- ELNil _ _ _ -> mempty
- ELInl _ _ e -> re e
- ELInr _ _ e -> re e
- ELCase _ e a b c -> re e <> (re a `alter` re1 b `alter` re1 c)
- EConstArr{} -> mempty
- EBuild _ _ a b -> re a <> many (re1 b)
- EFold1Inner _ _ a b c -> many (unpush (unpush (go (WSink .> WSink .> w) a))) <> re b <> re c
- ESum1Inner _ e -> re e
- EUnit _ e -> re e
- EReplicate1Inner _ a b -> re a <> re b
- EMaximum1Inner _ e -> re e
- EMinimum1Inner _ e -> re e
- EConst{} -> mempty
- EIdx0 _ e -> re e
- EIdx1 _ a b -> re a <> re b
- EIdx _ a b -> re a <> re b
- EShape _ e -> re e
- EOp _ _ e -> re e
- ECustom _ _ _ _ _ _ _ a b -> re a <> re b
- EWith _ _ a b -> re a <> re1 b
- EAccum _ _ _ a b e -> re a <> re b <> re e
- EZero _ _ e -> re e
- EPlus _ _ a b -> re a <> re b
- EOneHot _ _ _ a b -> re a <> re b
- EError{} -> mempty
- where
- re :: Monoid (r env') => Expr x env' t'' -> r env'
- re = go w
+ s = simplifySubstruc (typeOf topexpr) initialS
- re1 :: Monoid (r env') => Expr x (a : env') t'' -> r env'
- re1 = unpush . go (WSink .> w)
+ handleReduction :: t ~ TArr n (TScal t2)
+ => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2)))
+ -> Expr x env (TArr (S n) (TScal t2))
+ -> r
+ handleReduction reduce e
+ | STArr (SS n) _ <- typeOf e =
+ case s of
+ SsNone ->
+ occCountX SsNone e $ \env mke ->
+ k env $ \env' ->
+ use (mke env') $ ENil ext
+ SsArr' SsNone ->
+ occCountX (SsArr SsNone) e $ \env mke ->
+ k env $ \env' ->
+ elet (mke env') $
+ EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext)
+ SsArr' SsFull ->
+ occCountX (SsArr SsFull) e $ \env mke ->
+ k env $ \env' ->
+ reduce (mke env')
-deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
-deleteUnused SNil OccEnd k = k SETop
-deleteUnused (_ `SCons` env) OccEnd k =
- deleteUnused env OccEnd $ \sub -> k (SENo sub)
-deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
- deleteUnused env occenv $ \sub ->
+deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r
+deleteUnused SNil (Some OccEnd) k = k SETop
+deleteUnused (_ `SCons` env) (Some OccEnd) k =
+ deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub)
+deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k =
+ deleteUnused env (Some occenv) $ \sub ->
case count of Zero -> k (SENo sub)
- _ -> k (SEYes sub)
+ _ -> k (SEYesR sub)
unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
unsafeWeakenWithSubenv = \sub ->
@@ -162,7 +916,7 @@ unsafeWeakenWithSubenv = \sub ->
Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
where
sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
- sinkViaSubenv IZ (SEYes _) = Just IZ
+ sinkViaSubenv IZ (SEYesR _) = Just IZ
sinkViaSubenv IZ (SENo _) = Nothing
- sinkViaSubenv (IS i) (SEYes sub) = IS <$> sinkViaSubenv i sub
+ sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub
sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub
diff --git a/src/AST/Env.hs b/src/AST/Env.hs
index 4f34166..85faba3 100644
--- a/src/AST/Env.hs
+++ b/src/AST/Env.hs
@@ -1,59 +1,95 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE QuantifiedConstraints #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST.Env where
+import Data.Type.Equality
+
+import AST.Sparse
import AST.Weaken
+import CHAD.Types
import Data
-- | @env'@ is a subset of @env@: each element of @env@ is either included in
-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
-data Subenv env env' where
- SETop :: Subenv '[] '[]
- SEYes :: forall t env env'. Subenv env env' -> Subenv (t : env) (t : env')
- SENo :: forall t env env'. Subenv env env' -> Subenv (t : env) env'
-deriving instance Show (Subenv env env')
+data Subenv' s env env' where
+ SETop :: Subenv' s '[] '[]
+ SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env')
+ SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env'
+deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env')
+
+type Subenv = Subenv' (:~:)
+type SubenvS = Subenv' Sparse
+
+pattern SEYesR :: forall tenv tenv'. ()
+ => forall t env env'. (tenv ~ t : env, tenv' ~ t : env')
+ => Subenv env env' -> Subenv tenv tenv'
+pattern SEYesR s = SEYes Refl s
+
+{-# COMPLETE SETop, SEYesR, SENo #-}
-subList :: SList f env -> Subenv env env' -> SList f env'
+subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env'
subList SNil SETop = SNil
-subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub)
+subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub)
subList (SCons _ xs) (SENo sub) = subList xs sub
-subenvAll :: SList f env -> Subenv env env
+subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env
subenvAll SNil = SETop
-subenvAll (SCons _ env) = SEYes (subenvAll env)
+subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env)
-subenvNone :: SList f env -> Subenv env '[]
+subenvNone :: SList f env -> Subenv' s env '[]
subenvNone SNil = SETop
subenvNone (SCons _ env) = SENo (subenvNone env)
-subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]
-subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)
-subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)
-subenvOnehot SNil i = case i of {}
+subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t']
+subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env)
+subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp)
+subenvOnehot SNil i _ = case i of {}
-subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3
+subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3
subenvCompose SETop SETop = SETop
-subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2)
-subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
+subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2)
+subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
-subenvConcat :: Subenv env1 env1' -> Subenv env2 env2' -> Subenv (Append env2 env1) (Append env2' env1')
+subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1')
subenvConcat sub1 SETop = sub1
-subenvConcat sub1 (SEYes sub2) = SEYes (subenvConcat sub1 sub2)
+subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2)
subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2)
-sinkWithSubenv :: Subenv env env' -> env0 :> Append env' env0
+-- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2
+-- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r
+-- subenvSplit SNil sub k = k SETop sub
+-- subenvSplit (SCons _ list) (SENo sub) k =
+-- subenvSplit list sub $ \sub1 sub2 ->
+-- k (SENo sub1) sub2
+-- subenvSplit (SCons _ list) (SEYes s sub) k =
+-- subenvSplit list sub $ \sub1 sub2 ->
+-- k (SEYes s sub1) sub2
+
+sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0
sinkWithSubenv SETop = WId
-sinkWithSubenv (SEYes sub) = WSink .> sinkWithSubenv sub
+sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub
sinkWithSubenv (SENo sub) = sinkWithSubenv sub
-wUndoSubenv :: Subenv env env' -> env' :> env
+wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env
wUndoSubenv SETop = WId
-wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
+wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub)
wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub
+
+subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env'
+subenvMap _ SNil SETop = SETop
+subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub)
+subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub)
+
+subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env')
+subenvD2E SETop = SETop
+subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub)
+subenvD2E (SENo sub) = SENo (subenvD2E sub)
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index e09f3ae..2c51b85 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -6,6 +6,7 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where
@@ -25,6 +26,7 @@ import System.IO.Unsafe (unsafePerformIO)
import AST
import AST.Count
+import AST.Sparse.Types
import CHAD.Types
import Data
@@ -70,6 +72,7 @@ genNameIfUsedIn' prefix ty idx ex
_ -> return "_"
| otherwise = genName' prefix
+-- TODO: let this return a type-tagged thing so that name environments are more typed than Const
genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t
@@ -145,12 +148,20 @@ ppExpr' d val expr = case expr of
EMaybe _ a b e -> do
let STMaybe t = typeOf e
- a' <- ppExpr' 11 val a
+ e' <- ppExpr' 0 val e
+ a' <- ppExpr' 0 val a
name <- genNameIfUsedIn t IZ b
b' <- ppExpr' 0 (Const name `SCons` val) b
- e' <- ppExpr' 11 val e
- return $ ppParen (d > 10) $
- ppApp (ppString "maybe" <> ppX expr) [a', ppLam [ppString name] b', e']
+ return $ ppParen (d > 0) $
+ align $
+ group (flatAlt
+ (annotate AKey (ppString "case") <> ppX expr <+> e'
+ <> hardline <> annotate AKey (ppString "of"))
+ (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of")))
+ <> hardline
+ <> indent 2
+ (ppString "Nothing" <+> ppString "->" <+> a'
+ <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b')
ELNil _ _ _ -> return (ppString "LNil")
@@ -193,14 +204,21 @@ ppExpr' d val expr = case expr of
<> hardline <> e')
(ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e'])
+ EMap _ a b -> do
+ let STArr _ t1 = typeOf b
+ name <- genNameIfUsedIn' "i" t1 IZ a
+ a' <- ppExpr' 0 (Const name `SCons` val) a
+ b' <- ppExpr' 11 val b
+ return $ ppParen (d > 0) $
+ ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b']
+
EFold1Inner _ cm a b c -> do
name1 <- genNameIfUsedIn (typeOf a) (IS IZ) a
name2 <- genNameIfUsedIn (typeOf a) IZ a
a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
b' <- ppExpr' 11 val b
c' <- ppExpr' 11 val c
- let opname = case cm of Commut -> "fold1i(C)"
- Noncommut -> "fold1i"
+ let opname = "fold1i" ++ ppCommut cm
return $ ppParen (d > 10) $
ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
@@ -225,6 +243,39 @@ ppExpr' d val expr = case expr of
e' <- ppExpr' 11 val e
return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e'
+ EReshape _ n esh e -> do
+ esh' <- ppExpr' 11 val esh
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e']
+
+ EZip _ e1 e2 -> do
+ e1' <- ppExpr' 11 val e1
+ e2' <- ppExpr' 11 val e2
+ return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2']
+
+ EFold1InnerD1 _ cm a b c -> do
+ name1 <- genNameIfUsedIn (typeOf b) (IS IZ) a
+ name2 <- genNameIfUsedIn (typeOf b) IZ a
+ a' <- ppExpr' 0 (Const name2 `SCons` Const name1 `SCons` val) a
+ b' <- ppExpr' 11 val b
+ c' <- ppExpr' 11 val c
+ let opname = "fold1iD1" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name1, ppString name2] a', b', c']
+
+ EFold1InnerD2 _ cm ef ebog ed -> do
+ let STArr _ tB = typeOf ebog
+ STArr _ t2 = typeOf ed
+ namef1 <- genNameIfUsedIn tB (IS IZ) ef
+ namef2 <- genNameIfUsedIn t2 IZ ef
+ ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef
+ ebog' <- ppExpr' 11 val ebog
+ ed' <- ppExpr' 11 val ed
+ let opname = "fold1iD2" ++ ppCommut cm
+ return $ ppParen (d > 10) $
+ ppApp (annotate AHighlight (ppString opname) <> ppX expr)
+ [ppLam [ppString namef1, ppString namef2] ef', ebog', ed']
+
EConst _ ty v
| Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr
@@ -280,6 +331,10 @@ ppExpr' d val expr = case expr of
,e1'
,e2']
+ ERecompute _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e']
+
EWith _ t e1 e2 -> do
e1' <- ppExpr' 11 val e1
name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2
@@ -292,18 +347,24 @@ ppExpr' d val expr = case expr of
<> hardline <> e2')
(ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2'])
- EAccum _ t prj e1 e2 e3 -> do
+ EAccum _ t prj e1 sp e2 e3 -> do
e1' <- ppExpr' 11 val e1
e2' <- ppExpr' 11 val e2
e3' <- ppExpr' 11 val e3
return $ ppParen (d > 10) $
- ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), e1', e2', e3']
+ ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t)))
+ [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3']
EZero _ t e1 -> do
e1' <- ppExpr' 11 val e1
return $ ppParen (d > 0) $
annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+ EDeepZero _ t e1 -> do
+ e1' <- ppExpr' 11 val e1
+ return $ ppParen (d > 0) $
+ annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1'
+
EPlus _ t a b -> do
a' <- ppExpr' 11 val a
b' <- ppExpr' 11 val b
@@ -356,6 +417,20 @@ ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")"
ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj
ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n)
+ppSparse :: SMTy a -> Sparse a b -> String
+ppSparse t sp | Just Refl <- isDense t sp = "D"
+ppSparse _ SpAbsent = "A"
+ppSparse t (SpSparse s) = "S" ++ ppSparse t s
+ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")"
+ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s
+ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s
+ppSparse (SMTScal _) SpScal = "."
+
+ppCommut :: Commutative -> String
+ppCommut Commut = "(C)"
+ppCommut Noncommut = ""
+
ppX :: PrettyX x => Expr x env t -> ADoc
ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr)
@@ -388,6 +463,7 @@ ppSTy' :: Int -> STy t -> Doc q
ppSTy' _ STNil = ppString "1"
ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b
ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b
+ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t
ppSTy' d (STArr n t) = ppParen (d > 10) $
ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t
@@ -398,7 +474,6 @@ ppSTy' _ (STScal sty) = ppString $ case sty of
STF64 -> "f64"
STBool -> "bool"
ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t
-ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b
ppSMTy :: Int -> SMTy t -> String
ppSMTy d ty = render $ ppSMTy' d ty
diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs
new file mode 100644
index 0000000..2a29799
--- /dev/null
+++ b/src/AST/Sparse.hs
@@ -0,0 +1,287 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImpredicativeTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE RankNTypes #-}
+
+{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-}
+module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where
+
+import Data.Type.Equality
+
+import AST
+import AST.Sparse.Types
+import Data (SBool(..))
+
+
+sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t'
+sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext
+sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2
+sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh
+sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 =
+ eunPair e1 $ \w1 e1a e1b ->
+ eunPair (weakenExpr w1 e2) $ \w2 e2a e2b ->
+ EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a)
+ (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b)
+sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 =
+ elet e2 $
+ elcase (weakenExpr WSink e1)
+ (evar IZ)
+ (elcase (evar (IS IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ))
+ (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ)))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr"))
+ (elcase (evar (IS IZ))
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ))
+ (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll")
+ (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 =
+ elet e2 $
+ emaybe (weakenExpr WSink e1)
+ (evar IZ)
+ (emaybe (evar (IS IZ))
+ (EJust ext (evar IZ))
+ (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ))))
+sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2
+sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2
+
+
+cheapZero :: SMTy t -> Maybe (forall env. Ex env t)
+cheapZero SMTNil = Just (ENil ext)
+cheapZero (SMTPair t1 t2)
+ | Just e1 <- cheapZero t1
+ , Just e2 <- cheapZero t2
+ = Just (EPair ext e1 e2)
+ | otherwise
+ = Nothing
+cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2))
+cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t))
+cheapZero SMTArr{} = Nothing
+cheapZero (SMTScal t) = case t of
+ STI32 -> Just (EConst ext t 0)
+ STI64 -> Just (EConst ext t 0)
+ STF32 -> Just (EConst ext t 0.0)
+ STF64 -> Just (EConst ext t 0.0)
+
+
+data Injection sp a b where
+ -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that
+ -- 'sparsePlusS' can provide injections even if the caller doesn't require
+ -- them. This simplifies the sparsePlusS code.
+ Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b
+ Noinj :: Injection False a b
+
+withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b'
+withInj (Inj f) k = Inj (k f)
+withInj Noinj _ = Noinj
+
+withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2
+ -> ((forall e. Ex e a1 -> Ex e b1)
+ -> (forall e. Ex e a2 -> Ex e b2)
+ -> (forall e'. Ex e' a' -> Ex e' b'))
+ -> Injection sp a' b'
+withInj2 (Inj f) (Inj g) k = Inj (k f g)
+withInj2 Noinj _ _ = Noinj
+withInj2 _ Noinj _ = Noinj
+
+-- | This function produces quadratically-sized code in the presence of nested
+-- dynamic sparsity. TODO can this be improved?
+sparsePlusS
+ :: SBool inj1 -> SBool inj2
+ -> SMTy t -> Sparse t t1 -> Sparse t t2
+ -> (forall t3. Sparse t t3
+ -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent)
+ -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent)
+ -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3)
+ -> r)
+ -> r
+-- nil override (but don't destroy effects!)
+sparsePlusS _ _ SMTNil _ _ k =
+ k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext)
+
+-- simplifications
+sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k =
+ sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b)
+sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k =
+ sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus ->
+ k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext))
+
+sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k =
+ let ta = applySparse sp1 (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k =
+ let tb = applySparse sp2 (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k =
+ let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k =
+ let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+ (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)))
+
+sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k =
+ let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in
+ sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus ->
+ k sp3
+ (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ)))
+ minj2
+ (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b)
+sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k =
+ let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in
+ sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus ->
+ k sp3
+ minj1
+ (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ)))
+ (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ)))
+sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k
+sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k
+
+-- TODO: sparse of Just is just Maybe
+
+-- dense plus
+sparsePlusS _ _ t sp1 sp2 k
+ | Just Refl <- isDense t sp1
+ , Just Refl <- isDense t sp2
+ = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b)
+
+-- handle absents
+sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b)
+sparsePlusS ST _ t SpAbsent sp2 k
+ | Just zero2 <- cheapZero (applySparse sp2 t) =
+ k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b)
+ | otherwise =
+ k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b)
+
+sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a)
+sparsePlusS _ ST t sp1 SpAbsent k
+ | Just zero1 <- cheapZero (applySparse sp1 t) =
+ k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a)
+ | otherwise =
+ k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a)
+
+-- double sparse yields sparse
+sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+
+-- single sparse can yield non-sparse if the other argument is always present
+sparsePlusS SF _ t (SpSparse sp1) sp2 k =
+ sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus ->
+ k sp3 Noinj (Inj inj2)
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (inj2 (evar IZ))
+ (plus (evar IZ) (evar (IS IZ))))
+sparsePlusS ST _ t (SpSparse sp1) sp2 k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpSparse sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> EJust ext (inj2 b))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (EJust ext (inj2 (evar IZ)))
+ (EJust ext (plus (evar IZ) (evar (IS IZ)))))
+sparsePlusS req1 req2 t sp1 (SpSparse sp2) k =
+ sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus ->
+ k sp3 inj2 inj1 (flip plus)
+
+-- products
+sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k =
+ sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa ->
+ sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb ->
+ k (SpPair sp3a sp3b)
+ (withInj2 minj13a minj13b $ \inj13a inj13b ->
+ \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b))
+ (withInj2 minj23a minj23b $ \inj23a inj23b ->
+ \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b))
+ (\x1 x2 ->
+ eunPair x1 $ \w1 x1a x1b ->
+ eunPair (weakenExpr w1 x2) $ \w2 x2a x2b ->
+ EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b))
+
+-- coproducts
+sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k =
+ sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa ->
+ sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb ->
+ let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb))
+ inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb))
+ inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta))
+ in
+ k (SpLEither sp3a sp3b)
+ (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ))))
+ (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ))))
+ (\x1 x2 ->
+ elet x2 $
+ elcase (weakenExpr WSink x1)
+ (elcase (evar IZ)
+ nil
+ (inl (inj23a (evar IZ)))
+ (inr (inj23b (evar IZ))))
+ (elcase (evar (IS IZ))
+ (inl (inj13a (evar IZ)))
+ (inl (plusa (evar (IS IZ)) (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr"))
+ (elcase (evar (IS IZ))
+ (inr (inj13b (evar IZ)))
+ (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll")
+ (inr (plusb (evar (IS IZ)) (evar IZ)))))
+
+-- maybe
+sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k =
+ sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus ->
+ k (SpMaybe sp3)
+ (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ))))
+ (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ))))
+ (\a b ->
+ elet b $
+ emaybe (weakenExpr WSink a)
+ (emaybe (evar IZ)
+ (ENothing ext (applySparse sp3 (fromSMTy t)))
+ (EJust ext (inj2 (evar IZ))))
+ (emaybe (evar (IS IZ))
+ (EJust ext (inj1 (evar IZ)))
+ (EJust ext (plus (evar (IS IZ)) (evar IZ)))))
+
+-- dense array cotangents simply recurse
+sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k =
+ sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus ->
+ k (SpArr sp3)
+ (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ)))
+ (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+ (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ))
+ (EVar ext (applySparse sp2 (fromSMTy t)) IZ)))
+
+-- scalars
+sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t))
diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs
new file mode 100644
index 0000000..10cac4e
--- /dev/null
+++ b/src/AST/Sparse/Types.hs
@@ -0,0 +1,107 @@
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+module AST.Sparse.Types where
+
+import AST.Types
+
+import Data.Kind (Type, Constraint)
+import Data.Type.Equality
+
+
+data Sparse t t' where
+ SpSparse :: Sparse t t' -> Sparse t (TMaybe t')
+ SpAbsent :: Sparse t TNil
+
+ SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b')
+ SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b')
+ SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t')
+ SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t')
+ SpScal :: Sparse (TScal t) (TScal t)
+deriving instance Show (Sparse t t')
+
+class ApplySparse f where
+ applySparse :: Sparse t t' -> f t -> f t'
+
+instance ApplySparse STy where
+ applySparse (SpSparse s) t = STMaybe (applySparse s t)
+ applySparse SpAbsent _ = STNil
+ applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t)
+ applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t)
+ applySparse SpScal t = t
+
+instance ApplySparse SMTy where
+ applySparse (SpSparse s) t = SMTMaybe (applySparse s t)
+ applySparse SpAbsent _ = SMTNil
+ applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2)
+ applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t)
+ applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t)
+ applySparse SpScal t = t
+
+
+class IsSubType s where
+ type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint
+ subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t'
+ subtTrans :: s a b -> s b c -> s a c
+ subtFull :: IsSubTypeSubject s f => f t -> s t t
+
+instance IsSubType (:~:) where
+ type IsSubTypeSubject (:~:) f = ()
+ subtApply = gcastWith
+ subtTrans = trans
+ subtFull _ = Refl
+
+instance IsSubType Sparse where
+ type IsSubTypeSubject Sparse f = f ~ SMTy
+ subtApply = applySparse
+
+ subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2)
+ subtTrans _ SpAbsent = SpAbsent
+ subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b)
+ subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2)
+ subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2)
+ subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2)
+ subtTrans SpScal SpScal = SpScal
+
+ subtFull = spDense
+
+spDense :: SMTy t -> Sparse t t
+spDense SMTNil = SpAbsent
+spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2)
+spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2)
+spDense (SMTMaybe t) = SpMaybe (spDense t)
+spDense (SMTArr _ t) = SpArr (spDense t)
+spDense (SMTScal _) = SpScal
+
+isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t')
+isDense SMTNil SpAbsent = Just Refl
+isDense _ SpSparse{} = Nothing
+isDense _ SpAbsent = Nothing
+isDense (SMTPair t1 t2) (SpPair s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTLEither t1 t2) (SpLEither s1 s2)
+ | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl
+ | otherwise = Nothing
+isDense (SMTMaybe t) (SpMaybe s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTArr _ t) (SpArr s)
+ | Just Refl <- isDense t s = Just Refl
+ | otherwise = Nothing
+isDense (SMTScal _) SpScal = Just Refl
+
+isAbsent :: Sparse t t' -> Bool
+isAbsent (SpSparse s) = isAbsent s
+isAbsent SpAbsent = True
+isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2
+isAbsent (SpMaybe s) = isAbsent s
+isAbsent (SpArr s) = isAbsent s
+isAbsent SpScal = False
diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs
index 159934d..d276e44 100644
--- a/src/AST/SplitLets.hs
+++ b/src/AST/SplitLets.hs
@@ -22,7 +22,7 @@ splitLets = splitLets' (\t i w -> EVar ext t (w @> i))
splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t
splitLets' = \sub -> \case
EVar _ t i -> sub t i WId
- ELet _ (rhs :: Ex env t1) body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
+ ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body)
ECase x e a b ->
let STEither t1 t2 = typeOf e
in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b)
@@ -35,6 +35,13 @@ splitLets' = \sub -> \case
EFold1Inner x cm a b c ->
let STArr _ t1 = typeOf c
in EFold1Inner x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD1 x cm a b c ->
+ let STArr _ t1 = typeOf c
+ in EFold1InnerD1 x cm (split2 sub t1 t1 a) (splitLets' sub b) (splitLets' sub c)
+ EFold1InnerD2 x cm a b c ->
+ let STArr _ tB = typeOf b
+ STArr _ t2 = typeOf c
+ in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c)
EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b)
EFst x e -> EFst x (splitLets' sub e)
@@ -54,6 +61,7 @@ splitLets' = \sub -> \case
EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b)
EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e)
EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e)
+ EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (splitLets' sub e)
EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b)
@@ -61,9 +69,11 @@ splitLets' = \sub -> \case
EShape x e -> EShape x (splitLets' sub e)
EOp x op e -> EOp x op (splitLets' sub e)
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2)
+ ERecompute x e -> ERecompute x (splitLets' sub e)
EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2)
- EAccum x t p e1 e2 e3 -> EAccum x t p (splitLets' sub e1) (splitLets' sub e2) (splitLets' sub e3)
+ EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3)
EZero x t ezi -> EZero x t (splitLets' sub ezi)
+ EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi)
EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b)
EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b)
EError x t s -> EError x t s
@@ -87,15 +97,42 @@ splitLets' = \sub -> \case
-> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t
split2 sub tbind1 tbind2 body =
let (ptrs1', bs1') = split @env' tbind1
- bs1 = fst (weakenBindings weakenExpr WSink bs1')
+ bs1 = fst (weakenBindingsE WSink bs1')
(ptrs2, bs2) = split @(bind1 : env') tbind2
in letBinds bs1 $
- letBinds (fst (weakenBindings weakenExpr (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $
+ letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $
splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1)))
_ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env')))
t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w)))))
body
+ -- TODO: abstract this to splitN lol wtf
+ _split4 :: forall bind1 bind2 bind3 bind4 env' env t.
+ (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a)
+ -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t
+ _split4 sub tbind1 tbind2 tbind3 tbind4 body =
+ let (ptrs1, bs1') = split @env' tbind1
+ (ptrs2, bs2') = split @(bind1 : env') tbind2
+ (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3
+ (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4
+ bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1')
+ bs2 = fst (weakenBindingsE (WSink .> WSink) bs2')
+ bs3 = fst (weakenBindingsE WSink bs3')
+ b1 = bindingsBinds bs1
+ b2 = bindingsBinds bs2
+ b3 = bindingsBinds bs3
+ b4 = bindingsBinds bs4
+ in letBinds bs1 $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $
+ letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $
+ letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $
+ splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1))
+ _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink))
+ _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink))
+ _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env')))
+ t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w)))))))))
+ body
+
type family Split t where
Split (TPair a b) = SplitRec (TPair a b)
Split _ = '[]
@@ -123,11 +160,11 @@ split typ = case typ of
STPair{} -> splitRec (EVar ext typ IZ) typ
STNil -> other
STEither{} -> other
+ STLEither{} -> other
STMaybe{} -> other
STArr{} -> other
STScal{} -> other
STAccum{} -> other
- STLEither{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex (t : env) '[])
other = (Point typ IZ, BTop)
@@ -142,11 +179,11 @@ splitRec rhs typ = case typ of
(p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b
in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2)
STEither{} -> other
+ STLEither{} -> other
STMaybe{} -> other
STArr{} -> other
STScal{} -> other
STAccum{} -> other
- STLEither{} -> other
where
other :: (Pointers (t : env) t, Bindings Ex env '[t])
other = (Point typ IZ, BPush BTop (typ, rhs))
diff --git a/src/AST/Types.hs b/src/AST/Types.hs
index efb1e04..4ddcb50 100644
--- a/src/AST/Types.hs
+++ b/src/AST/Types.hs
@@ -5,9 +5,9 @@
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE TypeData #-}
module AST.Types where
import Data.Int (Int32, Int64)
@@ -23,12 +23,11 @@ type data Ty
= TNil
| TPair Ty Ty
| TEither Ty Ty
+ | TLEither Ty Ty
| TMaybe Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
| TAccum Ty -- ^ contained type must be a monoid type
- -- sparse monoid types
- | TLEither Ty Ty
type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
@@ -37,12 +36,11 @@ data STy t where
STNil :: STy TNil
STPair :: STy a -> STy b -> STy (TPair a b)
STEither :: STy a -> STy b -> STy (TEither a b)
+ STLEither :: STy a -> STy b -> STy (TLEither a b)
STMaybe :: STy a -> STy (TMaybe a)
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
STAccum :: SMTy t -> STy (TAccum t)
- -- sparse monoid types
- STLEither :: STy a -> STy b -> STy (TLEither a b)
deriving instance Show (STy t)
instance GCompare STy where
@@ -53,6 +51,8 @@ instance GCompare STy where
STPair{} _ -> GLT ; _ STPair{} -> GGT
(STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
STEither{} _ -> GLT ; _ STEither{} -> GGT
+ (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
+ STLEither{} _ -> GLT ; _ STLEither{} -> GGT
(STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a')
STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT
(STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t')
@@ -60,9 +60,7 @@ instance GCompare STy where
(STScal t) (STScal t') -> gorderingLift1 (gcompare t t')
STScal{} _ -> GLT ; _ STScal{} -> GGT
(STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t')
- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
- (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b')
- -- STLEither{} _ -> GLT ; _ STLEither{} -> GGT
+ -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT
instance TestEquality STy where testEquality = geq
instance GEq STy where geq = defaultGeq
@@ -173,15 +171,25 @@ type family ScalIsIntegral t where
ScalIsIntegral TBool = False
-- | Returns true for arrays /and/ accumulators.
-hasArrays :: STy t' -> Bool
-hasArrays STNil = False
-hasArrays (STPair a b) = hasArrays a || hasArrays b
-hasArrays (STEither a b) = hasArrays a || hasArrays b
-hasArrays (STMaybe t) = hasArrays t
-hasArrays STArr{} = True
-hasArrays STScal{} = False
-hasArrays STAccum{} = True
-hasArrays (STLEither a b) = hasArrays a || hasArrays b
+typeHasArrays :: STy t' -> Bool
+typeHasArrays STNil = False
+typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b
+typeHasArrays (STMaybe t) = typeHasArrays t
+typeHasArrays STArr{} = True
+typeHasArrays STScal{} = False
+typeHasArrays STAccum{} = True
+
+typeHasAccums :: STy t' -> Bool
+typeHasAccums STNil = False
+typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b
+typeHasAccums (STMaybe t) = typeHasAccums t
+typeHasAccums STArr{} = False
+typeHasAccums STScal{} = False
+typeHasAccums STAccum{} = True
type family Tup env where
Tup '[] = TNil
diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs
index 3d5f544..1712ba5 100644
--- a/src/AST/UnMonoid.hs
+++ b/src/AST/UnMonoid.hs
@@ -1,18 +1,22 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
-module AST.UnMonoid (unMonoid, zero, plus) where
+module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where
import AST
+import AST.Sparse.Types
import Data
--- | Remove 'EZero', 'EPlus' and 'EOneHot' from the program by expanding them
--- into their concrete implementations.
+-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by
+-- expanding them into their concrete implementations. Also ensure that
+-- 'EAccum' has a dense sparsity.
unMonoid :: Ex env t -> Ex env t
unMonoid = \case
EZero _ t e -> zero t e
+ EDeepZero _ t e -> deepZero t e
EPlus _ t a b -> plus t (unMonoid a) (unMonoid b)
EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b)
@@ -34,12 +38,17 @@ unMonoid = \case
ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c)
EConstArr _ n t x -> EConstArr ext n t x
EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b)
+ EMap _ a b -> EMap ext (unMonoid a) (unMonoid b)
EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c)
ESum1Inner _ e -> ESum1Inner ext (unMonoid e)
EUnit _ e -> EUnit ext (unMonoid e)
EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b)
EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e)
EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e)
+ EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b)
+ EZip _ a b -> EZip ext (unMonoid a) (unMonoid b)
+ EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
+ EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c)
EConst _ t x -> EConst ext t x
EIdx0 _ e -> EIdx0 ext (unMonoid e)
EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b)
@@ -47,12 +56,17 @@ unMonoid = \case
EShape _ e -> EShape ext (unMonoid e)
EOp _ op e -> EOp ext op (unMonoid e)
ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2)
+ ERecompute _ e -> ERecompute ext (unMonoid e)
EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b)
- EAccum _ t p a b e -> EAccum ext t p (unMonoid a) (unMonoid b) (unMonoid e)
+ EAccum _ t p eidx sp eval eacc ->
+ accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 ->
+ acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' ->
+ EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc))
EError _ t s -> EError ext t s
zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t
-zero SMTNil _ = ENil ext
+-- don't destroy the effects!
+zero SMTNil e = ELet ext e $ ENil ext
zero (SMTPair t1 t2) e =
ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ)))
(zero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
@@ -65,8 +79,30 @@ zero (SMTScal t) _ = case t of
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
+deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t
+deepZero SMTNil e = elet e $ ENil ext
+deepZero (SMTPair t1 t2) e =
+ ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ)))
+ (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ)))
+deepZero (SMTLEither t1 t2) e =
+ elcase e
+ (ELNil ext (fromSMTy t1) (fromSMTy t2))
+ (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ)))
+ (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ)))
+deepZero (SMTMaybe t) e =
+ emaybe e
+ (ENothing ext (fromSMTy t))
+ (EJust ext (deepZero t (evar IZ)))
+deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e
+deepZero (SMTScal t) _ = case t of
+ STI32 -> EConst ext STI32 0
+ STI64 -> EConst ext STI64 0
+ STF32 -> EConst ext STF32 0.0
+ STF64 -> EConst ext STF64 0.0
+
plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t
-plus SMTNil _ _ = ENil ext
+-- don't destroy the effects!
+plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext
plus (SMTPair t1 t2) a b =
let t = STPair (fromSMTy t1) (fromSMTy t2)
in ELet ext a $
@@ -104,7 +140,7 @@ plus (SMTArr _ t) a b =
a b
plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b)
-onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdx p t) -> Ex env a -> Ex env t
+onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t
onehot typ topprj idx arg = case (typ, topprj) of
(_, SAPHere) ->
ELet ext arg $
@@ -142,3 +178,78 @@ onehot typ topprj idx arg = case (typ, topprj) of
(onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg))
(ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $
zero t1 (EVar ext (tZeroInfo t1) IZ))
+
+accumulateSparse
+ :: SMTy t -> Sparse t t' -> Ex env t'
+ -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil)
+ -> Ex env TNil
+accumulateSparse topty topsp arg accum = case (topty, topsp) of
+ (_, s) | Just Refl <- isDense topty s ->
+ accum WId SAPHere (ENil ext) arg
+ (SMTScal _, SpScal) ->
+ accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh
+ (_, SpSparse s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w)))
+ (_, SpAbsent) ->
+ ENil ext
+ (SMTPair t1 t2, SpPair s1 s2) ->
+ eunPair arg $ \w1 e1 e2 ->
+ elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $
+ accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj))
+ (SMTLEither t1 t2, SpLEither s1 s2) ->
+ elcase arg
+ (ENil ext)
+ (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj)))
+ (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj)))
+ (SMTMaybe t, SpMaybe s) ->
+ emaybe arg
+ (ENil ext)
+ (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj)))
+ (SMTArr n t, SpArr s) ->
+ let tn = tTup (sreplicate n tIx) in
+ elet arg $
+ elet (EBuild ext n (EShape ext (evar IZ)) $
+ accumulateSparse t s
+ (EIdx ext (evar (IS IZ)) (EVar ext tn IZ))
+ (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $
+ ENil ext
+
+acPrjCompose
+ :: SAIDense dense
+ -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a)
+ -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b)
+ -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r
+acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2
+acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPFst p') idx'
+acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k =
+ acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPSnd p') idx'
+acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ)))
+acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx')
+acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPLeft p') idx'
+acPrjCompose d (SAPRight p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPRight p') idx'
+acPrjCompose d (SAPJust p1) idx1 p2 idx2 k =
+ acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' ->
+ k (SAPJust p') idx'
+acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
+acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k
+ | Dict <- styKnown (typeOf idx1) =
+ acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' ->
+ k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx')
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index d882e28..f0820b8 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -19,6 +19,7 @@ module AST.Weaken (module AST.Weaken, Append) where
import Data.Bifunctor (first)
import Data.Functor.Const
+import Data.GADT.Compare
import Data.Kind (Type)
import Data
@@ -31,6 +32,11 @@ data Idx env t where
IS :: Idx env t -> Idx (a : env) t
deriving instance Show (Idx env t)
+instance GEq (Idx env) where
+ geq IZ IZ = Just Refl
+ geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl
+ geq _ _ = Nothing
+
splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t)
splitIdx SNil i = Right i
splitIdx (SCons _ _) IZ = Left IZ
@@ -123,7 +129,7 @@ wCopies bs w =
let bs' = slistMap (\_ -> Const ()) bs
in WStack bs' bs' WId w
-wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
+wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env
wRaiseAbove SNil _ = WClosed
wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env)
diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs
index 6752c24..c6efe37 100644
--- a/src/AST/Weaken/Auto.hs
+++ b/src/AST/Weaken/Auto.hs
@@ -64,7 +64,7 @@ data SSegments (segments :: [(Symbol, [t])]) where
SSegNil :: SSegments '[]
SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list)
-instance (KnownSymbol name, name ~ name', segs ~ '[ '(name', ts)]) => IsLabel name (SList f ts -> SSegments segs) where
+instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where
fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil
auto :: KnownListSpine list => SList (Const ()) list