diff options
-rw-r--r-- | src/AST/Weaken.hs | 35 | ||||
-rw-r--r-- | src/AST/Weaken/Auto.hs | 82 | ||||
-rw-r--r-- | src/CHAD.hs | 143 | ||||
-rw-r--r-- | src/Example.hs | 12 |
4 files changed, 227 insertions, 45 deletions
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 78276ca..0a1e4ce 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -48,6 +48,9 @@ data env :> env' where -> Append pre (t : env) :> t : Append pre env' WSwap :: forall env as bs. SList (Const ()) as -> SList (Const ()) bs -> Append as (Append bs env) :> Append bs (Append as env) + WStack :: forall as bs env1 env2. SList (Const ()) as -> SList (Const ()) bs + -> as :> bs -> env1 :> env2 + -> Append as env1 :> Append bs env2 deriving instance Show (env :> env') infix 4 :> @@ -67,18 +70,25 @@ WPick (_ `SCons` _) _ @> IZ = IS IZ WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i WSwap @env (as :: SList _ as) (bs :: SList _ bs) @> i = case splitIdx @(Append bs env) as i of - Left i' -> skipOver bs (stack @env i' as) + Left i' -> indexSinks bs (indexRaiseAbove @env as i') Right i' -> case splitIdx @env bs i' of - Left j -> stack @(Append as env) j bs - Right j -> skipOver bs (skipOver as j) + Left j -> indexRaiseAbove @(Append as env) bs j + Right j -> indexSinks bs (indexSinks as j) +WStack @as @bs @env1 @env2 as bs wlo whi @> i = + case splitIdx @env1 as i of + Left i' -> indexRaiseAbove @env2 bs (wlo @> i') + Right i' -> indexSinks bs (whi @> i') + +indexSinks :: SList f as -> Idx bs t -> Idx (Append as bs) t +indexSinks SNil j = j +indexSinks (_ `SCons` bs') j = IS (indexSinks bs' j) + +indexRaiseAbove :: forall env as t f. SList f as -> Idx as t -> Idx (Append as env) t +indexRaiseAbove = flip go where - skipOver :: SList (Const ()) as' -> Idx bs' t -> Idx (Append as' bs') t - skipOver SNil j = j - skipOver (_ `SCons` bs') j = IS (skipOver bs' j) - - stack :: forall env' as' t. Idx as' t -> SList (Const ()) as' -> Idx (Append as' env') t - stack IZ (_ `SCons` _) = IZ - stack (IS j) (_ `SCons` as') = IS (stack @env' j as') + go :: forall as'. Idx as' t -> SList f as' -> Idx (Append as' env) t + go IZ (_ `SCons` _) = IZ + go (IS i) (_ `SCons` as) = IS (go i as) infixr 3 .> (.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 @@ -100,8 +110,9 @@ wSinksAnd SNil w = w wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2 -wCopies SNil w = w -wCopies (SCons _ spine) w = WCopy (wCopies spine w) +wCopies bs w = + let bs' = slistMap (\_ -> Const ()) bs + in WStack bs' bs' WId w wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs index 444c540..8555516 100644 --- a/src/AST/Weaken/Auto.hs +++ b/src/AST/Weaken/Auto.hs @@ -35,15 +35,29 @@ import Lemmas type family Lookup name list where Lookup name ('(name, x) : _) = x Lookup name (_ : list) = Lookup name list + Lookup name '[] = TypeError (Text "The name '" :<>: Text name :<>: Text "' does not appear in the list.") + + +-- | The @withPre@ type parameter indicates whether there can be 'LPreW' +-- occurrences within this layout. +data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where + LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments (Lookup name segments) + -- | Pre-weaken with a weakening + LPreW :: forall name1 name2 segments. + SegmentName name1 -> SegmentName name2 + -> Lookup name1 segments :> Lookup name2 segments + -> Layout True segments (Lookup name1 segments) + (:++:) :: Layout withPre segments env1 -> Layout withPre segments env2 -> Layout withPre segments (Append env1 env2) +infixr :++: +instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout withPre segments seg) where + fromLabel = LSeg (symbolSing @name) -data Layout (segments :: [(Symbol, [t])]) (env :: [t]) where - LSeg :: forall name segments. KnownSymbol name => Layout segments (Lookup name segments) - (:++:) :: Layout segments env1 -> Layout segments env2 -> Layout segments (Append env1 env2) -infixr :++: +newtype SegmentName name = SegmentName (SSymbol name) + deriving (Show) -instance (KnownSymbol name, seg ~ Lookup name segments) => IsLabel name (Layout segments seg) where - fromLabel = LSeg @name @segments +instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') where + fromLabel = SegmentName symbolSing data SSegments (segments :: [(Symbol, [t])]) where @@ -86,29 +100,51 @@ segmentLookup = \segs name -> case go segs name of case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of Refl -> go sseg name -data LinLayout (segments :: [(Symbol, [t])]) (env :: [t]) where - LinEnd :: LinLayout segments '[] - LinApp :: SSymbol name -> LinLayout segments env -> LinLayout segments (Append (Lookup name segments) env) +data LinLayout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where + LinEnd :: LinLayout withPre segments '[] + LinApp :: SSymbol name -> LinLayout withPre segments env + -> LinLayout withPre segments (Append (Lookup name segments) env) + LinAppPreW :: SSymbol name1 -> SSymbol name2 + -> Lookup name1 segments :> Lookup name2 segments + -> LinLayout True segments env + -> LinLayout True segments (Append (Lookup name1 segments) env) -linLayoutAppend :: LinLayout segments env1 -> LinLayout segments env2 -> LinLayout segments (Append env1 env2) +linLayoutAppend :: LinLayout withPre segments env1 -> LinLayout withPre segments env2 -> LinLayout withPre segments (Append env1 env2) linLayoutAppend LinEnd lin = lin -linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout segments env1')) (lin2 :: LinLayout _ env2) +linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2) | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2 = LinApp name (linLayoutAppend lin1 lin2) +linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2) + | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 + = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) -linLayoutEnv :: SSegments segments -> LinLayout segments env -> SList (Const ()) env +linLayoutEnv :: SSegments segments -> LinLayout withPre segments env -> SList (Const ()) env linLayoutEnv _ LinEnd = SNil linLayoutEnv segs (LinApp name lin) = sappend (segmentLookup segs name) (linLayoutEnv segs lin) +linLayoutEnv segs (LinAppPreW name1 _ _ lin) = sappend (segmentLookup segs name1) (linLayoutEnv segs lin) -lineariseLayout :: Layout segments env -> LinLayout segments env -lineariseLayout (LSeg @name :: Layout _ seg) +lineariseLayout :: Layout withPre segments env -> LinLayout withPre segments env +lineariseLayout (LSeg name :: Layout _ _ seg) | Refl <- lemAppendNil @seg - = LinApp (symbolSing @name) LinEnd + = LinApp name LinEnd lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 - -pullDown :: SSegments segments -> SSymbol name -> LinLayout segments env +lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ seg) + | Refl <- lemAppendNil @seg + = LinAppPreW name1 name2 w LinEnd + +preWeaken :: SSegments segments -> LinLayout True segments env + -> (forall env'. env :> env' -> LinLayout False segments env' -> r) -> r +preWeaken _ LinEnd k = k WId LinEnd +preWeaken segs (LinApp name lin) k = + preWeaken segs lin $ \w lin' -> + k (wCopies (segmentLookup segs name) w) (LinApp name lin') +preWeaken segs (LinAppPreW name1 name2 weak lin) k = + preWeaken segs lin $ \w lin' -> + k (WStack (segmentLookup segs name1) (segmentLookup segs name2) weak w) (LinApp name2 lin') + +pullDown :: SSegments segments -> SSymbol name -> LinLayout False segments env -> r -- Name was not found in source - -> (forall env'. LinLayout segments env' -> env :> Append (Lookup name segments) env' -> r) + -> (forall env'. LinLayout False segments env' -> env :> Append (Lookup name segments) env' -> r) -> r pullDown segs name@SSymbol linlayout kNotFound k = case linlayout of @@ -116,13 +152,13 @@ pullDown segs name@SSymbol linlayout kNotFound k = LinApp n'@SSymbol lin | Just Refl <- sameSymbol name n' -> k lin WId | otherwise -> - pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ env') w -> + pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ _ env') w -> k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) .> wCopies (segmentLookup segs n') w) sortLinLayouts :: forall segments env1 env2. SSegments segments - -> LinLayout segments env1 -> LinLayout segments env2 -> env1 :> env2 + -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 sortLinLayouts _ LinEnd LinEnd = WId sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2) @@ -139,5 +175,7 @@ sortLinLayouts segs LinEnd lin2@LinApp{} = WClosed (linLayoutEnv segs lin2) sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" autoWeak :: forall segments env1 env2. - SSegments segments -> Layout segments env1 -> Layout segments env2 -> env1 :> env2 -autoWeak segs ly1 ly2 = sortLinLayouts segs (lineariseLayout ly1) (lineariseLayout ly2) + SSegments segments -> Layout True segments env1 -> Layout False segments env2 -> env1 :> env2 +autoWeak segs ly1 ly2 = + preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> + sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak diff --git a/src/CHAD.hs b/src/CHAD.hs index 943f0a2..7747d46 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -262,7 +262,7 @@ vectoriseExpr :: forall prefix binds env t f. -> Ex (Append prefix (Append binds env)) t -> Ex (TIx : Append prefix (Append (Vectorise (S Z) binds) env)) t vectoriseExpr prefix binds env = - let wTarget :: Layout ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e + let wTarget :: Layout True ['("ix", '[TIx]), '("pre", prefix), '("vbinds", Vectorise (S Z) binds), '("env", env)] e -> e :> TIx : Append prefix (Append (Vectorise (S Z) binds) env) wTarget layout = autoWeak (#ix (auto1 @TIx) &. #pre prefix &. #vbinds (vectoriseEnv (SS SZ) binds) &. #env env) @@ -422,6 +422,10 @@ zeroTup :: SList STy env0 -> Ex env (Tup (D2E env0)) zeroTup SNil = ENil ext zeroTup (SCons t env) = EPair ext (zeroTup env) (zero t) +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + accumPromote :: forall dt env sto proxy r. proxy dt -> Descr env sto @@ -974,6 +978,7 @@ drev des = \case (subenvNone (select SMerge des)) (ENil ext) + -- TODO: either remove EBuilds1 entirely or rewrite it to work with an array of tapes instead of a vectorised tape EBuild1 _ ne (orige :: Ex _ eltty) | Ret (ne0 :: Bindings _ _ ne_binds) ne1 _ _ <- drev des ne -- allowed to ignore ne2 here because ne has a discrete result , let eltty = typeOf orige -> @@ -1050,6 +1055,90 @@ drev des = \case ESnd ext (EVar ext (STPair (STArr (SS SZ) STNil) (tTup (d2e envPro))) (IS IZ))) }} + EBuild _ (ndim :: SNat ndim) she (orige :: Ex _ eltty) + | Ret (she0 :: Bindings _ _ she_binds) she1 _ _ <- drev des she -- allowed to ignore she2 here because she has a discrete result + , let eltty = typeOf orige + , shty :: STy shty <- tTup (sreplicate ndim tIx) + , Refl <- indexTupD1Id ndim -> + deleteUnused (descrList des) (occEnvPop (occCountAll orige)) $ \(usedSub :: Subenv env env') -> + let e = unsafeWeakenWithSubenv (SEYes usedSub) orige in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + accumPromote eltty usedDes $ \prodes (envPro :: SList _ envPro) proSub wPro -> + case drev (prodes `DPush` (shty, SMerge)) e of { Ret (e0 :: Bindings _ _ e_binds) e1 sub e2 -> + case assertSubenvEmpty sub of { Refl -> + let tapety = tapeTy (bindingsBinds e0) in + let collectexpr = bindingsCollect e0 in + -- let ve0 = vectorise1Binds (tIx `SCons` sD1eEnv usedDes) IZ e0 in + Ret (she0 `BPush` (shty, she1) + `BPush` (STArr ndim tapety + ,EBuild ext ndim + (EVar ext shty IZ) + (letBinds (fst (weakenBindings weakenExpr (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #she0 (bindingsBinds she0) + &. #d1env (sD1eEnv des) + &. #d1env' (sD1eEnv usedDes)) + (#ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#ix :++: #sh :++: #she0 :++: #d1env)) + e0)) $ + collectexpr (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #she0 (bindingsBinds she0) + &. #e0 (bindingsBinds e0) + &. #d1env (sD1eEnv des) + &. #d1env' (sD1eEnv usedDes)) + (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#e0 :++: #ix :++: #sh :++: #she0 :++: #d1env))))) + (EBuild ext ndim + (EVar ext shty (IS IZ)) + (ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (IS IZ)) + (EVar ext shty IZ)) $ + let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ + in letBinds rebinds $ + weakenExpr (autoWeak (#ix (shty `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #she0 (bindingsBinds she0) + &. #e0 (bindingsBinds e0) + &. #tape (tapety `SCons` SNil) + &. #tapearr (STArr ndim tapety `SCons` SNil) + &. #prerebinds prerebinds + &. #d1env (sD1eEnv des) + &. #d1env' (sD1eEnv usedDes)) + (#e0 :++: #ix :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + ((#e0 :++: #prerebinds) :++: #tape :++: #ix :++: #tapearr :++: #sh :++: #she0 :++: #d1env)) + e1)) + (subenvCompose subMergeUsed proSub) + (let sinkOverEnvPro = wSinks @(D2 t : TArr ndim (Tape e_binds) : Tup (Replicate ndim TIx) : Append she_binds (D2AcE (Select env sto "accum"))) (d2ace envPro) in + ESnd ext $ + uninvertTup (d2e envPro) (STArr ndim STNil) $ + makeAccumulators @_ @_ @(TArr ndim TNil) envPro $ + EBuild ext ndim (EVar ext shty (sinkOverEnvPro @> IS (IS IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext ndim (EVar ext (STArr ndim (d2 eltty)) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext ndim (EVar ext (STArr ndim tapety) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + let (rebinds, prerebinds) = reconstructBindings (bindingsBinds e0) IZ + in letBinds rebinds $ + weakenExpr (autoWeak (#d (auto1 @(D2 eltty)) + &. #pro (d2ace envPro) + &. #ebinds (bindingsBinds e0) + &. #prerebinds prerebinds + &. #tape (tapety `SCons` SNil) + &. #ix (shty `SCons` SNil) + &. #darr (STArr ndim (d2 eltty) `SCons` SNil) + &. #tapearr (STArr ndim tapety `SCons` SNil) + &. #sh (shty `SCons` SNil) + &. #shebinds (bindingsBinds she0) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des))) + (#pro :++: #d :++: #ebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#ebinds :++: #prerebinds) :++: #tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #sh :++: #shebinds :++: #d2acEnv) + .> wPro (bindingsBinds e0)) + e2) + }} + EUnit _ e | Ret e0 e1 sub e2 <- drev des e -> Ret e0 @@ -1072,23 +1161,55 @@ drev des = \case | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil , STArr (SS n) eltty <- typeOf e -> - Ret (binds `BPush` (tTup (sreplicate (SS n) tIx), EShape ext e1)) - (weakenExpr WSink (EIdx1 ext e1 ei1)) + Ret (binds `BPush` (STArr (SS n) (d1 eltty), e1)) + (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) IZ) + (weakenExpr WSink ei1)) sub - (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (ELet ext (ebuildUp1 n (EFst ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)))) + (ESnd ext (EShape ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)))) (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ weakenExpr (WCopy (WSink .> WSink)) e2) + EIdx _ n e ei + -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. + | Rets binds (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) + <- retConcat des $ drev des e `SCons` drev des ei `SCons` SNil + , STArr _ eltty <- typeOf e + , Refl <- indexTupD1Id n -> + Ret (binds `BPush` (STArr n (d1 eltty), e1)) + (EIdx ext n (EVar ext (STArr n (d1 eltty)) IZ) + (weakenExpr WSink ei1)) + sub + (ELet ext (EBuild ext n (EShape ext (EVar ext (STArr n (d1 eltty)) (IS IZ))) + (EVar ext (d2 eltty) (IS IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EShape _ e + -- Allowed to ignore e2 here because the output of EShape is discrete, + -- hence we'd be passing a zero cotangent to e2 anyway. + | Ret e0 e1 _ _ <- drev des e + , STArr n _ <- typeOf e + , Refl <- indexTupD1Id n -> + Ret e0 + (EShape ext e1) + (subenvNone (select SMerge des)) + (ENil ext) + + ESum1Inner _ e + | Ret e0 e1 sub e2 <- drev des e + , STArr (SS n) t <- typeOf e -> + Ret (e0 `BPush` (STArr (SS n) t, e1)) + (ESum1Inner ext (EVar ext (STArr (SS n) t) IZ)) + sub + (ELet ext (EReplicate1Inner ext + (ESnd ext (EShape ext (EVar ext (STArr (SS n) t) (IS IZ)))) + (EVar ext (STArr n (d2 t)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + -- These should be the next to be implemented, I think - ESum1Inner{} -> err_unsupported "ESum" - EReplicate1Inner{} -> err_unsupported "EReplicate" - EShape{} -> err_unsupported "EShape" + EReplicate1Inner{} -> err_unsupported "EReplicate1Inner" EFold1Inner{} -> err_unsupported "EFold1Inner" - EIdx{} -> err_unsupported "EIdx" - EBuild{} -> err_unsupported "EBuild" - EWith{} -> err_accum EAccum{} -> err_accum diff --git a/src/Example.hs b/src/Example.hs index d1d04e3..fb4e851 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -1,6 +1,8 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Example where @@ -14,6 +16,16 @@ import Simplify -- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) +type family MergeEnv env where + MergeEnv '[] = '[] + MergeEnv (t : ts) = "merge" : MergeEnv ts + +mergeDescr :: KnownEnv env => Descr env (MergeEnv env) +mergeDescr = go knownEnv + where go :: SList STy env -> Descr env (MergeEnv env) + go SNil = DTop + go (t `SCons` env) = go env `DPush` (t, SMerge) + bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c bin op a b = EOp ext op (EPair ext a b) |