diff options
author | Tom Smeding <t.j.smeding@uu.nl> | 2024-01-25 17:25:32 +0100 |
---|---|---|
committer | Tom Smeding <t.j.smeding@uu.nl> | 2024-01-25 17:25:32 +0100 |
commit | 39b899b4951be5b78058d5c0e35977b065a63951 (patch) | |
tree | 787a7f68a111513c890e141cda215331189535db | |
parent | 11ad6ad3f4ff2c3aa8eaff4d6124f361716cafff (diff) |
Getting further
-rw-r--r-- | chad-fast.cabal | 1 | ||||
-rw-r--r-- | src/AST.hs | 12 | ||||
-rw-r--r-- | src/AST/Weaken.hs | 39 | ||||
-rw-r--r-- | src/CHAD.hs | 104 | ||||
-rw-r--r-- | src/Data.hs | 19 |
5 files changed, 122 insertions, 53 deletions
diff --git a/chad-fast.cabal b/chad-fast.cabal index fb5f4de..66452d9 100644 --- a/chad-fast.cabal +++ b/chad-fast.cabal @@ -16,6 +16,7 @@ library AST.Weaken CHAD -- Compile + Data Example Lemmas PreludeCu @@ -21,6 +21,7 @@ import Data.Kind (Type) import Data.Int import AST.Weaken +import Data data Nat = Z | S Nat @@ -39,12 +40,6 @@ deriving instance Functor (Vec n) deriving instance Foldable (Vec n) deriving instance Traversable (Vec n) -data SList f l where - SNil :: SList f '[] - SCons :: f a -> SList f l -> SList f (a : l) -deriving instance (forall a. Show (f a)) => Show (SList f l) -infixr `SCons` - data Ty = TNil | TPair Ty Ty @@ -232,6 +227,7 @@ WCopy _ @> IZ = IZ WCopy w @> (IS i) = IS (w @> i) WPop w @> i = w @> IS i WThen w1 w2 @> i = w2 @> w1 @> i +WClosed _ @> i = case i of {} weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t weakenExpr w = \case @@ -273,10 +269,6 @@ weakenVec :: (env :> env') -> Vec n (Expr x env TIx) -> Vec n (Expr x env' TIx) weakenVec _ VNil = VNil weakenVec w (e :< v) = weakenExpr w e :< weakenVec w v -slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list -slistMap _ SNil = SNil -slistMap f (SCons x list) = SCons (f x) (slistMap f list) - slistIdx :: SList f list -> Idx list t -> f t slistIdx (SCons x _) IZ = x slistIdx (SCons _ list) (IS i) = slistIdx list i diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index dd121fa..c7668e7 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -5,12 +5,17 @@ {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeApplications #-} -- The reason why this is a separate module with little in it: {-# LANGUAGE AllowAmbiguousTypes #-} module AST.Weaken where +import Data.Functor.Const + +import Data + data env :> env' where WId :: env :> env @@ -18,7 +23,7 @@ data env :> env' where WCopy :: env :> env' -> (t : env) :> (t : env') WPop :: (t : env) :> env' -> env :> env' WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 - WClosed :: '[] :> env + WClosed :: SList (Const ()) env -> '[] :> env deriving instance Show (env :> env') (.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 @@ -28,17 +33,25 @@ type family Append a b where Append '[] l = l Append (x : xs) l = x : Append xs l -data ListSpine list where - LSNil :: ListSpine '[] - LSCons :: ListSpine list -> ListSpine (t : list) +class KnownListSpine list where knownListSpine :: SList (Const ()) list +instance KnownListSpine '[] where knownListSpine = SNil +instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = SCons (Const ()) knownListSpine + +wSinks' :: forall list env. KnownListSpine list => env :> Append list env +wSinks' = wSinks (knownListSpine :: SList (Const ()) list) + +wSinks :: SList f list' -> env :> Append list' env +wSinks SNil = WId +wSinks (SCons _ spine) = WSink .> wSinks spine -class KnownListSpine list where knownListSpine :: ListSpine list -instance KnownListSpine '[] where knownListSpine = LSNil -instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = LSCons knownListSpine +wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env +wStack WId = WId +wStack WSink = WSink +wStack (WCopy w) = WCopy (wStack @env w) +wStack (WPop w) = WPop (wStack @env w) +wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2) +wStack (WClosed s) = wSinks s -wSinks :: forall list env. KnownListSpine list => env :> Append list env -wSinks = go (knownListSpine :: ListSpine list) - where - go :: forall list'. ListSpine list' -> env :> Append list' env - go LSNil = WId - go (LSCons spine) = WSink .> go spine +wRaiseAbove :: SList f env1 -> SList (Const ()) env -> env1 :> Append env1 env +wRaiseAbove SNil env = WClosed env +wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) diff --git a/src/CHAD.hs b/src/CHAD.hs index 209ed3b..642f58f 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -16,11 +16,14 @@ module CHAD where import Data.Bifunctor (first, second) +import Data.Functor.Const import Data.Kind (Type) +import Data.Proxy import Data.Some import GHC.TypeLits (Symbol) import AST +import Data import Lemmas @@ -43,7 +46,11 @@ sinkOver :: SList STy ts -> env :> Append ts env sinkOver SNil = WId sinkOver (SCons _ ts) = WSink .> sinkOver ts -sinkWithBindings :: Bindings f env binds -> env :> Append binds env +weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' +weakenOver SNil w = w +weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) + +sinkWithBindings :: Bindings f env binds -> env' :> Append binds env' sinkWithBindings BTop = WId sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b @@ -280,18 +287,17 @@ data Ret env0 sto t = (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge)))) deriving instance Show (Ret env0 sto t) -data RetPair env0 sto env t = +data RetPair env0 sto env shbinds t = forall env0Merge. - RetPair (Ex env (D1 t)) + RetPair (Ex (Append shbinds env) (D1 t)) (Subenv (Select env0 sto "merge") env0Merge) - (Ex (D2 t : env) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge)))) -deriving instance Show (RetPair env0 sto env t) + (Ex (D2 t : shbinds) (TEVM (D2E (Select env0 sto "accum")) (Tup (D2E env0Merge)))) +deriving instance Show (RetPair env0 sto env shbinds t) -TODO need to fix this data type still, I think? data Rets env0 sto env list = forall shbinds. Rets (Bindings Ex env shbinds) - (SList (RetPair env0 sto shbinds) list) + (SList (RetPair env0 sto env shbinds) list) deriving instance Show (Rets env0 sto env list) subList :: SList f env -> Subenv env env' -> SList f env' @@ -380,23 +386,56 @@ unscope des ty s sub e k = case s of -- d1W (WPop w) = WPop (d1W w) -- d1W (WThen u w) = WThen (d1W u) (d1W w) -weakenRetPair :: env :> env' -> RetPair env0 sto env t -> RetPair env0 sto env' t -weakenRetPair w (RetPair e1 sub e2) = RetPair (weakenExpr w e1) sub (weakenExpr (WCopy w) e2) +weakenRetPair :: SList STy shbinds -> env :> env' -> RetPair env0 sto env shbinds t -> RetPair env0 sto env' shbinds t +weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list weakenRets w (Rets binds list) = - weakenBindings weakenExpr w binds $ \binds' wbinds' -> - Rets binds' (slistMap (weakenRetPair wbinds') list) + let (binds', _) = weakenBindings weakenExpr w binds + in Rets binds' (slistMap (weakenRetPair (bindingsBinds binds) w) list) + +rebaseRetPair :: forall env b1 b2 env0 sto t f. SList f b1 -> SList f b2 -> RetPair env0 sto (Append b1 env) b2 t -> RetPair env0 sto env (Append b2 b1) t +rebaseRetPair b1 b2 (RetPair p sub d) + | Refl <- lemAppendAssoc @b2 @b1 @env = + RetPair p sub (weakenExpr (WCopy (wRaiseAbove b2 (slistMap (\_ -> Const ()) b1))) d) retConcat :: forall env0 sto list. SList (Ret env0 sto) list -> Rets env0 sto (D1E env0) list retConcat SNil = Rets BTop SNil -retConcat (SCons (Ret (b :: Bindings Ex (D1E env0) env2) p sub d) list) - | Rets binds pairs <- weakenRets (sinkWithBindings b) (retConcat list) +retConcat (SCons (Ret (b :: Bindings Ex (D1E env0) shbinds) p sub d) list) + | Rets binds1 pairs1 <- retConcat list + , Rets (binds :: Bindings Ex (Append shbinds (D1E env0)) shbinds2) pairs <- weakenRets (sinkWithBindings b) (Rets binds1 pairs1) + , Refl <- lemAppendAssoc @shbinds2 @shbinds @(D1E env0) = Rets (bconcat b binds) (SCons (RetPair (weakenExpr (sinkWithBindings binds) p) sub (weakenExpr (WCopy (sinkWithBindings binds)) d)) - pairs) + (slistMap (rebaseRetPair (bindingsBinds b) (bindingsBinds binds)) pairs)) +-- list ~ a : list' +-- SCons (Ret b p sub d) list :: SList (Ret env0 sto) list +-- Ret b p sub d :: Ret env0 sto a <- existential shbinds +-- b :: Bindings Ex (D1E env0) shbinds +-- p :: Ex (Append shbinds (D1E env0)) (D1 a) +-- d :: Ex (D2 a : shbinds) (TEVM ...) +-- +-- list :: SList (Ret env0 sto) list' +-- retConcat list :: Rets env0 sto (D1E env0) list' <- existential shbinds1 +-- binds1 :: Bindings Ex (D1E env0) shbinds1 +-- pairs1 :: SList (RetPair env0 sto (D1E env0) shbinds1) list' +-- +-- sinkWithBindings b :: forall e. e :> Append shbinds e +-- Rets binds pairs :: Rets env0 sto (Append shbinds (D1E env0)) list' <- existential shbinds2 +-- binds :: Bindings Ex (Append shbinds (D1E env0)) shbinds2 +-- pairs :: SList (RetPair env0 sto (Append shbinds (D1E env0)) shbinds2) list' +-- +-- we choose shbindsR ~ Append shbinds2 shbinds +-- result :: Rets env0 sto (D1E env0) list +-- result.1 :: Bindings Ex (D1E env0) shbindsR == Bindings Ex (D1E env0) (Append shbinds2 shbinds) +-- result.2 :: SList (RetPair env0 sto (D1E env0) shbindsR) list +-- result.2.head :: RetPair env0 sto (D1E env0) shbindsR a +-- result.2.tail :: SList (RetPair env0 sto (D1E env0) shbindsR) list' +-- = SList (RetPair env0 sto (D1E env0) (Append shbinds2 shbinds)) list' +-- +-- wanted: shbinds1 :> shbindsR d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) d1op (OAdd t) e = EOp ext (OAdd t) e @@ -464,18 +503,23 @@ select s@SMerge (DPush des (_, SAccum)) = select s des select s@SAccum (DPush des (_, SMerge)) = select s des select s@SMerge (DPush des (t, SMerge)) = SCons t (select s des) +sD1eEnv :: Descr env sto -> SList (Const ()) (D1E env) +sD1eEnv DTop = SNil +sD1eEnv (DPush d _) = SCons (Const ()) (sD1eEnv d) + freezeRet :: Descr env sto -> Ret env sto t -> (forall env'. Ex env' (D2 t)) -- the incoming cotangent value -> Ex (D1E env) (TPair (D1 t) (TEVM (D2E (Select env sto "accum")) (Tup (D2E (Select env sto "merge"))))) freezeRet descr (Ret e0 e1 sub e2) d = - letBinds e0 $ - EPair ext - e1 - (ELet ext d - (EMBind e2 - (EMReturn (d2e (select SAccum descr)) - (expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))))) + let e2' = weakenExpr (WCopy (wRaiseAbove (bindingsBinds e0) (sD1eEnv descr))) e2 + in letBinds e0 $ + EPair ext + e1 + (ELet ext d + (EMBind e2' + (EMReturn (d2e (select SAccum descr)) + (expandSubenvZeros (select SMerge descr) sub (EVar ext (tTup (d2e (subList (select SMerge descr) sub))) IZ))))) d2e :: SList STy env -> SList STy (D2E env) d2e SNil = SNil @@ -516,7 +560,7 @@ drev des policy = \case (weakenExpr (WCopy wbody0') body2') (EMBind (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ - weakenExpr (WCopy (wSinks @[_,_] .> WPop (sinkWithBindings body0'))) rhs2) + weakenExpr (WCopy (wSinks' @[_,_] .> WPop (sinkWithBindings body0'))) rhs2) (EMReturn d2acc (plus_RHS_Body (EVar ext (tTup (d2e (subList (select SMerge des) subRHS))) IZ) (EFst ext (EVar ext bodyResType (IS IZ))))))) @@ -532,9 +576,9 @@ drev des policy = \case (ECase ext (EVar ext (STEither STNil (STPair (d2 (typeOf a)) (d2 (typeOf b)))) IZ) (EMReturn d2acc (zeroTup (subList (select SMerge des) subBoth))) (EMBind (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy (wSinks @[_,_])) a2)) $ + (weakenExpr (WCopy (wSinks' @[_,_])) a2)) $ EMBind (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (wSinks @[_,_,_])) b2)) $ + (weakenExpr (WCopy (wSinks' @[_,_,_])) b2)) $ EMReturn d2acc (plus_A_B (EVar ext (tTup (d2e (subList (select SMerge des) subA))) (IS IZ)) @@ -568,7 +612,7 @@ drev des policy = \case (ECase ext (EVar ext (STEither STNil (STEither (d2 (typeOf e)) (d2 t2))) IZ) (EMReturn d2acc (zeroTup (subList (select SMerge des) sub))) (ECase ext (EVar ext (STEither (d2 (typeOf e)) (d2 t2)) IZ) - (weakenExpr (WCopy (wSinks @[_,_])) e2) + (weakenExpr (WCopy (wSinks' @[_,_])) e2) (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inl<-dinr"))) EInr _ t1 e @@ -580,7 +624,7 @@ drev des policy = \case (EMReturn d2acc (zeroTup (subList (select SMerge des) sub))) (ECase ext (EVar ext (STEither (d2 t1) (d2 (typeOf e))) IZ) (EError (STEVM d2acc (tTup (d2e (subList (select SMerge des) sub)))) "inr<-dinl") - (weakenExpr (WCopy (wSinks @[_,_])) e2))) + (weakenExpr (WCopy (wSinks' @[_,_])) e2))) ECase _ e a b | STEither t1 t2 <- typeOf e @@ -610,7 +654,7 @@ drev des policy = \case (EMBind (ECase ext (EVar ext (STEither (d1 t1) (d1 t2)) (IS (IS IZ))) (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ)))) - (case reconA (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of + (case reconA (WSink .> WCopy (wSinks' @[_,_,_] .> sinkWithBindings e0)) IZ of TupBindsReconstruct rebinds wrebinds -> letBinds rebinds $ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ @@ -624,7 +668,7 @@ drev des policy = \case (EError (STEVM d2acc tCaseRet) "dcase l/rtape")) (ECase ext (ESnd ext (EVar ext tPrimal (IS (IS IZ)))) (EError (STEVM d2acc tCaseRet) "dcase r/ltape") - (case reconB (WSink .> WCopy (wSinks @[_,_,_] .> sinkWithBindings e0)) IZ of + (case reconB (WSink .> WCopy (wSinks' @[_,_,_] .> sinkWithBindings e0)) IZ of TupBindsReconstruct rebinds wrebinds -> letBinds rebinds $ ELet ext (EVar ext (d2 (typeOf a)) (sinkWithBindings rebinds @> IS (IS IZ))) $ @@ -636,7 +680,7 @@ drev des policy = \case (EInr ext (d2 t1) (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subB'))) (d2 t2)) IZ)))))))) (EMBind (ELet ext (EInr ext STNil (ESnd ext (EVar ext tCaseRet IZ))) $ - weakenExpr (WCopy (wSinks @[_,_,_,_])) e2) $ + weakenExpr (WCopy (wSinks' @[_,_,_,_])) e2) $ EMReturn d2acc $ plus_AB_E (EFst ext (EVar ext tCaseRet (IS IZ))) @@ -663,7 +707,7 @@ drev des policy = \case sub (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) (EVar ext (d2 (opt2 op)) IZ)) - (weakenExpr (WCopy (wSinks @[_,_])) e2)) + (weakenExpr (WCopy (wSinks' @[_,_])) e2)) e -> error $ "CHAD: unsupported " ++ takeWhile (/= ' ') (show e) diff --git a/src/Data.hs b/src/Data.hs new file mode 100644 index 0000000..bd7f3af --- /dev/null +++ b/src/Data.hs @@ -0,0 +1,19 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module Data where + + +data SList f l where + SNil :: SList f '[] + SCons :: f a -> SList f l -> SList f (a : l) +deriving instance (forall a. Show (f a)) => Show (SList f l) +infixr `SCons` + +slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list +slistMap _ SNil = SNil +slistMap f (SCons x list) = SCons (f x) (slistMap f list) |