summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <t.j.smeding@uu.nl>2023-09-21 15:52:22 +0200
committerTom Smeding <t.j.smeding@uu.nl>2023-09-21 15:52:22 +0200
commit574569ee96a01d623baf8efdcd3908eef42b8007 (patch)
treecffc7d81e32b2b52430ac7da97a2a31102f55c97
parent8d07a43f0b364156433dc453b9d1cc762c032634 (diff)
Storage policy (accum / merge)
-rw-r--r--src/AST.hs2
-rw-r--r--src/CHAD.hs185
-rw-r--r--src/Example.hs7
-rw-r--r--src/Simplify.hs6
4 files changed, 142 insertions, 58 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 8d795bf..e39c74f 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -28,7 +28,7 @@ data Nat = Z | S Nat
data SNat n where
SZ :: SNat Z
SS :: SNat n -> SNat (S n)
-deriving instance (Show (SNat n))
+deriving instance Show (SNat n)
data Vec n t where
VNil :: Vec Z t
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 0c856b1..b074470 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -15,6 +15,7 @@ module CHAD where
import Data.Bifunctor (first, second)
import Data.Kind (Type)
+import Data.Some
import GHC.TypeLits (Symbol)
import AST
@@ -142,11 +143,11 @@ conv1Idx IZ = IZ
conv1Idx (IS i) = IS (conv1Idx i)
conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t))
- (Idx (D2E (Select env sto "merge")) (D2 t))
-conv2Idx (DPush _ _ SAccum) IZ = Left IZ
-conv2Idx (DPush _ _ SMerge) IZ = Right IZ
-conv2Idx (DPush des _ SAccum) (IS i) = first IS (conv2Idx des i)
-conv2Idx (DPush des _ SMerge) (IS i) = second IS (conv2Idx des i)
+ (Idx (Select env sto "merge") t)
+conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ
+conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ
+conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i)
+conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i)
conv2Idx DTop i = case i of {}
zero :: STy t -> Ex env (D2 t)
@@ -162,17 +163,73 @@ zero (STScal t) = case t of
STBool -> ENil ext
zero STEVM{} = error "EVM not allowed in input program"
+plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t)
+plus STNil _ _ = ENil ext
+plus (STPair t1 t2) a b =
+ let t = STPair (d2 t1) (d2 t2)
+ in plusSparse t a b $
+ EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ)))
+ (EFst ext (EVar ext t IZ)))
+ (plus t2 (ESnd ext (EVar ext t (IS IZ)))
+ (ESnd ext (EVar ext t IZ)))
+plus (STEither t1 t2) a b =
+ let t = STEither (d2 t1) (d2 t2)
+ in plusSparse t a b $
+ ECase ext (EVar ext t (IS IZ))
+ (ECase ext (EVar ext t (IS IZ))
+ (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ)))
+ (EError t "plus l+r"))
+ (ECase ext (EVar ext t (IS IZ))
+ (EError t "plus r+l")
+ (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ))))
+plus STArr{} _ _ = error "TODO arrays"
+plus (STScal t) a b = case t of
+ STI32 -> ENil ext
+ STI64 -> ENil ext
+ STF32 -> EOp ext (OAdd STF32) (EPair ext a b)
+ STF64 -> EOp ext (OAdd STF64) (EPair ext a b)
+ STBool -> ENil ext
+plus STEVM{} _ _ = error "EVM not allowed in input program"
+
+plusSparse :: STy a
+ -> Ex env (TEither TNil a) -> Ex env (TEither TNil a)
+ -> Ex (a : a : env) a
+ -> Ex env (TEither TNil a)
+plusSparse t a b adder =
+ ELet ext b $
+ ECase ext (weakenExpr WSink a)
+ (EVar ext (STEither STNil t) (IS IZ))
+ (EInr ext STNil
+ (ECase ext (EVar ext (STEither STNil t) (IS IZ))
+ (EVar ext t (IS IZ))
+ (weakenExpr (WCopy (WCopy WSink)) adder)))
+
type family Tup env where
Tup '[] = TNil
- Tup (t : ts) = TPair t (Tup ts)
+ Tup (t : ts) = TPair (Tup ts) t
tTup :: SList STy env -> STy (Tup env)
tTup SNil = STNil
-tTup (SCons t ts) = STPair t (tTup ts)
+tTup (SCons t ts) = STPair (tTup ts) t
zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0))
zeroTup SNil = ENil ext
-zeroTup (SCons t env) = EPair ext (zero t) (zeroTup env)
+zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t)
+
+onehotTup :: SList STy env0 -> Idx env0 t -> Ex env (D2 t) -> Ex env (Tup (D2E env0))
+onehotTup (SCons _ env) IZ d = EPair ext (zeroTup env) d
+onehotTup (SCons t env) (IS i) d = EPair ext (onehotTup env i d) (zero t)
+onehotTup SNil i _ = case i of {}
+
+plusTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -> Ex env (Tup (D2E env0)) -> Ex env (Tup (D2E env0))
+plusTup SNil _ _ = ENil ext
+plusTup env0@(SCons t env) a b =
+ ELet ext a $
+ ELet ext (weakenExpr WSink b) $
+ EPair ext (plusTup env (EFst ext (EVar ext (tTup (d2e env0)) (IS IZ)))
+ (EFst ext (EVar ext (tTup (d2e env0)) IZ)))
+ (plus t (ESnd ext (EVar ext (tTup (d2e env0)) (IS IZ)))
+ (ESnd ext (EVar ext (tTup (d2e env0)) IZ)))
data Ret env sto t =
forall env'.
@@ -277,58 +334,70 @@ deriving instance Show (Storage s)
-- | Environment description
data Descr env sto where
DTop :: Descr '[] '[]
- DPush :: Descr env sto -> STy t -> Storage s -> Descr (t : env) (s : sto)
+ DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
deriving instance Show (Descr env sto)
select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
select _ DTop = SNil
-select s@SAccum (DPush des t SAccum) = SCons t (select s des)
-select s@SMerge (DPush des _ SAccum) = select s des
-select s@SAccum (DPush des _ SMerge) = select s des
-select s@SMerge (DPush des t SMerge) = SCons t (select s des)
+select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
+select s@SMerge (DPush des (_, SAccum)) = select s des
+select s@SAccum (DPush des (_, SMerge)) = select s des
+select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
d2e :: SList STy env -> SList STy (D2E env)
d2e SNil = SNil
d2e (SCons t ts) = SCons (d2 t) (d2e ts)
-drev :: Descr env sto -> Ex env t -> Ret env sto t
-drev des = \case
+drev :: Descr env sto
+ -> (forall env' sto' t'. Descr env' sto' -> STy t' -> Some Storage)
+ -> Ex env t -> Ret env sto t
+drev des policy = \case
EVar _ t i ->
- case conv2Idx des i of
- Left accumI ->
- Ret BTop
- (EVar ext (d1 t) (conv1Idx i))
- (EMBind
- (EMOne d2mon accumI (EVar ext (d2 t) IZ))
- (EMReturn d2mon (zeroTup (select SMerge des))))
- Right tupI ->
- _
+ Ret BTop
+ (EVar ext (d1 t) (conv1Idx i))
+ (case conv2Idx des i of
+ Left accumI ->
+ EMBind
+ (EMOne d2acc accumI (EVar ext (d2 t) IZ))
+ (EMReturn d2acc (zeroTup (select SMerge des)))
+ Right tupI ->
+ EMReturn d2acc (onehotTup (select SMerge des) tupI (EVar ext (d2 t) IZ)))
ELet _ rhs body
- | Ret rhs0 rhs1 rhs2 <- drev des rhs
- , Ret body0 body1 body2 <- drev (DPush des (typeOf rhs) SMerge) body ->
+ | Ret rhs0 rhs1 rhs2 <- drev des policy rhs
+ , Some storage <- policy des (typeOf rhs)
+ , Ret body0 body1 body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body ->
weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' ->
Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0')
(weakenExpr wbody0' body1)
- (EMBind (EMScope (weakenExpr (WCopy wbody0') body2))
- (ELet ext (ESnd ext (EVar ext (STPair STNil (d2 (typeOf rhs))) IZ)) $
- weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2))
+ (EMBind
+ (weakenExpr (WCopy wbody0') $ case storage of SAccum -> EMScope body2 ; SMerge -> body2)
+ (ELet ext (ESnd ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 (typeOf rhs))) IZ)) $
+ EMBind
+ (weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2)
+ (EMReturn d2acc (plusTup (select SMerge des)
+ (EFst ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 (typeOf rhs))) (IS (IS IZ))))
+ (EVar ext (tTup (d2e (select SMerge des))) IZ)))))
EPair _ a b
| Rets binds (RetPair a1 a2 `SCons` RetPair b1 b2 `SCons` SNil)
- <- retConcat $ drev des a `SCons` drev des b `SCons` SNil
+ <- retConcat $ drev des policy a `SCons` drev des policy b `SCons` SNil
, let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) ->
Ret binds
(EPair ext a1 b1)
(ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
- (EMReturn d2mon (ENil ext))
+ (EMReturn d2acc (zeroTup (select SMerge des)))
(EMBind (ELet ext (EFst ext (EVar ext dt IZ))
- (weakenExpr (WCopy (wSinks @[_,_])) a2))
- (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
- (weakenExpr (WCopy (wSinks @[_,_,_])) b2))))
+ (weakenExpr (WCopy (wSinks @[_,_])) a2)) $
+ EMBind (ELet ext (ESnd ext (EVar ext dt (IS IZ)))
+ (weakenExpr (WCopy (wSinks @[_,_,_])) b2)) $
+ EMReturn d2acc
+ (plusTup (select SMerge des)
+ (EVar ext (tTup (d2e (select SMerge des))) (IS IZ))
+ (EVar ext (tTup (d2e (select SMerge des))) IZ))))
EFst _ e
- | Ret e0 e1 e2 <- drev des e
+ | Ret e0 e1 e2 <- drev des policy e
, STPair t1 t2 <- typeOf e ->
Ret e0
(EFst ext e1)
@@ -336,40 +405,42 @@ drev des = \case
weakenExpr (WCopy WSink) e2)
ESnd _ e
- | Ret e0 e1 e2 <- drev des e
+ | Ret e0 e1 e2 <- drev des policy e
, STPair t1 t2 <- typeOf e ->
Ret e0
(ESnd ext e1)
(ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
weakenExpr (WCopy WSink) e2)
- ENil _ -> Ret BTop (ENil ext) (EMReturn d2mon (ENil ext))
+ ENil _ -> Ret BTop (ENil ext) (EMReturn d2acc (zeroTup (select SMerge des)))
EInl _ t2 e
- | Ret e0 e1 e2 <- drev des e ->
+ | Ret e0 e1 e2 <- drev des policy e ->
Ret e0
(EInl ext (d1 t2) e1)
(ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
- (EMReturn d2mon (ENil ext))
+ (EMReturn d2acc (zeroTup (select SMerge des)))
(ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
(weakenExpr (WCopy (wSinks @[_,_])) e2)
- (EError (STEVM d2mon STNil) "inl<-dinr")))
+ (EError (STEVM d2acc (tTup (d2e (select SMerge des)))) "inl<-dinr")))
EInr _ t1 e
- | Ret e0 e1 e2 <- drev des e ->
+ | Ret e0 e1 e2 <- drev des policy e ->
Ret e0
(EInr ext (d1 t1) e1)
(ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
- (EMReturn d2mon (ENil ext))
+ (EMReturn d2acc (zeroTup (select SMerge des)))
(ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
- (EError (STEVM d2mon STNil) "inr<-dinl")
+ (EError (STEVM d2acc (tTup (d2e (select SMerge des)))) "inr<-dinl")
(weakenExpr (WCopy (wSinks @[_,_])) e2)))
ECase _ e a b
| STEither t1 t2 <- typeOf e
- , Ret e0 e1 e2 <- drev des e
- , Ret a0 a1 a2 <- drev (DPush des t1 SMerge) a
- , Ret b0 b1 b2 <- drev (DPush des t2 SMerge) b
+ , Ret e0 e1 e2 <- drev des policy e
+ , Some storageA <- policy des t1
+ , Some storageB <- policy des t2
+ , Ret a0 a1 a2 <- drev (des `DPush` (t1, storageA)) policy a
+ , Ret b0 b1 b2 <- drev (des `DPush` (t2, storageB)) policy b
, TupBinds tapeA collectA reconA <- tupBinds a0
, TupBinds tapeB collectB reconB <- tupBinds b0
, let tPrimal = STPair (d1 (typeOf a)) (STEither tapeA tapeB) ->
@@ -389,30 +460,30 @@ drev des = \case
TupBindsReconstruct rebinds wrebinds ->
letBinds rebinds $
ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
- EMBind (weakenExpr (WCopy wrebinds) (EMScope a2))
- (EMReturn d2mon
+ EMBind (weakenExpr (WCopy wrebinds) $ case storageA of SAccum -> EMScope a2 ; SMerge -> a2)
+ (EMReturn d2acc
(EInr ext STNil (EInl ext (d2 t2)
- (ESnd ext (EVar ext (STPair STNil (d2 t1)) IZ))))))
- (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape"))
+ (ESnd ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 t1)) IZ))))))
+ (EError (STEVM d2acc (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape"))
(ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ))))
- (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape")
+ (EError (STEVM d2acc (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape")
(case reconB (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of
TupBindsReconstruct rebinds wrebinds ->
letBinds rebinds $
ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $
- EMBind (weakenExpr (WCopy wrebinds) (EMScope b2))
- (EMReturn d2mon
+ EMBind (weakenExpr (WCopy wrebinds) $ case storageB of SAccum -> EMScope b2 ; SMerge -> b2)
+ (EMReturn d2acc
(EInr ext STNil (EInr ext (d2 t1)
- (ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ))))))))
+ (ESnd ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 t2)) IZ))))))))
(weakenExpr (WCopy (wSinks @[_,_,_])) e2))
EConst _ t val ->
Ret BTop
(EConst ext t val)
- (EMReturn d2mon (ENil ext))
+ (EMReturn d2acc (zeroTup (select SMerge des)))
EOp _ op e
- | Ret e0 e1 e2 <- drev des e ->
+ | Ret e0 e1 e2 <- drev des policy e ->
case d2op op of
Linear d2opfun ->
Ret e0
@@ -429,4 +500,4 @@ drev des = \case
e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e)
where
- d2mon = d2e (select SAccum des)
+ d2acc = d2e (select SAccum des)
diff --git a/src/Example.hs b/src/Example.hs
index f2e5966..ee07edf 100644
--- a/src/Example.hs
+++ b/src/Example.hs
@@ -1,6 +1,8 @@
{-# LANGUAGE DataKinds #-}
module Example where
+import Data.Some
+
import AST
import AST.Pretty
import CHAD
@@ -13,6 +15,11 @@ bin op a b = EOp ext op (EPair ext a b)
senv1 :: SList STy [TScal TF32, TScal TF32]
senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil
+descr1 :: Storage a -> Storage b
+ -> Descr [TScal TF32, TScal TF32] [b, a]
+descr1 a b = DTop `DPush` (t, a) `DPush` (t, b)
+ where t = STScal STF32
+
-- x y |- x * y + x
--
-- let x3 = (x1, x2)
diff --git a/src/Simplify.hs b/src/Simplify.hs
index acc2392..16a3e1d 100644
--- a/src/Simplify.hs
+++ b/src/Simplify.hs
@@ -48,6 +48,12 @@ simplify = \case
-- eta rule for return+bind
EMBind (EMReturn _ a) b -> simplify (ELet ext a b)
+ -- associativity of bind
+ EMBind (EMBind a b) c -> simplify (EMBind a (EMBind b (weakenExpr (WCopy WSink) c)))
+
+ -- bind-let commute
+ EMBind (ELet _ a b) c -> simplify (ELet ext a (EMBind b (weakenExpr (WCopy WSink) c)))
+
EVar _ t i -> EVar ext t i
ELet _ a b -> ELet ext (simplify a) (simplify b)
EPair _ a b -> EPair ext (simplify a) (simplify b)