summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Smeding <tom@tomsmeding.com>2024-01-26 23:43:23 +0100
committerTom Smeding <tom@tomsmeding.com>2024-01-26 23:43:40 +0100
commitfb7156b4aa11f154c3673504ac1f44407ccb0439 (patch)
treeb5b112a97dbbf016d3eed5dbda9805c7913d7afc
parentb90cc8077492d56989b06e6da947ab5c40badef8 (diff)
Linear-time tape reconstruction
A tutorial of the method here: https://play.haskell.org/saved/uHuGLfHZ
-rw-r--r--src/AST.hs16
-rw-r--r--src/AST/Weaken.hs39
-rw-r--r--src/CHAD.hs178
3 files changed, 174 insertions, 59 deletions
diff --git a/src/AST.hs b/src/AST.hs
index 2acc5a7..d1ef633 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -72,12 +72,6 @@ deriving instance Show (SScalTy t)
type TIx = TScal TI64
-type Idx :: [k] -> k -> Type
-data Idx env t where
- IZ :: Idx (t : env) t
- IS :: Idx env t -> Idx (a : env) t
-deriving instance Show (Idx env t)
-
type family ScalRep t where
ScalRep TI32 = Int32
ScalRep TI64 = Int64
@@ -219,16 +213,6 @@ vecLength :: Vec n t -> SNat n
vecLength VNil = SZ
vecLength (_ :< v) = SS (vecLength v)
-infixr @>
-(@>) :: env :> env' -> Idx env t -> Idx env' t
-WId @> i = i
-WSink @> i = IS i
-WCopy _ @> IZ = IZ
-WCopy w @> (IS i) = IS (w @> i)
-WPop w @> i = w @> IS i
-WThen w1 w2 @> i = w2 @> w1 @> i
-WClosed _ @> i = case i of {}
-
weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
weakenExpr w = \case
EVar x t i -> EVar x t (w @> i)
diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs
index 07a90dc..4b3016d 100644
--- a/src/AST/Weaken.hs
+++ b/src/AST/Weaken.hs
@@ -6,17 +6,26 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE StandaloneKindSignatures #-}
+{-# LANGUAGE EmptyCase #-}
--- The reason why this is a separate module with little in it:
+-- The reason why this is a separate module with "little" in it:
{-# LANGUAGE AllowAmbiguousTypes #-}
module AST.Weaken where
import Data.Functor.Const
+import Data.Kind (Type)
import Data
+type Idx :: [k] -> k -> Type
+data Idx env t where
+ IZ :: Idx (t : env) t
+ IS :: Idx env t -> Idx (a : env) t
+deriving instance Show (Idx env t)
+
data env :> env' where
WId :: env :> env
WSink :: env :> (t : env)
@@ -24,8 +33,21 @@ data env :> env' where
WPop :: (t : env) :> env' -> env :> env'
WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3
WClosed :: SList (Const ()) env -> '[] :> env
+ WIdx :: Idx env t -> (t : env) :> env
deriving instance Show (env :> env')
+infixr @>
+(@>) :: env :> env' -> Idx env t -> Idx env' t
+WId @> i = i
+WSink @> i = IS i
+WCopy _ @> IZ = IZ
+WCopy w @> (IS i) = IS (w @> i)
+WPop w @> i = w @> IS i
+WThen w1 w2 @> i = w2 @> w1 @> i
+WClosed _ @> i = case i of {}
+WIdx j @> IZ = j
+WIdx _ @> IS i = i
+
(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3
(.>) = flip WThen
@@ -48,13 +70,14 @@ wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2
wCopies SNil w = w
wCopies (SCons _ spine) w = WCopy (wCopies spine w)
-wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env
-wStack WId = WId
-wStack WSink = WSink
-wStack (WCopy w) = WCopy (wStack @env w)
-wStack (WPop w) = WPop (wStack @env w)
-wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2)
-wStack (WClosed s) = wSinks s
+-- wStack :: forall env b1 b2. b1 :> b2 -> Append b1 env :> Append b2 env
+-- wStack WId = WId
+-- wStack WSink = WSink
+-- wStack (WCopy w) = WCopy (wStack @env w)
+-- wStack (WPop w) = WPop (wStack @env w)
+-- wStack (WThen w1 w2) = WThen (wStack @env w1) (wStack @env w2)
+-- wStack (WClosed s) = wSinks s
+-- wStack (WIdx i) = WIdx (_ i)
wRaiseAbove :: SList f env1 -> SList g env -> env1 :> Append env1 env
wRaiseAbove SNil env = WClosed (slistMap (\_ -> Const ()) env)
diff --git a/src/CHAD.hs b/src/CHAD.hs
index 4c6cb0b..cc435db 100644
--- a/src/CHAD.hs
+++ b/src/CHAD.hs
@@ -37,12 +37,18 @@ import Data
import Lemmas
+-- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'.
data Bindings f env binds where
BTop :: Bindings f env '[]
BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds)
deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env')
infixl `BPush`
+mapBindings :: (forall env' t'. f env' t' -> g env' t')
+ -> Bindings f env binds -> Bindings g env binds
+mapBindings _ BTop = BTop
+mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e)
+
weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t)
-> env1 :> env2
-> Bindings f env1 binds
@@ -52,10 +58,6 @@ weakenBindings wf w (BPush b (t, x)) =
let (b', w') = weakenBindings wf w b
in (BPush b' (t, wf w' x), WCopy w')
-sinkOver :: SList STy ts -> env :> Append ts env
-sinkOver SNil = WId
-sinkOver (SCons _ ts) = WSink .> sinkOver ts
-
weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env'
weakenOver SNil w = w
weakenOver (SCons _ ts) w = WCopy (weakenOver ts w)
@@ -75,10 +77,6 @@ bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x))
-- -> (forall env12. Bindings f env env12 -> r) -> r
-- bconcat' wf b1 b2 k = weakenBindings wf (sinkWithBindings b1) b2 $ \b2' _ -> k (bconcat b1 b2')
--- type family Snoc l x where
--- Snoc '[] x = '[x]
--- Snoc (y : ys) x = y : Snoc ys x
-
-- bsnoc :: (forall env1 env2 t'. env1 :> env2 -> f env1 t' -> f env2 t')
-- -> STy t -> f env t -> Bindings f env binds -> (Bindings f env (Snoc binds t), Append binds env :> Append (Snoc binds t) env)
-- bsnoc _ t x BTop = (BPush BTop (t, x), WSink)
@@ -109,27 +107,137 @@ bindingsCollect (BPush binds (t, _)) w =
EPair ext (EVar ext t (w @> IZ))
(bindingsCollect binds (w .> WSink))
--- type family TapeUnfoldings binds where
--- TapeUnfoldings '[] = '[]
--- TapeUnfoldings (t : ts) = Snoc (TapeUnfoldings ts) (Tape (t : ts))
-
--- -- The input Ex must be duplicable.
--- -- This function is quadratic, and returns code whose internal representation is quadratic in size (due to types). It runtime should be linear, however.
--- tapeUnfoldings :: forall binds env. SList STy binds -> Ex env (Tape binds) -> Bindings Ex env (TapeUnfoldings binds)
--- tapeUnfoldings SNil _ = BTop
--- tapeUnfoldings (SCons t ts) e = fst $ bsnoc weakenExpr (tapeTy (SCons t ts)) e (tapeUnfoldings ts (ESnd ext e))
-
--- reconFromTape :: SList STy binds -> Bindings Ex env (TapeUnfoldings binds) -> Bindings Ex (Append (TapeUnfoldings binds) env) binds
--- reconFromTape SNil BTop = BTop
--- reconFromTape (SCons t ts) (BPush unfbinds (_, e)) = _
--- reconFromTape SCons{} BTop = error "unreachable"
-
--------------------------------------------------------------------------------------------
--- TODO: This function produces quadratic code, but it must be linear. Need to fix this! --
--------------------------------------------------------------------------------------------
-reconstructBindings :: SList STy binds -> Ex env (Tape binds) -> Bindings Ex env binds
-reconstructBindings SNil _ = BTop
-reconstructBindings (SCons t ts) e = BPush (reconstructBindings ts (ESnd ext e)) (t, weakenExpr (sinkOver ts) (EFst ext e))
+-- In order from large to small: i.e. in reverse order from what we want,
+-- because in a Bindings, the head of the list is the bottom-most entry.
+type family TapeUnfoldings binds where
+ TapeUnfoldings '[] = '[]
+ TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts
+
+type family Reverse l where
+ Reverse '[] = '[]
+ Reverse (t : ts) = Append (Reverse ts) '[t]
+
+-- An expression that is always 'snd'
+data UnfExpr env t where
+ UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t
+
+fromUnfExpr :: UnfExpr env t -> Ex env t
+fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ)
+
+-- - A bunch of 'snd' expressions taking us from knowing that there's a
+-- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix
+-- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in
+-- the environment.
+-- - In the extended environment, another bunch of let bindings (these are
+-- 'fst' expressions, but no need to know that statically) that project the
+-- fsts out of what we introduced above, one for each type in 'ts'.
+data Reconstructor env ts =
+ Reconstructor
+ (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts)))
+ (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts)
+
+ssnoc :: SList f ts -> f t -> SList f (Append ts '[t])
+ssnoc SNil a = SCons a SNil
+ssnoc (SCons t ts) a = SCons t (ssnoc ts a)
+
+sreverse :: SList f ts -> SList f (Reverse ts)
+sreverse SNil = SNil
+sreverse (SCons t ts) = ssnoc (sreverse ts) t
+
+sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2)
+sappend SNil l = l
+sappend (SCons x xs) l = SCons x (sappend xs l)
+
+stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts)
+stapeUnfoldings SNil = SNil
+stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts)
+
+-- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one.
+shiftUnfolder
+ :: STy t
+ -> SList STy ts
+ -> Bindings UnfExpr (Tape ts : env) list
+ -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts])
+shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts))
+shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) =
+ -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order
+ -- to expand an 'Append' in the types so that things simplify just enough.
+ -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list
+ -- of bindings produced by 'b'. We want to conclude from this that
+ -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know
+ -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after
+ -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step.
+ BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t
+ BPush{} -> UnfExSnd itemTy t)
+
+growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts)
+growRecon t ts (Reconstructor unfbs bs)
+ | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts])
+ , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env)
+ , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env
+ = Reconstructor
+ (shiftUnfolder t ts unfbs)
+ -- Add a 'fst' at the bottom of the builder stack.
+ -- First we have to weaken most of 'bs' to skip one more binding in the
+ -- unfolder stack above it.
+ (BPush (fst (weakenBindings weakenExpr
+ (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil))
+ (WSink :: env :> (Tape (t : ts) : env))) bs))
+ (t
+ ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $
+ wSinks @(Tape (t : ts) : env)
+ (sappend ts
+ (sappend (sappend (sreverse (stapeUnfoldings ts))
+ (SCons (tapeTy ts) SNil))
+ SNil))
+ @> IZ))
+
+buildReconstructor :: SList STy ts -> Reconstructor env ts
+buildReconstructor SNil = Reconstructor BTop BTop
+buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts)
+
+-- STRATEGY FOR reconstructBindings
+--
+-- binds = []
+-- e : ()
+--
+-- binds = [c]
+-- e : (c, ())
+-- x0 = snd x1 : ()
+-- y1 = fst e : c
+--
+-- binds = [b, c]
+-- e : (b, (c, ()))
+-- x1 = snd e : (c, ())
+-- x0 = snd x1 : ()
+-- y1 = fst x1 : c
+-- y2 = fst x2 : b
+--
+-- binds = [a, b, c]
+-- e : (a, (b, (c, ())))
+-- x2 = snd e : (b, (c, ()))
+-- x1 = snd x2 : (c, ())
+-- x0 = snd x1 : ()
+-- y1 = fst x1 : c
+-- y2 = fst x2 : b
+-- y3 = fst x3 : a
+
+-- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all
+-- the things in the list 'binds', we want to create a let stack that extracts
+-- all values from that tuple and in effect "restores" the environment
+-- described by 'binds'. The idea is that elsewhere, we took a slice of the
+-- environment and saved it all in a tuple to be restored later. We
+-- incidentally also add a bunch of additional bindings, namely 'Reverse
+-- (TapeUnfoldings binds)', so the calling code just has to skip those in
+-- whatever it wants to do.
+reconstructBindings :: SList STy binds -> Idx env (Tape binds)
+ -> (Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds)))
+ ,SList STy (Reverse (TapeUnfoldings binds)))
+reconstructBindings binds tape =
+ let Reconstructor unf build = buildReconstructor binds
+ in (fst $ weakenBindings weakenExpr (WIdx tape)
+ (bconcat (mapBindings fromUnfExpr unf) build)
+ ,sreverse (stapeUnfoldings binds))
letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t
letBinds BTop = id
@@ -651,20 +759,20 @@ drev des policy = \case
subOut
(EMBind
(ECase ext (ESnd ext (EVar ext tPrimal (IS IZ)))
- (let rebinds = reconstructBindings (bindingsBinds a0) (EVar ext tapeA IZ)
+ (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds a0) IZ
in letBinds rebinds $
- ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : e_binds) (bindingsBinds a0) @> IS IZ)) $
- EMBind (weakenExpr (WCopy (wRaiseAbove (bindingsBinds a0) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0))) a2')
+ ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_a_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds a0) prerebinds) @> IS IZ)) $
+ EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds a0) prerebinds) (tapeA `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds a0) prerebinds)) a2')
(EMReturn d2acc
(EPair ext
(expandSubenvZeros (subList (select SMerge des) subAB) sAB_A $
EFst ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))
(EInl ext (d2 t2)
(ESnd ext (EVar ext (STPair (tTup (d2e (subList (select SMerge des) subA'))) (d2 t1)) IZ))))))
- (let rebinds = reconstructBindings (bindingsBinds b0) (EVar ext tapeB IZ)
+ (let (rebinds, prerebinds) = reconstructBindings (bindingsBinds b0) IZ
in letBinds rebinds $
- ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : e_binds) (bindingsBinds b0) @> IS IZ)) $
- EMBind (weakenExpr (WCopy (wRaiseAbove (bindingsBinds b0) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0))) b2')
+ ELet ext (EVar ext (d2 (typeOf a)) (wSinks @(Tape rhs_b_binds : D2 t : t_primal_ty : e_binds) (sappend (bindingsBinds b0) prerebinds) @> IS IZ)) $
+ EMBind (weakenExpr (WCopy (wRaiseAbove (sappend (bindingsBinds b0) prerebinds) (tapeB `SCons` d2 (typeOf a) `SCons` tPrimal `SCons` bindingsBinds e0) .> wRaiseAbove (bindingsBinds b0) prerebinds)) b2')
(EMReturn d2acc
(EPair ext
(expandSubenvZeros (subList (select SMerge des) subAB) sAB_B $