diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-08-30 17:48:15 +0200 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-08-30 17:48:15 +0200 |
commit | 8b047ff11ebd4715647bfc041a190f72dcf4d5a9 (patch) | |
tree | e8440120b7bbd4e45b367acb3f7185d25e7f3766 | |
parent | f4b94d7cc2cb05611b462ba278e4f12f7a7a5e5e (diff) |
Migrate to accumulators (mostly removing EVM code)
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/AST.hs | 45 | ||||
-rw-r--r-- | src/AST/Count.hs | 18 | ||||
-rw-r--r-- | src/AST/Pretty.hs | 104 | ||||
-rw-r--r-- | src/AST/Weaken.hs | 23 | ||||
-rw-r--r-- | src/CHAD.hs | 281 | ||||
-rw-r--r-- | src/Data.hs | 6 | ||||
-rw-r--r-- | src/Example.hs | 7 | ||||
-rw-r--r-- | src/Simplify.hs | 125 |
9 files changed, 314 insertions, 296 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index c38c270..2e7ee22 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -27,7 +27,6 @@ library containers, template-haskell, transformers, - some hs-source-dirs: src default-language: @@ -30,7 +30,7 @@ data Ty | TEither Ty Ty | TArr Nat Ty -- ^ rank, element type | TScal ScalTy - | TEVM [Ty] Ty + | TAccum Nat Ty -- ^ rank and element type of the array being accumulated to deriving (Show, Eq, Ord) data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool @@ -43,7 +43,7 @@ data STy t where STEither :: STy a -> STy b -> STy (TEither a b) STArr :: SNat n -> STy t -> STy (TArr n t) STScal :: SScalTy t -> STy (TScal t) - STEVM :: SList STy env -> STy t -> STy (TEVM env t) + STAccum :: SNat n -> STy t -> STy (TAccum n t) deriving instance Show (STy t) data SScalTy t where @@ -97,11 +97,9 @@ data Expr x env t where EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t - -- EVM operations - EMOne :: SList STy venv -> Idx venv t -> Expr x env t -> Expr x env (TEVM venv TNil) - EMScope :: Expr x env (TEVM (t : venv) a) -> Expr x env (TEVM venv (TPair a t)) - EMReturn :: SList STy venv -> Expr x env t -> Expr x env (TEVM venv t) - EMBind :: Expr x env (TEVM venv a) -> Expr x (a : env) (TEVM venv b) -> Expr x env (TEVM venv b) + -- accumulation effect + EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t)) + EAccum :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum n t) -> Expr x env TNil -- partiality EError :: STy a -> String -> Expr x env a @@ -157,10 +155,8 @@ typeOf = \case EIdx _ e _ | STArr _ t <- typeOf e -> t EOp _ op _ -> opt2 op - EMOne t _ _ -> STEVM t STNil - EMScope e | STEVM (SCons t env) a <- typeOf e -> STEVM env (STPair a t) - EMReturn env e -> STEVM env (typeOf e) - EMBind _ e -> typeOf e + EWith e1 e2 -> STPair (typeOf e2) (typeOf e1) + EAccum _ _ _ -> STNil EError t _ -> t @@ -175,7 +171,7 @@ unSTy = \case STEither a b -> TEither (unSTy a) (unSTy b) STArr n t -> TArr (unSNat n) (unSTy t) STScal t -> TScal (unSScalTy t) - STEVM l t -> TEVM (unSList l) (unSTy t) + STAccum n t -> TAccum (unSNat n) (unSTy t) unSList :: SList STy env -> [Ty] unSList SNil = [] @@ -207,10 +203,8 @@ weakenExpr w = \case EIdx1 x e1 e2 -> EIdx1 x (weakenExpr w e1) (weakenExpr w e2) EIdx x e1 es -> EIdx x (weakenExpr w e1) (weakenExpr w <$> es) EOp x op e -> EOp x op (weakenExpr w e) - EMOne t i e -> EMOne t i (weakenExpr w e) - EMScope e -> EMScope (weakenExpr w e) - EMReturn t e -> EMReturn t (weakenExpr w e) - EMBind e1 e2 -> EMBind (weakenExpr w e1) (weakenExpr (WCopy w) e2) + EWith e1 e2 -> EWith (weakenExpr w e1) (weakenExpr (WCopy w) e2) + EAccum e1 e2 e3 -> EAccum (weakenExpr w e1) (weakenExpr w e2) (weakenExpr w e3) EError t s -> EError t s wsinkN :: SNat n -> env :> ConsN n TIx env @@ -233,3 +227,22 @@ slistIdx SNil i = case i of {} idx2int :: Idx env t -> Int idx2int IZ = 0 idx2int (IS n) = 1 + idx2int n + +class KnownScalTy t where knownScalTy :: SScalTy t +instance KnownScalTy TI32 where knownScalTy = STI32 +instance KnownScalTy TI64 where knownScalTy = STI64 +instance KnownScalTy TF32 where knownScalTy = STF32 +instance KnownScalTy TF64 where knownScalTy = STF64 +instance KnownScalTy TBool where knownScalTy = STBool + +class KnownTy t where knownTy :: STy t +instance KnownTy TNil where knownTy = STNil +instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy +instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy +instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum knownNat knownTy + +class KnownEnv env where knownEnv :: SList STy env +instance KnownEnv '[] where knownEnv = SNil +instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv diff --git a/src/AST/Count.hs b/src/AST/Count.hs index de04b5f..f66b809 100644 --- a/src/AST/Count.hs +++ b/src/AST/Count.hs @@ -1,7 +1,12 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DerivingVia #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} module AST.Count where +import GHC.Generics (Generic, Generically(..)) + import AST import Data @@ -18,9 +23,8 @@ instance Monoid Count where data Occ = Occ { _occLexical :: Count , _occRuntime :: Count } - deriving (Eq) -instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d) -instance Monoid Occ where mempty = Occ mempty mempty + deriving (Eq, Generic) + deriving (Semigroup, Monoid) via Generically Occ -- | One of the two branches is taken (<||>) :: Occ -> Occ -> Occ @@ -49,8 +53,6 @@ occCount idx = \case EIdx1 _ a b -> occCount idx a <> occCount idx b EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es EOp _ _ e -> occCount idx e - EMOne _ _ e -> occCount idx e - EMScope e -> occCount idx e - EMReturn _ e -> occCount idx e - EMBind a b -> occCount idx a <> occCount (IS idx) b + EWith a b -> occCount idx a <> occCount (IS idx) b + EAccum a b e -> occCount idx a <> occCount idx b <> occCount idx e EError{} -> mempty diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs index 1ffa980..3473131 100644 --- a/src/AST/Pretty.hs +++ b/src/AST/Pretty.hs @@ -41,14 +41,20 @@ instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) genId :: M Int genId = M (\i -> (i, i + 1)) +genName' :: String -> M String +genName' prefix = (prefix ++) . show <$> genId + genName :: M String -genName = ('x' :) . show <$> genId +genName = genName' "x" -genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String -genNameIfUsedIn ty idx ex +genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String +genNameIfUsedIn' prefix ty idx ex | occCount idx ex == mempty = case ty of STNil -> return "()" _ -> return "_" - | otherwise = genName + | otherwise = genName' prefix + +genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String +genNameIfUsedIn = genNameIfUsedIn' "x" ppExpr :: SList STy env -> Expr x env t -> String ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) "" @@ -64,6 +70,8 @@ ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS ppExpr' d val = \case EVar _ _ i -> return $ showString $ getConst $ valprj val i + e@ELet{} -> ppExprLet d val e + EPair _ a b -> do a' <- ppExpr' 0 val a b' <- ppExpr' 0 val b @@ -155,80 +163,46 @@ ppExpr' d val = \case (Prefix, s) -> s return $ showParen (d > 10) $ showString (ops ++ " ") . e' - EMOne venv i e -> do - let venvlen = length (unSList venv) - varname = 'v' : show (venvlen - idx2int i) - e' <- ppExpr' 11 val e + EWith e1 e2 -> do + e1' <- ppExpr' 11 val e1 + let STArr n t = typeOf e1 + name <- genNameIfUsedIn' "ac" (STAccum n t) IZ e2 + e2' <- ppExpr' 11 (VPush (Const name) val) e2 return $ showParen (d > 10) $ - showString ("one " ++ show varname ++ " ") . e' + showString "with " . e1' . showString (" (\\" ++ name ++ " -> ") + . e2' . showString ")" - EMScope e -> do - let venv = case typeOf e of STEVM v _ -> v - venvlen = length (unSList venv) - varname = 'v' : show venvlen - e' <- ppExpr' 11 val e + EAccum e1 e2 e3 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + e3' <- ppExpr' 11 val e3 return $ showParen (d > 10) $ - showString ("scope " ++ show varname ++ " ") . e' - - EMReturn _ e -> do - e' <- ppExpr' 11 val e - return $ showParen (d > 10) $ showString ("return ") . e' - - e@EMBind{} -> ppExprDo d val e - e@ELet{} -> ppExprDo d val e - - -- EMBind a b -> do - -- let STEVM _ t = typeOf a - -- a' <- ppExpr' 0 val a - -- name <- genNameIfUsedIn t IZ b - -- b' <- ppExpr' 0 (VPush (Const name) val) b - -- return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b' + showString "accum " . e1' . showString " " . e2' . showString " " . e3' EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s) -data Binding = MonadBind String ShowS - | LetBind String ShowS - -ppExprDo :: Int -> SVal env -> Expr x env t -> M ShowS -ppExprDo d val etop = do - let collect :: SVal env -> Expr x env t -> M ([Binding], ShowS) - collect val' (EMBind lhs body) = do - let STEVM _ t = typeOf lhs - name <- genNameIfUsedIn t IZ body - (binds, core) <- collect (VPush (Const name) val') body - lhs' <- ppExpr' 0 val' lhs - return (MonadBind name lhs' : binds, core) +ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS +ppExprLet d val etop = do + let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS) collect val' (ELet _ rhs body) = do name <- genNameIfUsedIn (typeOf rhs) IZ body - (binds, core) <- collect (VPush (Const name) val') body rhs' <- ppExpr' 0 val' rhs - return (LetBind name rhs' : binds, core) + (binds, core) <- collect (VPush (Const name) val') body + return ((name, rhs') : binds, core) collect val' e = ([],) <$> ppExpr' 0 val' e - fromLet = \case LetBind n s -> Just (n, s) ; _ -> Nothing - (binds, core) <- collect val etop - return $ showParen (d > 0) $ case traverse fromLet binds of - Just lbinds -> - let (open, close) = case lbinds of - [_] -> ("{ ", " }") - _ -> ("", "") - in showString ("let " ++ open) - . foldr (.) id - (intersperse (showString " ; ") - (map (\(name, rhs) -> showString (name ++ " = ") . rhs) lbinds)) - . showString (close ++ " in ") - . core - Nothing -> - showString "do { " - . foldr (.) id - (intersperse (showString " ; ") - (map (\case MonadBind name rhs -> showString (name ++ " <- ") . rhs - LetBind name rhs -> showString ("let { " ++ name ++ " = ") . rhs - . showString " }") - binds)) - . showString " ; " . core . showString " }" + let (open, close) = case binds of + [_] -> ("{ ", " }") + _ -> ("", "") + return $ showParen (d > 0) $ + showString ("let " ++ open) + . foldr (.) id + (intersperse (showString " ; ") + (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds)) + . showString (close ++ " in ") + . core data Fixity = Prefix | Infix deriving (Show) diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 4b3016d..d992404 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -36,7 +36,7 @@ data env :> env' where WIdx :: Idx env t -> (t : env) :> env deriving instance Show (env :> env') -infixr @> +infixr 2 @> (@>) :: env :> env' -> Idx env t -> Idx env' t WId @> i = i WSink @> i = IS i @@ -48,6 +48,7 @@ WClosed _ @> i = case i of {} WIdx j @> IZ = j WIdx _ @> IS i = i +infixr 3 .> (.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 (.>) = flip WThen @@ -70,14 +71,18 @@ wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2 wCopies SNil w = w wCopies (SCons _ spine) w = WCopy (wCopies spine w) --- wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env --- wStack WId = WId --- wStack WSink = WSink --- wStack (WCopy w) = WCopy (wStack @env w) --- wStack (WPop w) = WPop (wStack @env w) --- wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2) --- wStack (WClosed s) = wSinks s --- wStack (WIdx i) = WIdx (_ i) +wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env +wStack WId = WId +wStack WSink = WSink +wStack (WCopy w) = WCopy (wStack @env w) +wStack (WPop w) = WPop (wStack @env w) +wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2) +wStack (WClosed s) = wSinks s +wStack (WIdx i) = WIdx (goIdx i) + where + goIdx :: Idx b t -> Idx (Append b env) t + goIdx IZ = IZ + goIdx (IS i') = IS (goIdx i') wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env) diff --git a/src/CHAD.hs b/src/CHAD.hs index 2513f84..e209b67 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -1,16 +1,16 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE StandaloneKindSignatures #-} {-# LANGUAGE UndecidableInstances #-} -- I want to bring various type variables in scope using type annotations in @@ -29,7 +29,6 @@ module CHAD ( import Data.Bifunctor (first, second) import Data.Functor.Const import Data.Kind (Type) -import Data.Some import GHC.TypeLits (Symbol) import AST @@ -254,7 +253,7 @@ type family D2 t where D2 TNil = TNil D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b)) D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b)) - -- D2 (TArr n t) = _ + D2 (TArr n t) = TArr n (D2 t) D2 (TScal t) = D2s t type family D2s t where @@ -264,6 +263,9 @@ type family D2s t where D2s TF64 = TScal TF64 D2s TBool = TNil +type family D2Ac t where + D2Ac (TArr n t) = TAccum n t + type family D1E env where D1E '[] = '[] D1E (t : env) = D1 t : D1E env @@ -272,6 +274,10 @@ type family D2E env where D2E '[] = '[] D2E (t : env) = D2 t : D2E env +type family D2AcE env where + D2AcE '[] = '[] + D2AcE (t : env) = D2Ac t : D2AcE env + -- | Select only the types from the environment that have the specified storage type family Select env sto s where Select '[] '[] _ = '[] @@ -284,20 +290,24 @@ d1 (STPair a b) = STPair (d1 a) (d1 b) d1 (STEither a b) = STEither (d1 a) (d1 b) d1 (STArr n t) = STArr n (d1 t) d1 (STScal t) = STScal t -d1 STEVM{} = error "EVM not allowed in input program" +d1 STAccum{} = error "Accumulators not allowed in input program" d2 :: STy t -> STy (D2 t) d2 STNil = STNil d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b)) d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b)) -d2 STArr{} = error "TODO arrays" +d2 (STArr n t) = STArr n (d2 t) d2 (STScal t) = case t of STI32 -> STNil STI64 -> STNil STF32 -> STScal STF32 STF64 -> STScal STF64 STBool -> STNil -d2 STEVM{} = error "EVM not allowed in input program" +d2 STAccum{} = error "Accumulators not allowed in input program" + +d2ac :: STy t -> STy (D2Ac t) +d2ac (STArr n t) = STAccum n t +d2ac _ = error "Only arrays may appear in the accumulator environment" conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) conv1Idx IZ = IZ @@ -322,7 +332,7 @@ zero (STScal t) = case t of STF32 -> EConst ext STF32 0.0 STF64 -> EConst ext STF64 0.0 STBool -> ENil ext -zero STEVM{} = error "EVM not allowed in input program" +zero STAccum{} = error "Accumulators not allowed in input program" plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) plus STNil _ _ = ENil ext @@ -350,7 +360,7 @@ plus (STScal t) a b = case t of STF32 -> EOp ext (OAdd STF32) (EPair ext a b) STF64 -> EOp ext (OAdd STF64) (EPair ext a b) STBool -> ENil ext -plus STEVM{} _ _ = error "EVM not allowed in input program" +plus STAccum{} _ _ = error "Accumulators not allowed in input program" plusSparse :: STy a -> Ex env (TEither TNil a) -> Ex env (TEither TNil a) @@ -388,14 +398,14 @@ data Ret env0 sto t = Ret (Bindings Ex (D1E env0) shbinds) -- shared binds (Ex (Append shbinds (D1E env0)) (D1 t)) (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge)))) + (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) deriving instance Show (Ret env0 sto t) data RetPair env0 sto env shbinds t = forall env0Merge. RetPair (Ex (Append shbinds env) (D1 t)) (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge)))) + (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge))) deriving instance Show (RetPair env0 sto env shbinds t) data Rets env0 sto env list = @@ -430,7 +440,8 @@ subenvPlus :: SList STy env -> r subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext) subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl + subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SENo sub3) s31 s32 pl subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k = subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 -> @@ -464,24 +475,19 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e = in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var) expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t) -unscope :: Descr env0 sto - -> STy a -> Storage s - -> Subenv (Select (a : env0) (s : sto) "merge") envSub - -> Ex env (TEVM (D2E (Select (a : env0) (s : sto) "accum")) (Tup (D2E envSub))) - -> (forall envSub'. - Subenv (Select env0 sto "merge") envSub' - -> Ex env (TEVM (D2E (Select env0 sto "accum")) (TPair (Tup (D2E envSub')) (D2 a))) - -> r) - -> r -unscope des ty s sub e k = case s of - SAccum -> k sub (EMScope e) - SMerge -> case sub of - SEYes sub' -> k sub' e - SENo sub' -> k sub' $ - EMBind e $ - EMReturn (d2e (select SAccum des)) $ - EPair ext (EVar ext (tTup (d2e (subList (select SMerge des) sub'))) IZ) - (zero ty) +popFromScope + :: Descr env0 sto + -> STy a + -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub + -> Ex env (Tup (D2E envSub)) + -> (forall envSub'. + Subenv (Select env0 sto "merge") envSub' + -> Ex env (TPair (Tup (D2E envSub')) (D2 a)) + -> r) + -> r +popFromScope _ ty sub e k = case sub of + SEYes sub' -> k sub' e + SENo sub' -> k sub' $ EPair ext e (zero ty) -- d1W :: env :> env' -> D1E env :> D1E env' -- d1W WId = WId @@ -501,7 +507,7 @@ weakenRets w (Rets binds list) = rebaseRetPair :: forall env b1 b2 env0 sto t f. SList f b1 -> SList f b2 -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t rebaseRetPair b1 b2 (RetPair p sub d) | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair p sub (weakenExpr (WCopy (wRaiseAbove b2 b1)) d) + RetPair p sub (weakenExpr (WCopy (wStack @(D2AcE (Select env0 sto "accum")) (wRaiseAbove b2 b1))) d) retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list retConcat SNil = Rets BTop SNil @@ -509,37 +515,12 @@ retConcat (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list) | Rets binds1 pairs1 <- retConcat list , Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1) , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0) + , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum")) = Rets (bconcat b binds) (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) sub (weakenExpr (WCopy (sinkWithBindings binds)) d)) (slistMap (rebaseRetPair (bindingsBinds b) (bindingsBinds binds)) pairs)) --- list ~ a : list' --- SCons (Ret b p sub d) list :: SList (Ret env0 sto) list --- Ret b p sub d :: Ret env0 sto a <- existential shbinds --- b :: Bindings Ex (D1E env0) shbinds --- p :: Ex (Append shbinds (D1E env0)) (D1 a) --- d :: Ex (D2 a : shbinds) (TEVM ...) --- --- list :: SList (Ret env0 sto) list' --- retConcat list :: Rets env0 sto (D1E env0) list' <- existential shbinds1 --- binds1 :: Bindings Ex (D1E env0) shbinds1 --- pairs1 :: SList (RetPair env0 sto (D1E env0) shbinds1) list' --- --- sinkWithBindings b :: forall e. e :> Append shbinds e --- Rets binds pairs :: Rets env0 sto (Append shbinds (D1E env0)) list' <- existential shbinds2 --- binds :: Bindings Ex (Append shbinds (D1E env0)) shbinds2 --- pairs :: SList (RetPair env0 sto (Append shbinds (D1E env0)) shbinds2) list' --- --- we choose shbindsR ~ Append shbinds2 shbinds --- result :: Rets env0 sto (D1E env0) list --- result.1 :: Bindings Ex (D1E env0) shbindsR == Bindings Ex (D1E env0) (Append shbinds2 shbinds) --- result.2 :: SList (RetPair env0 sto (D1E env0) shbindsR) list --- result.2.head :: RetPair env0 sto (D1E env0) shbindsR a --- result.2.tail :: SList (RetPair env0 sto (D1E env0) shbindsR) list' --- = SList (RetPair env0 sto (D1E env0) (Append shbinds2 shbinds)) list' --- --- wanted: shbinds1 :> shbindsR d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) d1op (OAdd t) e = EOp ext (OAdd t) e @@ -557,7 +538,7 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) d2op :: SOp a t -> D2Op a t d2op op = case op of - OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d) + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d) OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) (EOp ext (OMul t) (EPair ext (EFst ext e) d))) @@ -611,86 +592,89 @@ sD1eEnv :: Descr env sto -> SList (Const ()) (D1E env) sD1eEnv DTop = SNil sD1eEnv (DPush d _) = SCons (Const ()) (sD1eEnv d) +d2e :: SList STy env -> SList STy (D2E env) +d2e SNil = SNil +d2e (SCons t ts) = SCons (d2 t) (d2e ts) + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (SCons t ts) = SCons (d2ac t) (d2ace ts) + freezeRet :: Descr env sto -> Ret env sto t -> Ex (D1E env) (D2 t) -- the incoming cotangent value - -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge"))))) + -> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) freezeRet descr (Ret e0 e1 sub e2) d = - let e2' = weakenExpr (WCopy (wRaiseAbove (bindingsBinds e0) (sD1eEnv descr))) e2 - in letBinds e0 $ + let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0 + e2' = weakenExpr (WCopy (wCopies (bindingsBinds e0) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2 + in letBinds e0' $ EPair ext - e1 - (ELet ext (weakenExpr (sinkWithBindings e0) d) - (EMBind e2' - (EMReturn (d2e (select SAccum descr)) - (expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))))) - -d2e :: SList STy env -> SList STy (D2E env) -d2e SNil = SNil -d2e (SCons t ts) = SCons (d2 t) (d2e ts) + (weakenExpr wInsertD2Ac e1) + (ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $ + ELet ext e2' $ + expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)) drev :: forall env sto t. Descr env sto - -> (forall env' sto' t'. Descr env' sto' -> STy t' -> Some Storage) -> Ex env t -> Ret env sto t -drev des policy = \case +drev des = \case EVar _ t i -> case conv2Idx des i of - Left accumI -> + Left _ -> Ret BTop (EVar ext (d1 t) (conv1Idx i)) (subenvNone (select SMerge des)) - (EMOne d2acc accumI (EVar ext (d2 t) IZ)) + (ENil ext) Right tupI -> Ret BTop (EVar ext (d1 t) (conv1Idx i)) (subenvOnehot (select SMerge des) tupI) - (EMReturn d2acc (EPair ext (ENil ext) (EVar ext (d2 t) IZ))) + (EPair ext (ENil ext) (EVar ext (d2 t) IZ)) ELet _ rhs body - | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des policy rhs - , Some storage <- policy des (typeOf rhs) - , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body + | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs + , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, SMerge)) body , let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D2AcE (Select env sto "accum")) , Refl <- lemAppendNil @body_shbinds -> - unscope des (typeOf rhs) storage subBody body2 $ \subBody' body2' -> + popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' -> subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body -> let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') (weakenExpr wbody0' body1) subBoth - (EMBind - (weakenExpr (WCopy (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0)))) body2') - (EMBind - (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) - (EMReturn d2acc (plus_RHS_Body - (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) - (EFst ext (EVar ext bodyResType (IS IZ))))))) + (ELet ext + (weakenExpr (WCopy (wStack @(D2AcE (Select env sto "accum")) (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0))))) + body2') $ + ELet ext + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) $ + plus_RHS_Body + (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) + (EFst ext (EVar ext bodyResType (IS IZ)))) EPair _ a b | Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat $ drev des policy a `SCons` drev des policy b `SCons` SNil + <- retConcat $ drev des a `SCons` drev des b `SCons` SNil , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B -> Ret binds (EPair ext a1 b1) subBoth (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) - (EMReturn d2acc (zeroTup (subList (select SMerge des) subBoth))) - (EMBind (ELet ext (EFst ext (EVar ext dt IZ)) + (zeroTup (subList (select SMerge des) subBoth)) + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ - EMBind (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ - EMReturn d2acc - (plus_A_B - (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)))) + plus_A_B + (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) + (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))) EFst _ e - | Ret e0 e1 sub e2 <- drev des policy e + | Ret e0 e1 sub e2 <- drev des e , STPair t1 t2 <- typeOf e -> Ret e0 (EFst ext e1) @@ -699,7 +683,7 @@ drev des policy = \case weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 e1 sub e2 <- drev des policy e + | Ret e0 e1 sub e2 <- drev des e , STPair t1 t2 <- typeOf e -> Ret e0 (ESnd ext e1) @@ -707,46 +691,47 @@ drev des policy = \case (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) - ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (EMReturn d2acc (ENil ext)) + ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) EInl _ t2 e - | Ret e0 e1 sub e2 <- drev des policy e -> + | Ret e0 e1 sub e2 <- drev des e -> Ret e0 (EInl ext (d1 t2) e1) sub (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) - (EMReturn d2acc (zeroTup (subList (select SMerge des) sub))) + (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) (weakenExpr (WCopy (wSinks' @[_,_])) e2) - (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inl<-dinr"))) + (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))) EInr _ t1 e - | Ret e0 e1 sub e2 <- drev des policy e -> + | Ret e0 e1 sub e2 <- drev des e -> Ret e0 (EInr ext (d1 t1) e1) sub (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) - (EMReturn d2acc (zeroTup (subList (select SMerge des) sub))) + (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inr<-dinl") + (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") (weakenExpr (WCopy (wSinks' @[_,_])) e2))) ECase _ e (a :: Ex _ t) b | STEither t1 t2 <- typeOf e - , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des policy e - , Some storageA <- policy des t1 - , Some storageB <- policy des t2 - , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, storageA)) policy a - , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, storageB)) policy b + , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des e + , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, SMerge)) a + , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, SMerge)) b + , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) , let tapeA = tapeTy (bindingsBinds a0) , let tapeB = tapeTy (bindingsBinds b0) , let collectA = bindingsCollect a0 , let collectB = bindingsCollect b0 , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) , let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0 - , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 -> - unscope des t1 storageA subA a2 $ \subA' a2' -> - unscope des t2 storageB subB b2 $ \subB' b2' -> + , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 + -> + popFromScope des t1 subA a2 $ \subA' a2' -> + popFromScope des t2 subB b2 $ \subB' b2' -> subenvPlus (select SMerge des) subA' subB' $ \subAB sAB_A sAB_B _ -> subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E -> let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in @@ -757,43 +742,49 @@ drev des policy = \case (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0')))))) (EFst ext (EVar ext tPrimal IZ)) subOut - (EMBind + (ELet ext (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds a0) IZ in letBinds rebinds $ - ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $ - EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds)) a2') - (EMReturn d2acc - (EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)) - (EInl ext (d2 t2) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)))))) + ELet ext + (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $ + ELet ext + (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $ + WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds)) + a2') $ + EPair ext + (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $ + EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)) + (EInl ext (d2 t2) + (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)))) (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds b0) IZ in letBinds rebinds $ - ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $ - EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds)) b2') - (EMReturn d2acc - (EPair ext - (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ - EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ)) - (EInr ext (d2 t1) - (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))))))) - (EMBind (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $ - weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ - EMReturn d2acc $ - plus_AB_E - (EFst ext (EVar ext tCaseRet (IS IZ))) - (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ))) + ELet ext + (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $ + ELet ext + (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $ + WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds)) + b2') $ + EPair ext + (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ + EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ)) + (EInr ext (d2 t1) + (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))))) $ + ELet ext + (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $ + weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ + plus_AB_E + (EFst ext (EVar ext tCaseRet (IS IZ))) + (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)) EConst _ t val -> Ret BTop (EConst ext t val) (subenvNone (select SMerge des)) - (EMReturn d2acc (ENil ext)) + (ENil ext) EOp _ op e - | Ret e0 e1 sub e2 <- drev des policy e -> + | Ret e0 e1 sub e2 <- drev des e -> case d2op op of Linear d2opfun -> Ret e0 @@ -813,7 +804,7 @@ drev des policy = \case Ret BTop (EError (d1 t) s) (subenvNone (select SMerge des)) - (EMReturn d2acc (ENil ext)) + (ENil ext) -- These should be the next to be implemented, I think EBuild1{} -> err_unsupported "EBuild1" @@ -823,13 +814,9 @@ drev des policy = \case EBuild{} -> err_unsupported "EBuild" EIdx{} -> err_unsupported "EIdx" - EMOne{} -> err_evm - EMScope{} -> err_evm - EMReturn{} -> err_evm - EMBind{} -> err_evm + EWith{} -> err_accum + EAccum{} -> err_accum where - d2acc = d2e (select SAccum des) - - err_evm = error "EVM operations unsupported in the source program" + err_accum = error "Accumulator operations unsupported in the source program" err_unsupported s = error $ "CHAD: unsupported " ++ s diff --git a/src/Data.hs b/src/Data.hs index 728dafe..c3381d5 100644 --- a/src/Data.hs +++ b/src/Data.hs @@ -1,11 +1,11 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE DeriveTraversable #-} module Data where @@ -31,6 +31,10 @@ fromSNat :: SNat n -> Int fromSNat SZ = 0 fromSNat (SS n) = succ (fromSNat n) +class KnownNat n where knownNat :: SNat n +instance KnownNat Z where knownNat = SZ +instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat + data Vec n t where VNil :: Vec Z t (:<) :: t -> Vec n t -> Vec (S n) t diff --git a/src/Example.hs b/src/Example.hs index 30031c0..572d67e 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -1,8 +1,6 @@ {-# LANGUAGE DataKinds #-} module Example where -import Data.Some - import AST import AST.Pretty import CHAD @@ -10,7 +8,7 @@ import Data import Simplify --- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SAccum SAccum in freezeRet d (drev d (\_ _ -> Some SAccum) ex5) (EConst ext STF32 1.0) +-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c @@ -104,7 +102,8 @@ descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (S ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) ex5 = ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ)) - (EVar ext (STScal STF32) IZ) + (bin (OMul STF32) (EVar ext (STScal STF32) IZ) + (EVar ext (STScal STF32) (IS IZ))) (bin (OMul STF32) (EVar ext (STScal STF32) IZ) (bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ)) (EConst ext STF32 1.0))) diff --git a/src/Simplify.hs b/src/Simplify.hs index 44de164..a5f90b3 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -1,83 +1,84 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImplicitParams #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE DataKinds #-} module Simplify where +import Data.Monoid + import AST import AST.Count import Data -simplifyN :: Int -> Ex env t -> Ex env t +simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t simplifyN 0 = id simplifyN n = simplifyN (n - 1) . simplify -simplify :: Ex env t -> Ex env t -simplify = \case +simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t +simplify = let ?accumInScope = checkAccumInScope @env knownEnv in simplify' + +simplify' :: (?accumInScope :: Bool) => Ex env t -> Ex env t +simplify' = \case -- inlining ELet _ rhs body - | Occ lexOcc runOcc <- occCount IZ body + | not ?accumInScope || not (hasAdds rhs) -- cannot discard effectful computations + , Occ lexOcc runOcc <- occCount IZ body , lexOcc <= One -- prevent code size blowup , runOcc <= One -- prevent runtime increase - -> simplify (subst1 rhs body) + -> simplify' (subst1 rhs body) | cheapExpr rhs - -> simplify (subst1 rhs body) + -> simplify' (subst1 rhs body) -- let splitting ELet _ (EPair _ a b) body -> - simplify $ + simplify' $ ELet ext a $ ELet ext (weakenExpr WSink b) $ subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) IS i -> EVar ext t (IS (IS i))) body + -- let rotation + ELet _ (ELet _ rhs a) b -> + ELet ext (simplify' rhs) $ + ELet ext (simplify' a) $ + weakenExpr (WCopy WSink) (simplify' b) + -- beta rules for products - EFst _ (EPair _ e _) -> simplify e - ESnd _ (EPair _ _ e) -> simplify e + EFst _ (EPair _ e _) -> simplify' e + ESnd _ (EPair _ _ e) -> simplify' e -- beta rules for coproducts - ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs) - ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs) + ECase _ (EInl _ _ e) rhs _ -> simplify' (ELet ext e rhs) + ECase _ (EInr _ _ e) _ rhs -> simplify' (ELet ext e rhs) -- TODO: array indexing (index of build, index of fold) -- TODO: constant folding for operations - -- eta rule for return+bind - EMBind (EMReturn _ a) b -> simplify (ELet ext a b) - - -- associativity of bind - EMBind (EMBind a b) c -> simplify (EMBind a (EMBind b (weakenExpr (WCopy WSink) c))) - - -- bind-let commute - EMBind (ELet _ a b) c -> simplify (ELet ext a (EMBind b (weakenExpr (WCopy WSink) c))) - - -- return-let commute - EMReturn env (ELet _ a b) -> simplify (ELet ext a (EMReturn env b)) - EVar _ t i -> EVar ext t i - ELet _ a b -> ELet ext (simplify a) (simplify b) - EPair _ a b -> EPair ext (simplify a) (simplify b) - EFst _ e -> EFst ext (simplify e) - ESnd _ e -> ESnd ext (simplify e) + ELet _ a b -> ELet ext (simplify' a) (simplify' b) + EPair _ a b -> EPair ext (simplify' a) (simplify' b) + EFst _ e -> EFst ext (simplify' e) + ESnd _ e -> ESnd ext (simplify' e) ENil _ -> ENil ext - EInl _ t e -> EInl ext t (simplify e) - EInr _ t e -> EInr ext t (simplify e) - ECase _ e a b -> ECase ext (simplify e) (simplify a) (simplify b) - EBuild1 _ a b -> EBuild1 ext (simplify a) (simplify b) - EBuild _ es e -> EBuild ext (fmap simplify es) (simplify e) - EFold1 _ a b -> EFold1 ext (simplify a) (simplify b) + EInl _ t e -> EInl ext t (simplify' e) + EInr _ t e -> EInr ext t (simplify' e) + ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b) + EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b) + EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e) + EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b) EConst _ t v -> EConst ext t v - EIdx1 _ a b -> EIdx1 ext (simplify a) (simplify b) - EIdx _ e es -> EIdx ext (simplify e) (fmap simplify es) - EOp _ op e -> EOp ext op (simplify e) - EMOne t i e -> EMOne t i (simplify e) - EMScope e -> EMScope (simplify e) - EMReturn t e -> EMReturn t (simplify e) - EMBind a b -> EMBind (simplify a) (simplify b) + EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b) + EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es) + EOp _ op e -> EOp ext op (simplify' e) + EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2) + EAccum e1 e2 e3 -> EAccum (simplify' e1) (simplify' e2) (simplify' e3) EError t s -> EError t s cheapExpr :: Expr x env t -> Bool @@ -116,10 +117,8 @@ subst' f w = \case EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es) EOp x op e -> EOp x op (subst' f w e) - EMOne t i e -> EMOne t i (subst' f w e) - EMScope e -> EMScope (subst' f w e) - EMReturn t e -> EMReturn t (subst' f w e) - EMBind a b -> EMBind (subst' f w a) (subst' (sinkF f) (WCopy w) b) + EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) + EAccum e1 e2 e3 -> EAccum (subst' f w e1) (subst' f w e2) (subst' f w e3) EError t s -> EError t s where sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) @@ -134,3 +133,39 @@ subst' f w = \case sinkFN SZ f' x t w' i = f' x t w' i sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ) sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i + +-- | This can be made more precise by tracking (and not counting) adds on +-- locally eliminated accumulators. +hasAdds :: Expr x env t -> Bool +hasAdds = \case + EVar _ _ _ -> False + ELet _ rhs body -> hasAdds rhs || hasAdds body + EPair _ a b -> hasAdds a || hasAdds b + EFst _ e -> hasAdds e + ESnd _ e -> hasAdds e + ENil _ -> False + EInl _ _ e -> hasAdds e + EInr _ _ e -> hasAdds e + ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b + EBuild1 _ a b -> hasAdds a || hasAdds b + EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e + EFold1 _ a b -> hasAdds a || hasAdds b + EConst _ _ _ -> False + EIdx1 _ a b -> hasAdds a || hasAdds b + EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es) + EOp _ _ e -> hasAdds e + EWith a b -> hasAdds a || hasAdds b + EAccum _ _ _ -> True + EError _ _ -> False + +checkAccumInScope :: SList STy env -> Bool +checkAccumInScope = \case SNil -> False + SCons t env -> check t || checkAccumInScope env + where + check :: STy t -> Bool + check STNil = False + check (STPair s t) = check s || check t + check (STEither s t) = check s || check t + check (STArr _ t) = check t + check (STScal _) = False + check STAccum{} = True |