diff options
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/AST.hs | 32 | ||||
-rw-r--r-- | src/AST/Count.hs | 28 | ||||
-rw-r--r-- | src/AST/Env.hs | 43 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 16 | ||||
-rw-r--r-- | src/AST/Weaken.hs | 6 | ||||
-rw-r--r-- | src/CHAD.hs | 209 | ||||
-rw-r--r-- | src/Example.hs | 15 | ||||
-rw-r--r-- | src/Simplify.hs | 6 |
9 files changed, 241 insertions, 115 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index 19c2852..1bff84b 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -12,6 +12,7 @@ library exposed-modules: AST AST.Count + AST.Env AST.Pretty AST.Weaken AST.Weaken.Auto @@ -19,6 +19,7 @@ import Data.Functor.Const import Data.Kind (Type) import Data.Int +import AST.Env import AST.Weaken import Data @@ -90,15 +91,17 @@ data Expr x env t where -- array operations EBuild1 :: x (TArr (S Z) t) -> Expr x env TIx -> Expr x (TIx : env) t -> Expr x env (TArr (S Z) t) - EBuild :: x (TArr n t) -> Vec n (Expr x env TIx) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) + EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (ConsN n TIx env) t -> Expr x env (TArr n t) EFold1 :: x (TArr n t) -> Expr x (t : t : env) t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) + EReplicate :: x (TArr (S n) t) -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) -- TODO: unused -- expression operations EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t + EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t -- accumulation effect @@ -114,6 +117,18 @@ type Ex = Expr (Const ()) ext :: Const () a ext = Const () +type family Replicate n x where + Replicate Z x = '[] + Replicate (S n) x = x : Replicate n x + +type family Tup env where + Tup '[] = TNil + Tup (t : ts) = TPair (Tup ts) t + +tTup :: SList STy env -> STy (Tup env) +tTup SNil = STNil +tTup (SCons t ts) = STPair (tTup ts) t + type SOp :: Ty -> Ty -> Type data SOp a t where OAdd :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) @@ -151,14 +166,16 @@ typeOf = \case ECase _ _ a _ -> typeOf a EBuild1 _ _ e -> STArr (SS SZ) (typeOf e) - EBuild _ es e -> STArr (vecLength es) (typeOf e) + EBuild _ n _ e -> STArr n (typeOf e) EFold1 _ _ e | STArr (SS n) t <- typeOf e -> STArr n t EUnit _ e -> STArr SZ (typeOf e) + EReplicate _ e | STArr n t <- typeOf e -> STArr (SS n) t EConst _ t _ -> STScal t EIdx0 _ e | STArr _ t <- typeOf e -> t EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t EIdx _ e _ | STArr _ t <- typeOf e -> t + -- EShape _ e | STArr n _ <- typeOf e -> _ EOp _ op _ -> opt2 op EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) @@ -214,9 +231,10 @@ subst' f w = \case EInr x t e -> EInr x t (subst' f w e) ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) EBuild1 x a b -> EBuild1 x (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EBuild x es e -> EBuild x (fmap (subst' f w) es) (subst' (sinkFN (vecLength es) f) (wcopyN (vecLength es) w) e) + EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkFN n f) (wcopyN n w) b) EFold1 x a b -> EFold1 x (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) EUnit x e -> EUnit x (subst' f w e) + EReplicate x e -> EReplicate x (subst' f w e) EConst x t v -> EConst x t v EIdx0 x e -> EIdx0 x (subst' f w e) EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) @@ -254,6 +272,11 @@ wpopN :: SNat n -> ConsN n TIx env :> env' -> env :> env' wpopN SZ w = w wpopN (SS n) w = wpopN n (WPop w) +wUndoSubenv :: Subenv env env' -> env' :> env +wUndoSubenv SETop = WId +wUndoSubenv (SEYes sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub + slistIdx :: SList f list -> Idx list t -> f t slistIdx (SCons x _) IZ = x slistIdx (SCons _ list) (IS i) = slistIdx list i @@ -281,3 +304,6 @@ instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum class KnownEnv env where knownEnv :: SList STy env instance KnownEnv '[] where knownEnv = SNil instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv + +ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) +ebuildUp1 n sh size f = EBuild ext (SS n) (EPair ext sh size) (error "TODO" f) diff --git a/src/AST/Count.hs b/src/AST/Count.hs index 289c1fb..a4ff9f2 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -2,12 +2,14 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} module AST.Count where @@ -15,6 +17,7 @@ import Data.Functor.Const import GHC.Generics (Generic, Generically(..)) import AST +import AST.Env import Data @@ -110,9 +113,10 @@ occCountGeneral onehot unpush unpushN alter many = go EInr _ _ e -> go e ECase _ e a b -> go e <> (unpush (go a) `alter` unpush (go b)) EBuild1 _ a b -> go a <> many (unpush (go b)) - EBuild _ es e -> foldMap go es <> many (unpushN (vecLength es) (go e)) + EBuild _ n a b -> go a <> many (unpushN n (go b)) EFold1 _ a b -> many (unpush (unpush (go a))) <> go b EUnit _ e -> go e + EReplicate _ e -> go e EConst{} -> mempty EIdx0 _ e -> go e EIdx1 _ a b -> go a <> go b @@ -121,3 +125,25 @@ occCountGeneral onehot unpush unpushN alter many = go EWith a b -> go a <> unpush (go b) EAccum1 a b e -> go a <> go b <> go e EError{} -> mempty + + +deleteUnused :: SList f env -> OccEnv env -> (forall env'. Subenv env env' -> r) -> r +deleteUnused SNil OccEnd k = k SETop +deleteUnused (_ `SCons` env) OccEnd k = + deleteUnused env OccEnd $ \sub -> k (SENo sub) +deleteUnused (_ `SCons` env) (OccPush occenv (Occ _ count)) k = + deleteUnused env occenv $ \sub -> + case count of Zero -> k (SENo sub) + _ -> k (SEYes sub) + +unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t +unsafeWeakenWithSubenv = \sub -> + subst (\x t i -> case sinkWithSubenv i sub of + Just i' -> EVar x t i' + Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") + where + sinkWithSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) + sinkWithSubenv IZ (SEYes _) = Just IZ + sinkWithSubenv IZ (SENo _) = Nothing + sinkWithSubenv (IS i) (SEYes sub) = IS <$> sinkWithSubenv i sub + sinkWithSubenv (IS i) (SENo sub) = sinkWithSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs new file mode 100644 index 0000000..c33bad3 --- /dev/null +++ b/src/AST/Env.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module AST.Env where + +import AST.Weaken +import Data + + +-- | @env'@ is a subset of @env@: each element of @env@ is either included in +-- @env'@ ('SEYes') or not included in @env'@ ('SENo'). +data Subenv env env' where + SETop :: Subenv '[] '[] + SEYes :: Subenv env env' -> Subenv (t : env) (t : env') + SENo :: Subenv env env' -> Subenv (t : env) env' +deriving instance Show (Subenv env env') + +subList :: SList f env -> Subenv env env' -> SList f env' +subList SNil SETop = SNil +subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) +subList (SCons _ xs) (SENo sub) = subList xs sub + +subenvAll :: SList f env -> Subenv env env +subenvAll SNil = SETop +subenvAll (SCons _ env) = SEYes (subenvAll env) + +subenvNone :: SList f env -> Subenv env '[] +subenvNone SNil = SETop +subenvNone (SCons _ env) = SENo (subenvNone env) + +subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] +subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) +subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) +subenvOnehot SNil i = case i of {} + +subenvCompose :: Subenv env1 env2 -> Subenv env2 env3 -> Subenv env1 env3 +subenvCompose SETop SETop = SETop +subenvCompose (SEYes sub1) (SEYes sub2) = SEYes (subenvCompose sub1 sub2) +subenvCompose (SEYes sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 1dc9dd3..dbbc021 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -113,14 +113,12 @@ ppExpr' d val = \case return $ showParen (d > 10) $ showString "build1 " . a' . showString (" (\\" ++ name ++ " -> ") . b' . showString ")" - EBuild _ es e -> do - es' <- mapM (ppExpr' 0 val) es - names <- mapM (const genName) es -- TODO generate underscores - e' <- ppExpr' 0 (vpushN names val) e + EBuild _ n a b -> do + a' <- ppExpr' 11 val a + names <- sequence (vecGenerate n (\_ -> genName)) -- TODO generate underscores + e' <- ppExpr' 0 (vpushN names val) b return $ showParen (d > 10) $ - showString "build [" - . foldr (.) id (intersperse (showString ", ") (reverse (toList es'))) - . showString "] (\\[" + showString "build " . a' . showString " (\\[" . foldr (.) id (intersperse (showString ",") (map showString (reverse (toList names)))) . showString ("] -> ") . e' . showString ")" @@ -137,6 +135,10 @@ ppExpr' d val = \case e' <- ppExpr' 11 val e return $ showParen (d > 10) $ showString "unit " . e' + EReplicate _ e -> do + e' <- ppExpr' 11 val e + return $ showParen (d > 10) $ showString "replicate " . e' + EConst _ ty v -> return $ showString $ case ty of STI32 -> show v ; STI64 -> show v ; STF32 -> show v ; STF64 -> show v ; STBool -> show v diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index e0b5232..78276ca 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -39,7 +39,7 @@ splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) data env :> env' where WId :: env :> env WSink :: forall t env. env :> (t : env) - WCopy :: env :> env' -> (t : env) :> (t : env') + WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env') WPop :: (t : env) :> env' -> env :> env' WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 WClosed :: SList (Const ()) env -> '[] :> env @@ -95,6 +95,10 @@ wSinks :: forall env bs f. SList f bs -> env :> Append bs env wSinks SNil = WId wSinks (SCons _ spine) = WSink .> wSinks spine +wSinksAnd :: forall env env' bs f. SList f bs -> env :> env' -> env :> Append bs env' +wSinksAnd SNil w = w +wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w + wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2 wCopies SNil w = w wCopies (SCons _ spine) w = WCopy (wCopies spine w) diff --git a/src/CHAD.hs b/src/CHAD.hs index 007ffe3..087a26e 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -30,10 +30,12 @@ module CHAD ( import Data.Bifunctor (first, second) import Data.Functor.Const import Data.Kind (Type) +import GHC.Stack (HasCallStack) import GHC.TypeLits (Symbol) import AST import AST.Count +import AST.Env import AST.Weaken.Auto import Data import Lemmas @@ -422,14 +424,6 @@ plusSparse t a b adder = (EVar ext t (IS IZ)) (weakenExpr (WCopy (WCopy WSink)) adder))) -type family Tup env where - Tup '[] = TNil - Tup (t : ts) = TPair (Tup ts) t - -tTup :: SList STy env -> STy (Tup env) -tTup SNil = STNil -tTup (SCons t ts) = STPair (tTup ts) t - zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) @@ -437,18 +431,20 @@ zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) accumPromote :: forall dt env sto proxy r. proxy dt -> Descr env sto - -> OccEnv env -> (forall stoRepl envPro. - Descr env stoRepl + (Select env stoRepl "merge" ~ '[]) + => Descr env stoRepl -- ^ A revised environment description that switches -- arrays (used in the OccEnv) that are currently on - -- "merge" storage, to "accum" storage. - -> Subenv (Select env sto "merge") (Select env stoRepl "merge") - -- ^ The new storage has fewer "merge"-storage entries. + -- "merge" storage, to "accum" storage. Any other "merge" + -- entries are deleted. -> SList STy envPro -- ^ New entries on top of the original dual environment, -- that house the accumulators for the promoted arrays in -- the original environment. + -> Subenv (Select env sto "merge") envPro + -- ^ The promoted entries were merge entries in the + -- original environment. -> (forall shbinds. SList STy shbinds -> (D2 dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) @@ -458,16 +454,15 @@ accumPromote :: forall dt env sto proxy r. -- extended with some accumulators. -> r) -> r -accumPromote _ DTop _ k = k DTop SETop SNil (\_ -> WId) -accumPromote _ descr OccEnd k = k descr (subenvAll (select SMerge descr)) SNil (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = - accumPromote pdty descr occenv $ \(storepl :: Descr env1 stoRepl) mergesub (envpro :: SList _ envPro) wf -> - case (t, sto, occ) of +accumPromote _ DTop k = k DTop SNil SETop (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub wf -> + case sto of -- Accumulators are left as-is - (_, SAccum, _) -> + SAccum -> k (storepl `DPush` (t, SAccum)) - mergesub envpro + prosub (\shbinds -> autoWeak (#pro (d2ace envpro) &. #d (auto1 @(D2 dt)) &. #shb shbinds &. #acc (auto1 @(D2Ac t)) &. #tl (d2ace (select SAccum descr))) (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) @@ -477,34 +472,29 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) (occenv `OccPush` occ) k = (#d :++: #shb :++: #acc :++: #tl) (#acc :++: (#d :++: #shb :++: #tl))) - -- Arrays with "merge" storage and non-zero usage are promoted to an accumulator in envPro - (STArr (arrn :: SNat arrn) (arrt :: STy arrt), SMerge, Occ _ c) | c > Zero -> - k (storepl `DPush` (t, SAccum)) - (SENo mergesub) - (STArr arrn arrt `SCons` envpro) - (\(shbinds :: SList _ shbinds) -> - let shbindsC = slistMap (\_ -> Const ()) shbinds - in - -- wf: - -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WCopy wf: - -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WPICK: ^ THESE TWO || - -- goal: | ARE EQUAL || - -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - WCopy (wf shbinds) - .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) - (WId @(D2AcE (Select env1 stoRepl "accum")))) - - -- Used "merge" values must be an array, so reject everything else. (TODO: generalise this) - (_, SMerge, Occ _ c) - | c > Zero -> - error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t - | otherwise -> - k (storepl `DPush` (t, SMerge)) - (SEYes mergesub) - envpro - wf + SMerge -> case t of + -- Arrays with "merge" storage are promoted to an accumulator in envPro + STArr (arrn :: SNat arrn) (arrt :: STy arrt) -> + k (storepl `DPush` (t, SAccum)) + (STArr arrn arrt `SCons` envpro) + (SEYes prosub) + (\(shbinds :: SList _ shbinds) -> + let shbindsC = slistMap (\_ -> Const ()) shbinds + in + -- wf: + -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WCopy wf: + -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WPICK: ^ THESE TWO || + -- goal: | ARE EQUAL || + -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + WCopy (wf shbinds) + .> WPick @(TAccum arrn (D2 arrt)) @(D2 dt : shbinds) (Const () `SCons` shbindsC) + (WId @(D2AcE (Select env1 stoRepl "accum")))) + + -- "merge" values must be an array, so reject everything else. (TODO: generalise this) + _ -> + error $ "Closure variable of 'build'-like thing contains a non-array SMerge value: " ++ show t -- where -- containsTArr :: STy t' -> Bool -- containsTArr = \case @@ -537,14 +527,6 @@ uninvertTup (t `SCons` list) tcore e = (ESnd ext (EVar ext recT IZ)) (ESnd ext (EFst ext (EVar ext recT IZ)))) --- | @env'@ is a subset of @env@: each element of @env@ is either included in --- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv env env' where - SETop :: Subenv '[] '[] - SEYes :: Subenv env env' -> Subenv (t : env) (t : env') - SENo :: Subenv env env' -> Subenv (t : env) env' -deriving instance Show (Subenv env env') - data Ret env0 sto t = forall shbinds env0Merge. Ret (Bindings Ex (D1E env0) shbinds) -- shared binds @@ -566,24 +548,6 @@ data Rets env0 sto env list = (SList (RetPair env0 sto env shbinds) list) deriving instance Show (Rets env0 sto env list) -subList :: SList f env -> Subenv env env' -> SList f env' -subList SNil SETop = SNil -subList (SCons x xs) (SEYes sub) = SCons x (subList xs sub) -subList (SCons _ xs) (SENo sub) = subList xs sub - -subenvAll :: SList f env -> Subenv env env -subenvAll SNil = SETop -subenvAll (SCons _ env) = SEYes (subenvAll env) - -subenvNone :: SList f env -> Subenv env '[] -subenvNone SNil = SETop -subenvNone (SCons _ env) = SENo (subenvNone env) - -subenvOnehot :: SList f env -> Idx env t -> Subenv env '[t] -subenvOnehot (SCons _ env) IZ = SEYes (subenvNone env) -subenvOnehot (SCons _ env) (IS i) = SENo (subenvOnehot env i) -subenvOnehot SNil i = case i of {} - subenvPlus :: SList STy env -> Subenv env env1 -> Subenv env env2 -> (forall env3. Subenv env env3 @@ -631,7 +595,7 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e = in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) -assertSubenvEmpty :: Subenv env env' -> env' :~: '[] +assertSubenvEmpty :: HasCallStack => Subenv env env' -> env' :~: '[] assertSubenvEmpty (SENo sub) | Refl <- assertSubenvEmpty sub = Refl assertSubenvEmpty SETop = Refl assertSubenvEmpty SEYes{} = error "assertSubenvEmpty: not empty" @@ -748,6 +712,10 @@ data Descr env sto where DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) deriving instance Show (Descr env sto) +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _)) = t `SCons` descrList des + select :: Storage s -> Descr env sto -> SList STy (Select env sto s) select _ DTop = SNil select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) @@ -755,6 +723,26 @@ select s@SMerge (DPush des (_, SAccum)) = select s des select s@SAccum (DPush des (_, SMerge)) = select s des select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' + -> (forall sto'. Descr env' sto' + -> Subenv (Select env sto "merge") (Select env' sto' "merge") + -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) + -> Subenv (D1E env) (D1E env') + -> r) + -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, sto)) (SEYes sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) + SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) +subDescr (des `DPush` (_, sto)) (SENo sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) + SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + sD1eEnv :: Descr env sto -> SList STy (D1E env) sD1eEnv DTop = SNil sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) @@ -990,16 +978,18 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) - EBuild1 _ ne e - | Ret (ne0 :: Bindings _ _ ne_binds) ne1 nsub ne2 <- drev des ne - , let eltty = typeOf e -> - accumPromote eltty des (occEnvPop (occCountAll e)) $ \vdes proSub envPro wPro -> - case drev (vdes `DPush` (tIx, SMerge)) e of { Ret e0 e1 sub e2 -> + EBuild1 _ ne (orige :: Ex _ eltty) + | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result + , let eltty = typeOf orige -> + deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> + let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> + case drev (prodes `DPush` (tIx, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 -> case assertSubenvEmpty sub of { Refl -> - case assertSubenvEmpty proSub of { Refl -> - let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv des) IZ e0 in + let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in Ret (bconcat (ne0 `BPush` (tIx, ne1)) - (fst (weakenBindings weakenExpr (WCopy (wSinks (bindingsBinds ne0))) ve0))) + (fst (weakenBindings weakenExpr (WCopy (wSinksAnd (bindingsBinds ne0) (wUndoSubenv subD1eUsed))) ve0))) (EBuild1 ext (weakenExpr (autoWeak (#ve0 (bindingsBinds ve0) &. #binds (tIx `SCons` bindingsBinds ne0) @@ -1007,7 +997,7 @@ drev des = \case #binds ((#ve0 :++: #binds) :++: #tl)) (EVar ext tIx IZ)) - (subst (\_ t i -> case splitIdx @(TIx : D1E env) (bindingsBinds e0) i of + (subst (\_ t i -> case splitIdx @(TIx : D1E env') (bindingsBinds e0) i of Left ibind -> let ibind' = autoWeak (#ix (auto1 @TIx) @@ -1020,9 +1010,9 @@ drev des = \case in EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) t) ibind') (EVar ext tIx IZ)) Right IZ -> EVar ext tIx IZ -- build lambda index argument - Right (IS ienv) -> EVar ext t (IS (wSinks (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) @> ienv))) + Right (IS ienv) -> EVar ext t (IS (wSinksAnd (sappend (bindingsBinds ve0) (tIx `SCons` bindingsBinds ne0)) (wUndoSubenv subD1eUsed) @> ienv))) e1)) - nsub + (subenvCompose subMergeUsed proSub) (ELet ext (uninvertTup (d2e envPro) (STArr (SS SZ) STNil) $ makeAccumulators @_ @_ @(TArr (S Z) TNil) envPro $ @@ -1035,8 +1025,21 @@ drev des = \case #binds (#pro :++: #d :++: (#ve0 :++: #binds) :++: #tl)) (EVar ext tIx IZ)) - -- TODO: use vectoriseExpr - (_ $ + (ELet ext (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (d2 eltty)) + (IS (wSinks @(TArr (S Z) (D2 eltty) : Append (Append (Vectorise (S Z) e_binds) (TIx : ne_binds)) (D2AcE (Select env sto "accum"))) + (d2ace envPro) + @> IZ))) + (EVar ext tIx IZ))) $ + weakenExpr (autoWeak (#i (auto1 @TIx) + &. #dpro (d2ace envPro) + &. #d (d2 eltty `SCons` SNil) + &. #darr (STArr (SS SZ) (d2 eltty) `SCons` SNil) + &. #n (auto1 @TIx) + &. #vbinds (bindingsBinds ve0) + &. #ne0 (bindingsBinds ne0) + &. #tl (d2ace (select SAccum des))) + (#i :++: (#dpro :++: #d) :++: #vbinds :++: #tl) + (#d :++: #i :++: #dpro :++: #darr :++: (#vbinds :++: #n :++: #ne0) :++: #tl)) $ vectoriseExpr (sappend (d2ace envPro) (d2 eltty `SCons` SNil)) (bindingsBinds e0) (d2ace (select SAccum des)) $ weakenExpr (autoWeak (#dpro (d2ace envPro) &. #d (d2 eltty `SCons` SNil) @@ -1044,19 +1047,12 @@ drev des = \case &. #tl (d2ace (select SAccum des))) (#dpro :++: #d :++: #binds :++: #tl) ((#dpro :++: #d) :++: #binds :++: #tl)) $ - weakenExpr (wPro (bindingsBinds e0)) e2)) $ + weakenExpr (wCopies (d2ace envPro) (WCopy @(D2 eltty) (wCopies (bindingsBinds e0) (wUndoSubenv subAccumUsed)))) $ + weakenExpr (wPro (bindingsBinds e0)) $ + e2)) $ ELet ext (ENil ext) $ - weakenExpr (autoWeak (#nil (auto1 @TNil) - &. #d (auto1 @(D2 t)) - &. #nilarr (auto1 @(TArr (S Z) TNil)) - &. #ve0 (bindingsBinds ve0) - &. #n (auto1 @TIx) - &. #binds (bindingsBinds ne0) - &. #tl (d2ace (select SAccum des))) - (#nil :++: #binds :++: #tl) - (#nil :++: #nilarr :++: #d :++: (#ve0 :++: #n :++: #binds) :++: #tl)) - ne2) - }}} + ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ))) + }} EUnit _ e | Ret e0 e1 sub e2 <- drev des e -> @@ -1075,9 +1071,20 @@ drev des = \case (ELet ext (EUnit ext (EVar ext (d2 t) IZ)) $ weakenExpr (WCopy WSink) e2) + EIdx1 _ e ei + -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. + | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) + <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil + -> + Ret binds + (EIdx1 ext e1 ei1) + sub + (_ e2) + -- These should be the next to be implemented, I think - EIdx1{} -> err_unsupported "EIdx1" EFold1{} -> err_unsupported "EFold1" + EShape{} -> err_unsupported "EShape" + EReplicate{} -> err_unsupported "EReplicate" EIdx{} -> err_unsupported "EIdx" EBuild{} -> err_unsupported "EBuild" diff --git a/src/Example.hs b/src/Example.hs index 572d67e..86264e1 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -107,3 +107,18 @@ ex5 = (bin (OMul STF32) (EVar ext (STScal STF32) IZ) (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) (EConst ext STF32 1.0))) + +senv6 :: SList STy [TScal TI64, TScal TF32] +senv6 = STScal STI64 `SCons` STScal STF32 `SCons` SNil + +descr6 :: Descr [TScal TI64, TScal TF32] ["merge", "merge"] +descr6 = DTop `DPush` (STScal STF32, SMerge) `DPush` (STScal STI64, SMerge) + +ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) +ex6 = + ELet ext (EUnit ext (EVar ext (STScal STF32) (IS IZ))) $ + ELet ext (EBuild1 ext (EVar ext tIx (IS IZ)) $ + ELet ext (EIdx0 ext (EVar ext (STArr SZ (STScal STF32)) (IS IZ))) $ + bin (OMul STF32) (EVar ext (STScal STF32) IZ) + (EVar ext (STScal STF32) IZ)) $ + (EIdx0 ext (EIdx1 ext (EVar ext (STArr (SS SZ) (STScal STF32)) IZ) (EConst ext STI64 3))) diff --git a/src/Simplify.hs b/src/Simplify.hs index f2fc54a..698c667 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -71,9 +71,10 @@ simplify' = \case EInr _ t e -> EInr ext t (simplify' e) ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b) EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b) - EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e) + EBuild _ n a b -> EBuild ext n (simplify' a) (simplify' b) EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b) EUnit _ e -> EUnit ext (simplify' e) + EReplicate _ e -> EReplicate ext (simplify' e) EConst _ t v -> EConst ext t v EIdx0 _ e -> EIdx0 ext (simplify' e) EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) @@ -104,9 +105,10 @@ hasAdds = \case EInr _ _ e -> hasAdds e ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b EBuild1 _ a b -> hasAdds a || hasAdds b - EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e + EBuild _ _ a b -> hasAdds a || hasAdds b EFold1 _ a b -> hasAdds a || hasAdds b EUnit _ e -> hasAdds e + EReplicate _ e -> hasAdds e EConst _ _ _ -> False EIdx0 _ e -> hasAdds e EIdx1 _ a b -> hasAdds a || hasAdds b |