summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-08-30 17:48:15 +0200
committerTom Smeding <tom@tomsmeding.com>2024-08-30 17:48:15 +0200
commit8b047ff11ebd4715647bfc041a190f72dcf4d5a9 (patch)
treee8440120b7bbd4e45b367acb3f7185d25e7f3766
parentf4b94d7cc2cb05611b462ba278e4f12f7a7a5e5e (diff)
Migrate to accumulators (mostly removing EVM code)
-rw-r--r--chad-fast.cabal1
-rw-r--r--src/AST.hs45
-rw-r--r--src/AST/Count.hs18
-rw-r--r--src/AST/Pretty.hs104
-rw-r--r--src/AST/Weaken.hs23
-rw-r--r--src/CHAD.hs281
-rw-r--r--src/Data.hs6
-rw-r--r--src/Example.hs7
-rw-r--r--src/Simplify.hs125
9 files changed, 314 insertions, 296 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal
index c38c270..2e7ee22 100644
--- a/chad-fast.cabal
+++ b/chad-fast.cabal
@@ -27,7 +27,6 @@ library
containers,
template-haskell,
transformers,
- some
hs-source-dirs:
src
default-language:
diff --git a/src/AST.hs b/src/AST.hs
index aeab1b7..2267672 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -30,7 +30,7 @@ data Ty
| TEither Ty Ty
| TArr Nat Ty -- ^ rank, element type
| TScal ScalTy
- | TEVM [Ty] Ty
+ | TAccum Nat Ty -- ^ rank and element type of the array being accumulated to
deriving (Show, Eq, Ord)
data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
@@ -43,7 +43,7 @@ data STy t where
STEither :: STy a -> STy b -> STy (TEither a b)
STArr :: SNat n -> STy t -> STy (TArr n t)
STScal :: SScalTy t -> STy (TScal t)
- STEVM :: SList STy env -> STy t -> STy (TEVM env t)
+ STAccum :: SNat n -> STy t -> STy (TAccum n t)
deriving instance Show (STy t)
data SScalTy t where
@@ -97,11 +97,9 @@ data Expr x env t where
EIdx :: x t -> Expr x env (TArr n t) -> Vec n (Expr x env TIx) -> Expr x env t
EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
- -- EVM operations
- EMOne :: SList STy venv -> Idx venv t -> Expr x env t -> Expr x env (TEVM venv TNil)
- EMScope :: Expr x env (TEVM (t : venv) a) -> Expr x env (TEVM venv (TPair a t))
- EMReturn :: SList STy venv -> Expr x env t -> Expr x env (TEVM venv t)
- EMBind :: Expr x env (TEVM venv a) -> Expr x (a : env) (TEVM venv b) -> Expr x env (TEVM venv b)
+ -- accumulation effect
+ EWith :: Expr x env (TArr n t) -> Expr x (TAccum n t : env) a -> Expr x env (TPair a (TArr n t))
+ EAccum :: Expr x env TIx -> Expr x env t -> Expr x env (TAccum n t) -> Expr x env TNil
-- partiality
EError :: STy a -> String -> Expr x env a
@@ -157,10 +155,8 @@ typeOf = \case
EIdx _ e _ | STArr _ t <- typeOf e -> t
EOp _ op _ -> opt2 op
- EMOne t _ _ -> STEVM t STNil
- EMScope e | STEVM (SCons t env) a <- typeOf e -> STEVM env (STPair a t)
- EMReturn env e -> STEVM env (typeOf e)
- EMBind _ e -> typeOf e
+ EWith e1 e2 -> STPair (typeOf e2) (typeOf e1)
+ EAccum _ _ _ -> STNil
EError t _ -> t
@@ -175,7 +171,7 @@ unSTy = \case
STEither a b -> TEither (unSTy a) (unSTy b)
STArr n t -> TArr (unSNat n) (unSTy t)
STScal t -> TScal (unSScalTy t)
- STEVM l t -> TEVM (unSList l) (unSTy t)
+ STAccum n t -> TAccum (unSNat n) (unSTy t)
unSList :: SList STy env -> [Ty]
unSList SNil = []
@@ -207,10 +203,8 @@ weakenExpr w = \case
EIdx1 x e1 e2 -> EIdx1 x (weakenExpr w e1) (weakenExpr w e2)
EIdx x e1 es -> EIdx x (weakenExpr w e1) (weakenExpr w <$> es)
EOp x op e -> EOp x op (weakenExpr w e)
- EMOne t i e -> EMOne t i (weakenExpr w e)
- EMScope e -> EMScope (weakenExpr w e)
- EMReturn t e -> EMReturn t (weakenExpr w e)
- EMBind e1 e2 -> EMBind (weakenExpr w e1) (weakenExpr (WCopy w) e2)
+ EWith e1 e2 -> EWith (weakenExpr w e1) (weakenExpr (WCopy w) e2)
+ EAccum e1 e2 e3 -> EAccum (weakenExpr w e1) (weakenExpr w e2) (weakenExpr w e3)
EError t s -> EError t s
wsinkN :: SNat n -> env :> ConsN n TIx env
@@ -233,3 +227,22 @@ slistIdx SNil i = case i of {}
idx2int :: Idx env t -> Int
idx2int IZ = 0
idx2int (IS n) = 1 + idx2int n
+
+class KnownScalTy t where knownScalTy :: SScalTy t
+instance KnownScalTy TI32 where knownScalTy = STI32
+instance KnownScalTy TI64 where knownScalTy = STI64
+instance KnownScalTy TF32 where knownScalTy = STF32
+instance KnownScalTy TF64 where knownScalTy = STF64
+instance KnownScalTy TBool where knownScalTy = STBool
+
+class KnownTy t where knownTy :: STy t
+instance KnownTy TNil where knownTy = STNil
+instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
+instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
+instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
+instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
+instance (KnownNat n, KnownTy t) => KnownTy (TAccum n t) where knownTy = STAccum knownNat knownTy
+
+class KnownEnv env where knownEnv :: SList STy env
+instance KnownEnv '[] where knownEnv = SNil
+instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
diff --git a/src/AST/Count.hs b/src/AST/Count.hs
index de04b5f..f66b809 100644
--- a/src/AST/Count.hs
+++ b/src/AST/Count.hs
@@ -1,7 +1,12 @@
-{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE LambdaCase #-}
module AST.Count where
+import GHC.Generics (Generic, Generically(..))
+
import AST
import Data
@@ -18,9 +23,8 @@ instance Monoid Count where
data Occ = Occ { _occLexical :: Count
, _occRuntime :: Count }
- deriving (Eq)
-instance Semigroup Occ where Occ a b <> Occ c d = Occ (a <> c) (b <> d)
-instance Monoid Occ where mempty = Occ mempty mempty
+ deriving (Eq, Generic)
+ deriving (Semigroup, Monoid) via Generically Occ
-- | One of the two branches is taken
(<||>) :: Occ -> Occ -> Occ
@@ -49,8 +53,6 @@ occCount idx = \case
EIdx1 _ a b -> occCount idx a <> occCount idx b
EIdx _ e es -> occCount idx e <> foldMap (occCount idx) es
EOp _ _ e -> occCount idx e
- EMOne _ _ e -> occCount idx e
- EMScope e -> occCount idx e
- EMReturn _ e -> occCount idx e
- EMBind a b -> occCount idx a <> occCount (IS idx) b
+ EWith a b -> occCount idx a <> occCount (IS idx) b
+ EAccum a b e -> occCount idx a <> occCount idx b <> occCount idx e
EError{} -> mempty
diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs
index 1ffa980..3473131 100644
--- a/src/AST/Pretty.hs
+++ b/src/AST/Pretty.hs
@@ -41,14 +41,20 @@ instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j)
genId :: M Int
genId = M (\i -> (i, i + 1))
+genName' :: String -> M String
+genName' prefix = (prefix ++) . show <$> genId
+
genName :: M String
-genName = ('x' :) . show <$> genId
+genName = genName' "x"
-genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
-genNameIfUsedIn ty idx ex
+genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String
+genNameIfUsedIn' prefix ty idx ex
| occCount idx ex == mempty = case ty of STNil -> return "()"
_ -> return "_"
- | otherwise = genName
+ | otherwise = genName' prefix
+
+genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String
+genNameIfUsedIn = genNameIfUsedIn' "x"
ppExpr :: SList STy env -> Expr x env t -> String
ppExpr senv e = fst (runM (mkVal senv >>= \val -> ppExpr' 0 val e) 1) ""
@@ -64,6 +70,8 @@ ppExpr' :: Int -> SVal env -> Expr x env t -> M ShowS
ppExpr' d val = \case
EVar _ _ i -> return $ showString $ getConst $ valprj val i
+ e@ELet{} -> ppExprLet d val e
+
EPair _ a b -> do
a' <- ppExpr' 0 val a
b' <- ppExpr' 0 val b
@@ -155,80 +163,46 @@ ppExpr' d val = \case
(Prefix, s) -> s
return $ showParen (d > 10) $ showString (ops ++ " ") . e'
- EMOne venv i e -> do
- let venvlen = length (unSList venv)
- varname = 'v' : show (venvlen - idx2int i)
- e' <- ppExpr' 11 val e
+ EWith e1 e2 -> do
+ e1' <- ppExpr' 11 val e1
+ let STArr n t = typeOf e1
+ name <- genNameIfUsedIn' "ac" (STAccum n t) IZ e2
+ e2' <- ppExpr' 11 (VPush (Const name) val) e2
return $ showParen (d > 10) $
- showString ("one " ++ show varname ++ " ") . e'
+ showString "with " . e1' . showString (" (\\" ++ name ++ " -> ")
+ . e2' . showString ")"
- EMScope e -> do
- let venv = case typeOf e of STEVM v _ -> v
- venvlen = length (unSList venv)
- varname = 'v' : show venvlen
- e' <- ppExpr' 11 val e
+ EAccum e1 e2 e3 -> do
+ e1' <- ppExpr' 11 val e1
+ e2' <- ppExpr' 11 val e2
+ e3' <- ppExpr' 11 val e3
return $ showParen (d > 10) $
- showString ("scope " ++ show varname ++ " ") . e'
-
- EMReturn _ e -> do
- e' <- ppExpr' 11 val e
- return $ showParen (d > 10) $ showString ("return ") . e'
-
- e@EMBind{} -> ppExprDo d val e
- e@ELet{} -> ppExprDo d val e
-
- -- EMBind a b -> do
- -- let STEVM _ t = typeOf a
- -- a' <- ppExpr' 0 val a
- -- name <- genNameIfUsedIn t IZ b
- -- b' <- ppExpr' 0 (VPush (Const name) val) b
- -- return $ showParen (d > 10) $ a' . showString (" >>= \\" ++ name ++ " -> ") . b'
+ showString "accum " . e1' . showString " " . e2' . showString " " . e3'
EError _ s -> return $ showParen (d > 10) $ showString ("error " ++ show s)
-data Binding = MonadBind String ShowS
- | LetBind String ShowS
-
-ppExprDo :: Int -> SVal env -> Expr x env t -> M ShowS
-ppExprDo d val etop = do
- let collect :: SVal env -> Expr x env t -> M ([Binding], ShowS)
- collect val' (EMBind lhs body) = do
- let STEVM _ t = typeOf lhs
- name <- genNameIfUsedIn t IZ body
- (binds, core) <- collect (VPush (Const name) val') body
- lhs' <- ppExpr' 0 val' lhs
- return (MonadBind name lhs' : binds, core)
+ppExprLet :: Int -> SVal env -> Expr x env t -> M ShowS
+ppExprLet d val etop = do
+ let collect :: SVal env -> Expr x env t -> M ([(String, ShowS)], ShowS)
collect val' (ELet _ rhs body) = do
name <- genNameIfUsedIn (typeOf rhs) IZ body
- (binds, core) <- collect (VPush (Const name) val') body
rhs' <- ppExpr' 0 val' rhs
- return (LetBind name rhs' : binds, core)
+ (binds, core) <- collect (VPush (Const name) val') body
+ return ((name, rhs') : binds, core)
collect val' e = ([],) <$> ppExpr' 0 val' e
- fromLet = \case LetBind n s -> Just (n, s) ; _ -> Nothing
-
(binds, core) <- collect val etop
- return $ showParen (d > 0) $ case traverse fromLet binds of
- Just lbinds ->
- let (open, close) = case lbinds of
- [_] -> ("{ ", " }")
- _ -> ("", "")
- in showString ("let " ++ open)
- . foldr (.) id
- (intersperse (showString " ; ")
- (map (\(name, rhs) -> showString (name ++ " = ") . rhs) lbinds))
- . showString (close ++ " in ")
- . core
- Nothing ->
- showString "do { "
- . foldr (.) id
- (intersperse (showString " ; ")
- (map (\case MonadBind name rhs -> showString (name ++ " <- ") . rhs
- LetBind name rhs -> showString ("let { " ++ name ++ " = ") . rhs
- . showString " }")
- binds))
- . showString " ; " . core . showString " }"
+ let (open, close) = case binds of
+ [_] -> ("{ ", " }")
+ _ -> ("", "")
+ return $ showParen (d > 0) $
+ showString ("let " ++ open)
+ . foldr (.) id
+ (intersperse (showString " ; ")
+ (map (\(name, rhs) -> showString (name ++ " = ") . rhs) binds))
+ . showString (close ++ " in ")
+ . core
data Fixity = Prefix | Infix
deriving (Show)
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 4b3016d..d992404 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -36,7 +36,7 @@ data env :> env' where
WIdx :: Idx env t -> (t : env) :> env
deriving instance Show (env :> env')
-infixr @>
+infixr 2 @>
(@>) :: env :> env' -> Idx env t -> Idx env' t
WId @> i = i
WSink @> i = IS i
@@ -48,6 +48,7 @@ WClosed _ @> i = case i of {}
WIdx j @> IZ = j
WIdx _ @> IS i = i
+infixr 3 .>
(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3
(.>) = flip WThen
@@ -70,14 +71,18 @@ wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2
wCopies SNil w = w
wCopies (SCons _ spine) w = WCopy (wCopies spine w)
--- wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env
--- wStack WId = WId
--- wStack WSink = WSink
--- wStack (WCopy w) = WCopy (wStack @env w)
--- wStack (WPop w) = WPop (wStack @env w)
--- wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2)
--- wStack (WClosed s) = wSinks s
--- wStack (WIdx i) = WIdx (_ i)
+wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env
+wStack WId = WId
+wStack WSink = WSink
+wStack (WCopy w) = WCopy (wStack @env w)
+wStack (WPop w) = WPop (wStack @env w)
+wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2)
+wStack (WClosed s) = wSinks s
+wStack (WIdx i) = WIdx (goIdx i)
+ where
+ goIdx :: Idx b t -> Idx (Append b env) t
+ goIdx IZ = IZ
+ goIdx (IS i') = IS (goIdx i')
wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 2513f84..e209b67 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -1,16 +1,16 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE UndecidableInstances #-}
-- I want to bring various type variables in scope using type annotations in
@@ -29,7 +29,6 @@ module CHAD (
import Data.Bifunctor (first, second)
import Data.Functor.Const
import Data.Kind (Type)
-import Data.Some
import GHC.TypeLits (Symbol)
import AST
@@ -254,7 +253,7 @@ type family D2 t where
D2 TNil = TNil
D2 (TPair a b) = TEither TNil (TPair (D2 a) (D2 b))
D2 (TEither a b) = TEither TNil (TEither (D2 a) (D2 b))
- -- D2 (TArr n t) = _
+ D2 (TArr n t) = TArr n (D2 t)
D2 (TScal t) = D2s t
type family D2s t where
@@ -264,6 +263,9 @@ type family D2s t where
D2s TF64 = TScal TF64
D2s TBool = TNil
+type family D2Ac t where
+ D2Ac (TArr n t) = TAccum n t
+
type family D1E env where
D1E '[] = '[]
D1E (t : env) = D1 t : D1E env
@@ -272,6 +274,10 @@ type family D2E env where
D2E '[] = '[]
D2E (t : env) = D2 t : D2E env
+type family D2AcE env where
+ D2AcE '[] = '[]
+ D2AcE (t : env) = D2Ac t : D2AcE env
+
-- | Select only the types from the environment that have the specified storage
type family Select env sto s where
Select '[] '[] _ = '[]
@@ -284,20 +290,24 @@ d1 (STPair a b) = STPair (d1 a) (d1 b)
d1 (STEither a b) = STEither (d1 a) (d1 b)
d1 (STArr n t) = STArr n (d1 t)
d1 (STScal t) = STScal t
-d1 STEVM{} = error "EVM not allowed in input program"
+d1 STAccum{} = error "Accumulators not allowed in input program"
d2 :: STy t -> STy (D2 t)
d2 STNil = STNil
d2 (STPair a b) = STEither STNil (STPair (d2 a) (d2 b))
d2 (STEither a b) = STEither STNil (STEither (d2 a) (d2 b))
-d2 STArr{} = error "TODO arrays"
+d2 (STArr n t) = STArr n (d2 t)
d2 (STScal t) = case t of
STI32 -> STNil
STI64 -> STNil
STF32 -> STScal STF32
STF64 -> STScal STF64
STBool -> STNil
-d2 STEVM{} = error "EVM not allowed in input program"
+d2 STAccum{} = error "Accumulators not allowed in input program"
+
+d2ac :: STy t -> STy (D2Ac t)
+d2ac (STArr n t) = STAccum n t
+d2ac _ = error "Only arrays may appear in the accumulator environment"
conv1Idx :: Idx env t -> Idx (D1E env) (D1 t)
conv1Idx IZ = IZ
@@ -322,7 +332,7 @@ zero (STScal t) = case t of
STF32 -> EConst ext STF32 0.0
STF64 -> EConst ext STF64 0.0
STBool -> ENil ext
-zero STEVM{} = error "EVM not allowed in input program"
+zero STAccum{} = error "Accumulators not allowed in input program"
plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
plus STNil _ _ = ENil ext
@@ -350,7 +360,7 @@ plus (STScal t) a b = case t of
STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
STBool -> ENil ext
-plus STEVM{} _ _ = error "EVM not allowed in input program"
+plus STAccum{} _ _ = error "Accumulators not allowed in input program"
plusSparse :: STy a
-> Ex env (TEither TNil a) -> Ex env (TEither TNil a)
@@ -388,14 +398,14 @@ data Ret env0 sto t =
Ret (Bindings Ex (D1E env0) shbinds) -- shared binds
(Ex (Append shbinds (D1E env0)) (D1 t))
(Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge))))
+ (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
deriving instance Show (Ret env0 sto t)
data RetPair env0 sto env shbinds t =
forall env0Merge.
RetPair (Ex (Append shbinds env) (D1 t))
(Subenv (Select env0 sto "merge") env0Merge)
- (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge))))
+ (Ex (D2 t : Append shbinds (D2AcE (Select env0 sto "accum"))) (Tup (D2E env0Merge)))
deriving instance Show (RetPair env0 sto env shbinds t)
data Rets env0 sto env list =
@@ -430,7 +440,8 @@ subenvPlus :: SList STy env
-> r
subenvPlus SNil SETop SETop k = k SETop SETop SETop (\_ _ -> ENil ext)
subenvPlus (SCons _ env) (SENo sub1) (SENo sub2) k =
- subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl -> k (SENo sub3) s31 s32 pl
+ subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
+ k (SENo sub3) s31 s32 pl
subenvPlus (SCons _ env) (SEYes sub1) (SENo sub2) k =
subenvPlus env sub1 sub2 $ \sub3 s31 s32 pl ->
k (SEYes sub3) (SEYes s31) (SENo s32) $ \e1 e2 ->
@@ -464,24 +475,19 @@ expandSubenvZeros (SCons t ts) (SEYes sub) e =
in EPair ext (expandSubenvZeros ts sub (EFst ext var)) (ESnd ext var)
expandSubenvZeros (SCons t ts) (SENo sub) e = EPair ext (expandSubenvZeros ts sub e) (zero t)
-unscope :: Descr env0 sto
- -> STy a -> Storage s
- -> Subenv (Select (a : env0) (s : sto) "merge") envSub
- -> Ex env (TEVM (D2E (Select (a : env0) (s : sto) "accum")) (Tup (D2E envSub)))
- -> (forall envSub'.
- Subenv (Select env0 sto "merge") envSub'
- -> Ex env (TEVM (D2E (Select env0 sto "accum")) (TPair (Tup (D2E envSub')) (D2 a)))
- -> r)
- -> r
-unscope des ty s sub e k = case s of
- SAccum -> k sub (EMScope e)
- SMerge -> case sub of
- SEYes sub' -> k sub' e
- SENo sub' -> k sub' $
- EMBind e $
- EMReturn (d2e (select SAccum des)) $
- EPair ext (EVar ext (tTup (d2e (subList (select SMerge des) sub'))) IZ)
- (zero ty)
+popFromScope
+ :: Descr env0 sto
+ -> STy a
+ -> Subenv (Select (a : env0) ("merge" : sto) "merge") envSub
+ -> Ex env (Tup (D2E envSub))
+ -> (forall envSub'.
+ Subenv (Select env0 sto "merge") envSub'
+ -> Ex env (TPair (Tup (D2E envSub')) (D2 a))
+ -> r)
+ -> r
+popFromScope _ ty sub e k = case sub of
+ SEYes sub' -> k sub' e
+ SENo sub' -> k sub' $ EPair ext e (zero ty)
-- d1W :: env :> env' -> D1E env :> D1E env'
-- d1W WId = WId
@@ -501,7 +507,7 @@ weakenRets w (Rets binds list) =
rebaseRetPair :: forall env b1 b2 env0 sto t f. SList f b1 -> SList f b2 -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t
rebaseRetPair b1 b2 (RetPair p sub d)
| Refl <- lemAppendAssoc @b2 @b1 @env =
- RetPair p sub (weakenExpr (WCopy (wRaiseAbove b2 b1)) d)
+ RetPair p sub (weakenExpr (WCopy (wStack @(D2AcE (Select env0 sto "accum")) (wRaiseAbove b2 b1))) d)
retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list
retConcat SNil = Rets BTop SNil
@@ -509,37 +515,12 @@ retConcat (SCons (Ret (b :: Bindings _ _ shbinds) p sub d) list)
| Rets binds1 pairs1 <- retConcat list
, Rets (binds :: Bindings _ _ shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1)
, Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0)
+ , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D2AcE (Select env0 sto "accum"))
= Rets (bconcat b binds)
(SCons (RetPair (weakenExpr (sinkWithBindings binds) p)
sub
(weakenExpr (WCopy (sinkWithBindings binds)) d))
(slistMap (rebaseRetPair (bindingsBinds b) (bindingsBinds binds)) pairs))
--- list ~ a : list'
--- SCons (Ret b p sub d) list :: SList (Ret env0 sto) list
--- Ret b p sub d :: Ret env0 sto a <- existential shbinds
--- b :: Bindings Ex (D1E env0) shbinds
--- p :: Ex (Append shbinds (D1E env0)) (D1 a)
--- d :: Ex (D2 a : shbinds) (TEVM ...)
---
--- list :: SList (Ret env0 sto) list'
--- retConcat list :: Rets env0 sto (D1E env0) list' <- existential shbinds1
--- binds1 :: Bindings Ex (D1E env0) shbinds1
--- pairs1 :: SList (RetPair env0 sto (D1E env0) shbinds1) list'
---
--- sinkWithBindings b :: forall e. e :> Append shbinds e
--- Rets binds pairs :: Rets env0 sto (Append shbinds (D1E env0)) list' <- existential shbinds2
--- binds :: Bindings Ex (Append shbinds (D1E env0)) shbinds2
--- pairs :: SList (RetPair env0 sto (Append shbinds (D1E env0)) shbinds2) list'
---
--- we choose shbindsR ~ Append shbinds2 shbinds
--- result :: Rets env0 sto (D1E env0) list
--- result.1 :: Bindings Ex (D1E env0) shbindsR == Bindings Ex (D1E env0) (Append shbinds2 shbinds)
--- result.2 :: SList (RetPair env0 sto (D1E env0) shbindsR) list
--- result.2.head :: RetPair env0 sto (D1E env0) shbindsR a
--- result.2.tail :: SList (RetPair env0 sto (D1E env0) shbindsR) list'
--- = SList (RetPair env0 sto (D1E env0) (Append shbinds2 shbinds)) list'
---
--- wanted: shbinds1 :> shbindsR
d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
d1op (OAdd t) e = EOp ext (OAdd t) e
@@ -557,7 +538,7 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
d2op :: SOp a t -> D2Op a t
d2op op = case op of
- OAdd _ -> Linear $ \d -> EInr ext STNil (EPair ext d d)
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d)
OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
(EOp ext (OMul t) (EPair ext (EFst ext e) d)))
@@ -611,86 +592,89 @@ sD1eEnv :: Descr env sto -> SList (Const ()) (D1E env)
sD1eEnv DTop = SNil
sD1eEnv (DPush d _) = SCons (Const ()) (sD1eEnv d)
+d2e :: SList STy env -> SList STy (D2E env)
+d2e SNil = SNil
+d2e (SCons t ts) = SCons (d2 t) (d2e ts)
+
+d2ace :: SList STy env -> SList STy (D2AcE env)
+d2ace SNil = SNil
+d2ace (SCons t ts) = SCons (d2ac t) (d2ace ts)
+
freezeRet :: Descr env sto
-> Ret env sto t
-> Ex (D1E env) (D2 t) -- the incoming cotangent value
- -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge")))))
+ -> Ex (Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
freezeRet descr (Ret e0 e1 sub e2) d =
- let e2' = weakenExpr (WCopy (wRaiseAbove (bindingsBinds e0) (sD1eEnv descr))) e2
- in letBinds e0 $
+ let (e0', wInsertD2Ac) = weakenBindings weakenExpr (wSinks (d2ace (select SAccum descr))) e0
+ e2' = weakenExpr (WCopy (wCopies (bindingsBinds e0) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2
+ in letBinds e0' $
EPair ext
- e1
- (ELet ext (weakenExpr (sinkWithBindings e0) d)
- (EMBind e2'
- (EMReturn (d2e (select SAccum descr))
- (expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ)))))
-
-d2e :: SList STy env -> SList STy (D2E env)
-d2e SNil = SNil
-d2e (SCons t ts) = SCons (d2 t) (d2e ts)
+ (weakenExpr wInsertD2Ac e1)
+ (ELet ext (weakenExpr (sinkWithBindings e0 .> wSinks (d2ace (select SAccum descr))) d) $
+ ELet ext e2' $
+ expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))
drev :: forall env sto t.
Descr env sto
- -> (forall env' sto' t'. Descr env' sto' -> STy t' -> Some Storage)
-> Ex env t -> Ret env sto t
-drev des policy = \case
+drev des = \case
EVar _ t i ->
case conv2Idx des i of
- Left accumI ->
+ Left _ ->
Ret BTop
(EVar ext (d1 t) (conv1Idx i))
(subenvNone (select SMerge des))
- (EMOne d2acc accumI (EVar ext (d2 t) IZ))
+ (ENil ext)
Right tupI ->
Ret BTop
(EVar ext (d1 t) (conv1Idx i))
(subenvOnehot (select SMerge des) tupI)
- (EMReturn d2acc (EPair ext (ENil ext) (EVar ext (d2 t) IZ)))
+ (EPair ext (ENil ext) (EVar ext (d2 t) IZ))
ELet _ rhs body
- | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des policy rhs
- , Some storage <- policy des (typeOf rhs)
- , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body
+ | Ret (rhs0 :: Bindings _ _ rhs_shbinds) (rhs1 :: Ex _ d1_a) subRHS rhs2 <- drev des rhs
+ , Ret (body0 :: Bindings _ _ body_shbinds) body1 subBody body2 <- drev (des `DPush` (typeOf rhs, SMerge)) body
, let (body0', wbody0') = weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0
, Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D1E env)
+ , Refl <- lemAppendAssoc @body_shbinds @(d1_a : rhs_shbinds) @(D2AcE (Select env sto "accum"))
, Refl <- lemAppendNil @body_shbinds ->
- unscope des (typeOf rhs) storage subBody body2 $ \subBody' body2' ->
+ popFromScope des (typeOf rhs) subBody body2 $ \subBody' body2' ->
subenvPlus (select SMerge des) subRHS subBody' $ \subBoth _ _ plus_RHS_Body ->
let bodyResType = STPair (tTup (d2e (subList (select SMerge des) subBody'))) (d2 (typeOf rhs)) in
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
(weakenExpr wbody0' body1)
subBoth
- (EMBind
- (weakenExpr (WCopy (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0)))) body2')
- (EMBind
- (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
- weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2)
- (EMReturn d2acc (plus_RHS_Body
- (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
- (EFst ext (EVar ext bodyResType (IS IZ)))))))
+ (ELet ext
+ (weakenExpr (WCopy (wStack @(D2AcE (Select env sto "accum")) (wRaiseAbove (bindingsBinds body0) (SCons (typeOf rhs1) (bindingsBinds rhs0)))))
+ body2') $
+ ELet ext
+ (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $
+ weakenExpr (WCopy (wSinks' @[_,_] .> WPop @d1_a (sinkWithBindings body0'))) rhs2) $
+ plus_RHS_Body
+ (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ)
+ (EFst ext (EVar ext bodyResType (IS IZ))))
EPair _ a b
| Rets binds (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil)
- <- retConcat $ drev des policy a `SCons` drev des policy b `SCons` SNil
+ <- retConcat $ drev des a `SCons` drev des b `SCons` SNil
, let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
subenvPlus (select SMerge des) subA subB $ \subBoth _ _ plus_A_B ->
Ret binds
(EPair ext a1 b1)
subBoth
(ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
- (EMReturn d2acc (zeroTup (subList (select SMerge des) subBoth)))
- (EMBind (ELet ext (EFst ext (EVar ext dt IZ))
+ (zeroTup (subList (select SMerge des) subBoth))
+ (ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
(weakenExpr (WCopy (wSinks' @[_,_])) a2)) $
- EMBind (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
(weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $
- EMReturn d2acc
- (plus_A_B
- (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
- (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))))
+ plus_A_B
+ (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
+ (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)))
EFst _ e
- | Ret e0 e1 sub e2 <- drev des policy e
+ | Ret e0 e1 sub e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
(EFst ext e1)
@@ -699,7 +683,7 @@ drev des policy = \case
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 e1 sub e2 <- drev des policy e
+ | Ret e0 e1 sub e2 <- drev des e
, STPair t1 t2 <- typeOf e ->
Ret e0
(ESnd ext e1)
@@ -707,46 +691,47 @@ drev des policy = \case
(ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (EMReturn d2acc (ENil ext))
+ ENil _ -> Ret BTop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
EInl _ t2 e
- | Ret e0 e1 sub e2 <- drev des policy e ->
+ | Ret e0 e1 sub e2 <- drev des e ->
Ret e0
(EInl ext (d1 t2) e1)
sub
(ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
- (EMReturn d2acc (zeroTup (subList (select SMerge des) sub)))
+ (zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
(weakenExpr (WCopy (wSinks' @[_,_])) e2)
- (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inl<-dinr")))
+ (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")))
EInr _ t1 e
- | Ret e0 e1 sub e2 <- drev des policy e ->
+ | Ret e0 e1 sub e2 <- drev des e ->
Ret e0
(EInr ext (d1 t1) e1)
sub
(ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
- (EMReturn d2acc (zeroTup (subList (select SMerge des) sub)))
+ (zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
- (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inr<-dinl")
+ (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
(weakenExpr (WCopy (wSinks' @[_,_])) e2)))
ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
- , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des policy e
- , Some storageA <- policy des t1
- , Some storageB <- policy des t2
- , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, storageA)) policy a
- , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, storageB)) policy b
+ , Ret (e0 :: Bindings _ _ e_binds) e1 subE e2 <- drev des e
+ , Ret (a0 :: Bindings _ _ rhs_a_binds) a1 subA a2 <- drev (des `DPush` (t1, SMerge)) a
+ , Ret (b0 :: Bindings _ _ rhs_b_binds) b1 subB b2 <- drev (des `DPush` (t2, SMerge)) b
+ , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
+ , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum"))
, let tapeA = tapeTy (bindingsBinds a0)
, let tapeB = tapeTy (bindingsBinds b0)
, let collectA = bindingsCollect a0
, let collectB = bindingsCollect b0
, (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB)
, let (a0', wa0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) a0
- , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0 ->
- unscope des t1 storageA subA a2 $ \subA' a2' ->
- unscope des t2 storageB subB b2 $ \subB' b2' ->
+ , let (b0', wb0') = weakenBindings weakenExpr (WCopy (sinkWithBindings e0)) b0
+ ->
+ popFromScope des t1 subA a2 $ \subA' a2' ->
+ popFromScope des t2 subB b2 $ \subB' b2' ->
subenvPlus (select SMerge des) subA' subB' $ \subAB sAB_A sAB_B _ ->
subenvPlus (select SMerge des) subAB subE $ \subOut _ _ plus_AB_E ->
let tCaseRet = STPair (tTup (d2e (subList (select SMerge des) subAB))) (STEither (d2 t1) (d2 t2)) in
@@ -757,43 +742,49 @@ drev des policy = \case
(letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0'))))))
(EFst ext (EVar ext tPrimal IZ))
subOut
- (EMBind
+ (ELet ext
(ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
(let (rebinds, prerebinds) = reconstructBindings (bindingsBinds a0) IZ
in letBinds rebinds $
- ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $
- EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds)) a2')
- (EMReturn d2acc
- (EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))
- (EInl ext (d2 t2)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))))))
+ ELet ext
+ (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $
+ ELet ext
+ (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $
+ WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds))
+ a2') $
+ EPair ext
+ (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
+ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))
+ (EInl ext (d2 t2)
+ (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))))
(let (rebinds, prerebinds) = reconstructBindings (bindingsBinds b0) IZ
in letBinds rebinds $
- ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $
- EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds)) b2')
- (EMReturn d2acc
- (EPair ext
- (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
- EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))
- (EInr ext (d2 t1)
- (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ)))))))
- (EMBind (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $
- weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
- EMReturn d2acc $
- plus_AB_E
- (EFst ext (EVar ext tCaseRet (IS IZ)))
- (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ)))
+ ELet ext
+ (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : Append e_binds (D2AcE (Select env sto "accum"))) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $
+ ELet ext
+ (weakenExpr (wStack @(D2AcE (Select env sto "accum")) $
+ WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds))
+ b2') $
+ EPair ext
+ (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $
+ EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))
+ (EInr ext (d2 t1)
+ (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ))))) $
+ ELet ext
+ (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $
+ weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
+ plus_AB_E
+ (EFst ext (EVar ext tCaseRet (IS IZ)))
+ (EVar ext (tTup (d2e (subList (select SMerge des) subE))) IZ))
EConst _ t val ->
Ret BTop
(EConst ext t val)
(subenvNone (select SMerge des))
- (EMReturn d2acc (ENil ext))
+ (ENil ext)
EOp _ op e
- | Ret e0 e1 sub e2 <- drev des policy e ->
+ | Ret e0 e1 sub e2 <- drev des e ->
case d2op op of
Linear d2opfun ->
Ret e0
@@ -813,7 +804,7 @@ drev des policy = \case
Ret BTop
(EError (d1 t) s)
(subenvNone (select SMerge des))
- (EMReturn d2acc (ENil ext))
+ (ENil ext)
-- These should be the next to be implemented, I think
EBuild1{} -> err_unsupported "EBuild1"
@@ -823,13 +814,9 @@ drev des policy = \case
EBuild{} -> err_unsupported "EBuild"
EIdx{} -> err_unsupported "EIdx"
- EMOne{} -> err_evm
- EMScope{} -> err_evm
- EMReturn{} -> err_evm
- EMBind{} -> err_evm
+ EWith{} -> err_accum
+ EAccum{} -> err_accum
where
- d2acc = d2e (select SAccum des)
-
- err_evm = error "EVM operations unsupported in the source program"
+ err_accum = error "Accumulator operations unsupported in the source program"
err_unsupported s = error $ "CHAD: unsupported " ++ s
diff --git a/src/Data.hs b/src/Data.hs
index 728dafe..c3381d5 100644
--- a/src/Data.hs
+++ b/src/Data.hs
@@ -1,11 +1,11 @@
{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE DeriveTraversable #-}
module Data where
@@ -31,6 +31,10 @@ fromSNat :: SNat n -> Int
fromSNat SZ = 0
fromSNat (SS n) = succ (fromSNat n)
+class KnownNat n where knownNat :: SNat n
+instance KnownNat Z where knownNat = SZ
+instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat
+
data Vec n t where
VNil :: Vec Z t
(:<) :: t -> Vec n t -> Vec (S n) t
diff --git a/src/Example.hs b/src/Example.hs
index 30031c0..572d67e 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -1,8 +1,6 @@
{-# LANGUAGE DataKinds #-}
module Example where
-import Data.Some
-
import AST
import AST.Pretty
import CHAD
@@ -10,7 +8,7 @@ import Data
import Simplify
--- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SAccum SAccum in freezeRet d (drev d (\_ _ -> Some SAccum) ex5) (EConst ext STF32 1.0)
+-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0)
bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c
@@ -104,7 +102,8 @@ descr5 a b = DTop `DPush` (STEither (STScal STF32) (STScal STF32), a) `DPush` (S
ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32)
ex5 =
ECase ext (EVar ext (STEither (STScal STF32) (STScal STF32)) (IS IZ))
- (EVar ext (STScal STF32) IZ)
+ (bin (OMul STF32) (EVar ext (STScal STF32) IZ)
+ (EVar ext (STScal STF32) (IS IZ)))
(bin (OMul STF32) (EVar ext (STScal STF32) IZ)
(bin (OAdd STF32) (EVar ext (STScal STF32) (IS IZ))
(EConst ext STF32 1.0)))
diff --git a/src/Simplify.hs b/src/Simplify.hs
index 44de164..a5f90b3 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -1,83 +1,84 @@
+{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE DataKinds #-}
module Simplify where
+import Data.Monoid
+
import AST
import AST.Count
import Data
-simplifyN :: Int -> Ex env t -> Ex env t
+simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t
simplifyN 0 = id
simplifyN n = simplifyN (n - 1) . simplify
-simplify :: Ex env t -> Ex env t
-simplify = \case
+simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t
+simplify = let ?accumInScope = checkAccumInScope @env knownEnv in simplify'
+
+simplify' :: (?accumInScope :: Bool) => Ex env t -> Ex env t
+simplify' = \case
-- inlining
ELet _ rhs body
- | Occ lexOcc runOcc <- occCount IZ body
+ | not ?accumInScope || not (hasAdds rhs) -- cannot discard effectful computations
+ , Occ lexOcc runOcc <- occCount IZ body
, lexOcc <= One -- prevent code size blowup
, runOcc <= One -- prevent runtime increase
- -> simplify (subst1 rhs body)
+ -> simplify' (subst1 rhs body)
| cheapExpr rhs
- -> simplify (subst1 rhs body)
+ -> simplify' (subst1 rhs body)
-- let splitting
ELet _ (EPair _ a b) body ->
- simplify $
+ simplify' $
ELet ext a $
ELet ext (weakenExpr WSink b) $
subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ)
IS i -> EVar ext t (IS (IS i)))
body
+ -- let rotation
+ ELet _ (ELet _ rhs a) b ->
+ ELet ext (simplify' rhs) $
+ ELet ext (simplify' a) $
+ weakenExpr (WCopy WSink) (simplify' b)
+
-- beta rules for products
- EFst _ (EPair _ e _) -> simplify e
- ESnd _ (EPair _ _ e) -> simplify e
+ EFst _ (EPair _ e _) -> simplify' e
+ ESnd _ (EPair _ _ e) -> simplify' e
-- beta rules for coproducts
- ECase _ (EInl _ _ e) rhs _ -> simplify (ELet ext e rhs)
- ECase _ (EInr _ _ e) _ rhs -> simplify (ELet ext e rhs)
+ ECase _ (EInl _ _ e) rhs _ -> simplify' (ELet ext e rhs)
+ ECase _ (EInr _ _ e) _ rhs -> simplify' (ELet ext e rhs)
-- TODO: array indexing (index of build, index of fold)
-- TODO: constant folding for operations
- -- eta rule for return+bind
- EMBind (EMReturn _ a) b -> simplify (ELet ext a b)
-
- -- associativity of bind
- EMBind (EMBind a b) c -> simplify (EMBind a (EMBind b (weakenExpr (WCopy WSink) c)))
-
- -- bind-let commute
- EMBind (ELet _ a b) c -> simplify (ELet ext a (EMBind b (weakenExpr (WCopy WSink) c)))
-
- -- return-let commute
- EMReturn env (ELet _ a b) -> simplify (ELet ext a (EMReturn env b))
-
EVar _ t i -> EVar ext t i
- ELet _ a b -> ELet ext (simplify a) (simplify b)
- EPair _ a b -> EPair ext (simplify a) (simplify b)
- EFst _ e -> EFst ext (simplify e)
- ESnd _ e -> ESnd ext (simplify e)
+ ELet _ a b -> ELet ext (simplify' a) (simplify' b)
+ EPair _ a b -> EPair ext (simplify' a) (simplify' b)
+ EFst _ e -> EFst ext (simplify' e)
+ ESnd _ e -> ESnd ext (simplify' e)
ENil _ -> ENil ext
- EInl _ t e -> EInl ext t (simplify e)
- EInr _ t e -> EInr ext t (simplify e)
- ECase _ e a b -> ECase ext (simplify e) (simplify a) (simplify b)
- EBuild1 _ a b -> EBuild1 ext (simplify a) (simplify b)
- EBuild _ es e -> EBuild ext (fmap simplify es) (simplify e)
- EFold1 _ a b -> EFold1 ext (simplify a) (simplify b)
+ EInl _ t e -> EInl ext t (simplify' e)
+ EInr _ t e -> EInr ext t (simplify' e)
+ ECase _ e a b -> ECase ext (simplify' e) (simplify' a) (simplify' b)
+ EBuild1 _ a b -> EBuild1 ext (simplify' a) (simplify' b)
+ EBuild _ es e -> EBuild ext (fmap simplify' es) (simplify' e)
+ EFold1 _ a b -> EFold1 ext (simplify' a) (simplify' b)
EConst _ t v -> EConst ext t v
- EIdx1 _ a b -> EIdx1 ext (simplify a) (simplify b)
- EIdx _ e es -> EIdx ext (simplify e) (fmap simplify es)
- EOp _ op e -> EOp ext op (simplify e)
- EMOne t i e -> EMOne t i (simplify e)
- EMScope e -> EMScope (simplify e)
- EMReturn t e -> EMReturn t (simplify e)
- EMBind a b -> EMBind (simplify a) (simplify b)
+ EIdx1 _ a b -> EIdx1 ext (simplify' a) (simplify' b)
+ EIdx _ e es -> EIdx ext (simplify' e) (fmap simplify' es)
+ EOp _ op e -> EOp ext op (simplify' e)
+ EWith e1 e2 -> EWith (simplify' e1) (let ?accumInScope = True in simplify' e2)
+ EAccum e1 e2 e3 -> EAccum (simplify' e1) (simplify' e2) (simplify' e3)
EError t s -> EError t s
cheapExpr :: Expr x env t -> Bool
@@ -116,10 +117,8 @@ subst' f w = \case
EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
EIdx x e es -> EIdx x (subst' f w e) (fmap (subst' f w) es)
EOp x op e -> EOp x op (subst' f w e)
- EMOne t i e -> EMOne t i (subst' f w e)
- EMScope e -> EMScope (subst' f w e)
- EMReturn t e -> EMReturn t (subst' f w e)
- EMBind a b -> EMBind (subst' f w a) (subst' (sinkF f) (WCopy w) b)
+ EWith e1 e2 -> EWith (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
+ EAccum e1 e2 e3 -> EAccum (subst' f w e1) (subst' f w e2) (subst' f w e3)
EError t s -> EError t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
@@ -134,3 +133,39 @@ subst' f w = \case
sinkFN SZ f' x t w' i = f' x t w' i
sinkFN (SS _) _ x t w' IZ = EVar x t (w' @> IZ)
sinkFN (SS n) f' x t w' (IS i) = sinkFN n f' x t (WPop w') i
+
+-- | This can be made more precise by tracking (and not counting) adds on
+-- locally eliminated accumulators.
+hasAdds :: Expr x env t -> Bool
+hasAdds = \case
+ EVar _ _ _ -> False
+ ELet _ rhs body -> hasAdds rhs || hasAdds body
+ EPair _ a b -> hasAdds a || hasAdds b
+ EFst _ e -> hasAdds e
+ ESnd _ e -> hasAdds e
+ ENil _ -> False
+ EInl _ _ e -> hasAdds e
+ EInr _ _ e -> hasAdds e
+ ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b
+ EBuild1 _ a b -> hasAdds a || hasAdds b
+ EBuild _ es e -> getAny (foldMap (Any . hasAdds) es) || hasAdds e
+ EFold1 _ a b -> hasAdds a || hasAdds b
+ EConst _ _ _ -> False
+ EIdx1 _ a b -> hasAdds a || hasAdds b
+ EIdx _ e es -> hasAdds e || getAny (foldMap (Any . hasAdds) es)
+ EOp _ _ e -> hasAdds e
+ EWith a b -> hasAdds a || hasAdds b
+ EAccum _ _ _ -> True
+ EError _ _ -> False
+
+checkAccumInScope :: SList STy env -> Bool
+checkAccumInScope = \case SNil -> False
+ SCons t env -> check t || checkAccumInScope env
+ where
+ check :: STy t -> Bool
+ check STNil = False
+ check (STPair s t) = check s || check t
+ check (STEither s t) = check s || check t
+ check (STArr _ t) = check t
+ check (STScal _) = False
+ check STAccum{} = True