diff options
| author | Tom Smeding <tom@tomsmeding.com> | 2024-01-26 23:43:23 +0100 | 
|---|---|---|
| committer | Tom Smeding <tom@tomsmeding.com> | 2024-01-26 23:43:40 +0100 | 
| commit | fb7156b4aa11f154c3673504ac1f44407ccb0439 (patch) | |
| tree | b5b112a97dbbf016d3eed5dbda9805c7913d7afc | |
| parent | b90cc8077492d56989b06e6da947ab5c40badef8 (diff) | |
Linear-time tape reconstruction
A tutorial of the method here: https://play.haskell.org/saved/uHuGLfHZ
| -rw-r--r-- | src/AST.hs | 16 | ||||
| -rw-r--r-- | src/AST/Weaken.hs | 39 | ||||
| -rw-r--r-- | src/CHAD.hs | 172 | 
3 files changed, 171 insertions, 56 deletions
| @@ -72,12 +72,6 @@ deriving instance Show (SScalTy t)  type TIx = TScal TI64 -type Idx :: [k] -> k -> Type -data Idx env t where -  IZ :: Idx (t : env) t -  IS :: Idx env t -> Idx (a : env) t -deriving instance Show (Idx env t) -  type family ScalRep t where    ScalRep TI32 = Int32    ScalRep TI64 = Int64 @@ -219,16 +213,6 @@ vecLength :: Vec n t -> SNat n  vecLength VNil = SZ  vecLength (_ :< v) = SS (vecLength v) -infixr @> -(@>) :: env :> env' -> Idx env t -> Idx env' t -WId @> i = i -WSink @> i = IS i -WCopy _ @> IZ = IZ -WCopy w @> (IS i) = IS (w @> i) -WPop w @> i = w @> IS i -WThen w1 w2 @> i = w2 @> w1 @> i -WClosed _ @> i = case i of {} -  weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t  weakenExpr w = \case    EVar x t i -> EVar x t (w @> i) diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs index 07a90dc..4b3016d 100644 --- a/src/AST/Weaken.hs +++ b/src/AST/Weaken.hs @@ -6,17 +6,26 @@  {-# LANGUAGE ScopedTypeVariables #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE EmptyCase #-} --- The reason why this is a separate module with little in it: +-- The reason why this is a separate module with "little" in it:  {-# LANGUAGE AllowAmbiguousTypes #-}  module AST.Weaken where  import Data.Functor.Const +import Data.Kind (Type)  import Data +type Idx :: [k] -> k -> Type +data Idx env t where +  IZ :: Idx (t : env) t +  IS :: Idx env t -> Idx (a : env) t +deriving instance Show (Idx env t) +  data env :> env' where    WId :: env :> env    WSink :: env :> (t : env) @@ -24,8 +33,21 @@ data env :> env' where    WPop :: (t : env) :> env' -> env :> env'    WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3    WClosed :: SList (Const ()) env -> '[] :> env +  WIdx :: Idx env t -> (t : env) :> env  deriving instance Show (env :> env') +infixr @> +(@>) :: env :> env' -> Idx env t -> Idx env' t +WId @> i = i +WSink @> i = IS i +WCopy _ @> IZ = IZ +WCopy w @> (IS i) = IS (w @> i) +WPop w @> i = w @> IS i +WThen w1 w2 @> i = w2 @> w1 @> i +WClosed _ @> i = case i of {} +WIdx j @> IZ = j +WIdx _ @> IS i = i +  (.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3  (.>) = flip WThen @@ -48,13 +70,14 @@ wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2  wCopies SNil w = w  wCopies (SCons _ spine) w = WCopy (wCopies spine w) -wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env -wStack WId = WId -wStack WSink = WSink -wStack (WCopy w) = WCopy (wStack @env w) -wStack (WPop w) = WPop (wStack @env w) -wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2) -wStack (WClosed s) = wSinks s +-- wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env +-- wStack WId = WId +-- wStack WSink = WSink +-- wStack (WCopy w) = WCopy (wStack @env w) +-- wStack (WPop w) = WPop (wStack @env w) +-- wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2) +-- wStack (WClosed s) = wSinks s +-- wStack (WIdx i) = WIdx (_ i)  wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env  wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env) diff --git a/src/CHAD.hs b/src/CHAD.hs index 4c6cb0b..cc435db 100644 --- a/src/CHAD.hs +++ b/src/CHAD.hs @@ -37,12 +37,18 @@ import Data  import Lemmas +-- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'.  data Bindings f env binds where    BTop :: Bindings f env '[]    BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds)  deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')  infixl `BPush` +mapBindings :: (forall env' t'. f env' t' -> g env' t') +            -> Bindings f env binds -> Bindings g env binds +mapBindings _ BTop = BTop +mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e) +  weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)                 -> env1 :> env2                 -> Bindings f env1 binds @@ -52,10 +58,6 @@ weakenBindings wf w (BPush b (t, x)) =    let (b', w') = weakenBindings wf w b    in (BPush b' (t, wf w' x), WCopy w') -sinkOver :: SList STy ts -> env :> Append ts env -sinkOver SNil = WId -sinkOver (SCons _ ts) = WSink .> sinkOver ts -  weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'  weakenOver SNil w = w  weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) @@ -75,10 +77,6 @@ bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x))  --          -> (forall env12. Bindings f env env12 -> r) -> r  -- bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2') --- type family Snoc l x where ---   Snoc '[] x = '[x] ---   Snoc (y : ys) x = y : Snoc ys x -  -- bsnoc :: (forall env1 env2 t'. env1 :> env2 -> f env1 t' -> f env2 t')  --       -> STy t -> f env t -> Bindings f env binds -> (Bindings f env (Snoc binds t), Append binds env :> Append (Snoc binds t) env)  -- bsnoc _ t x BTop = (BPush BTop (t, x), WSink) @@ -109,27 +107,137 @@ bindingsCollect (BPush binds (t, _)) w =    EPair ext (EVar ext t (w @> IZ))              (bindingsCollect binds (w .> WSink)) --- type family TapeUnfoldings binds where ---   TapeUnfoldings '[] = '[] ---   TapeUnfoldings (t : ts) = Snoc (TapeUnfoldings ts) (Tape (t : ts)) +-- In order from large to small: i.e. in reverse order from what we want, +-- because in a Bindings, the head of the list is the bottom-most entry. +type family TapeUnfoldings binds where +  TapeUnfoldings '[] = '[] +  TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts + +type family Reverse l where +  Reverse '[] = '[] +  Reverse (t : ts) = Append (Reverse ts) '[t] + +-- An expression that is always 'snd' +data UnfExpr env t where +  UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t + +fromUnfExpr :: UnfExpr env t -> Ex env t +fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) + +-- - A bunch of 'snd' expressions taking us from knowing that there's a +--   'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix +--   this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in +--   the environment. +-- - In the extended environment, another bunch of let bindings (these are +--   'fst' expressions, but no need to know that statically) that project the +--   fsts out of what we introduced above, one for each type in 'ts'. +data Reconstructor env ts = +  Reconstructor +    (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) +    (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) + +ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) +ssnoc SNil a = SCons a SNil +ssnoc (SCons t ts) a = SCons t (ssnoc ts a) + +sreverse :: SList f ts -> SList f (Reverse ts) +sreverse SNil = SNil +sreverse (SCons t ts) = ssnoc (sreverse ts) t --- -- The input Ex must be duplicable. --- -- This function is quadratic, and returns code whose internal representation is quadratic in size (due to types). It runtime should be linear, however. --- tapeUnfoldings :: forall binds env. SList STy binds -> Ex env (Tape binds) -> Bindings Ex env (TapeUnfoldings binds) --- tapeUnfoldings SNil _ = BTop --- tapeUnfoldings (SCons t ts) e = fst $ bsnoc weakenExpr (tapeTy (SCons t ts)) e (tapeUnfoldings ts (ESnd ext e)) +sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2) +sappend SNil l = l +sappend (SCons x xs) l = SCons x (sappend xs l) --- reconFromTape :: SList STy binds -> Bindings Ex env (TapeUnfoldings binds) -> Bindings Ex (Append (TapeUnfoldings binds) env) binds --- reconFromTape SNil BTop = BTop --- reconFromTape (SCons t ts) (BPush unfbinds (_, e)) = _ --- reconFromTape SCons{} BTop = error "unreachable" +stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) +stapeUnfoldings SNil = SNil +stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) + +-- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. +shiftUnfolder +  :: STy t +  -> SList STy ts +  -> Bindings UnfExpr (Tape ts : env) list +  -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) +shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) +shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = +  -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order +  -- to expand an 'Append' in the types so that things simplify just enough. +  -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list +  -- of bindings produced by 'b'. We want to conclude from this that +  -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know +  -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after +  -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. +  BPush (shiftUnfolder newTy ts b) (t, case b of BTop    -> UnfExSnd itemTy t +                                                 BPush{} -> UnfExSnd itemTy t) + +growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) +growRecon t ts (Reconstructor unfbs bs) +  | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) +  , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) +  , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env +  = Reconstructor +      (shiftUnfolder t ts unfbs) +      -- Add a 'fst' at the bottom of the builder stack. +      -- First we have to weaken most of 'bs' to skip one more binding in the +      -- unfolder stack above it. +      (BPush (fst (weakenBindings weakenExpr +                      (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) +                               (WSink :: env :> (Tape (t : ts) : env))) bs)) +             (t +             ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ +               wSinks @(Tape (t : ts) : env) +                 (sappend ts +                          (sappend (sappend (sreverse (stapeUnfoldings ts)) +                                            (SCons (tapeTy ts) SNil)) +                                   SNil)) +               @> IZ)) + +buildReconstructor :: SList STy ts -> Reconstructor env ts +buildReconstructor SNil = Reconstructor BTop BTop +buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) + +-- STRATEGY FOR reconstructBindings +-- +-- binds = [] +-- e : () +-- +-- binds = [c] +-- e : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst e  : c +-- +-- binds = [b, c] +-- e : (b, (c, ())) +-- x1 = snd e  : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- +-- binds = [a, b, c] +-- e : (a, (b, (c, ()))) +-- x2 = snd e  : (b, (c, ())) +-- x1 = snd x2 : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- y3 = fst x3 : a -------------------------------------------------------------------------------------------- --- TODO: This function produces quadratic code, but it must be linear. Need to fix this! -- -------------------------------------------------------------------------------------------- -reconstructBindings :: SList STy binds -> Ex env (Tape binds) -> Bindings Ex env binds -reconstructBindings SNil _ = BTop -reconstructBindings (SCons t ts) e = BPush (reconstructBindings ts (ESnd ext e)) (t, weakenExpr (sinkOver ts) (EFst ext e)) +-- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all +-- the things in the list 'binds', we want to create a let stack that extracts +-- all values from that tuple and in effect "restores" the environment +-- described by 'binds'. The idea is that elsewhere, we took a slice of the +-- environment and saved it all in a tuple to be restored later. We +-- incidentally also add a bunch of additional bindings, namely 'Reverse +-- (TapeUnfoldings binds)', so the calling code just has to skip those in +-- whatever it wants to do. +reconstructBindings :: SList STy binds -> Idx env (Tape binds) +                    -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) +                       ,SList STy (Reverse (TapeUnfoldings binds))) +reconstructBindings binds tape = +  let Reconstructor unf build = buildReconstructor binds +  in (fst $ weakenBindings weakenExpr (WIdx tape) +             (bconcat (mapBindings fromUnfExpr unf) build) +     ,sreverse (stapeUnfoldings binds))  letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t  letBinds BTop = id @@ -651,20 +759,20 @@ drev des policy = \case          subOut          (EMBind             (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) -              (let rebinds = reconstructBindings (bindingsBinds a0) (EVar ext tapeA IZ) +              (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds a0) IZ                 in letBinds rebinds $ -                    ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : e_binds) (bindingsBinds a0) @> IS IZ)) $ -                      EMBind (weakenExpr (WCopy (wRaiseAbove (bindingsBinds a0) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0))) a2') +                    ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $ +                      EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds)) a2')                               (EMReturn d2acc                                  (EPair ext                                    (expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $                                       EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))                                    (EInl ext (d2 t2)                                      (ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ)))))) -              (let rebinds = reconstructBindings (bindingsBinds b0) (EVar ext tapeB IZ) +              (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds b0) IZ                 in letBinds rebinds $ -                    ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : e_binds) (bindingsBinds b0) @> IS IZ)) $ -                      EMBind (weakenExpr (WCopy (wRaiseAbove (bindingsBinds b0) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0))) b2') +                    ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $ +                      EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds)) b2')                               (EMReturn d2acc                                  (EPair ext                                     (expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $ | 
