summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST.hs32
-rw-r--r--src/AST/Count.hs28
-rw-r--r--src/AST/Env.hs43
-rw-r--r--src/AST/Pretty.hs16
-rw-r--r--src/AST/Weaken.hs6
-rw-r--r--src/CHAD.hs209
-rw-r--r--src/Example.hs15
-rw-r--r--src/Simplify.hs6
9 files changed, 241 insertions, 115 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index 19c2852..1bff84b 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -12,6 +12,7 @@ library
exposed-modules:
AST
AST.Count
+ AST.Env
AST.Pretty
AST.Weaken
AST.Weaken.Auto
diff --git a/src/AST.hs b/src/AST.hs
index 6c90be3..802ee2a 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -19,6 +19,7 @@ import Data.Functor.Const
import Data.Kind (Type)
import Data.Int
+import AST.Env
import AST.Weaken
import Data
@@ -90,15 +91,17 @@ data Expr x env t where
-- array operations
EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t)
- EBuild :: x (TArr n t) -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)
+ EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t)
EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
+ EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused
-- expression operations
EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t
+ EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
-- accumulation effect
@@ -114,6 +117,18 @@ type Ex = Expr (Const ())
ext :: Const () a
ext = Const ()
+type family Replicate n x where
+ Replicate Z x = '[]
+ Replicate (S n) x = x : Replicate n x
+
+type family Tup env where
+ Tup '[] = TNil
+ Tup (t : ts) = TPair (Tup ts) t
+
+tTup :: SList STy env -> STy (Tup env)
+tTup SNil = STNil
+tTup (SCons t ts) = STPair (tTup ts) t
+
type SOp :: Ty -> Ty -> Type
data SOp a t where
OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
@@ -151,14 +166,16 @@ typeOf = \case
ECase _ _ a _ -> typeOf a
EBuild1 _ _ e -> STArr (SS SZ) (typeOf e)
- EBuild _ es e -> STArr (vecLength es) (typeOf e)
+ EBuild _ n _ e -> STArr n (typeOf e)
EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
EUnit _ e -> STArr SZ (typeOf e)
+ EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t
EConst _ t _ -> STScal t
EIdx0 _ e | STArr _ t <- typeOf e -> t
EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
EIdx _ e _ | STArr _ t <- typeOf e -> t
+ -- EShape _ e | STArr n _ <- typeOf e -> _
EOp _ op _ -> opt2 op
EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
@@ -214,9 +231,10 @@ subst' f w = \case
EInr x t e -> EInr x t (subst' f w e)
ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e)
+ EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkFN n f) (wcopyN n w) b)
EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b)
EUnit x e -> EUnit x (subst' f w e)
+ EReplicate x e -> EReplicate x (subst' f w e)
EConst x t v -> EConst x t v
EIdx0 x e -> EIdx0 x (subst' f w e)
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
@@ -254,6 +272,11 @@ wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env'
wpopN SZ w = w
wpopN (SS n) w = wpopN n (WPop w)
+wUndoSubenv :: Subenv env env' -> env' :> env
+wUndoSubenv SETop = WId
+wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub)
+wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub
+
slistIdx :: SList f list -> Idx list t -> f t
slistIdx (SCons x _) IZ = x
slistIdx (SCons _ list) (IS i) = slistIdx list i
@@ -281,3 +304,6 @@ instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum
class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
+
+ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
+ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) (error "TODO" f)
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index 289c1fb..a4ff9f2 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -2,12 +2,14 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module AST.Count where
@@ -15,6 +17,7 @@ import Data.Functor.Const
import GHC.Generics (Generic, Generically(..))
import AST
+import AST.Env
import Data
@@ -110,9 +113,10 @@ occCountGeneral onehot unpush unpushN alter many = go
EInr _ _ e -> go e
ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b))
EBuild1 _ a b -> go a <> many (unpush (go b))
- EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e))
+ EBuild _ n a b -> go a <> many (unpushN n (go b))
EFold1 _ a b -> many (unpush (unpush (go a))) <> go b
EUnit _ e -> go e
+ EReplicate _ e -> go e
EConst{} -> mempty
EIdx0 _ e -> go e
EIdx1 _ a b -> go a <> go b
@@ -121,3 +125,25 @@ occCountGeneral onehot unpush unpushN alter many = go
EWith a b -> go a <> unpush (go b)
EAccum1 a b e -> go a <> go b <> go e
EError{} -> mempty
+
+
+deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r
+deleteUnused SNil OccEnd k = k SETop
+deleteUnused (_ `SCons` env) OccEnd k =
+ deleteUnused env OccEnd $ \sub -> k (SENo sub)
+deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k =
+ deleteUnused env occenv $ \sub ->
+ case count of Zero -> k (SENo sub)
+ _ -> k (SEYes sub)
+
+unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t
+unsafeWeakenWithSubenv = \sub ->
+ subst (\x t i -> case sinkWithSubenv i sub of
+ Just i' -> EVar x t i'
+ Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away")
+ where
+ sinkWithSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t)
+ sinkWithSubenv IZ (SEYes _) = Just IZ
+ sinkWithSubenv IZ (SENo _) = Nothing
+ sinkWithSubenv (IS i) (SEYes sub) = IS <$> sinkWithSubenv i sub
+ sinkWithSubenv (IS i) (SENo sub) = sinkWithSubenv i sub
diff --git a/src/AST/Env.hs b/src/AST/Env.hs
new file mode 100644
index 0000000..c33bad3
--- /dev/null
+++ b/src/AST/Env.hs
@@ -0,0 +1,43 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE PolyKinds #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeOperators #-}
+module AST.Env where
+
+import AST.Weaken
+import Data
+
+
+-- | @env'@ is a subset of @env@: each element of @env@ is either included in
+-- @env'@ ('SEYes') or not included in @env'@ ('SENo').
+data Subenv env env' where
+ SETop :: Subenv '[] '[]
+ SEYes :: Subenv env env' -> Subenv (t : env) (t : env')
+ SENo :: Subenv env env' -> Subenv (t : env) env'
+deriving instance Show (Subenv env env')
+
+subList :: SList f env -> Subenv env env' -> SList f env'
+subList SNil SETop = SNil
+subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub)
+subList (SCons _ xs) (SENo sub) = subList xs sub
+
+subenvAll :: SList f env -> Subenv env env
+subenvAll SNil = SETop
+subenvAll (SCons _ env) = SEYes (subenvAll env)
+
+subenvNone :: SList f env -> Subenv env '[]
+subenvNone SNil = SETop
+subenvNone (SCons _ env) = SENo (subenvNone env)
+
+subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]
+subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)
+subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)
+subenvOnehot SNil i = case i of {}
+
+subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3
+subenvCompose SETop SETop = SETop
+subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2)
+subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2)
+subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2)
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 1dc9dd3..dbbc021 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -113,14 +113,12 @@ ppExpr' d val = \case
return $ showParen (d > 10) $
showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")"
- EBuild _ es e -> do
- es' <- mapM (ppExpr' 0 val) es
- names <- mapM (const genName) es -- TODO generate underscores
- e' <- ppExpr' 0 (vpushN names val) e
+ EBuild _ n a b -> do
+ a' <- ppExpr' 11 val a
+ names <- sequence (vecGenerate n (\_ -> genName)) -- TODO generate underscores
+ e' <- ppExpr' 0 (vpushN names val) b
return $ showParen (d > 10) $
- showString "build ["
- . foldr (.) id (intersperse (showString ", ") (reverse (toList es')))
- . showString "] (\\["
+ showString "build " . a' . showString " (\\["
. foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names))))
. showString ("] -> ") . e' . showString ")"
@@ -137,6 +135,10 @@ ppExpr' d val = \case
e' <- ppExpr' 11 val e
return $ showParen (d > 10) $ showString "unit " . e'
+ EReplicate _ e -> do
+ e' <- ppExpr' 11 val e
+ return $ showParen (d > 10) $ showString "replicate " . e'
+
EConst _ ty v -> return $ showString $ case ty of
STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index e0b5232..78276ca 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -39,7 +39,7 @@ splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i)
data env :> env' where
WId :: env :> env
WSink :: forall t env. env :> (t : env)
- WCopy :: env :> env' -> (t : env) :> (t : env')
+ WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env')
WPop :: (t : env) :> env' -> env :> env'
WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3
WClosed :: SList (Const ()) env -> '[] :> env
@@ -95,6 +95,10 @@ wSinks :: forall env bs f. SList f bs -> env :> Append bs env
wSinks SNil = WId
wSinks (SCons _ spine) = WSink .> wSinks spine
+wSinksAnd :: forall env env' bs f. SList f bs -> env :> env' -> env :> Append bs env'
+wSinksAnd SNil w = w
+wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w
+
wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2
wCopies SNil w = w
wCopies (SCons _ spine) w = WCopy (wCopies spine w)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 007ffe3..087a26e 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -30,10 +30,12 @@ module CHAD (
import Data.Bifunctor (first, second)
import Data.Functor.Const
import Data.Kind (Type)
+import GHC.Stack (HasCallStack)
import GHC.TypeLits (Symbol)
import AST
import AST.Count
+import AST.Env
import AST.Weaken.Auto
import Data
import Lemmas
@@ -422,14 +424,6 @@ plusSparse t a b adder =
(EVar ext t (IS IZ))
(weakenExpr (WCopy (WCopy WSink)) adder)))
-type family Tup env where
- Tup '[] = TNil
- Tup (t : ts) = TPair (Tup ts) t
-
-tTup :: SList STy env -> STy (Tup env)
-tTup SNil = STNil
-tTup (SCons t ts) = STPair (tTup ts) t
-
zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
@@ -437,18 +431,20 @@ zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
accumPromote :: forall dt env sto proxy r.
proxy dt
-> Descr env sto
- -> OccEnv env
-> (forall stoRepl envPro.
- Descr env stoRepl
+ (Select env stoRepl "merge" ~ '[])
+ => Descr env stoRepl
-- ^ A revised environment description that switches
-- arrays (used in the OccEnv) that are currently on
- -- "merge" storage, to "accum" storage.
- -> Subenv (Select env sto "merge") (Select env stoRepl "merge")
- -- ^ The new storage has fewer "merge"-storage entries.
+ -- "merge" storage, to "accum" storage. Any other "merge"
+ -- entries are deleted.
-> SList STy envPro
-- ^ New entries on top of the original dual environment,
-- that house the accumulators for the promoted arrays in
-- the original environment.
+ -> Subenv (Select env sto "merge") envPro
+ -- ^ The promoted entries were merge entries in the
+ -- original environment.
-> (forall shbinds.
SList STy shbinds
-> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum")))
@@ -458,16 +454,15 @@ accumPromote :: forall dt env sto proxy r.
-- extended with some accumulators.
-> r)
-> r
-accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId)
-accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId)
-accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k =
- accumPromote pdty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf ->
- case (t, sto, occ) of
+accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId)
+accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
+ accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf ->
+ case sto of
-- Accumulators are left as-is
- (_, SAccum, _) ->
+ SAccum ->
k (storepl `DPush` (t, SAccum))
- mergesub
envpro
+ prosub
(\shbinds ->
autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr)))
(#acc :++: (#pro :++: #d :++: #shb :++: #tl))
@@ -477,34 +472,29 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k =
(#d :++: #shb :++: #acc :++: #tl)
(#acc :++: (#d :++: #shb :++: #tl)))
- -- Arrays with "merge" storage and non-zero usage are promoted to an accumulator in envPro
- (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero ->
- k (storepl `DPush` (t, SAccum))
- (SENo mergesub)
- (STArr arrn arrt `SCons` envpro)
- (\(shbinds :: SList _ shbinds) ->
- let shbindsC = slistMap (\_ -> Const ()) shbinds
- in
- -- wf:
- -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- -- WCopy wf:
- -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- -- WPICK: ^ THESE TWO ||
- -- goal: | ARE EQUAL ||
- -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
- WCopy (wf shbinds)
- .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
- (WId @(D2AcE (Select env1 stoRepl "accum"))))
-
- -- Used "merge" values must be an array, so reject everything else. (TODO: generalise this)
- (_, SMerge, Occ _ c)
- | c > Zero ->
- error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t
- | otherwise ->
- k (storepl `DPush` (t, SMerge))
- (SEYes mergesub)
- envpro
- wf
+ SMerge -> case t of
+ -- Arrays with "merge" storage are promoted to an accumulator in envPro
+ STArr (arrn :: SNat arrn) (arrt :: STy arrt) ->
+ k (storepl `DPush` (t, SAccum))
+ (STArr arrn arrt `SCons` envpro)
+ (SEYes prosub)
+ (\(shbinds :: SList _ shbinds) ->
+ let shbindsC = slistMap (\_ -> Const ()) shbinds
+ in
+ -- wf:
+ -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WCopy wf:
+ -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ -- WPICK: ^ THESE TWO ||
+ -- goal: | ARE EQUAL ||
+ -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum")))
+ WCopy (wf shbinds)
+ .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC)
+ (WId @(D2AcE (Select env1 stoRepl "accum"))))
+
+ -- "merge" values must be an array, so reject everything else. (TODO: generalise this)
+ _ ->
+ error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t
-- where
-- containsTArr :: STy t' -> Bool
-- containsTArr = \case
@@ -537,14 +527,6 @@ uninvertTup (t `SCons` list) tcore e =
(ESnd ext (EVar ext recT IZ))
(ESnd ext (EFst ext (EVar ext recT IZ))))
--- | @env'@ is a subset of @env@: each element of @env@ is either included in
--- @env'@ ('SEYes') or not included in @env'@ ('SENo').
-data Subenv env env' where
- SETop :: Subenv '[] '[]
- SEYes :: Subenv env env' -> Subenv (t : env) (t : env')
- SENo :: Subenv env env' -> Subenv (t : env) env'
-deriving instance Show (Subenv env env')
-
data Ret env0 sto t =
forall shbinds env0Merge.
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
@@ -566,24 +548,6 @@ data Rets env0 sto env list =
(SList (RetPair env0 sto env shbinds) list)
deriving instance Show (Rets env0 sto env list)
-subList :: SList f env -> Subenv env env' -> SList f env'
-subList SNil SETop = SNil
-subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub)
-subList (SCons _ xs) (SENo sub) = subList xs sub
-
-subenvAll :: SList f env -> Subenv env env
-subenvAll SNil = SETop
-subenvAll (SCons _ env) = SEYes (subenvAll env)
-
-subenvNone :: SList f env -> Subenv env '[]
-subenvNone SNil = SETop
-subenvNone (SCons _ env) = SENo (subenvNone env)
-
-subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t]
-subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env)
-subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i)
-subenvOnehot SNil i = case i of {}
-
subenvPlus :: SList STy env
-> Subenv env env1 -> Subenv env env2
-> (forall env3. Subenv env env3
@@ -631,7 +595,7 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e =
in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
-assertSubenvEmpty :: Subenv env env' -> env' :~: '[]
+assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[]
assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl
assertSubenvEmpty SETop = Refl
assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty"
@@ -748,6 +712,10 @@ data Descr env sto where
DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
deriving instance Show (Descr env sto)
+descrList :: Descr env sto -> SList STy env
+descrList DTop = SNil
+descrList (des `DPush` (t, _)) = t `SCons` descrList des
+
select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
select _ DTop = SNil
select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
@@ -755,6 +723,26 @@ select s@SMerge (DPush des (_, SAccum)) = select s des
select s@SAccum (DPush des (_, SMerge)) = select s des
select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
+-- | This could have more precise typing on the output storage.
+subDescr :: Descr env sto -> Subenv env env'
+ -> (forall sto'. Descr env' sto'
+ -> Subenv (Select env sto "merge") (Select env' sto' "merge")
+ -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
+ -> Subenv (D1E env) (D1E env')
+ -> r)
+ -> r
+subDescr DTop SETop k = k DTop SETop SETop SETop
+subDescr (des `DPush` (t, sto)) (SEYes sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
+ SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
+subDescr (des `DPush` (_, sto)) (SENo sub) k =
+ subDescr des sub $ \des' submerge subaccum subd1e ->
+ case sto of
+ SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
+ SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
+
sD1eEnv :: Descr env sto -> SList STy (D1E env)
sD1eEnv DTop = SNil
sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
@@ -990,16 +978,18 @@ drev des = \case
(subenvNone (select SMerge des))
(ENil ext)
- EBuild1 _ ne e
- | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne
- , let eltty = typeOf e ->
- accumPromote eltty des (occEnvPop (occCountAll e)) $ \vdes proSub envPro wPro ->
- case drev (vdes `DPush` (tIx, SMerge)) e of { Ret e0 e1 sub e2 ->
+ EBuild1 _ ne (orige :: Ex _ eltty)
+ | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result
+ , let eltty = typeOf orige ->
+ deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') ->
+ let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in
+ subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed ->
+ accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro ->
+ case drev (prodes `DPush` (tIx, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 ->
case assertSubenvEmpty sub of { Refl ->
- case assertSubenvEmpty proSub of { Refl ->
- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 in
+ let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in
Ret (bconcat (ne0 `BPush` (tIx, ne1))
- (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0)))
+ (fst (weakenBindings weakenExpr (WCopy (wSinksAnd (bindingsBinds ne0) (wUndoSubenv subD1eUsed))) ve0)))
(EBuild1 ext
(weakenExpr (autoWeak (#ve0 (bindingsBinds ve0)
&. #binds (tIx `SCons` bindingsBinds ne0)
@@ -1007,7 +997,7 @@ drev des = \case
#binds
((#ve0 :++: #binds) :++: #tl))
(EVar ext tIx IZ))
- (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of
+ (subst (\_ t i -> case splitIdx @(TIx : D1E env') (bindingsBinds e0) i of
Left ibind ->
let ibind' =
autoWeak (#ix (auto1 @TIx)
@@ -1020,9 +1010,9 @@ drev des = \case
in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind')
(EVar ext tIx IZ))
Right IZ -> EVar ext tIx IZ -- build lambda index argument
- Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv)))
+ Right (IS ienv) -> EVar ext t (IS (wSinksAnd (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) (wUndoSubenv subD1eUsed) @> ienv)))
e1))
- nsub
+ (subenvCompose subMergeUsed proSub)
(ELet ext
(uninvertTup (d2e envPro) (STArr (SS SZ) STNil) $
makeAccumulators @_ @_ @(TArr (S Z) TNil) envPro $
@@ -1035,8 +1025,21 @@ drev des = \case
#binds
(#pro :++: #d :++: (#ve0 :++: #binds) :++: #tl))
(EVar ext tIx IZ))
- -- TODO: use vectoriseExpr
- (_ $
+ (ELet ext (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (d2 eltty))
+ (IS (wSinks @(TArr (S Z) (D2 eltty) : Append (Append (Vectorise (S Z) e_binds) (TIx : ne_binds)) (D2AcE (Select env sto "accum")))
+ (d2ace envPro)
+ @> IZ)))
+ (EVar ext tIx IZ))) $
+ weakenExpr (autoWeak (#i (auto1 @TIx)
+ &. #dpro (d2ace envPro)
+ &. #d (d2 eltty `SCons` SNil)
+ &. #darr (STArr (SS SZ) (d2 eltty) `SCons` SNil)
+ &. #n (auto1 @TIx)
+ &. #vbinds (bindingsBinds ve0)
+ &. #ne0 (bindingsBinds ne0)
+ &. #tl (d2ace (select SAccum des)))
+ (#i :++: (#dpro :++: #d) :++: #vbinds :++: #tl)
+ (#d :++: #i :++: #dpro :++: #darr :++: (#vbinds :++: #n :++: #ne0) :++: #tl)) $
vectoriseExpr (sappend (d2ace envPro) (d2 eltty `SCons` SNil)) (bindingsBinds e0) (d2ace (select SAccum des)) $
weakenExpr (autoWeak (#dpro (d2ace envPro)
&. #d (d2 eltty `SCons` SNil)
@@ -1044,19 +1047,12 @@ drev des = \case
&. #tl (d2ace (select SAccum des)))
(#dpro :++: #d :++: #binds :++: #tl)
((#dpro :++: #d) :++: #binds :++: #tl)) $
- weakenExpr (wPro (bindingsBinds e0)) e2)) $
+ weakenExpr (wCopies (d2ace envPro) (WCopy @(D2 eltty) (wCopies (bindingsBinds e0) (wUndoSubenv subAccumUsed)))) $
+ weakenExpr (wPro (bindingsBinds e0)) $
+ e2)) $
ELet ext (ENil ext) $
- weakenExpr (autoWeak (#nil (auto1 @TNil)
- &. #d (auto1 @(D2 t))
- &. #nilarr (auto1 @(TArr (S Z) TNil))
- &. #ve0 (bindingsBinds ve0)
- &. #n (auto1 @TIx)
- &. #binds (bindingsBinds ne0)
- &. #tl (d2ace (select SAccum des)))
- (#nil :++: #binds :++: #tl)
- (#nil :++: #nilarr :++: #d :++: (#ve0 :++: #n :++: #binds) :++: #tl))
- ne2)
- }}}
+ ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ)))
+ }}
EUnit _ e
| Ret e0 e1 sub e2 <- drev des e ->
@@ -1075,9 +1071,20 @@ drev des = \case
(ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $
weakenExpr (WCopy WSink) e2)
+ EIdx1 _ e ei
+ -- We're allowed to ignore ei2 here because the output of 'ei' is discrete.
+ | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil)
+ <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil
+ ->
+ Ret binds
+ (EIdx1 ext e1 ei1)
+ sub
+ (_ e2)
+
-- These should be the next to be implemented, I think
- EIdx1{} -> err_unsupported "EIdx1"
EFold1{} -> err_unsupported "EFold1"
+ EShape{} -> err_unsupported "EShape"
+ EReplicate{} -> err_unsupported "EReplicate"
EIdx{} -> err_unsupported "EIdx"
EBuild{} -> err_unsupported "EBuild"
diff --git a/src/Example.hs b/src/Example.hs
index 572d67e..86264e1 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -107,3 +107,18 @@ ex5 =
(bin (OMul STF32) (EVar ext (STScal STF32) IZ)
(bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
(EConst ext STF32 1.0)))
+
+senv6 :: SList STy [TScal TI64, TScal TF32]
+senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil
+
+descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"]
+descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge)
+
+ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32)
+ex6 =
+ ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $
+ ELet ext (EBuild1 ext (EVar ext tIx (IS IZ)) $
+ ELet ext (EIdx0 ext (EVar ext (STArr SZ (STScal STF32)) (IS IZ))) $
+ bin (OMul STF32) (EVar ext (STScal STF32) IZ)
+ (EVar ext (STScal STF32) IZ)) $
+ (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3)))
diff --git a/src/Simplify.hs b/src/Simplify.hs
index f2fc54a..698c667 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -71,9 +71,10 @@ simplify' = \case
EInr _ t e -> EInr ext t (simplify' e)
ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b)
EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b)
- EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e)
+ EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b)
EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)
EUnit _ e -> EUnit ext (simplify' e)
+ EReplicate _ e -> EReplicate ext (simplify' e)
EConst _ t v -> EConst ext t v
EIdx0 _ e -> EIdx0 ext (simplify' e)
EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)
@@ -104,9 +105,10 @@ hasAdds = \case
EInr _ _ e -> hasAdds e
ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b
EBuild1 _ a b -> hasAdds a || hasAdds b
- EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e
+ EBuild _ _ a b -> hasAdds a || hasAdds b
EFold1 _ a b -> hasAdds a || hasAdds b
EUnit _ e -> hasAdds e
+ EReplicate _ e -> hasAdds e
EConst _ _ _ -> False
EIdx0 _ e -> hasAdds e
EIdx1 _ a b -> hasAdds a || hasAdds b