diff options
author | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 15:25:13 +0100 |
---|---|---|
committer | Tom Smeding <tom@tomsmeding.com> | 2024-11-26 15:25:13 +0100 |
commit | ae2b1b71a91d60d3bd1dfb21fce98c05c1a4fcbb (patch) | |
tree | 1f6afda4b1d6925fe8224ee4f2ca40212fe11aa6 /src/CHAD.hs | |
parent | 7774da51c532006da82617ce307d136897693280 (diff) |
WIP accum top-level args
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r-- | src/CHAD.hs | 149 |
1 files changed, 34 insertions, 115 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs index 9f58f73..3da083b 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -32,18 +32,17 @@ module CHAD ( ) where import Data.Functor.Const -import Data.Kind (Type) import Data.Type.Bool (If) import Data.Type.Equality (type (==)) import GHC.Stack (HasCallStack) -import GHC.TypeLits (Symbol) import AST import AST.Bindings import AST.Count import AST.Env import AST.Weaken.Auto -import CHAD.Heuristics +import CHAD.Accum +import CHAD.EnvDescr import CHAD.Types import Data import Lemmas @@ -197,66 +196,6 @@ reconstructBindings binds tape = ,sreverse (stapeUnfoldings binds)) ----------------------- ENVIRONMENT DESCRIPTION AND STORAGE --------------------- - -type Storage :: Symbol -> Type -data Storage s where - SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator - SMerge :: Storage "merge" -- ^ just return and merge - SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions -deriving instance Show (Storage s) - --- | Environment description -data Descr env sto where - DTop :: Descr '[] '[] - DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto) -deriving instance Show (Descr env sto) - -descrList :: Descr env sto -> SList STy env -descrList DTop = SNil -descrList (des `DPush` (t, _)) = t `SCons` descrList des - --- | This could have more precise typing on the output storage. -subDescr :: Descr env sto -> Subenv env env' - -> (forall sto'. Descr env' sto' - -> Subenv (Select env sto "merge") (Select env' sto' "merge") - -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) - -> Subenv (D1E env) (D1E env') - -> r) - -> r -subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, sto)) (SEYes sub) k = - subDescr des sub $ \des' submerge subaccum subd1e -> - case sto of - SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e) - SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e) - SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e) -subDescr (des `DPush` (_, sto)) (SENo sub) k = - subDescr des sub $ \des' submerge subaccum subd1e -> - case sto of - SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) - SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) - SDiscr -> k des' submerge subaccum (SENo subd1e) - --- | Select only the types from the environment that have the specified storage -type family Select env sto s where - Select '[] '[] _ = '[] - Select (t : ts) (s : sto) s = t : Select ts sto s - Select (_ : ts) (_ : sto) s = Select ts sto s - -select :: Storage s -> Descr env sto -> SList STy (Select env sto s) -select _ DTop = SNil -select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des) -select s@SMerge (DPush des (_, SAccum)) = select s des -select s@SDiscr (DPush des (_, SAccum)) = select s des -select s@SAccum (DPush des (_, SMerge)) = select s des -select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) -select s@SDiscr (DPush des (_, SMerge)) = select s des -select s@SAccum (DPush des (_, SDiscr)) = select s des -select s@SMerge (DPush des (_, SDiscr)) = select s des -select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des) - - ---------------------------------- DERIVATIVES --------------------------------- d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) @@ -283,24 +222,24 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) d2op :: SOp a t -> D2Op a t d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d) + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d) OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d))) + EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d))) ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) - OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext) + OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) + OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) + OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t))) ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil) OIf -> Linear $ \_ -> ENil ext ORound64 -> Linear $ \_ -> EConst ext STF64 0.0 OToFl64 -> Linear $ \_ -> ENil ext ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil) where d2opUnArrangeInt :: SScalTy a -> (D2s a ~ TScal a => D2Op (TScal a) t) @@ -316,11 +255,11 @@ d2op op = case op of -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) -> D2Op (TPair (TScal a) (TScal a)) t d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) - STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) + STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil) STF32 -> float STF64 -> float - STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext) + STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil) floatingD2 :: ScalIsFloating a ~ True => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r @@ -332,13 +271,8 @@ d2op op = case op of integralD2 STI32 k = k integralD2 STI64 k = k -sD1eEnv :: Descr env sto -> SList STy (D1E env) -sD1eEnv DTop = SNil -sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d) - -d2ace :: SList STy env -> SList STy (D2AcE env) -d2ace SNil = SNil -d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts) +desD1E :: Descr env sto -> SList STy (D1E env) +desD1E = d1e . descrList -- d1W :: env :> env' -> D1E env :> D1E env' -- d1W WId = WId @@ -610,24 +544,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k = -- STScal{} -> False -- STAccum{} -> error "An accumulator in merge storage?" -makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators SNil e = e -makeAccumulators (STArr n t `SCons` envpro) e = - makeAccumulators envpro $ - EWith (zero (STArr n t)) e -makeAccumulators (t `SCons` _) _ = error $ "makeAccumulators: Not only arrays in envpro: " ++ show t - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - ---------------------------- RETURN TRIPLE FROM CHAD --------------------------- @@ -701,7 +617,7 @@ freezeRet :: Descr env sto -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0 - e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2 + e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 in letBinds e0' $ EPair ext (weakenExpr wInsertD2Ac e1) @@ -709,7 +625,7 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) = &. #tape (subList (bindingsBinds e0) subtape) &. #shbinds (bindingsBinds e0) &. #d2ace (d2ace (select SAccum descr)) - &. #tl (sD1eEnv descr)) + &. #tl (desD1E descr)) (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) (#shbinds :++: #d :++: #d2ace :++: #tl)) e2') $ @@ -782,7 +698,7 @@ drev des = \case subtape (EPair ext a1 b1) subBoth - (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) + (EMaybe ext (zeroTup (subList (select SMerge des) subBoth)) (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ @@ -790,7 +706,8 @@ drev des = \case (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ plus_A_B (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) - (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))) + (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)) + (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)) EFst _ e | Ret e0 subtape e1 sub e2 <- drev des e @@ -799,7 +716,7 @@ drev des = \case subtape (EFst ext e1) sub - (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $ + (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $ weakenExpr (WCopy WSink) e2) ESnd _ e @@ -809,7 +726,7 @@ drev des = \case subtape (ESnd ext e1) sub - (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ + (ELet ext (EJust ext (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $ weakenExpr (WCopy WSink) e2) ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext) @@ -820,11 +737,12 @@ drev des = \case subtape (EInl ext (d1 t2) e1) sub - (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) + (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) (weakenExpr (WCopy (wSinks' @[_,_])) e2) - (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))) + (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")) + (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ)) EInr _ t1 e | Ret e0 subtape e1 sub e2 <- drev des e -> @@ -832,11 +750,12 @@ drev des = \case subtape (EInr ext (d1 t1) e1) sub - (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ) + (EMaybe ext (zeroTup (subList (select SMerge des) sub)) (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) (EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl") - (weakenExpr (WCopy (wSinks' @[_,_])) e2))) + (weakenExpr (WCopy (wSinks' @[_,_])) e2)) + (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ)) ECase _ e (a :: Ex _ t) b | STEither t1 t2 <- typeOf e @@ -907,7 +826,7 @@ drev des = \case (EInr ext (d2 t1) (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $ ELet ext - (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $ + (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $ weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $ plus_AB_E (EFst ext (EVar ext tCaseRet (IS IZ))) @@ -986,16 +905,16 @@ drev des = \case (EVar ext shty IZ) (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) &. #sh (shty `SCons` SNil) - &. #d1env (sD1eEnv des) - &. #d1env' (sD1eEnv usedDes)) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) (#ix :++: #sh :++: #d1env)) e0)) $ let w = autoWeak (#ix (shty `SCons` SNil) &. #sh (shty `SCons` SNil) &. #e0 (bindingsBinds e0) - &. #d1env (sD1eEnv des) - &. #d1env' (sD1eEnv usedDes)) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes)) (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) (#e0 :++: #ix :++: #sh :++: #d1env) in EPair ext (weakenExpr w e1) (collectexpr w))) |