summaryrefslogtreecommitdiff
path: root/src/CHAD.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/CHAD.hs')
-rw-r--r--src/CHAD.hs149
1 files changed, 34 insertions, 115 deletions
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 9f58f73..3da083b 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -32,18 +32,17 @@ module CHAD (
) where
import Data.Functor.Const
-import Data.Kind (Type)
import Data.Type.Bool (If)
import Data.Type.Equality (type (==))
import GHC.Stack (HasCallStack)
-import GHC.TypeLits (Symbol)
import AST
import AST.Bindings
import AST.Count
import AST.Env
import AST.Weaken.Auto
-import CHAD.Heuristics
+import CHAD.Accum
+import CHAD.EnvDescr
import CHAD.Types
import Data
import Lemmas
@@ -197,66 +196,6 @@ reconstructBindings binds tape =
,sreverse (stapeUnfoldings binds))
----------------------- ENVIRONMENT DESCRIPTION AND STORAGE ---------------------
-
-type Storage :: Symbol -> Type
-data Storage s where
- SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator
- SMerge :: Storage "merge" -- ^ just return and merge
- SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions
-deriving instance Show (Storage s)
-
--- | Environment description
-data Descr env sto where
- DTop :: Descr '[] '[]
- DPush :: Descr env sto -> (STy t, Storage s) -> Descr (t : env) (s : sto)
-deriving instance Show (Descr env sto)
-
-descrList :: Descr env sto -> SList STy env
-descrList DTop = SNil
-descrList (des `DPush` (t, _)) = t `SCons` descrList des
-
--- | This could have more precise typing on the output storage.
-subDescr :: Descr env sto -> Subenv env env'
- -> (forall sto'. Descr env' sto'
- -> Subenv (Select env sto "merge") (Select env' sto' "merge")
- -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum"))
- -> Subenv (D1E env) (D1E env')
- -> r)
- -> r
-subDescr DTop SETop k = k DTop SETop SETop SETop
-subDescr (des `DPush` (t, sto)) (SEYes sub) k =
- subDescr des sub $ \des' submerge subaccum subd1e ->
- case sto of
- SMerge -> k (des' `DPush` (t, sto)) (SEYes submerge) subaccum (SEYes subd1e)
- SAccum -> k (des' `DPush` (t, sto)) submerge (SEYes subaccum) (SEYes subd1e)
- SDiscr -> k (des' `DPush` (t, sto)) submerge subaccum (SEYes subd1e)
-subDescr (des `DPush` (_, sto)) (SENo sub) k =
- subDescr des sub $ \des' submerge subaccum subd1e ->
- case sto of
- SMerge -> k des' (SENo submerge) subaccum (SENo subd1e)
- SAccum -> k des' submerge (SENo subaccum) (SENo subd1e)
- SDiscr -> k des' submerge subaccum (SENo subd1e)
-
--- | Select only the types from the environment that have the specified storage
-type family Select env sto s where
- Select '[] '[] _ = '[]
- Select (t : ts) (s : sto) s = t : Select ts sto s
- Select (_ : ts) (_ : sto) s = Select ts sto s
-
-select :: Storage s -> Descr env sto -> SList STy (Select env sto s)
-select _ DTop = SNil
-select s@SAccum (DPush des (t, SAccum)) = SCons t (select s des)
-select s@SMerge (DPush des (_, SAccum)) = select s des
-select s@SDiscr (DPush des (_, SAccum)) = select s des
-select s@SAccum (DPush des (_, SMerge)) = select s des
-select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des)
-select s@SDiscr (DPush des (_, SMerge)) = select s des
-select s@SAccum (DPush des (_, SDiscr)) = select s des
-select s@SMerge (DPush des (_, SDiscr)) = select s des
-select s@SDiscr (DPush des (t, SDiscr)) = SCons t (select s des)
-
-
---------------------------------- DERIVATIVES ---------------------------------
d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t)
@@ -283,24 +222,24 @@ data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a))
d2op :: SOp a t -> D2Op a t
d2op op = case op of
- OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EInr ext STNil (EPair ext d d)
+ OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EJust ext (EPair ext d d)
OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d ->
- EInr ext STNil (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
- (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
+ EJust ext (EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d))
+ (EOp ext (OMul t) (EPair ext (EFst ext e) d)))
ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d
- OLt t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OLe t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
- OEq t -> Linear $ \_ -> EInl ext (STPair (d2 (STScal t)) (d2 (STScal t))) (ENil ext)
+ OLt t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
+ OLe t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
+ OEq t -> Linear $ \_ -> ENothing ext (STPair (d2 (STScal t)) (d2 (STScal t)))
ONot -> Linear $ \_ -> ENil ext
- OAnd -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- OOr -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OAnd -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ OOr -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
OIf -> Linear $ \_ -> ENil ext
ORound64 -> Linear $ \_ -> EConst ext STF64 0.0
OToFl64 -> Linear $ \_ -> ENil ext
ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d)
OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d)
OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d)
- OIDiv t -> integralD2 t $ Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ OIDiv t -> integralD2 t $ Linear $ \_ -> ENothing ext (STPair STNil STNil)
where
d2opUnArrangeInt :: SScalTy a
-> (D2s a ~ TScal a => D2Op (TScal a) t)
@@ -316,11 +255,11 @@ d2op op = case op of
-> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t)
-> D2Op (TPair (TScal a) (TScal a)) t
d2opBinArrangeInt ty float = case ty of
- STI32 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
- STI64 -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STI32 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
+ STI64 -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
STF32 -> float
STF64 -> float
- STBool -> Linear $ \_ -> EInl ext (STPair STNil STNil) (ENil ext)
+ STBool -> Linear $ \_ -> ENothing ext (STPair STNil STNil)
floatingD2 :: ScalIsFloating a ~ True
=> SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r
@@ -332,13 +271,8 @@ d2op op = case op of
integralD2 STI32 k = k
integralD2 STI64 k = k
-sD1eEnv :: Descr env sto -> SList STy (D1E env)
-sD1eEnv DTop = SNil
-sD1eEnv (DPush d (t, _)) = SCons (d1 t) (sD1eEnv d)
-
-d2ace :: SList STy env -> SList STy (D2AcE env)
-d2ace SNil = SNil
-d2ace (SCons t ts) = SCons (STAccum (d2 t)) (d2ace ts)
+desD1E :: Descr env sto -> SList STy (D1E env)
+desD1E = d1e . descrList
-- d1W :: env :> env' -> D1E env :> D1E env'
-- d1W WId = WId
@@ -610,24 +544,6 @@ accumPromote pdty (descr `DPush` (t :: STy t, sto)) k =
-- STScal{} -> False
-- STAccum{} -> error "An accumulator in merge storage?"
-makeAccumulators :: SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro))
-makeAccumulators SNil e = e
-makeAccumulators (STArr n t `SCons` envpro) e =
- makeAccumulators envpro $
- EWith (zero (STArr n t)) e
-makeAccumulators (t `SCons` _) _ = error $ "makeAccumulators: Not only arrays in envpro: " ++ show t
-
-uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list))
-uninvertTup SNil _ e = EPair ext e (ENil ext)
-uninvertTup (t `SCons` list) tcore e =
- ELet ext (uninvertTup list (STPair tcore t) e) $
- let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding
- in EPair ext
- (EFst ext (EFst ext (EVar ext recT IZ)))
- (EPair ext
- (ESnd ext (EVar ext recT IZ))
- (ESnd ext (EFst ext (EVar ext recT IZ))))
-
---------------------------- RETURN TRIPLE FROM CHAD ---------------------------
@@ -701,7 +617,7 @@ freezeRet :: Descr env sto
-> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge"))))
freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
let (e0', wInsertD2Ac) = weakenBindings weakenExpr (WSink .> wSinks (d2ace (select SAccum descr))) e0
- e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (sD1eEnv descr)))) e2
+ e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2
in letBinds e0' $
EPair ext
(weakenExpr wInsertD2Ac e1)
@@ -709,7 +625,7 @@ freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ t) =
&. #tape (subList (bindingsBinds e0) subtape)
&. #shbinds (bindingsBinds e0)
&. #d2ace (d2ace (select SAccum descr))
- &. #tl (sD1eEnv descr))
+ &. #tl (desD1E descr))
(#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl)
(#shbinds :++: #d :++: #d2ace :++: #tl))
e2') $
@@ -782,7 +698,7 @@ drev des = \case
subtape
(EPair ext a1 b1)
subBoth
- (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ)
+ (EMaybe ext
(zeroTup (subList (select SMerge des) subBoth))
(ELet ext (ELet ext (EFst ext (EVar ext dt IZ))
(weakenExpr (WCopy (wSinks' @[_,_])) a2)) $
@@ -790,7 +706,8 @@ drev des = \case
(weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $
plus_A_B
(EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ))
- (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ)))
+ (EVar ext (tTup (d2e (subList (select SMerge des) subB))) IZ))
+ (EVar ext (STMaybe (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ))
EFst _ e
| Ret e0 subtape e1 sub e2 <- drev des e
@@ -799,7 +716,7 @@ drev des = \case
subtape
(EFst ext e1)
sub
- (ELet ext (EInr ext STNil (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $
+ (ELet ext (EJust ext (EPair ext (EVar ext (d2 t1) IZ) (zero t2))) $
weakenExpr (WCopy WSink) e2)
ESnd _ e
@@ -809,7 +726,7 @@ drev des = \case
subtape
(ESnd ext e1)
sub
- (ELet ext (EInr ext STNil (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
+ (ELet ext (EJust ext (EPair ext (zero t1) (EVar ext (d2 t2) IZ))) $
weakenExpr (WCopy WSink) e2)
ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (select SMerge des)) (ENil ext)
@@ -820,11 +737,12 @@ drev des = \case
subtape
(EInl ext (d1 t2) e1)
sub
- (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ)
+ (EMaybe ext
(zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ)
(weakenExpr (WCopy (wSinks' @[_,_])) e2)
- (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr")))
+ (EError (tTup (d2e (subList (select SMerge des) sub))) "inl<-dinr"))
+ (EVar ext (STMaybe (STEither (d2 (typeOf e)) (d2 t2))) IZ))
EInr _ t1 e
| Ret e0 subtape e1 sub e2 <- drev des e ->
@@ -832,11 +750,12 @@ drev des = \case
subtape
(EInr ext (d1 t1) e1)
sub
- (ECase ext (EVar ext (STEither STNil (STEither (d2 t1) (d2 (typeOf e)))) IZ)
+ (EMaybe ext
(zeroTup (subList (select SMerge des) sub))
(ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ)
(EError (tTup (d2e (subList (select SMerge des) sub))) "inr<-dinl")
- (weakenExpr (WCopy (wSinks' @[_,_])) e2)))
+ (weakenExpr (WCopy (wSinks' @[_,_])) e2))
+ (EVar ext (STMaybe (STEither (d2 t1) (d2 (typeOf e)))) IZ))
ECase _ e (a :: Ex _ t) b
| STEither t1 t2 <- typeOf e
@@ -907,7 +826,7 @@ drev des = \case
(EInr ext (d2 t1)
(ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB))) (d2 t2)) IZ))))) $
ELet ext
- (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $
+ (ELet ext (EJust ext (ESnd ext (EVar ext tCaseRet IZ))) $
weakenExpr (WCopy (wSinks' @[_,_,_])) e2) $
plus_AB_E
(EFst ext (EVar ext tCaseRet (IS IZ)))
@@ -986,16 +905,16 @@ drev des = \case
(EVar ext shty IZ)
(letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil)
&. #sh (shty `SCons` SNil)
- &. #d1env (sD1eEnv des)
- &. #d1env' (sD1eEnv usedDes))
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
(#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
(#ix :++: #sh :++: #d1env))
e0)) $
let w = autoWeak (#ix (shty `SCons` SNil)
&. #sh (shty `SCons` SNil)
&. #e0 (bindingsBinds e0)
- &. #d1env (sD1eEnv des)
- &. #d1env' (sD1eEnv usedDes))
+ &. #d1env (desD1E des)
+ &. #d1env' (desD1E usedDes))
(#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed))
(#e0 :++: #ix :++: #sh :++: #d1env)
in EPair ext (weakenExpr w e1) (collectexpr w)))