diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2023-09-21 15:52:22 +0200 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2023-09-21 15:52:22 +0200 |
commit | 574569ee96a01d623baf8efdcd3908eef42b8007 (patch) | |
tree | cffc7d81e32b2b52430ac7da97a2a31102f55c97 | |
parent | 8d07a43f0b364156433dc453b9d1cc762c032634 (diff) |
Storage policy (accum / merge)
-rw-r--r-- | src/AST.hs | 2 | ||||
-rw-r--r-- | src/CHAD.hs | 185 | ||||
-rw-r--r-- | src/Example.hs | 7 | ||||
-rw-r--r-- | src/Simplify.hs | 6 |
4 files changed, 142 insertions, 58 deletions
@@ -28,7 +28,7 @@ data Nat = Z | S Nat data SNat n where SZ :: SNat Z SS :: SNat n -> SNat (S n) -deriving instance (Show (SNat n)) +deriving instance Show (SNat n) data Vec n t where VNil :: Vec Z t diff --git a/src/CHAD.hs b/src/CHAD.hs index 0c856b1..b074470 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -15,6 +15,7 @@ module CHAD where import Data.Bifunctor (first, second) import Data.Kind (Type) +import Data.Some import GHC.TypeLits (Symbol) import AST @@ -142,11 +143,11 @@ conv1Idx IZ = IZ conv1Idx (IS i) = IS (conv1Idx i) conv2Idx :: Descr env sto -> Idx env t -> Either (Idx (D2E (Select env sto "accum")) (D2 t)) - (Idx (D2E (Select env sto "merge")) (D2 t)) -conv2Idx (DPush _ _ SAccum) IZ = Left IZ -conv2Idx (DPush _ _ SMerge) IZ = Right IZ -conv2Idx (DPush des _ SAccum) (IS i) = first IS (conv2Idx des i) -conv2Idx (DPush des _ SMerge) (IS i) = second IS (conv2Idx des i) + (Idx (Select env sto "merge") t) +conv2Idx (DPush _ (_, SAccum)) IZ = Left IZ +conv2Idx (DPush _ (_, SMerge)) IZ = Right IZ +conv2Idx (DPush des (_, SAccum)) (IS i) = first IS (conv2Idx des i) +conv2Idx (DPush des (_, SMerge)) (IS i) = second IS (conv2Idx des i) conv2Idx DTop i = case i of {} zero :: STy t -> Ex env (D2 t) @@ -162,17 +163,73 @@ zero (STScal t) = case t of STBool -> ENil ext zero STEVM{} = error "EVM not allowed in input program" +plus :: STy t -> Ex env (D2 t) -> Ex env (D2 t) -> Ex env (D2 t) +plus STNil _ _ = ENil ext +plus (STPair t1 t2) a b = + let t = STPair (d2 t1) (d2 t2) + in plusSparse t a b $ + EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) + (EFst ext (EVar ext t IZ))) + (plus t2 (ESnd ext (EVar ext t (IS IZ))) + (ESnd ext (EVar ext t IZ))) +plus (STEither t1 t2) a b = + let t = STEither (d2 t1) (d2 t2) + in plusSparse t a b $ + ECase ext (EVar ext t (IS IZ)) + (ECase ext (EVar ext t (IS IZ)) + (EInl ext (d2 t2) (plus t1 (EVar ext (d2 t1) (IS IZ)) (EVar ext (d2 t1) IZ))) + (EError t "plus l+r")) + (ECase ext (EVar ext t (IS IZ)) + (EError t "plus r+l") + (EInr ext (d2 t1) (plus t2 (EVar ext (d2 t2) (IS IZ)) (EVar ext (d2 t2) IZ)))) +plus STArr{} _ _ = error "TODO arrays" +plus (STScal t) a b = case t of + STI32 -> ENil ext + STI64 -> ENil ext + STF32 -> EOp ext (OAdd STF32) (EPair ext a b) + STF64 -> EOp ext (OAdd STF64) (EPair ext a b) + STBool -> ENil ext +plus STEVM{} _ _ = error "EVM not allowed in input program" + +plusSparse :: STy a + -> Ex env (TEither TNil a) -> Ex env (TEither TNil a) + -> Ex (a : a : env) a + -> Ex env (TEither TNil a) +plusSparse t a b adder = + ELet ext b $ + ECase ext (weakenExpr WSink a) + (EVar ext (STEither STNil t) (IS IZ)) + (EInr ext STNil + (ECase ext (EVar ext (STEither STNil t) (IS IZ)) + (EVar ext t (IS IZ)) + (weakenExpr (WCopy (WCopy WSink)) adder))) + type family Tup env where Tup '[] = TNil - Tup (t : ts) = TPair t (Tup ts) + Tup (t : ts) = TPair (Tup ts) t tTup :: SList STy env -> STy (Tup env) tTup SNil = STNil -tTup (SCons t ts) = STPair t (tTup ts) +tTup (SCons t ts) = STPair (tTup ts) t zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext -zeroTup (SCons t env) = EPair ext (zero t) (zeroTup env) +zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) + +onehotTup :: SList STy env0 -> Idx env0 t -> Ex env (D2 t) -> Ex env (Tup (D2E env0)) +onehotTup (SCons _ env) IZ d = EPair ext (zeroTup env) d +onehotTup (SCons t env) (IS i) d = EPair ext (onehotTup env i d) (zero t) +onehotTup SNil i _ = case i of {} + +plusTup :: SList STy env0 -> Ex env (Tup (D2E env0)) -> Ex env (Tup (D2E env0)) -> Ex env (Tup (D2E env0)) +plusTup SNil _ _ = ENil ext +plusTup env0@(SCons t env) a b = + ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EPair ext (plusTup env (EFst ext (EVar ext (tTup (d2e env0)) (IS IZ))) + (EFst ext (EVar ext (tTup (d2e env0)) IZ))) + (plus t (ESnd ext (EVar ext (tTup (d2e env0)) (IS IZ))) + (ESnd ext (EVar ext (tTup (d2e env0)) IZ))) data Ret env sto t = forall env'. @@ -277,58 +334,70 @@ deriving instance Show (Storage s) -- | Environment description data Descr env sto where DTop :: Descr '[] '[] - DPush :: Descr env sto -> STy t -> Storage s -> Descr (t : env) (s : sto) + DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) deriving instance Show (Descr env sto) select :: Storage s -> Descr env sto -> SList STy (Select env sto s) select _ DTop = SNil -select s@SAccum (DPush des t SAccum) = SCons t (select s des) -select s@SMerge (DPush des _ SAccum) = select s des -select s@SAccum (DPush des _ SMerge) = select s des -select s@SMerge (DPush des t SMerge) = SCons t (select s des) +select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, SAccum)) = select s des +select s@SAccum (DPush des (_, SMerge)) = select s des +select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) d2e :: SList STy env -> SList STy (D2E env) d2e SNil = SNil d2e (SCons t ts) = SCons (d2 t) (d2e ts) -drev :: Descr env sto -> Ex env t -> Ret env sto t -drev des = \case +drev :: Descr env sto + -> (forall env' sto' t'. Descr env' sto' -> STy t' -> Some Storage) + -> Ex env t -> Ret env sto t +drev des policy = \case EVar _ t i -> - case conv2Idx des i of - Left accumI -> - Ret BTop - (EVar ext (d1 t) (conv1Idx i)) - (EMBind - (EMOne d2mon accumI (EVar ext (d2 t) IZ)) - (EMReturn d2mon (zeroTup (select SMerge des)))) - Right tupI -> - _ + Ret BTop + (EVar ext (d1 t) (conv1Idx i)) + (case conv2Idx des i of + Left accumI -> + EMBind + (EMOne d2acc accumI (EVar ext (d2 t) IZ)) + (EMReturn d2acc (zeroTup (select SMerge des))) + Right tupI -> + EMReturn d2acc (onehotTup (select SMerge des) tupI (EVar ext (d2 t) IZ))) ELet _ rhs body - | Ret rhs0 rhs1 rhs2 <- drev des rhs - , Ret body0 body1 body2 <- drev (DPush des (typeOf rhs) SMerge) body -> + | Ret rhs0 rhs1 rhs2 <- drev des policy rhs + , Some storage <- policy des (typeOf rhs) + , Ret body0 body1 body2 <- drev (des `DPush` (typeOf rhs, storage)) policy body -> weakenBindings weakenExpr (WCopy (sinkWithBindings rhs0)) body0 $ \body0' wbody0' -> Ret (bconcat (rhs0 `BPush` (d1 (typeOf rhs), rhs1)) body0') (weakenExpr wbody0' body1) - (EMBind (EMScope (weakenExpr (WCopy wbody0') body2)) - (ELet ext (ESnd ext (EVar ext (STPair STNil (d2 (typeOf rhs))) IZ)) $ - weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2)) + (EMBind + (weakenExpr (WCopy wbody0') $ case storage of SAccum -> EMScope body2 ; SMerge -> body2) + (ELet ext (ESnd ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 (typeOf rhs))) IZ)) $ + EMBind + (weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2) + (EMReturn d2acc (plusTup (select SMerge des) + (EFst ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 (typeOf rhs))) (IS (IS IZ)))) + (EVar ext (tTup (d2e (select SMerge des))) IZ))))) EPair _ a b | Rets binds (RetPair a1 a2 `SCons` RetPair b1 b2 `SCons` SNil) - <- retConcat $ drev des a `SCons` drev des b `SCons` SNil + <- retConcat $ drev des policy a `SCons` drev des policy b `SCons` SNil , let dt = STPair (d2 (typeOf a)) (d2 (typeOf b)) -> Ret binds (EPair ext a1 b1) (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) - (EMReturn d2mon (ENil ext)) + (EMReturn d2acc (zeroTup (select SMerge des))) (EMBind (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy (wSinks @[_,_])) a2)) - (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (wSinks @[_,_,_])) b2)))) + (weakenExpr (WCopy (wSinks @[_,_])) a2)) $ + EMBind (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (wSinks @[_,_,_])) b2)) $ + EMReturn d2acc + (plusTup (select SMerge des) + (EVar ext (tTup (d2e (select SMerge des))) (IS IZ)) + (EVar ext (tTup (d2e (select SMerge des))) IZ)))) EFst _ e - | Ret e0 e1 e2 <- drev des e + | Ret e0 e1 e2 <- drev des policy e , STPair t1 t2 <- typeOf e -> Ret e0 (EFst ext e1) @@ -336,40 +405,42 @@ drev des = \case weakenExpr (WCopy WSink) e2) ESnd _ e - | Ret e0 e1 e2 <- drev des e + | Ret e0 e1 e2 <- drev des policy e , STPair t1 t2 <- typeOf e -> Ret e0 (ESnd ext e1) (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) - ENil _ -> Ret BTop (ENil ext) (EMReturn d2mon (ENil ext)) + ENil _ -> Ret BTop (ENil ext) (EMReturn d2acc (zeroTup (select SMerge des))) EInl _ t2 e - | Ret e0 e1 e2 <- drev des e -> + | Ret e0 e1 e2 <- drev des policy e -> Ret e0 (EInl ext (d1 t2) e1) (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) - (EMReturn d2mon (ENil ext)) + (EMReturn d2acc (zeroTup (select SMerge des))) (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) (weakenExpr (WCopy (wSinks @[_,_])) e2) - (EError (STEVM d2mon STNil) "inl<-dinr"))) + (EError (STEVM d2acc (tTup (d2e (select SMerge des)))) "inl<-dinr"))) EInr _ t1 e - | Ret e0 e1 e2 <- drev des e -> + | Ret e0 e1 e2 <- drev des policy e -> Ret e0 (EInr ext (d1 t1) e1) (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) - (EMReturn d2mon (ENil ext)) + (EMReturn d2acc (zeroTup (select SMerge des))) (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) - (EError (STEVM d2mon STNil) "inr<-dinl") + (EError (STEVM d2acc (tTup (d2e (select SMerge des)))) "inr<-dinl") (weakenExpr (WCopy (wSinks @[_,_])) e2))) ECase _ e a b | STEither t1 t2 <- typeOf e - , Ret e0 e1 e2 <- drev des e - , Ret a0 a1 a2 <- drev (DPush des t1 SMerge) a - , Ret b0 b1 b2 <- drev (DPush des t2 SMerge) b + , Ret e0 e1 e2 <- drev des policy e + , Some storageA <- policy des t1 + , Some storageB <- policy des t2 + , Ret a0 a1 a2 <- drev (des `DPush` (t1, storageA)) policy a + , Ret b0 b1 b2 <- drev (des `DPush` (t2, storageB)) policy b , TupBinds tapeA collectA reconA <- tupBinds a0 , TupBinds tapeB collectB reconB <- tupBinds b0 , let tPrimal = STPair (d1 (typeOf a)) (STEither tapeA tapeB) -> @@ -389,30 +460,30 @@ drev des = \case TupBindsReconstruct rebinds wrebinds -> letBinds rebinds $ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ - EMBind (weakenExpr (WCopy wrebinds) (EMScope a2)) - (EMReturn d2mon + EMBind (weakenExpr (WCopy wrebinds) $ case storageA of SAccum -> EMScope a2 ; SMerge -> a2) + (EMReturn d2acc (EInr ext STNil (EInl ext (d2 t2) - (ESnd ext (EVar ext (STPair STNil (d2 t1)) IZ)))))) - (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape")) + (ESnd ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 t1)) IZ)))))) + (EError (STEVM d2acc (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase l/rtape")) (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ)))) - (EError (STEVM d2mon (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape") + (EError (STEVM d2acc (STEither STNil (STEither (d2 t1) (d2 t2)))) "dcase r/ltape") (case reconB (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of TupBindsReconstruct rebinds wrebinds -> letBinds rebinds $ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ - EMBind (weakenExpr (WCopy wrebinds) (EMScope b2)) - (EMReturn d2mon + EMBind (weakenExpr (WCopy wrebinds) $ case storageB of SAccum -> EMScope b2 ; SMerge -> b2) + (EMReturn d2acc (EInr ext STNil (EInr ext (d2 t1) - (ESnd ext (EVar ext (STPair STNil (d2 t2)) IZ)))))))) + (ESnd ext (EVar ext (STPair (tTup (d2e (select SMerge des))) (d2 t2)) IZ)))))))) (weakenExpr (WCopy (wSinks @[_,_,_])) e2)) EConst _ t val -> Ret BTop (EConst ext t val) - (EMReturn d2mon (ENil ext)) + (EMReturn d2acc (zeroTup (select SMerge des))) EOp _ op e - | Ret e0 e1 e2 <- drev des e -> + | Ret e0 e1 e2 <- drev des policy e -> case d2op op of Linear d2opfun -> Ret e0 @@ -429,4 +500,4 @@ drev des = \case e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e) where - d2mon = d2e (select SAccum des) + d2acc = d2e (select SAccum des) diff --git a/src/Example.hs b/src/Example.hs index f2e5966..ee07edf 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -1,6 +1,8 @@ {-# LANGUAGE DataKinds #-} module Example where +import Data.Some + import AST import AST.Pretty import CHAD @@ -13,6 +15,11 @@ bin op a b = EOp ext op (EPair ext a b) senv1 :: SList STy [TScal TF32, TScal TF32] senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil +descr1 :: Storage a -> Storage b + -> Descr [TScal TF32, TScal TF32] [b, a] +descr1 a b = DTop `DPush` (t, a) `DPush` (t, b) + where t = STScal STF32 + -- x y |- x * y + x -- -- let x3 = (x1, x2) diff --git a/src/Simplify.hs b/src/Simplify.hs index acc2392..16a3e1d 100644 --- a/src/Simplify.hs +++ b/src/Simplify.hs @@ -48,6 +48,12 @@ simplify = \case -- eta rule for return+bind EMBind (EMReturn _ a) b -> simplify (ELet ext a b) + -- associativity of bind + EMBind (EMBind a b) c -> simplify (EMBind a (EMBind b (weakenExpr (WCopy WSink) c))) + + -- bind-let commute + EMBind (ELet _ a b) c -> simplify (ELet ext a (EMBind b (weakenExpr (WCopy WSink) c))) + EVar _ t i -> EVar ext t i ELet _ a b -> ELet ext (simplify a) (simplify b) EPair _ a b -> EPair ext (simplify a) (simplify b) |